Переопределение интерфейса взаимодействия с моделью в TextAttack Наследуемся от родительского класса ModelWrapper и переопределяем __call__: import requests import json import numpy as np from textattack.models.wrappers import ModelWrapper class APIMockModelRemoteBinary(ModelWrapper): def __init__(self, url): self.url = url def __call__(self, text_input_list): return self.send_requests(text_input_list) def send_requests(self, text_input_list): responses = [] for text in text_input_list: try: payload = {"text": text} response = requests.post( self.url, json=payload, timeout=5 ) response.raise_for_status() prediction_data = response.json() prediction = prediction_data.get("prediction") if prediction is None: raise ValueError("Ответ не содержит поле 'prediction'") # Преобразование в формат [[negative_prob, positive_prob]] responses.append([1 - prediction, prediction]) except requests.exceptions.RequestException as e: print(f"Ошибка запроса: {e}") # Возвращаем нейтральные вероятности в случае ошибки responses.append([0.5, 0.5]) except (ValueError, KeyError) as e: print(f"Ошибка обработки ответа: {e}") responses.append([0.5, 0.5]) return np.array(responses) Как использовать: from textattack.attack_recipes import TextFoolerJin2019, DeepWordBugGao2018 from textattack.datasets import Dataset from textattack.attack_results import SuccessfulAttackResult from textattack import Attacker, AttackArgs text = "Этот какнал просто потрясающий!" label = 1 # Создаем датасет с одним примером dataset = Dataset([(text, label)]) # URL вашего API url = "http://127.0.0.1:8000/predict" # Инициализация обертки модели model_wrapper = APIMockModelRemoteBinary(url) # Настройки атаки attack_args = AttackArgs( num_examples=1, # Сколько примеров атаковать checkpoint_interval=5, parallel=False, # Отключаем параллелизм при первом запуске disable_stdout=True # Отключаем вывод в консоль ) # Список атакующих рецептов attack_recipes = [ TextFoolerJin2019, DeepWordBugGao2018 ] for recipe in attack_recipes: print(f"\nЗапуск атаки: {recipe.__name__}") attack = recipe.build(model_wrapper) attacker = Attacker(attack, dataset, attack_args) attack_results = attacker.attack_dataset() # Анализ результатов for result in attack_results: if isinstance(result, SuccessfulAttackResult): print(f"Успешная атака: {result.perturbed_text()}") else: print("Атака не удалась") Таким образом, код позволяет выполнить атаку на модель машинного обучения, которая доступна через REST API. Наслаждаемся🍷 #MLSecOps
1 неделю назад