Releases: erfanzar/EasyDeL
Small Bug Fixes
We're excited to announce EasyDeL v0.1.2, bringing improved compatibility and key bug fixes! 🎉
🔄 What's New?
-
JAX 0.5.3 Support ✅
- EasyDeL now fully supports JAX 0.5.3, and everything works smoothly with implicit auto behavior.
-
Fixed Auto PyTree Issues 🌳
- Resolved issues with implicit PyTree structures, ensuring better compatibility with JAX transformations.
-
Fixed Auto CLI from Dataclass Bugs 🛠️
- Automatic CLI generation from
dataclass
now works correctly without unexpected failures.
- Automatic CLI generation from
📦 Upgrade Now
pip install --upgrade easydel
As always, please feel free to open any issues or talk about them if you encounter any problems. 🚀
Happy coding!
EasyDeL version 0.1.1: Vision-Language Models (VLMs) & Vision Models Update
EasyDeL Release Notes: Vision-Language Models (VLMs) & Vision Models Update (Pre-Train, Finetune and Inference)
First JAX Library to Support VLMs!
EasyDeL is now the first library to support Vision-Language Models (VLMs) in JAX, bringing cutting-edge multimodal AI capabilities to the ecosystem. This update significantly expands our vision model offerings while optimizing performance and usability across the board.
New Models Added
We’ve added support for the following vision and multimodal models, unlocking new capabilities in computer vision and vision-language tasks:
- Aya Vision – A high-performance vision model with strong generalization capabilities
- Cohere2 – Enhanced visual reasoning and feature extraction (LLM/VLM via AyaVision)
- LLaVA – A Vision-Language model for image-grounded understanding
- SigLip – State-of-the-art self-supervised learning for visual representations
Architecture & Performance Enhancements
We've made several improvements to streamline the framework and improve efficiency:
- Unified Configuration Handling: Refactored configuration methods to ensure consistency across all modules, reducing redundant code and making customization easier.
- Lazy Imports for Faster Startup: Implemented lazy loading of dependencies, significantly reducing initialization time and improving integration flexibility.
- Extended VLM Support: Expanded Vision-Language Model (VLM) support throughout
vinference
core and API server, enabling seamless inference and integration.
Technical Maintenance & Cleanup
- Removed deprecated code to enhance maintainability and keep the codebase clean.
- Improved internal documentation and structured error handling for more robust deployments.
To use EasyDeL for Vision-Language inference with models like LLaVA, follow this setup:
🔧 Installation
Ensure you have EasyDeL installed:
pip install "easydel[all]==0.1.1"
🚀 Running a Vision-Language Model
Here’s a minimal script to load and serve a VLM using EasyDeL:
import easydel as ed
import jax
from jax import numpy as jnp
from transformers import AutoProcessor
def main():
sharding_axis_dims = (1, 1, -1, 1) # DP, FSDP, TP, SP
prefill_length = 8192 - 1024
max_new_tokens = 1024
max_length = max_new_tokens + prefill_length
pretrained_model_name_or_path = "llava-hf/llava-1.5-7b-hf"
dtype = jnp.bfloat16
param_dtype = jnp.bfloat16
partition_axis = ed.PartitionAxis()
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path)
processor.padding_side = "left"
model = ed.AutoEasyDeLModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path,
auto_shard_model=True,
sharding_axis_dims=sharding_axis_dims,
config_kwargs=ed.EasyDeLBaseConfigDict(
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
attn_dtype=param_dtype,
attn_mechanism=ed.AttentionMechanisms.AUTO,
),
quantization_method=ed.EasyDeLQuantizationMethods.NONE,
param_dtype=param_dtype,
dtype=dtype,
partition_axis=partition_axis,
precision=jax.lax.Precision.DEFAULT,
)
inference = ed.vInference(
model=model,
processor_class=processor,
generation_config=ed.vInferenceConfig(
max_new_tokens=max_new_tokens,
temperature=0.8,
do_sample=True,
top_p=0.95,
top_k=10,
eos_token_id=model.generation_config.eos_token_id,
streaming_chunks=32,
num_return_sequences=1,
),
inference_name="easydelvlm",
)
inference.precompile(
ed.vInferencePreCompileConfig(
batch_size=1,
prefill_length=prefill_length,
vision_included=True,
vision_batch_size=1,
vision_channels=3,
vision_height=336,
vision_width=336,
)
)
ed.vInferenceApiServer(inference, max_workers=10).fire()
if __name__ == "__main__":
main()
🌍 Example API Request
Once your model is running, you can query it using an OpenAI-compatible API format:
{
"model": "easydelvlm",
"messages": [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
},
{
"type": "text",
"text": "Describe this image in detail."
}
]
}
],
"temperature": 1.0,
"top_p": 1.0,
"max_tokens": 16
}
This release marks a major milestone in EasyDeL's evolution, ensuring that JAX users have access to state-of-the-art vision and multimodal AI models with optimized performance. 🚀
EasyDeL v0.1.0 - No Trade-Off – Unleashing Uncompromised Performance & Modular Magic
We’re pleased to introduce EasyDeL v0.1.0—a significant update that improves our framework’s performance, modularity, and integration capabilities. This release brings several important changes and enhancements, ensuring a smoother and more flexible experience for model training and inference.
Introduction
EasyDeL v0.1.0 marks a solid step forward in our journey. With a renewed focus on modularity, distributed training, and improved API integrations, this release aims to offer better support for your research and development needs without overpromising. We continue to work hard on making EasyDeL a dependable tool for deep learning and machine learning tasks.
New Core Components
NNX Flax API Integration
- What’s New:
We have replaced the previous Linen-based implementation with the NNX Flax API. - Benefits:
- More efficient computation graphs and cleaner API design.
- Enhanced flexibility for customization and future extensions.
vInference Engine & vInferenceAPIServer
- vInference Engine:
A new component designed to deliver reliable model inference with low latency. - vInferenceAPIServer:
Provides an OpenAI-compatible interface to make model deployment straightforward. - Key Points:
- Better integration with production environments.
- Improved logging and monitoring features.
Distributed Training and Scalability
Support for Ray and MultiSlice
- Enhanced Distribution:
EasyDeL now supports distribution with Ray and MultiSlice, making it easier to scale training workloads across multiple nodes or GPUs. - Impact:
- More efficient resource utilization.
- Reduced training times for larger models in distributed settings.
Expanded Trainer Suite
New and Enhanced Trainers
- GRPO Trainers:
Introduced to help manage more advanced training scenarios. - Reward Model Trainers:
Added support for reinforcement learning and preference-based training. - Bug Fixes:
Important fixes have been applied to ORPO and DPO trainers to improve overall stability and reliability. - Overall Improvements:
Enhanced logging, improved error handling, and more configurable options have been integrated to make the training process more predictable and user-friendly.
Attention Mechanism and Performance Enhancements
Bug Fixes and Optimizations
- Attention Mechanisms:
Resolved issues in Flash Attention (GPU/TPU) and Splash Attention (TPU) to ensure smoother operations. - Performance:
Fine-tuned kernel launch times, memory management, and synchronization across devices for a modest but valuable performance boost. - Dynamic Quantization:
Continued improvements in support for various quantization methods (NF4, A8BIT, A8Q, A4Q) offer a better balance between model size and inference speed.
Extended Model Support
New and Updated Models
- DeepSeekV3:
We’ve added support for DeepSeekV3, keeping up with emerging model architectures. - General Model Expansion:
Additional new models have been integrated, ensuring that EasyDeL remains compatible with a wider range of model types.
Modularity and Hackability
A More Modular Codebase
- Improved Structure:
The codebase has been refactored into clearer, well-organized modules and functions, making it easier for developers to navigate and customize. - Customization:
Whether modifying trainer behavior or integrating new models, the enhanced modular design allows changes without impacting overall system stability. - Community Focus:
We encourage developers and researchers to explore and extend the framework in ways that best suit their projects.
Additional Improvements & Bug Fixes
- Documentation Updates:
In-line documentation and external resources have been refreshed to reflect these changes. - Stability Enhancements:
A number of bug fixes across trainers, attention mechanisms, and hardware-specific operations lead to a more reliable framework. - Developer Experience:
Enhanced error messages and detailed logging have been implemented to simplify troubleshooting and further development. - API Consistency:
Internal APIs have been standardized and better documented for smoother integration with external tools.
Looking Ahead
EasyDeL v0.1.0 sets a strong foundation for future improvements. Upcoming updates will continue to expand support for distributed training, integrate additional models, and further refine the user and developer experience.
Full Changelog: 0.0.80...0.1.0
EasyDeL version 0.0.80
EasyDeL 0.0.80 brings enhanced flexibility, expanded model support, and improved performance with the introduction of vInference and optimized GPU/TPU integration. This version offers a significant speed and performance boost, with benchmarks showing improvements of over 4.9%, making EasyDeL more dynamic and easier to work with.
New Features:
- Platform and Backend Flexibility: Users can now specify the platform (e.g., TRITON) and backend (e.g., GPU) to optimize their workflows.
Expanded Model Support: We have added support for new models including olmo2, qwen2_moe, mamba2, and others, enhancing the tool's versatility. - Enhanced Trainers: Trainers are now more customizable and hackable, providing greater flexibility for project-specific needs.
- New Trainer Types: Introduced sequence-to-sequence trainers and sequence classification trainers to support a wider range of training tasks.
- vInference Engine: A robust inference engine for LLMs with Long-Term Support (LTS), ensuring stability and reliability.
- vInferenceApiServer: A backend for the inference engine that is fully compatible with OpenAI APIs, facilitating easy integration.
- Optimized GPU Integration: Leverages custom, direct TRITON calls for improved GPU performance, speeding up processing times.
- Dynamic Quantization Support: Added support for quantization types NF4, A8BIT, A8Q, and A4Q, enabling efficiency and scalability.
Performance Improvements:
- EasyDeL 0.0.80 has been optimized for speed and performance, with benchmarks showing improvements of over 4.9% compared to previous versions.
- The tool is now more dynamic and easier to work with, enhancing the overall user experience.
This release is a significant step forward in making EasyDeL a more powerful and flexible tool for machine learning tasks. We look forward to your feedback and continued support.
Documentation:
Comprehensive documentation is available at https://easydel.readthedocs.io/en/latest/
Example Usage:
Load any of the 40+ available models with EasyDeL:
sharding_axis_dims = (1, 1, 1, -1) # sequence sharding for better inference and training
max_length = 2**15
pretrained_model_name_or_path = "AnyEasyModel"
dtype = jnp.float16
model, params = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
input_shape=(len(jax.devices()), max_length),
auto_shard_params=True,
sharding_axis_dims=sharding_axis_dims,
config_kwargs=EasyDeLBaseConfigDict(
use_scan_mlp=False,
attn_dtype=jnp.float16,
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
attn_mechanism=ed.AttentionMechanisms.VANILLA,
kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.A8BIT,
use_sharded_kv_caching=False,
gradeint_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
),
quantization_method=ed.EasyDeLQuantizationMethods.NF4,
quantization_block_size=256,
platform=ed.EasyDeLPlatforms.TRITON,
partition_axis=ed.PartitionAxis(),
param_dtype=dtype,
dtype=dtype,
precision=lax.Precision("fastest"),
)
This release marks a significant advancement in making EasyDeL a more powerful and flexible tool for machine learning tasks. We look forward to your feedback and continued support.
Note
This might be the last release of EasyDeL that incorporates HF/Flax modules. In future versions, EasyDeL will transition to its own base
modules and may adopt Equinox or Flax NNX, provided that NNX meets sufficient performance standards. Users are encouraged to
provide feedback on this direction.
This release represents a significant step forward in making EasyDeL a more powerful and flexible tool for machine learning tasks. We
look forward to your feedback and continued support.
EasyDeL version 0.0.69
This release brings significant scalability improvements, new models, bug fixes, and usability enhancements to EasyDeL.
Highlights:
- Multi-host GPU Training: EasyDeL now scales seamlessly across multiple GPUs and hosts for demanding training workloads.
- New Models: Expand your NLP arsenal with the addition of Gemma2, OLMo, and Aya models.
- Improved KV Cache Quantization: Enjoy a substantial accuracy boost with enhanced KV cache quantization, achieving +21% accuracy compared to the previous version.
- Simplified Model Management: Load and save pretrained models effortlessly using the new
model.from_pretrained
andmodel.save_pretrained
methods. - Enhanced Generation Pipeline: The
GenerationPipeLine
now supports streaming token generation, ideal for real-time applications. - Introducing the ApiEngine: Leverage the power of the new
ApiEngine
andengine_client
for seamless integration with your applications.
Other Changes:
- Fixed GPU Flash Attention bugs for increased stability.
- Updated required
jax
version to>=0.4.28
for optimal performance. Versions0.4.29
or higher are recommended if available. - Streamlined the
structure
import process and resolved multi-host training issues.
Upgrade:
To upgrade to EasyDeL v0.0.69, use the following command:
pip install --upgrade easydel==0.0.69
EasyDeL - 0.0.67
-
New Features
GenerationPipeLine
was added for fast streaming and easy generation with JAX.- Using Int8Params instead of
LinearBitKernel
. - Better GPU support.
- EasyDeLState is now better and supports more general options.
- Trainers now support
.save_pretrained(to_torch)
and training logging. - EasyDeLState supports to_8bit.
- All of the models support
to_8bit
for params. - imports are now 91x times faster in EasyDeL version 0.0.67.
-
Removed API
JAXServe
is no longer available.PyTorchServe
is no longer available.EasyServe
is no longer available.LinearBitKernel
is no longer available.EasyDeL
partitioners are no longer available.Llama/Mistral/Falcon/Mpt
static convertors or transforms are no longer available.
-
Known Issues
- Lora Kernel Sometimes Crash.
GenerationPipeLine
has a compiling problem when the number of available devices is more than 4 and using 8_bit params.- Most of the features won't work for TPU-v3 and GPUs with compute capability lower than 7.5.
- Kaggle session will crash after importing EasyDeL (Kaggle's latest environment is not stable it's not related to EasyDeL). (Fixed in EasyDeL version 0.0.67)
Pallas Fusion: GPU Turbocharged 🚀
EasyDeL version 0.0.65
-
New Features
- Pallas Flash Attention on CPU/GPU/TPU via FJFormer and supports bias.
- ORPO Trainer is added and now it's in your bag.
- WebSocket Serve Engine.
- Now EasyDeL is 30% faster on GPUs.
- No JAX-Triton is now needed to run GPU kernels.
- Now you can specify the backward kernel implementation for Pallas Attention.
- now you have to import EasyDeL as
easydel
instead ofEasyDel
.
-
New Models
- OpenELM model series are now present.
- DeepseekV2 model series are now present.
-
Fixed Bugs
- CUDNN FlashAttention Bugs are now fixed.
- Llama3 Model 8Bit quantization of parameters had a lot of improvements.
- Splash Attention bugs on TPUs are now fixed .
- Dbrx Model Bugs are fixed.
- DPOTrainer Bugs are Fixed (creating dataset).
-
Known Bugs
- Splash Attention won't work on TPUv3.
- Pallas Attention won't work on TPUv3.
- You need to install flash_attn in order to convert HF DeepseekV2 to EasyDeL (bug in DeepseekV2 implementation from original authors).
- Some Examples are out dated.
Full Changelog: 0.0.63...0.0.65
0.0.63
whats changed
- Phi3 Model Added.
- Dbrx Model Added.
- Arctic Model Added.
- Lora Fine-Tuning Bugs Fixed.
- Vanilla Attention is Optimized.
- Sharded Vanilla is the default attention mechanism now.
Full Changelog: 0.0.61...0.0.63
EasyDeL-0.0.61 Dynamic Changes
What's Changed
- Add support for iterable dataset loading by @yhavinga in #138
SFTTrainer
bugs are fixed.Parameter quantization
is now available for all of the models.AutoEasyDeLModelForCausalLM
now supportsload_in_8bit
.- Memory Management improved.
Gemma
Models Generation Issue is now Fixed.- Trainers are now 2~8% faster.
- Attention Operation is improved.
- The
Cohere
Model is now present. JAXServer
is improved.- Due to recent changes a lot of examples of documentation have changed and will be changed soon.
Full Changelog: 0.0.60...0.0.61
EasyDeL Version 0.0.60
What's Changed
SFTTrainer
is now available.VideoCausalLanguageModelTrainer
is now available.- New models such as Grok-1, Qwen2Moe, Mamba, Rwkv, and Whisper are available.
- MoE models had some speed improvements.
- Training Speed is now 18%~42% faster.
- Normal Attention is now faster by 12%~30% #131 .
- DPOTrainer Bugs Fixed.
- CausalLanguageModelTrainer is now more customizable.
- WANDB logging has improved.
- Performace Mode is added to Training Arguments.
- Model configs pass attributes to PretrainedConfig to prevent override… by @yhavinga in #122
- Ignore token label smooth z loss by @yhavinga in #123
- Time the whole train loop instead of only call to train step function by @yhavinga in #124
- Add save_total_limit argument to delete older checkpoints by @yhavinga in #127
- Add gradient norm logging, fix metric collection on multi-worker setup by @yhavinga in #135
Full Changelog: 0.0.55...0.0.60