10,2 тыс подписчиков
🌟 Zamba2-Instruct: две гибридные SLM на 2.7 и 1.2 млрд. параметров.
Zamba2-Instruct - семейство инструктивных моделей на архитектуре Mamba2+Transformers для NLP-задач.
В семействе 2 модели:
Высокая производительность семейства по сравнению с релевантными Transformers-only моделями достигается за счет конкатенации эмбедингов модели с входными данными для блока внимания и использование LoRA projection matrices к общему MLP-слою.
Модели файнтюнились (SFT+DPO) на instruct-ориентированных наборах данных (ultrachat_200k, Infinity-Instruct, ultrafeedback_binarized, orca_dpo_pairs и OpenHermesPreferences).
Тесты Zamba2-Instruct продемонстрировали внушительную скорость генерации текста и эффективное использование памяти, обходя MT-bench более крупные по количеству параметров модели/ (Zamba2-Instruct-2.7B превзошла Mistral-7B-Instruct-v0.1, а Zamba2-Instruct-1.2B - Gemma2-2B-Instruct)
⚠️ Для запуска на СPU укажите use_mamba_kernels=False при загрузке модели с помощью AutoModelForCausalLM.from_pretrained.
▶️Локальная установка и инференс Zamba2-2.7B-Instruct:
# Clone repo
git clone https://github.com/Zyphra/transformers_zamba2.git
cd transformers_zamba2
# Install the repository & accelerate:
pip install -e .
pip install accelerate
# Inference:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-2.7B-instruct")
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-2.7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16)
user_turn_1 = "user_prompt1."
assistant_turn_1 = "assistant_prompt."
user_turn_2 = "user_prompt2."
sample = [{'role': 'user', 'content': user_turn_1}, {'role': 'assistant', 'content': assistant_turn_1}, {'role': 'user', 'content': user_turn_2}]
chat_sample = tokenizer.apply_chat_template(sample, tokenize=False)
input_ids = tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=150, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
print((tokenizer.decode(outputs[0])))
📌Лицензирование : Apache 2.0 License.
#AI #ML #SLM #Zamba2 #Instruct
1 минута
20 октября 2024