Exemplo n.º 1
0
    def __init__(
        self,
        vocab: Vocabulary,
        predictor: Model,
        adversary: Model,
        bias_direction: BiasDirectionWrapper,
        predictor_output_key: str,
        **kwargs,
    ):
        super().__init__(vocab, **kwargs)

        self.predictor = predictor
        self.adversary = adversary

        # want to keep adversary label hook during evaluation
        embedding_layer = find_embedding_layer(self.predictor)
        self.bias_direction = bias_direction
        self.predetermined_bias_direction = self.bias_direction(
            embedding_layer)
        self._adversary_label_hook = _AdversaryLabelHook(
            self.predetermined_bias_direction)
        embedding_layer.register_forward_hook(self._adversary_label_hook)

        self.vocab = self.predictor.vocab
        self._regularizer = self.predictor._regularizer

        self.predictor_output_key = predictor_output_key
Exemplo n.º 2
0
    def _construct_embedding_matrix(self) -> Embedding:
        """
        For HotFlip, we need a word embedding matrix to search over. The below is necessary for
        models such as ELMo, character-level models, or for models that use a projection layer
        after their word embeddings.

        We run all of the tokens from the vocabulary through the TextFieldEmbedder, and save the
        final output embedding. We then group all of those output embeddings into an "embedding
        matrix".
        """
        embedding_layer = util.find_embedding_layer(self.predictor._model)
        self.embedding_layer = embedding_layer
        if isinstance(embedding_layer, (Embedding, torch.nn.modules.sparse.Embedding)):
            # If we're using something that already has an only embedding matrix, we can just use
            # that and bypass this method.
            return embedding_layer.weight

        # We take the top `self.max_tokens` as candidates for hotflip.  Because we have to
        # construct a new vector for each of these, we can't always afford to use the whole vocab,
        # for both runtime and memory considerations.
        all_tokens = list(self.vocab._token_to_index[self.namespace])[: self.max_tokens]
        max_index = self.vocab.get_token_index(all_tokens[-1], self.namespace)
        self.invalid_replacement_indices = [
            i for i in self.invalid_replacement_indices if i < max_index
        ]

        inputs = self._make_embedder_input(all_tokens)

        # pass all tokens through the fake matrix and create an embedding out of it.
        embedding_matrix = embedding_layer(inputs).squeeze()

        return embedding_matrix
Exemplo n.º 3
0
    def test_get_gradients_when_requires_grad_is_false(self):
        inputs = {
            "sentence": "I always write unit tests",
        }

        archive = load_archive(
            self.FIXTURES_ROOT
            / "basic_classifier"
            / "embedding_with_trainable_is_false"
            / "model.tar.gz"
        )
        predictor = Predictor.from_archive(archive)

        # ensure that requires_grad is initially False on the embedding layer
        embedding_layer = util.find_embedding_layer(predictor._model)
        assert not embedding_layer.weight.requires_grad
        instance = predictor._json_to_instance(inputs)
        outputs = predictor._model.forward_on_instance(instance)
        labeled_instances = predictor.predictions_to_labeled_instances(instance, outputs)
        # ensure that gradients are always present, despite requires_grad being false on the embedding layer
        for instance in labeled_instances:
            grads = predictor.get_gradients([instance])[0]
            assert bool(grads)
        # ensure that no side effects remain
        assert not embedding_layer.weight.requires_grad
Exemplo n.º 4
0
    def _register_hooks(self, alpha: int, embeddings_list: List, token_offsets: List):
        """
        Register a forward hook on the embedding layer which scales the embeddings by alpha. Used
        for one term in the Integrated Gradients sum.

        We store the embedding output into the embeddings_list when alpha is zero.  This is used
        later to element-wise multiply the input by the averaged gradients.
        """

        def forward_hook(module, inputs, output):
            # Save the input for later use. Only do so on first call.
            if alpha == 0:
                embeddings_list.append(output.squeeze(0).clone().detach())

            # Scale the embedding by alpha
            output.mul_(alpha)

        def get_token_offsets(module, inputs, outputs):
            offsets = util.get_token_offsets_from_text_field_inputs(inputs)
            if offsets is not None:
                token_offsets.append(offsets)

        # Register the hooks
        handles = []
        embedding_layer = util.find_embedding_layer(self.predictor._model)
        handles.append(embedding_layer.register_forward_hook(forward_hook))
        text_field_embedder = util.find_text_field_embedder(self.predictor._model)
        handles.append(text_field_embedder.register_forward_hook(get_token_offsets))
        return handles
