Найти тему
10,2 тыс подписчиков

🌟 Zamba2-Instruct: две гибридные SLM на 2.7 и 1.2 млрд. параметров.


Zamba2-Instruct - семейство инструктивных моделей на архитектуре Mamba2+Transformers для NLP-задач.

В семействе 2 модели:


Высокая производительность семейства по сравнению с релевантными Transformers-only моделями достигается за счет конкатенации эмбедингов модели с входными данными для блока внимания и использование LoRA projection matrices к общему MLP-слою.

Модели файнтюнились (SFT+DPO) на instruct-ориентированных наборах данных (ultrachat_200k, Infinity-Instruct, ultrafeedback_binarized, orca_dpo_pairs и OpenHermesPreferences).

Тесты Zamba2-Instruct продемонстрировали внушительную скорость генерации текста и эффективное использование памяти, обходя MT-bench более крупные по количеству параметров модели/ (Zamba2-Instruct-2.7B превзошла Mistral-7B-Instruct-v0.1, а Zamba2-Instruct-1.2B - Gemma2-2B-Instruct)

⚠️ Для запуска на СPU укажите use_mamba_kernels=False при загрузке модели с помощью AutoModelForCausalLM.from_pretrained.

▶️Локальная установка и инференс Zamba2-2.7B-Instruct:

# Clone repo
git clone https://github.com/Zyphra/transformers_zamba2.git
cd transformers_zamba2

# Install the repository & accelerate:
pip install -e .
pip install accelerate

# Inference:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-2.7B-instruct")
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-2.7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16)

user_turn_1 = "user_prompt1."
assistant_turn_1 = "assistant_prompt."
user_turn_2 = "user_prompt2."
sample = [{'role': 'user', 'content': user_turn_1}, {'role': 'assistant', 'content': assistant_turn_1}, {'role': 'user', 'content': user_turn_2}]
chat_sample = tokenizer.apply_chat_template(sample, tokenize=False)

input_ids = tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=150, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
print((tokenizer.decode(outputs[0])))

📌Лицензирование : Apache 2.0 License.

🖥GitHub


#AI #ML #SLM #Zamba2 #Instruct
1 минута