4. 如何添加新的评测算法¶
4.1. 攻击算法存储位置¶
文本模块 text
下的文件结构如下
├── EvalBox
│ ├── **Attack**
│ │ ├──attack.py
│ │ ├──a2t.py
│ │ ├──bae.py
│ │ ├──....py
├── Models
├── utils
├── test
├── Datasets
4.2. 扩展实例 —— TextFooler 算法¶
TextFooler
算法路径为:
~/AI-Testing/text/EvalBox/Attack/text_fooler.py
TextFooler
算法源代码:
import time
from typing import Any, Optional
import torch
from .attack import Attack
from ...Models.base import NLPVictimModel
from ...utils.constraints import (
MaxWordsPerturbed,
PartOfSpeech,
WordEmbeddingDistance,
SentenceEncoderBase,
MultilingualUSE,
InputColumnModification,
RepeatModification,
StopwordModification,
)
from ...utils.goal_functions import UntargetedClassification
from ...utils.search_methods import WordImportanceRanking
from ...utils.transformations import WordEmbeddingSubstitute
from ...utils.assets import fetch
from ...utils.strings import normalize_language, LANGUAGE
from ...utils.word_embeddings import ( # noqa: F401
GensimWordEmbedding,
CounterFittedEmbedding,
ChineseWord2Vec,
TencentAILabEmbedding,
)
from ...Models.TestModel.bert_amazon_zh import VictimBERTAmazonZH
from ...Models.TestModel.roberta_sst import VictimRoBERTaSST
class TextFooler(Attack):
__name__ = "TextFooler"
def __init__(
self,
model: Optional[NLPVictimModel] = None,
device: Optional[torch.device] = None,
language: str = "zh",
**kwargs: Any,
) -> None:
self.__initialized = False
self.__valid_params = {
"max_candidates",
"max_perturb_percent",
"min_cos_sim",
"use_threshold",
"verbose",
}
self._language = normalize_language(language)
self._model = model
if not self._model:
start = time.time()
if self._language == LANGUAGE.CHINESE:
print("load default Chinese victim model...")
self._model = VictimBERTAmazonZH()
else:
print("load default English victim model...")
self._model = VictimRoBERTaSST()
print(f"default model loaded in {time.time()-start:.2f} seconds.")
self._device = device
self._parse_params(**kwargs)
def _parse_params(self, **kwargs):
assert set(kwargs).issubset(self.__valid_params)
if self.__initialized:
self.__reset_params(**kwargs)
self.__parse_params(**kwargs)
self.__initialized = True
def __parse_params(self, **kwargs):
verbose = kwargs.get("verbose", 0)
print("start initializing attacking parameters...")
start = time.time()
print("loading stopwords and word embedding...")
if self._language == LANGUAGE.CHINESE:
stopwords = fetch("stopwords_zh")
embedding = TencentAILabEmbedding()
elif self._language == LANGUAGE.ENGLISH:
stopwords = fetch("stopwords_en")
embedding = CounterFittedEmbedding()
else:
raise ValueError(f"暂不支持语言 {self._language.name.lower()}")
print(f"stopwords and word embedding loaded in {time.time()-start:.2f} seconds")
constraints = [RepeatModification(), StopwordModification(stopwords=stopwords)]
input_column_modification = InputColumnModification(
["premise", "hypothesis"], {"premise"}
)
constraints.append(input_column_modification)
min_cos_sim = kwargs.get("min_cos_sim", 0.5)
constraints.append(WordEmbeddingDistance(embedding, min_cos_sim=min_cos_sim))
_pos_tagger = {
LANGUAGE.ENGLISH: "nltk",
LANGUAGE.CHINESE: "jieba",
}[self._language]
if self._language == LANGUAGE.ENGLISH:
constraints.append(
PartOfSpeech(
language=self._language,
tagger=_pos_tagger,
allow_verb_noun_swap=True,
)
)
use_threshold = kwargs.get("use_threshold", 0.840845057)
start = time.time()
print("loading universal sentence encoder...")
use_constraint = MultilingualUSE(
threshold=use_threshold,
metric="angular",
compare_against_original=False,
window_size=15,
skip_text_shorter_than_window=True,
)
print(f"universal sentence encoder loaded in {time.time()-start:.2f} seconds")
constraints.append(use_constraint)
max_perturb_percent = kwargs.get("max_perturb_percent", 0.2)
constraints.append(MaxWordsPerturbed(max_percent=max_perturb_percent))
goal_function = UntargetedClassification(self._model)
search_method = WordImportanceRanking(wir_method="delete", verbose=verbose)
max_candidates = kwargs.get("max_candidates", 30)
transformation = WordEmbeddingSubstitute(
embedding,
max_candidates=max_candidates,
verbose=verbose,
)
super().__init__(
model=self._model,
device=self._device,
IsTargeted=False,
goal_function=goal_function,
constraints=constraints,
transformation=transformation,
search_method=search_method,
language=self._language,
verbose=verbose,
)
def __reset_params(self, **kwargs: Any) -> None:
if "max_candidates" in kwargs:
self.transformation.max_candidates = kwargs.get("max_candidates")
if "max_perturb_percent" in kwargs:
for c in self.constraints:
if isinstance(c, MaxWordsPerturbed):
c.max_percent = kwargs.get("max_perturb_percent")
if "min_cos_sim" in kwargs:
for c in self.constraints:
if isinstance(c, WordEmbeddingDistance):
c.min_cos_sim = kwargs.get("min_cos_sim")
if "use_threshold" in kwargs:
for c in self.constraints:
if isinstance(c, (SentenceEncoderBase, MultilingualUSE)):
c.threshold = kwargs.get("use_threshold")
self.verbose = kwargs.get("verbose", self.verbose)
def prepare_data(self):
raise NotImplementedError("请勿调用此方法")
def generate(self):
raise NotImplementedError("请勿调用此方法")
4.3. 扩展说明¶
用户需要实现个人攻击算法,并继承基础的
Attack
类用户需要将待扩展的攻击算法对应文件,如
new_attack_method.py
,放置于以下路径中
~/AI-Testing/text/EvalBox/Attack/
(可选) 用户可以在
/AI-Testing/text/const.py
中将自定义的攻击算法加入到字典ATTACK_RECIPES
以及列表RECOMMENDED_RECIPES
中,方便使用命令行工具text/cli.py
调用。(不加入也能使用,但是会触发警告(warning))(可选) 用户可以在
/AI-Testing/text/EvalBox/Attack/default_config.yml
中将自定义的攻击算法的参数配置写入其中,方便后续修改。