Exemplo n.º 5
0
    def _register_forward_hook(self, alpha: int, embeddings_list: List):

        def forward_hook(module, inputs, output):
            if alpha == 0:
                embeddings_list.append(output.squeeze(0).clone().detach().numpy())

            output.mul_(alpha)

        embedding_layer = util.find_embedding_layer(self.predictor._model)
        handle = embedding_layer.register_forward_hook(forward_hook)
        return handle
Exemplo n.º 6
0
    def _register_embedding_value_hooks(self, embedding_values):
        def forward_hook_layers(module, input, output):
            embedding_values.append(output)

        forward_hooks = []
        embedding_layer = util.find_embedding_layer(self.predictor._model)
        # workaround, otherwise forward hook is not called on embedding
        embedding_layer.weight.requires_grad = True
        forward_hooks.append(
            embedding_layer.register_forward_hook(forward_hook_layers))
        return forward_hooks
Exemplo n.º 7
0
    def _register_embedding_gradient_hooks(self, embedding_gradients):
        def backward_hook_layers(module, grad_in, grad_out):
            embedding_gradients.append(grad_out[0])

        backward_hooks = []
        embedding_layer = util.find_embedding_layer(self.predictor._model)
        # workaround, otherwise forward hook is not called on embedding
        embedding_layer.weight.requires_grad = True
        backward_hooks.append(
            embedding_layer.register_backward_hook(backward_hook_layers))
        return backward_hooks
Exemplo n.º 8
0
    def _register_embedding_value_hook(self, alpha: float, embedding_values):
        def forward_hook(module, inputs, output):
            # Save the input for later use. Only do so on first call.
            if alpha == 0:
                embedding_values.append(output.squeeze(0).clone().detach())

            # Scale the embedding by alpha
            output.mul_(alpha)

        embedding_layer = util.find_embedding_layer(self.predictor._model)
        embedding_layer.weight.requires_grad = True
        return embedding_layer.register_forward_hook(forward_hook)
Exemplo n.º 9
0
def register_embedding_hook(model):
    embedding_layer = util.find_embedding_layer(model)

    # grad_in/grad_out/inputs are tuples, outputs is a tensor
    def fw_hook_layers(EMBEDDING, inputs, outputs):
        ram_append('EMBEDDING_HOOK.fw', outputs)

    def bw_hook_layers(EMBEDDING, grad_in, grad_out):
        ram_append('EMBEDDING_HOOK.bw', grad_out[0])

    fw_hook = embedding_layer.register_forward_hook(fw_hook_layers)
    bw_hook = embedding_layer.register_backward_hook(bw_hook_layers)
    return [fw_hook, bw_hook]
Exemplo n.º 10
0
    def _register_forward_hook(self, embeddings_list: List):
        """
        Finds all of the TextFieldEmbedders, and registers a forward hook onto them. When forward()
        is called, embeddings_list is filled with the embedding values. This is necessary because
        our normalization scheme multiplies the gradient by the embedding value.
        """
        def forward_hook(module, inputs, output):
            embeddings_list.append(output.squeeze(0).clone().detach().numpy())

        embedding_layer = util.find_embedding_layer(self.predictor._model)
        handle = embedding_layer.register_forward_hook(forward_hook)

        return handle
