3. 如何添加新的评测算法

3.1. 攻击算法存储位置

文本模块 text 下的文件结构如下

├── EvalBox
│   ├── **Attack**
│   │   ├──attack.py
│   │   ├──a2t.py
│   │   ├──bae.py
│   │   ├──....py
├── Models
├── utils
├── test
├── Datasets

3.2. 扩展实例 —— TextFooler 算法

TextFooler 算法路径为:

~/AI-Testing/text/EvalBox/Attack/text_fooler.py

TextFooler 算法源代码:

import time
from typing import NoReturn, Any, Optional

import torch

from .attack import Attack
from ...Models.base import NLPVictimModel
from ...utils.constraints import (
  MaxWordsPerturbed,
  PartOfSpeech,
  WordEmbeddingDistance,
  SentenceEncoderBase,
  MultilingualUSE,
  InputColumnModification,
  RepeatModification,
  StopwordModification,
)
from ...utils.goal_functions import UntargetedClassification
from ...utils.search_methods import WordImportanceRanking
from ...utils.transformations import WordEmbeddingSubstitute
from ...utils.assets import fetch
from ...utils.strings import normalize_language, LANGUAGE
from ...utils.word_embeddings import (  # noqa: F401
  GensimWordEmbedding,
  CounterFittedEmbedding,
  ChineseWord2Vec,
  TencentAILabEmbedding,
)
from ...Models.TestModel.bert_amazon_zh import VictimBERTAmazonZH
from ...Models.TestModel.roberta_sst import VictimRoBERTaSST


class TextFooler(Attack):

  __name__ = "TextFooler"

  def __init__(
    self,
    model: Optional[NLPVictimModel] = None,
    device: Optional[torch.device] = None,
    language: str = "zh",
    **kwargs: Any,
  ) -> NoReturn:
    self.__initialized = False
    self.__valid_params = {
      "max_candidates",
      "max_perturb_percent",
      "min_cos_sim",
      "use_threshold",
      "verbose",
    }
    self._language = normalize_language(language)
    self._model = model
    if not self._model:
      start = time.time()
      if self._language == LANGUAGE.CHINESE:
        print("load default Chinese victim model...")
        self._model = VictimBERTAmazonZH()
      else:
        print("load default English victim model...")
        self._model = VictimRoBERTaSST()
      print(f"default model loaded in {time.time()-start:.2f} seconds.")
    self._device = device
    self._parse_params(**kwargs)

  def _parse_params(self, **kwargs):
    assert set(kwargs).issubset(self.__valid_params)
    if self.__initialized:
      self.__reset_params(**kwargs)
    self.__parse_params(**kwargs)
    self.__initialized = True

  def __parse_params(self, **kwargs):
    verbose = kwargs.get("verbose", 0)
    print("start initializing attacking parameters...")
    start = time.time()
    print("loading stopwords and word embedding...")
    if self._language == LANGUAGE.CHINESE:
      stopwords = fetch("stopwords_zh")
      embedding = TencentAILabEmbedding()
    elif self._language == LANGUAGE.ENGLISH:
      stopwords = fetch("stopwords_en")
      embedding = CounterFittedEmbedding()
    else:
      raise ValueError(f"暂不支持语言 {self._language.name.lower()}")
    print(f"stopwords and word embedding loaded in {time.time()-start:.2f} seconds")

    constraints = [RepeatModification(), StopwordModification(stopwords=stopwords)]

    input_column_modification = InputColumnModification(
      ["premise", "hypothesis"], {"premise"}
    )
    constraints.append(input_column_modification)

    min_cos_sim = kwargs.get("min_cos_sim", 0.5)
    constraints.append(WordEmbeddingDistance(embedding, min_cos_sim=min_cos_sim))

    _pos_tagger = {
      LANGUAGE.ENGLISH: "nltk",
      LANGUAGE.CHINESE: "jieba",
    }[self._language]
    if self._language == LANGUAGE.ENGLISH:
      constraints.append(
        PartOfSpeech(
          language=self._language,
          tagger=_pos_tagger,
          allow_verb_noun_swap=True,
        )
      )

    use_threshold = kwargs.get("use_threshold", 0.840845057)
    start = time.time()
    print("loading universal sentence encoder...")
    use_constraint = MultilingualUSE(
      threshold=use_threshold,
      metric="angular",
      compare_against_original=False,
      window_size=15,
      skip_text_shorter_than_window=True,
    )
    print(f"universal sentence encoder loaded in {time.time()-start:.2f} seconds")
    constraints.append(use_constraint)

    max_perturb_percent = kwargs.get("max_perturb_percent", 0.2)
    constraints.append(MaxWordsPerturbed(max_percent=max_perturb_percent))

    goal_function = UntargetedClassification(self._model)
    search_method = WordImportanceRanking(wir_method="delete", verbose=verbose)

    max_candidates = kwargs.get("max_candidates", 30)
    transformation = WordEmbeddingSubstitute(
      embedding,
      max_candidates=max_candidates,
      verbose=verbose,
    )

    super().__init__(
      model=self._model,
      device=self._device,
      IsTargeted=False,
      goal_function=goal_function,
      constraints=constraints,
      transformation=transformation,
      search_method=search_method,
      language=self._language,
      verbose=verbose,
    )

  def __reset_params(self, **kwargs: Any) -> NoReturn:
    if "max_candidates" in kwargs:
      self.transformation.max_candidates = kwargs.get("max_candidates")
    if "max_perturb_percent" in kwargs:
      for c in self.constraints:
        if isinstance(c, MaxWordsPerturbed):
          c.max_percent = kwargs.get("max_perturb_percent")
    if "min_cos_sim" in kwargs:
      for c in self.constraints:
        if isinstance(c, WordEmbeddingDistance):
          c.min_cos_sim = kwargs.get("min_cos_sim")
    if "use_threshold" in kwargs:
      for c in self.constraints:
        if isinstance(c, (SentenceEncoderBase, MultilingualUSE)):
          c.threshold = kwargs.get("use_threshold")
    self.verbose = kwargs.get("verbose", self.verbose)

  def prepare_data(self):
    raise NotImplementedError("请勿调用此方法")

  def generate(self):
    raise NotImplementedError("请勿调用此方法")

3.3. 扩展说明

  1. 用户需要实现个人攻击算法,并继承基础的 Attack

  2. 用户需要将待扩展的攻击算法对应文件,如 new_attack_method.py,放置于以下路径中

~/AI-Testing/text/EvalBox/Attack/
  1. (可选) 用户可以在 /AI-Testing/text/const.py 中将自定义的攻击算法加入到字典 ATTACK_RECIPES 以及列表 RECOMMENDED_RECIPES 中,方便使用命令行工具 text/cli.py 调用。(不加入也能使用,但是会触发警告(warning))

  2. (可选) 用户可以在 /AI-Testing/text/EvalBox/Attack/default_config.yml 中将自定义的攻击算法的参数配置写入其中,方便后续修改。