def forward(self, data_batch): input_var = data_batch['input'] input_mask = data_batch['input_mask'] target_var = data_batch['target'] target_mask = data_batch['target_mask'] target_max_len = data_batch['target_max_len'] batch_size = input_var.shape[1] encoder_input = self.embedding(input_var.t()) encoder_hidden, encoder_cell = self.encoder(encoder_input, input_mask) if self.training: decoder_max_len = target_max_len else: decoder_max_len = args.max_seq_len _, decoder_logits, _ = self.decoder(encoder_hidden, encoder_cell, max_len=decoder_max_len) loss_max_len = min(decoder_logits.shape[0], target_var.shape[0]) loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\ = seq_cross_entropy_loss(decoder_logits, target_var, target_mask, loss_max_len) return loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc
def forward(self, data_batch, msg_tau=1.0): input_var = data_batch['input'] input_mask = data_batch['input_mask'] target_var = data_batch['target'] target_mask = data_batch['target_mask'] target_max_len = data_batch['target_max_len'] message, msg_logits, msg_mask = self.speaker(input_var, input_mask, tau=msg_tau) spk_entropy = (F.softmax(msg_logits, dim=2) * msg_logits).sum(dim=2).sum(dim=0) if self.training: log_msg_prob = torch.sum(msg_logits * message, dim=2).sum(dim=0) else: log_msg_prob = 0. seq_logits = self.listener(message, msg_mask, target_max_len) if self.training: target_one_hot = F.one_hot( target_var, num_classes=self.voc_size).to(seq_logits.dtype) log_seq_prob = torch.sum(target_one_hot * seq_logits, dim=2).sum(dim=0) else: log_seq_prob = 0. loss_max_len = min(seq_logits.shape[0], target_var.shape[0]) loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\ = seq_cross_entropy_loss(seq_logits, target_var, target_mask, loss_max_len) if self.training and self.msg_mode == 'SCST': self.speaker.eval() self.listener.eval() msg, _, msg_mask = self.speaker(input_var, input_mask) s_logits = self.listener(msg, msg_mask) loss_max_len = min(s_logits.shape[0], target_var.shape[0]) baseline = seq_cross_entropy_loss(s_logits, target_var, target_mask, loss_max_len)[3] self.speaker.train() self.listener.train() else: baseline = 0. return loss, log_msg_prob, log_seq_prob, baseline, print_losses, \ seq_correct, tok_acc, seq_acc, seq_logits, spk_entropy
def _speaker_learn_(model, data_batch, target, tgt_mask): input_var = data_batch['correct']['imgs'] message, msg_logits, _ = model.speaker(input_var) loss_max_len = min(message.shape[0], target.shape[0]) loss, _, _, _, tok_acc, seq_acc\ = seq_cross_entropy_loss(msg_logits, target, tgt_mask, loss_max_len) return loss, tok_acc, seq_acc
def forward(self, data_batch, msg_tau=1.0): correct_data = data_batch['correct'] input_var = correct_data['imgs'] target_var = correct_data['message'] target_mask = correct_data['msg_mask'] message, msg_logits, _ = self.speaker(input_var, tau=msg_tau) log_msg_prob = torch.sum(msg_logits, dim=1) loss_max_len = min(message.shape[0], target_var.shape[0]) loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\ = seq_cross_entropy_loss(msg_logits, target_var, target_mask, loss_max_len) return loss, print_losses, seq_correct, tok_acc, seq_acc, message, log_msg_prob
def forward(self, data_batch): input_var = data_batch['input'] msg_mask = data_batch['input_mask'] target_var = data_batch['target'] target_mask = data_batch['target_mask'] target_max_len = data_batch['target_max_len'] msg = F.one_hot(input_var, num_classes=args.msg_vocsize).to(torch.float32) msg_mask = msg_mask.to(torch.float32).unsqueeze(1) listener_outputs = self.listener(msg, msg_mask, target_max_len) loss_max_len = min(listener_outputs.shape[0], target_var.shape[0]) loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\ = seq_cross_entropy_loss(listener_outputs, target_var, target_mask, loss_max_len) return loss, 0., print_losses, tok_correct, seq_correct, tok_acc, seq_acc
def forward(self, data_batch): input_var = data_batch['input'] input_mask = data_batch['input_mask'] target_var = data_batch['target'] target_mask = data_batch['target_mask'] target_max_len = data_batch['target_max_len'] batch_size = input_var.shape[1] message, msg_logits, msg_mask = self.speaker(input_var, input_mask) log_msg_prob = torch.sum(msg_logits, dim=1) loss_max_len = min(message.shape[0], target_var.shape[0]) loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\ = seq_cross_entropy_loss(msg_logits, target_var, target_mask, loss_max_len) return loss, print_losses, seq_correct, tok_acc, seq_acc, message, log_msg_prob