Exemplo n.º 11
0
    def attribute_kwargs(self,
                         captum_inputs: Tuple,
                         mask_features_by_token: bool = False) -> Dict:
        """
        Args:
            captum_inputs (Tuple): result of CaptumCompatible.instances_to_captum_inputs.
            mask_features_by_token (bool, optional): For Captum methods that require a feature mask,
                                                     define each token as a feature if True. If False,
                                                     define each scalar in the embedding dimension as a
                                                     feature (e.g., default behavior in LIME).
                                                     Defaults to False.

        Returns:
            Dict: key-word arguments to be given to the attribute method of the
                  relevant Captum Attribution sub-class.
        """
        inputs, target, additional = captum_inputs
        vocab = self.predictor._model.vocab

        # Manually check for distilbert.
        if isinstance(self.predictor._model,
                      DistilBertForSequenceClassification):
            embedding = self.predictor._model.embeddings
        else:
            embedding = util.find_embedding_layer(self.predictor._model)

        pad_idx = vocab.get_token_index(vocab._padding_token)
        pad_idx = torch.LongTensor([[pad_idx]]).to(inputs[0].device)
        pad_idxs = tuple(
            pad_idx.expand(tensor.size()[:2]) for tensor in inputs)
        baselines = tuple(embedding(idx) for idx in pad_idxs)

        attr_kwargs = {
            'inputs': inputs,
            'target': target,
            'baselines': baselines,
            'additional_forward_args': additional
        }

        # For methods that require a feature mask, define each token as one feature
        if mask_features_by_token:
            # see: https://captum.ai/api/lime.html for the definition of a feature mask
            input_tensor = inputs[0]
            bs, seq_len, emb_dim = input_tensor.shape
            feature_mask = torch.tensor(list(range(bs * seq_len))).reshape(
                [bs, seq_len, 1])
            feature_mask = feature_mask.expand(-1, -1, emb_dim)
            attr_kwargs[
                'feature_mask'] = feature_mask  # (bs, seq_len, emb_dim)

        return attr_kwargs
Exemplo n.º 12
0
    def _construct_embedding_matrix(self) -> Embedding:
        """
        For HotFlip, we need a word embedding matrix to search over. The below is necessary for
        models such as ELMo, character-level models, or for models that use a projection layer
        after their word embeddings.

        We run all of the tokens from the vocabulary through the TextFieldEmbedder, and save the
        final output embedding. We then group all of those output embeddings into an "embedding
        matrix".
        """
        # Gets all tokens in the vocab and their corresponding IDs
        all_tokens = self.vocab._token_to_index[self.namespace]
        all_indices = list(self.vocab._index_to_token[self.namespace].keys())
        all_inputs = {"tokens": torch.LongTensor(all_indices).unsqueeze(0)}

        # A bit of a hack; this will only work with some dataset readers, but it'll do for now.
        indexers = self.predictor._dataset_reader._token_indexers  # type: ignore
        for token_indexer in indexers.values():
            # handle when a model uses character-level inputs, e.g., a CharCNN
            if isinstance(token_indexer, TokenCharactersIndexer):
                tokens = [Token(x) for x in all_tokens]
                max_token_length = max(len(x) for x in all_tokens)
                indexed_tokens = token_indexer.tokens_to_indices(
                    tokens, self.vocab, "token_characters")
                padded_tokens = token_indexer.as_padded_tensor(
                    indexed_tokens, {"token_characters": len(tokens)},
                    {"num_token_characters": max_token_length})
                all_inputs['token_characters'] = torch.LongTensor(
                    padded_tokens['token_characters']).unsqueeze(0)
            # for ELMo models
            if isinstance(token_indexer, ELMoTokenCharactersIndexer):
                elmo_tokens = []
                for token in all_tokens:
                    elmo_indexed_token = token_indexer.tokens_to_indices(
                        [Token(text=token)], self.vocab,
                        "sentence")["sentence"]
                    elmo_tokens.append(elmo_indexed_token[0])
                all_inputs["elmo"] = torch.LongTensor(elmo_tokens).unsqueeze(0)

        embedding_layer = util.find_embedding_layer(self.predictor._model)
        if isinstance(embedding_layer, torch.nn.modules.sparse.Embedding):
            embedding_matrix = embedding_layer.weight
        else:
            # pass all tokens through the fake matrix and create an embedding out of it.
            embedding_matrix = embedding_layer(all_inputs).squeeze()

        return Embedding(num_embeddings=self.vocab.get_vocab_size(
            self.namespace),
                         embedding_dim=embedding_matrix.shape[1],
                         weight=embedding_matrix,
                         trainable=False)
