Skip to content

EasyDeL version 0.1.1: Vision-Language Models (VLMs) & Vision Models Update

Compare
Choose a tag to compare
@erfanzar erfanzar released this 18 Mar 18:06
· 7 commits to main since this release

EasyDeL Release Notes: Vision-Language Models (VLMs) & Vision Models Update (Pre-Train, Finetune and Inference)

First JAX Library to Support VLMs!

EasyDeL is now the first library to support Vision-Language Models (VLMs) in JAX, bringing cutting-edge multimodal AI capabilities to the ecosystem. This update significantly expands our vision model offerings while optimizing performance and usability across the board.

New Models Added

We’ve added support for the following vision and multimodal models, unlocking new capabilities in computer vision and vision-language tasks:

  • Aya Vision – A high-performance vision model with strong generalization capabilities
  • Cohere2 – Enhanced visual reasoning and feature extraction (LLM/VLM via AyaVision)
  • LLaVA – A Vision-Language model for image-grounded understanding
  • SigLip – State-of-the-art self-supervised learning for visual representations

Architecture & Performance Enhancements

We've made several improvements to streamline the framework and improve efficiency:

  • Unified Configuration Handling: Refactored configuration methods to ensure consistency across all modules, reducing redundant code and making customization easier.
  • Lazy Imports for Faster Startup: Implemented lazy loading of dependencies, significantly reducing initialization time and improving integration flexibility.
  • Extended VLM Support: Expanded Vision-Language Model (VLM) support throughout vinference core and API server, enabling seamless inference and integration.

Technical Maintenance & Cleanup

  • Removed deprecated code to enhance maintainability and keep the codebase clean.
  • Improved internal documentation and structured error handling for more robust deployments.

To use EasyDeL for Vision-Language inference with models like LLaVA, follow this setup:

🔧 Installation

Ensure you have EasyDeL installed:

pip install "easydel[all]==0.1.1"

🚀 Running a Vision-Language Model

Here’s a minimal script to load and serve a VLM using EasyDeL:

import easydel as ed
import jax
from jax import numpy as jnp
from transformers import AutoProcessor

def main():
    sharding_axis_dims = (1, 1, -1, 1)  # DP, FSDP, TP, SP
    prefill_length = 8192 - 1024
    max_new_tokens = 1024
    max_length = max_new_tokens + prefill_length
    pretrained_model_name_or_path = "llava-hf/llava-1.5-7b-hf"

    dtype = jnp.bfloat16
    param_dtype = jnp.bfloat16
    partition_axis = ed.PartitionAxis()

    processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path)
    processor.padding_side = "left"

    model = ed.AutoEasyDeLModelForImageTextToText.from_pretrained(
        pretrained_model_name_or_path,
        auto_shard_model=True,
        sharding_axis_dims=sharding_axis_dims,
        config_kwargs=ed.EasyDeLBaseConfigDict(
            freq_max_position_embeddings=max_length,
            mask_max_position_embeddings=max_length,
            kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
            gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
            attn_dtype=param_dtype,
            attn_mechanism=ed.AttentionMechanisms.AUTO,
        ),
        quantization_method=ed.EasyDeLQuantizationMethods.NONE,
        param_dtype=param_dtype,
        dtype=dtype,
        partition_axis=partition_axis,
        precision=jax.lax.Precision.DEFAULT,
    )

    inference = ed.vInference(
        model=model,
        processor_class=processor,
        generation_config=ed.vInferenceConfig(
            max_new_tokens=max_new_tokens,
            temperature=0.8,
            do_sample=True,
            top_p=0.95,
            top_k=10,
            eos_token_id=model.generation_config.eos_token_id,
            streaming_chunks=32,
            num_return_sequences=1,
        ),
        inference_name="easydelvlm",
    )

    inference.precompile(
        ed.vInferencePreCompileConfig(
            batch_size=1,
            prefill_length=prefill_length,
            vision_included=True,
            vision_batch_size=1,
            vision_channels=3,
            vision_height=336,
            vision_width=336,
        )
    )
    
    ed.vInferenceApiServer(inference, max_workers=10).fire()

if __name__ == "__main__":
    main()

🌍 Example API Request

Once your model is running, you can query it using an OpenAI-compatible API format:

{
  "model": "easydelvlm",
  "messages": [
    {
      "role": "user",
      "content": [
        {
          "type": "image",
          "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
        },
        {
          "type": "text",
          "text": "Describe this image in detail."
        }
      ]
    }
  ],
  "temperature": 1.0,
  "top_p": 1.0,
  "max_tokens": 16
}

This release marks a major milestone in EasyDeL's evolution, ensuring that JAX users have access to state-of-the-art vision and multimodal AI models with optimized performance. 🚀