def forward(self, # type: ignore task_index: torch.IntTensor, reverse: torch.ByteTensor, epoch_trained: torch.IntTensor, for_training: torch.ByteTensor, tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None, text_id: torch.IntTensor = None) -> Dict[str, torch.Tensor]: embeddeds = self._encoder(task_index, tokens, epoch_trained, self._valid_discriminator, reverse, for_training, text_id) batch_size = get_batch_size(embeddeds["embedded_text"]) sentiment_logits = self._sentiment_discriminator(embeddeds["embedded_text"]) p_domain_logits = self._p_domain_discriminator(embeddeds["private_embedding"]) # TODO set reverse = true s_domain_logits = self._s_domain_discriminator(embeddeds["share_embedding"], reverse=reverse) logits = [sentiment_logits, p_domain_logits, s_domain_logits] # domain_logits = self._domain_discriminator(embedded_text) output_dict = {'logits': sentiment_logits} if label is not None: loss = self._loss(sentiment_logits, label) # task_index = task_index.unsqueeze(0) task_index = task_index.expand(batch_size) # targets = [label, label, label, task_index, task_index] # print(p_domain_logits.shape, task_index, task_index.shape) p_domain_loss = self._domain_loss(p_domain_logits, task_index) s_domain_loss = self._domain_loss(s_domain_logits, task_index) # logger.info("Share domain logits standard variation is {}", # torch.mean(torch.std(F.softmax(s_domain_logits), dim=-1))) output_dict["tokens"] = tokens output_dict['stm_loss'] = loss output_dict['p_d_loss'] = p_domain_loss output_dict['s_d_loss'] = s_domain_loss # TODO add share domain logits std loss output_dict['loss'] = loss + 0.06 * p_domain_loss + 0.04 * s_domain_loss for (metric_name, metric) in zip(self.metrics.keys(), self.metrics.values()): if "auc" in metric_name: metric(self.decode(output_dict)["label"], label) continue metric(sentiment_logits, label) print("for training", for_training) if not for_training: with open("class_probabilities.txt", "a", encoding="utf8") as f: f.write(f"Task: {TASKS_NAME[task_index[0].detach()]}\nLine ID: ") f.write(" ".join(list(map(str, text_id.cpu().detach().numpy())))) f.write("\nProb: ") f.write(" ".join(list(map(str, F.softmax(sentiment_logits, dim=-1).cpu().detach().numpy())))) f.write("\nLabel: " + " ".join(list(map(str, label.cpu().detach().numpy()))) + "\n") f.write("\n\n\n") return output_dict
def forward( self, tokens_list: Dict[str, torch.LongTensor], positions_list: Dict[str, torch.LongTensor], sent_positions_list: Dict[str, torch.LongTensor], before_loc_start: torch.IntTensor = None, before_loc_end: torch.IntTensor = None, after_loc_start_list: torch.IntTensor = None, after_loc_end_list: torch.IntTensor = None, before_category: torch.IntTensor = None, after_category_list: torch.IntTensor = None, before_category_mask: torch.IntTensor = None, after_category_mask_list: torch.IntTensor = None ) -> Dict[str, torch.Tensor]: """ :param tokens_list: Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. :param positions_list: same as tokens_list :param sent_positions_list: same as tokens_list :param before_loc_start: torch.IntTensor = None, required An integer ``IndexField`` representation of the before location start :param before_loc_end: torch.IntTensor = None, required An integer ``IndexField`` representation of the before location end :param after_loc_start_list: torch.IntTensor = None, required A list of integers ``ListField (IndexField)`` representation of the list of after location starts along the sequence of steps :param after_loc_end_list: torch.IntTensor = None, required A list of integers ``ListField (IndexField)`` representation of the list of after location ends along the sequence of steps :param before_category: torch.IntTensor = None, required An integer ``IndexField`` representation of the before location category :param after_category_list: torch.IntTensor = None, required A list of integers ``ListField (IndexField)`` representation of the list of after location categories along the sequence of steps :param before_category_mask: torch.IntTensor = None, required An integer ``IndexField`` representation of whether the before location is known or not (0/1) :param after_category_mask_list: torch.IntTensor = None, required A list of integers ``ListField (IndexField)`` representation of the list of whether after location is known or not for each step along the sequence of steps :return: An output dictionary consisting of: best_span: torch.FloatTensor A tensor of shape ``()`` true_span: torch.FloatTensor loss: torch.FloatTensor """ # batchsize * listLength * paragraphSize * embeddingSize input_embedding_paragraph = self._text_field_embedder(tokens_list) input_pos_embedding_paragraph = self._pos_field_embedder( positions_list) input_sent_pos_embedding_paragraph = self._sent_pos_field_embedder( sent_positions_list) # batchsize * listLength * paragraphSize * (embeddingSize*2) embedding_paragraph = torch.cat([ input_embedding_paragraph, input_pos_embedding_paragraph, input_sent_pos_embedding_paragraph ], dim=-1) # batchsize * listLength * paragraphSize, this mask is shared with the text fields and sequence label fields para_mask = util.get_text_field_mask(tokens_list, num_wrapping_dims=1).float() # batchsize * listLength , this mask is shared with the index fields para_index_mask, para_index_mask_indices = torch.max(para_mask, 2) # apply mask to update the index values, padded instances will be 0 after_loc_start_list = (after_loc_start_list.float() * para_index_mask.unsqueeze(2)).long() after_loc_end_list = (after_loc_end_list.float() * para_index_mask.unsqueeze(2)).long() after_category_list = (after_category_list.float() * para_index_mask.unsqueeze(2)).long() after_category_mask_list = (after_category_mask_list.float() * para_index_mask.unsqueeze(2)).long() batch_size, list_size, paragraph_size, input_dim = embedding_paragraph.size( ) # to store the values passed to next step tmp_category_probability = torch.zeros(batch_size, 3) tmp_start_probability = torch.zeros(batch_size, paragraph_size) loss = 0 # store the predict logits for the whole lists category_predict_logits_after_list = torch.rand( batch_size, list_size, 3) best_span_after_list = torch.rand(batch_size, list_size, 2) for index in range(list_size): # get one slice of step for prediction embedding_paragraph_slice = embedding_paragraph[:, index, :, :].squeeze( 1) para_mask_slice = para_mask[:, index, :].squeeze(1) para_lstm_mask_slice = para_mask_slice if self._mask_lstms else None para_index_mask_slice = para_index_mask[:, index] after_category_mask_slice = after_category_mask_list[:, index, :].squeeze( ) # bi-LSTM: generate the contextual embeddings for the current step # size: batchsize * paragraph_size * modeling_layer_hidden_size encoded_paragraph = self._dropout( self._modeling_layer(embedding_paragraph_slice, para_lstm_mask_slice)) # max-pooling output for three category classification category_input, category_input_indices = torch.max( encoded_paragraph, 1) modeling_dim = encoded_paragraph.size(-1) span_start_input = encoded_paragraph # predict the initial before location state if index == 0: # three category classification for initial before location category_predict_logits_before = self._category_before_predictor( category_input) tmp_category_probability = category_predict_logits_before '''Model the before_loc prediction''' # predict the initial before location start scores # shape: batchsize * paragraph_size span_start_logits_before = self._span_start_predictor_before( span_start_input).squeeze(-1) # shape: batchsize * paragraph_size span_start_probs_before = util.masked_softmax( span_start_logits_before, para_mask_slice) tmp_start_probability = span_start_probs_before # shape: batchsize * hiddensize span_start_representation_before = util.weighted_sum( encoded_paragraph, span_start_probs_before) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation_before = span_start_representation_before.unsqueeze( 1).expand(batch_size, paragraph_size, modeling_dim) # incorporate the original contextual embeddings and weighted sum vector from location start prediction # shape: batchsize * paragraph_size * 2hiddensize span_end_representation_before = torch.cat( [encoded_paragraph, tiled_start_representation_before], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end_before = self._dropout( self._span_end_encoder_before( span_end_representation_before, para_lstm_mask_slice)) # initial before location end prediction encoded_span_end_before = torch.cat( [encoded_paragraph, encoded_span_end_before], dim=-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_logits_before = self._span_end_predictor_before( encoded_span_end_before).squeeze(-1) span_end_probs_before = util.masked_softmax( span_end_logits_before, para_mask_slice) # best_span_bef = self._get_best_span(span_start_logits_bef, span_end_logits_bef) best_span_before, best_span_before_start, best_span_before_end, best_span_before_real = \ self._get_best_span_single_extend(span_start_logits_before, span_end_logits_before, category_predict_logits_before, before_category_mask) # compute the loss for initial bef location three-category classification before_null_pred = softmax(category_predict_logits_before) before_null_pred_values, before_null_pred_indices = torch.max( before_null_pred, 1) loss += nll_loss(before_null_pred, before_category.squeeze(-1)) # compute the loss for initial bef location start/end prediction before_loc_start_pred = util.masked_softmax( span_start_logits_before, para_mask_slice) logpy_before_start = torch.gather( before_loc_start_pred, 1, before_loc_start).view(-1).float() before_category_mask = before_category_mask.float() loss += -(logpy_before_start * before_category_mask).mean() before_loc_end_pred = util.masked_softmax( span_end_logits_before, para_mask_slice) logpy_before_end = torch.gather(before_loc_end_pred, 1, before_loc_end).view(-1) loss += -(logpy_before_end * before_category_mask).mean() # get the real predicted location spans # convert category output (Null and Unk) into spans ((-2,-2) or (-1, -1)) before_loc_start_real = self._get_real_spans_extend( before_loc_start, before_category, before_category_mask) before_loc_end_real = self._get_real_spans_extend( before_loc_end, before_category, before_category_mask) true_span_before = torch.stack( [before_loc_start_real, before_loc_end_real], dim=-1) true_span_before = true_span_before.squeeze(1) # input for (after location) three category classification category_input_after = torch.cat( (category_input, tmp_category_probability), dim=1) category_predict_logits_after = self._category_after_predictor( category_input_after) tmp_category_probability = category_predict_logits_after # copy the predict logits for the index of the list category_predict_logits_after_tmp = category_predict_logits_after.unsqueeze( 1) category_predict_logits_after_list[:, index, :] = category_predict_logits_after_tmp.data ''' Model the after_loc prediction ''' # after location start prediction: takes contextual embeddings and weighted sum vector as input # shape: batchsize * hiddensize prev_start = util.weighted_sum(category_input, tmp_start_probability) tiled_prev_start = prev_start.unsqueeze(1).expand( batch_size, paragraph_size, modeling_dim) span_start_input_after = torch.cat( (span_start_input, tiled_prev_start), dim=2) encoded_start_input_after = self._dropout( self._span_start_encoder_after(span_start_input_after, para_lstm_mask_slice)) span_start_input_after_cat = torch.cat( [encoded_paragraph, encoded_start_input_after], dim=-1) # predict the after location start span_start_logits_after = self._span_start_predictor_after( span_start_input_after_cat).squeeze(-1) # shape: batchsize * paragraph_size span_start_probs_after = util.masked_softmax( span_start_logits_after, para_mask_slice) tmp_start_probability = span_start_probs_after # after location end prediction: takes contextual embeddings and weight sum vector as input # shape: batchsize * hiddensize span_start_representation_after = util.weighted_sum( encoded_paragraph, span_start_probs_after) # Tensor Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation_after = span_start_representation_after.unsqueeze( 1).expand(batch_size, paragraph_size, modeling_dim) # shape: batchsize * paragraph_size * 2hiddensize span_end_representation_after = torch.cat( [encoded_paragraph, tiled_start_representation_after], dim=-1) # Tensor Shape: (batch_size, passage_length, encoding_dim) encoded_span_end_after = self._dropout( self._span_end_encoder_after(span_end_representation_after, para_lstm_mask_slice)) encoded_span_end_after = torch.cat( [encoded_paragraph, encoded_span_end_after], dim=-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_logits_after = self._span_end_predictor_after( encoded_span_end_after).squeeze(-1) span_end_probs_after = util.masked_softmax(span_end_logits_after, para_mask_slice) # get the best span for after location prediction best_span_after, best_span_after_start, best_span_after_end, best_span_after_real = \ self._get_best_span_single_extend(span_start_logits_after, span_end_logits_after, category_predict_logits_after, after_category_mask_slice) # copy current best span to the list for final evaluation best_span_after_list[:, index, :] = best_span_after.data.view( batch_size, 1, 2) """ Compute the Loss for this slice """ after_category_mask = after_category_mask_slice.float().squeeze( -1) # batchsize after_category_slice = after_category_list[:, index, :] # batchsize * 1 after_loc_start_slice = after_loc_start_list[:, index, :] after_loc_end_slice = after_loc_end_list[:, index, :] # compute the loss for (after location) three category classification para_index_mask_slice_tiled = para_index_mask_slice.unsqueeze( 1).expand(para_index_mask_slice.size(0), 3) after_category_pred = util.masked_softmax( category_predict_logits_after, para_index_mask_slice_tiled) logpy_after_category = torch.gather(after_category_pred, 1, after_category_slice).view(-1) loss += -(logpy_after_category * para_index_mask_slice).mean() # compute the loss for location start/end prediction after_loc_start_pred = util.masked_softmax(span_start_logits_after, para_mask_slice) logpy_after_start = torch.gather(after_loc_start_pred, 1, after_loc_start_slice).view(-1) loss += -(logpy_after_start * after_category_mask).mean() after_loc_end_pred = util.masked_softmax(span_end_logits_after, para_mask_slice) logpy_after_end = torch.gather(after_loc_end_pred, 1, after_loc_end_slice).view(-1) loss += -(logpy_after_end * after_category_mask).mean() # for evaluation (combine the all annotations) after_loc_start_real = self._get_real_spans_extend_list( after_loc_start_list, after_category_list, after_category_mask_list) after_loc_end_real = self._get_real_spans_extend_list( after_loc_end_list, after_category_list, after_category_mask_list) true_span_after = torch.stack( [after_loc_start_real, after_loc_end_real], dim=-1) true_span_after = true_span_after.squeeze(2) best_span_after_list = Variable(best_span_after_list) true_span_after = true_span_after.view( true_span_after.size(0) * true_span_after.size(1), true_span_after.size(2)).float() para_index_mask_tiled = para_index_mask.view(-1, 1) para_index_mask_tiled = para_index_mask_tiled.expand( para_index_mask_tiled.size(0), 2) para_index_mask_tiled2 = para_index_mask.unsqueeze(2).expand( para_index_mask.size(0), para_index_mask.size(1), 2) after_category_mask_list_tiled = after_category_mask_list.expand( batch_size, list_size, 2) after_category_mask_list_tiled = after_category_mask_list_tiled * para_index_mask_tiled2.long( ) # merge all the best spans predicted for the current batch, filter out the padded instances merged_sys_span, merged_gold_span = self._get_merged_spans( true_span_before, best_span_before, true_span_after, best_span_after_list, para_index_mask_tiled) output_dict = {} output_dict["best_span"] = merged_sys_span.view( 1, merged_sys_span.size(0) * merged_sys_span.size(1)) output_dict["true_span"] = merged_gold_span.view( 1, merged_gold_span.size(0) * merged_gold_span.size(1)) output_dict["loss"] = loss return output_dict
def forward(self, # type: ignore task_index: torch.IntTensor, reverse: torch.ByteTensor, for_training: torch.ByteTensor, train_stage: torch.IntTensor, tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: """ :param task_index: :param reverse: :param for_training: :param train_stage: ["share_senti", "share_classify", "share_classify_adversarial", "domain_valid", "domain_valid_adversarial"] :param tokens: :param label: :return: """ embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens).float() embed_tokens = self._encoder(embedded_text, mask) batch_size = get_batch_size(embed_tokens) # bs * (25*4) seq_vec = self._seq_vec(embed_tokens, mask) # TODO add linear layer domain_embeddings = self._domain_embeddings(torch.arange(self._de_dim).cuda()) de_scores = F.softmax( self._de_attention(seq_vec, domain_embeddings.expand(batch_size, *domain_embeddings.size())), dim=1) de_valid = False if np.random.rand() < 0.3: de_valid = True noise = 0.01 * torch.normal(mean=0.5, # std=torch.std(domain_embeddings).sign_()) std=torch.empty(*de_scores.size()).fill_(1.0)) de_scores = de_scores + noise.cuda() domain_embedding = torch.matmul(de_scores, domain_embeddings) domain_embedding = self._de_feedforward(domain_embedding) # train sentiment classify if train_stage.cpu() == torch.tensor(0) or not for_training: de_representation = torch.tanh(torch.add(domain_embedding, seq_vec)) sentiment_logits = self._sentiment_discriminator(de_representation) if label is not None: loss = self._loss(sentiment_logits, label) self.metrics["{}_stm_acc".format(TASKS_NAME[task_index.cpu()])](sentiment_logits, label) if train_stage.cpu() == torch.tensor(1) or not for_training: s_domain_logits = self._s_domain_discriminator(seq_vec, reverse=reverse) task_index = task_index.expand(batch_size) loss = self._domain_loss(s_domain_logits, task_index) self.metrics["s_domain_acc"](s_domain_logits, task_index) if train_stage.cpu() == torch.tensor(2) or not for_training: valid_logits = self._valid_discriminator(domain_embedding, reverse=reverse) valid_label = torch.ones(batch_size).cuda() if de_valid: valid_label = torch.zeros(batch_size).cuda() if self._label_smoothing is not None and self._label_smoothing > 0.0: loss = sequence_cross_entropy_with_logits(valid_logits, valid_label.unsqueeze(0).cuda(), torch.tensor(1).unsqueeze(0).cuda(), average="token", label_smoothing=self._label_smoothing) else: loss = self._valid_loss(valid_logits, torch.zeros(2).scatter_(0, valid_label, torch.tensor(1.0)).cuda()) self.metrics["valid_acc"](valid_logits, valid_label) # TODO add orthogonal loss output_dict = {"loss": loss} return output_dict
def forward( self, # type: ignore task_index: torch.IntTensor, reverse: torch.ByteTensor, epoch_trained: torch.IntTensor, for_training: torch.ByteTensor, tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: embeddeds = self._encoder(task_index, tokens, epoch_trained, self._valid_discriminator, reverse, for_training) batch_size = get_batch_size(embeddeds["embedded_text"]) sentiment_logits = self._sentiment_discriminator( embeddeds["embedded_text"]) p_domain_logits = self._p_domain_discriminator( embeddeds["private_embedding"]) # TODO set reverse = true s_domain_logits = self._s_domain_discriminator( embeddeds["share_embedding"], reverse=reverse) # TODO set reverse = true # TODO use share_embedding instead of domain_embedding valid_logits = self._valid_discriminator(embeddeds["domain_embedding"], reverse=reverse) valid_label = embeddeds['valid'] logits = [ sentiment_logits, p_domain_logits, s_domain_logits, valid_logits ] # domain_logits = self._domain_discriminator(embedded_text) output_dict = {'logits': sentiment_logits} if label is not None: loss = self._loss(sentiment_logits, label) # task_index = task_index.unsqueeze(0) task_index = task_index.expand(batch_size) targets = [label, task_index, task_index, valid_label] # print(p_domain_logits.shape, task_index, task_index.shape) p_domain_loss = self._domain_loss(p_domain_logits, task_index) s_domain_loss = self._domain_loss(s_domain_logits, task_index) logger.info( "Share domain logits standard variation is {}", torch.mean(torch.std(F.softmax(s_domain_logits), dim=-1))) if self._label_smoothing is not None and self._label_smoothing > 0.0: valid_loss = sequence_cross_entropy_with_logits( valid_logits, valid_label.unsqueeze(0).cuda(), torch.tensor(1).unsqueeze(0).cuda(), average="token", label_smoothing=self._label_smoothing) else: valid_loss = self._valid_loss( valid_logits, torch.zeros(2).scatter_(0, valid_label, torch.tensor(1.0)).cuda()) output_dict['stm_loss'] = loss output_dict['p_d_loss'] = p_domain_loss output_dict['s_d_loss'] = s_domain_loss output_dict['valid_loss'] = valid_loss # TODO add share domain logits std loss output_dict['loss'] = loss + p_domain_loss + 0.005 * s_domain_loss\ # + 0.005 * valid_loss # + torch.mean(torch.std(s_domain_logits, dim=1)) # output_dict['loss'] = loss + p_domain_loss + 0.005 * s_domain_loss for (metric, logit, target) in zip(self.metrics.values(), logits, targets): metric(logit, target) return output_dict