Exemplo n.º 13
0
    def _register_embedding_gradient_hooks(self, embedding_gradients):
        """
        Registers a backward hook on the embedding layer of the model.  Used to save the gradients
        of the embeddings for use in get_gradients()

        When there are multiple inputs (e.g., a passage and question), the hook
        will be called multiple times. We append all the embeddings gradients
        to a list.

        We additionally add a hook on the _forward_ pass of the model's `TextFieldEmbedder` to save
        token offsets, if there are any.  Having token offsets means that you're using a mismatched
        token indexer, so we need to aggregate the gradients across wordpieces in a token.  We do
        that with a simple sum.
        """

        def hook_layers(module, grad_in, grad_out):
            grads = grad_out[0]
            if self._token_offsets:
                # If you have a mismatched indexer with multiple TextFields, it's quite possible
                # that the order we deal with the gradients is wrong.  We'll just take items from
                # the list one at a time, and try to aggregate the gradients.  If we got the order
                # wrong, we should crash, so you'll know about it.  If you get an error because of
                # that, open an issue on github, and we'll see what we can do.  The intersection of
                # multiple TextFields and mismatched indexers is pretty small (currently empty, that
                # I know of), so we'll ignore this corner case until it's needed.
                offsets = self._token_offsets.pop(0)
                span_grads, span_mask = util.batched_span_select(grads.contiguous(), offsets)
                span_mask = span_mask.unsqueeze(-1)
                span_grads *= span_mask  # zero out paddings

                span_grads_sum = span_grads.sum(2)
                span_grads_len = span_mask.sum(2)
                # Shape: (batch_size, num_orig_tokens, embedding_size)
                grads = span_grads_sum / torch.clamp_min(span_grads_len, 1)

                # All the places where the span length is zero, write in zeros.
                grads[(span_grads_len == 0).expand(grads.shape)] = 0

            embedding_gradients.append(grads)

        def get_token_offsets(module, inputs, outputs):
            offsets = util.get_token_offsets_from_text_field_inputs(inputs)
            if offsets is not None:
                self._token_offsets.append(offsets)

        hooks = []
        text_field_embedder = util.find_text_field_embedder(self._model)
        hooks.append(text_field_embedder.register_forward_hook(get_token_offsets))
        embedding_layer = util.find_embedding_layer(self._model)
        hooks.append(embedding_layer.register_backward_hook(hook_layers))
        return hooks
Exemplo n.º 14
0
    def __init__(self, vocab: Vocabulary, base_model: Model,
                 bias_mitigator: Lazy[BiasMitigatorWrapper], **kwargs):
        super().__init__(vocab, **kwargs)

        self.base_model = base_model
        # want to keep bias mitigation hook during test time
        embedding_layer = find_embedding_layer(self.base_model)

        self.bias_mitigator = bias_mitigator.construct(
            embedding_layer=embedding_layer)
        embedding_layer.register_forward_hook(self.bias_mitigator)

        self.vocab = self.base_model.vocab
        self._regularizer = self.base_model._regularizer
Exemplo n.º 15
0
    def __init__(
        self,
        classifier: Model,  # TransactionsClassifier
        reader: TransactionsDatasetReader,
        num_steps: int = 10,
        epsilon: float = 0.01,
        device: int = -1,
    ) -> None:
        super().__init__(classifier=classifier, reader=reader, device=device)
        self.classifier = self.classifier.train()
        self.num_steps = num_steps
        self.epsilon = epsilon

        self.emb_layer = util.find_embedding_layer(self.classifier).weight
Exemplo n.º 16
0
 def get_interpretable_layer(self) -> torch.nn.Module:
     """
     Returns the input/embedding layer of the model.
     If the predictor wraps around a non-AllenNLP model,
     this function should be overridden to specify the correct input/embedding layer.
     For the cases where the input layer _is_ an embedding layer, this should be the
     layer 0 of the embedder.
     """
     try:
         return util.find_embedding_layer(self._model)
     except RuntimeError:
         raise RuntimeError(
             "If the model does not use `TextFieldEmbedder`, please override "
             "`get_interpretable_layer` in your predictor to specify the embedding layer."
         )
Exemplo n.º 17
0
    def _register_forward_hook(self, stdev: float):
        """
        Register a forward hook on the embedding layer which adds random noise to every embedding.
        Used for one term in the SmoothGrad sum.
        """
        def forward_hook(module, inputs, output):  # pylint: disable=unused-argument
            # Random noise = N(0, stdev * (max-min))
            scale = output.detach().max() - output.detach().min()
            noise = torch.randn(output.shape).to(output.device) * stdev * scale

            # Add the random noise
            output.add_(noise)

        # Register the hook
        embedding_layer = util.find_embedding_layer(self.predictor._model)
        handle = embedding_layer.register_forward_hook(forward_hook)
        return handle
