def forward(self, task_index: torch.IntTensor, tokens: Dict[str, torch.LongTensor], epoch_trained: torch.IntTensor, valid_discriminator: Discriminator, reverse: torch.ByteTensor, for_training: torch.ByteTensor) -> Dict[str, torch.Tensor]: embedded_text_input = self._text_field_embedder(tokens) tokens_mask = util.get_text_field_mask(tokens) batch_size = get_batch_size(tokens) # TODO if np.random.rand() < -1 and for_training.all(): logger.info("Domain Embedding with Perturbation") domain_embeddings = self._domain_embeddings( torch.arange(0, len(TASKS_NAME)).cuda()) domain_embedding = get_perturbation_domain_embedding( domain_embeddings, task_index, epoch_trained) # domain_embedding = FGSM(self._domain_embeddings, task_index, valid_discriminator) output_dict = {"valid": torch.tensor(0)} else: logger.info("Domain Embedding without Perturbation") domain_embedding = self._domain_embeddings(task_index) output_dict = {"valid": torch.tensor(1)} output_dict["domain_embedding"] = domain_embedding embedded_text_input = self._input_dropout(embedded_text_input) if self._with_domain_embedding: domain_embedding = domain_embedding.expand(batch_size, 1, -1) embedded_text_input = torch.cat( (domain_embedding, embedded_text_input), 1) tokens_mask = torch.cat( [tokens_mask.new_ones(batch_size, 1), tokens_mask], 1) shared_encoded_text = self._shared_encoder(embedded_text_input, tokens_mask) # shared_encoded_text = self._seq2vec(shared_encoded_text, tokens_mask) shared_encoded_text = get_final_encoder_states(shared_encoded_text, tokens_mask, bidirectional=True) output_dict["share_embedding"] = shared_encoded_text private_encoded_text = self._private_encoder(embedded_text_input, tokens_mask) # private_encoded_text = self._seq2vec(private_encoded_text) private_encoded_text = get_final_encoder_states(private_encoded_text, tokens_mask, bidirectional=True) output_dict["private_embedding"] = private_encoded_text embedded_text = torch.cat([shared_encoded_text, private_encoded_text], -1) output_dict["embedded_text"] = embedded_text return output_dict
def reset(self, reset: torch.ByteTensor) -> None: """ Parameters ---------- reset : ``torch.ByteTensor`` A tensor of shape ``(batch_size,)`` indicating whether the state (e.g. list of previously seen entities) for the corresponding batch element should be reset. """ if (len(reset) != len(self._remaining)) and not reset.all(): raise RuntimeError('Changing the batch size without resetting all internal states is ' 'undefined.') # If everything is being reset, then we treat as if the Module has just been initialized. # This simplifies the case where the batch_size has been if reset.all(): batch_size = reset.shape[0] self._remaining = [dict() for _ in range(batch_size)] # Otherwise only reset the internal state for the indicated batch elements else: for i, should_reset in enumerate(reset): if should_reset: self._remaining[i] = {}