Exemplo n.º 18
0
    def _register_embedding_gradient_hooks(self, embedding_gradients):
        """
        Registers a backward hook on the
        :class:`~allennlp.modules.text_field_embedder.basic_text_field_embbedder.BasicTextFieldEmbedder`
        class. Used to save the gradients of the embeddings for use in get_gradients()

        When there are multiple inputs (e.g., a passage and question), the hook
        will be called multiple times. We append all the embeddings gradients
        to a list.
        """

        def hook_layers(module, grad_in, grad_out):
            embedding_gradients.append(grad_out[0])

        backward_hooks = []
        embedding_layer = util.find_embedding_layer(self._model)
        backward_hooks.append(embedding_layer.register_backward_hook(hook_layers))
        return backward_hooks
Exemplo n.º 19
0
    def _register_forward_hook(self, alpha: int, embeddings_list: List):
        """
        Register a forward hook on the embedding layer which scales the embeddings by alpha. Used
        for one term in the Integrated Gradients sum.

        We store the embedding output into the embeddings_list when alpha is zero.  This is used
        later to element-wise multiply the input by the averaged gradients.
        """
        def forward_hook(module, inputs, output):  # pylint: disable=unused-argument
            # Save the input for later use. Only do so on first call.
            if alpha == 0:
                embeddings_list.append(output.squeeze(0).clone().detach().numpy())

            # Scale the embedding by alpha
            output.mul_(alpha)

        # Register the hook
        embedding_layer = util.find_embedding_layer(self.predictor._model)
        handle = embedding_layer.register_forward_hook(forward_hook)
        return handle
Exemplo n.º 20
0
    def __init__(
        self,
        classifier: Model,  # TransactionsClassifier
        reader: TransactionsDatasetReader,
        num_steps: int = 10,
        epsilon: float = 0.01,
        position: Position = Position.END,
        num_tokens_to_add: int = 2,
        total_amount: float = 5000,
        device: int = -1,
    ) -> None:
        super().__init__(classifier=classifier, reader=reader, device=device)
        self.classifier = self.classifier.train()
        self.num_steps = num_steps
        self.epsilon = epsilon

        self.emb_layer = util.find_embedding_layer(self.classifier).weight

        self.position = position
        self.num_tokens_to_add = num_tokens_to_add
        self.total_amount = total_amount
Exemplo n.º 21
0
    def __init__(
        self,
        classifier: Model,  # TransactionsClassifier
        lm: Model,
        lm_threshold: float,
        reader: TransactionsDatasetReader,
        num_steps: int = 10,
        epsilon: float = 0.01,
        device: int = -1,
    ) -> None:
        super().__init__(classifier=classifier, reader=reader, device=device)
        self.classifier = self.classifier.train()
        self.lm = lm
        if self.device >= 0 and torch.cuda.is_available():
            self.lm.cuda(self.device)
        self.lm_threshold = lm_threshold
        self.num_steps = num_steps
        self.epsilon = epsilon

        if self.device >= 0 and torch.cuda.is_available():
            self.lm.cuda(self.device)
            
        self.emb_layer = util.find_embedding_layer(self.classifier).weight
Exemplo n.º 22
0
    def _register_hooks(self, embeddings_list: List, token_offsets: List):
        """
        Finds all of the TextFieldEmbedders, and registers a forward hook onto them. When forward()
        is called, embeddings_list is filled with the embedding values. This is necessary because
        our normalization scheme multiplies the gradient by the embedding value.
        """
        def forward_hook(module, inputs, output):
            embeddings_list.append(output.squeeze(0).clone().detach())

        def get_token_offsets(module, inputs, outputs):
            offsets = util.get_token_offsets_from_text_field_inputs(inputs)
            if offsets is not None:
                token_offsets.append(offsets)

        # Register the hooks
        handles = []
        embedding_layer = util.find_embedding_layer(self.predictor._model)
        handles.append(embedding_layer.register_forward_hook(forward_hook))
        text_field_embedder = util.find_text_field_embedder(
            self.predictor._model)
        handles.append(
            text_field_embedder.register_forward_hook(get_token_offsets))
        return handles
Exemplo n.º 23
0
    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics.
        """
        logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
        # peak_cpu_usage = common_util.peak_memory_mb()
        # logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
        gpu_usage = []
        for gpu, memory in common_util.peak_gpu_memory().items():
            gpu_usage.append((gpu, memory))
            logger.info(f"GPU {gpu} memory usage MB: {memory}")

        train_loss = 0.0
        # Set the model to "train" mode.
        self._pytorch_model.train()

        if isinstance(
                self.adv_policy,
            (adv_utils.HotFlipPolicy, adv_utils.RandomNeighbourPolicy)):
            hooks = adv_utils.register_embedding_hook(self.model)
            embedding_matrix = util.find_embedding_layer(self.model).weight

        # Get tqdm for the training batches
        batch_generator = iter(self.data_loader)

        logger.info("Training")

        num_training_batches = len(self.data_loader)

        # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the master's
        # progress is shown
        if self._master:
            batch_generator_tqdm = Tqdm.tqdm(batch_generator,
                                             total=num_training_batches)
        else:
            batch_generator_tqdm = batch_generator

        self._last_log = time.time()
        last_save_time = time.time()

        batches_this_epoch = 0
        if self._batch_num_total is None:
            self._batch_num_total = 0

        histogram_parameters = set(
            self.model.get_parameters_for_histogram_tensorboard_logging())

        cumulative_batch_group_size = 0
        done_early = False
        for batch in batch_generator_tqdm:
            if self._distributed:
                # Check whether the other workers have stopped already (due to differing amounts of
                # data in each). If so, we can't proceed because we would hang when we hit the
                # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
                # here because NCCL process groups apparently don't support BoolTensor.
                done = torch.tensor(0, device=self.cuda_device)
                torch.distributed.all_reduce(done,
                                             torch.distributed.ReduceOp.SUM)
                if done.item() > 0:
                    done_early = True
                    logger.warning(
                        f"Worker {torch.distributed.get_rank()} finishing training early! "
                        "This implies that there is an imbalance in your training "
                        "data across the workers and that some amount of it will be "
                        "ignored. A small amount of this is fine, but a major imbalance "
                        "should be avoided. Note: This warning will appear unless your "
                        "data is perfectly balanced.")
                    break

            batches_this_epoch += 1
            self._batch_num_total += 1
            batch_num_total = self._batch_num_total

            self.optimizer.zero_grad()

            # normal samples
            loss = self.batch_loss(batch, for_training=True)
            if torch.isnan(loss):
                raise ValueError("nan loss encountered")
            if self._opt_level is not None:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            train_loss += loss.item()

            # adversarial samples
            if self.adv_policy.adv_iteration > 0:
                adv_utils.set_adv_mode(True)
                adv_fields = batch[self.adv_policy.adv_field]
                token_key = adv_utils.guess_token_key_from_field(adv_fields)
                raw_tokens = adv_fields['tokens'][token_key].cuda()
                adv_tokens = raw_tokens.clone()
                for adv_idx in range(self.adv_policy.adv_iteration):
                    if isinstance(self.adv_policy, adv_utils.HotFlipPolicy):
                        fw, bw = adv_utils.read_embedding_hook(
                            self.adv_policy.forward_order)
                        adv_tokens = adv_utils.hotflip(
                            raw_tokens=raw_tokens,
                            adv_tokens=adv_tokens,
                            embeds=fw,
                            grads=bw,
                            embedding_matrix=embedding_matrix,
                            searcher=self.adv_policy.searcher,
                            replace_num=self.adv_policy.replace_num,
                        )
                    elif isinstance(self.adv_policy,
                                    adv_utils.RandomNeighbourPolicy):
                        adv_tokens = adv_utils.random_swap(
                            raw_tokens=raw_tokens,
                            adv_tokens=adv_tokens,
                            searcher=self.adv_policy.searcher,
                            replace_num=self.adv_policy.replace_num,
                        )
                    elif isinstance(self.adv_policy,
                                    adv_utils.DoItYourselfPolicy):
                        adv_utils.send("step", self.adv_policy.step)
                    else:
                        raise Exception
                    adv_fields['tokens'][token_key] = adv_tokens
                    loss = self.batch_loss(
                        batch,
                        for_training=True) / self.adv_policy.adv_iteration
                    if torch.isnan(loss):
                        raise ValueError("nan loss encountered")
                    if self._opt_level is not None:
                        with amp.scale_loss(loss,
                                            self.optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    train_loss += loss.item()
                adv_utils.set_adv_mode(False)
            batch_grad_norm = self.rescale_gradients()

            torch.cuda.empty_cache()

            # This does nothing if batch_num_total is None or you are using a
            # scheduler which doesn't update per batch.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step_batch(batch_num_total)
            if self._momentum_scheduler:
                self._momentum_scheduler.step_batch(batch_num_total)

            if self._tensorboard.should_log_histograms_this_batch(
            ) and self._master:
                # get the magnitude of parameter updates for logging
                # We need a copy of current parameters to compute magnitude of updates,
                # and copy them to CPU so large models won't go OOM on the GPU.
                param_updates = {
                    name: param.detach().cpu().clone()
                    for name, param in self.model.named_parameters()
                }
                self.optimizer.step()
                for name, param in self.model.named_parameters():
                    param_updates[name].sub_(param.detach().cpu())
                    update_norm = torch.norm(param_updates[name].view(-1))
                    param_norm = torch.norm(param.view(-1)).cpu()
                    self._tensorboard.add_train_scalar(
                        "gradient_update/" + name,
                        update_norm / (param_norm + 1e-7))
            else:
                self.optimizer.step()

            # Update moving averages
            if self._moving_average is not None:
                self._moving_average.apply(batch_num_total)

            # Update the description with the latest metrics
            metrics = training_util.get_metrics(
                self.model,
                train_loss,
                batches_this_epoch,
                world_size=self._world_size,
                cuda_device=[self.cuda_device],
            )

            # Updating tqdm only for the master as the trainers wouldn't have one
            if self._master:
                description = training_util.description_from_metrics(metrics)
                batch_generator_tqdm.set_description(description,
                                                     refresh=False)

            if self._master:
                for callback in self._batch_callbacks:
                    callback(
                        self,
                        epoch,
                        batches_this_epoch,
                        is_training=True,
                    )

            # Log parameter values to Tensorboard (only from the master)
            if self._tensorboard.should_log_this_batch() and self._master:
                self._tensorboard.log_parameter_and_gradient_statistics(
                    self.model, batch_grad_norm)
                self._tensorboard.log_learning_rates(self.model,
                                                     self.optimizer)

                self._tensorboard.add_train_scalar("loss/loss_train",
                                                   metrics["loss"])
                self._tensorboard.log_metrics(
                    {"epoch_metrics/" + k: v
                     for k, v in metrics.items()})

            if self._tensorboard.should_log_histograms_this_batch(
            ) and self._master:
                self._tensorboard.log_histograms(self.model,
                                                 histogram_parameters)

            if self._log_batch_size_period:
                batch_group_size = sum(
                    training_util.get_batch_size(batch)
                    for batch in batch_group)
                cumulative_batch_group_size += batch_group_size
                if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
                    average = cumulative_batch_group_size / batches_this_epoch
                    logger.info(
                        f"current batch size: {batch_group_size} mean batch size: {average}"
                    )
                    self._tensorboard.add_train_scalar("current_batch_size",
                                                       batch_group_size)
                    self._tensorboard.add_train_scalar("mean_batch_size",
                                                       average)

            # Save model if needed.
            if (self._model_save_interval is not None and
                (time.time() - last_save_time > self._model_save_interval)
                    and self._master):
                last_save_time = time.time()
                self._save_checkpoint("{0}.{1}".format(
                    epoch, training_util.time_to_str(int(last_save_time))))
        if self._distributed and not done_early:
            logger.warning(
                f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)."
            )
            # Indicate that we're done so that any workers that have remaining data stop the epoch early.
            done = torch.tensor(1, device=self.cuda_device)
            torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
            assert done.item()

        # Let all workers finish their epoch before computing
        # the final statistics for the epoch.
        if self._distributed:
            dist.barrier()

        if isinstance(
                self.adv_policy,
            (adv_utils.HotFlipPolicy, adv_utils.RandomNeighbourPolicy)):
            for hook in hooks:
                hook.remove()

        metrics = training_util.get_metrics(
            self.model,
            train_loss,
            batches_this_epoch,
            reset=True,
            world_size=self._world_size,
            cuda_device=[self.cuda_device],
        )
        # metrics["cpu_memory_MB"] = peak_cpu_usage
        for (gpu_num, memory) in gpu_usage:
            metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory
        return metrics