예제 #1
0
    def forward(self, data_batch):
        input_var = data_batch['input']
        input_mask = data_batch['input_mask']
        target_var = data_batch['target']
        target_mask = data_batch['target_mask']
        target_max_len = data_batch['target_max_len']

        batch_size = input_var.shape[1]

        encoder_input = self.embedding(input_var.t())
        encoder_hidden, encoder_cell = self.encoder(encoder_input, input_mask)

        if self.training:
            decoder_max_len = target_max_len
        else:
            decoder_max_len = args.max_seq_len

        _, decoder_logits, _ = self.decoder(encoder_hidden,
                                            encoder_cell,
                                            max_len=decoder_max_len)

        loss_max_len = min(decoder_logits.shape[0], target_var.shape[0])
        loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\
            = seq_cross_entropy_loss(decoder_logits, target_var, target_mask, loss_max_len)

        return loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc
예제 #2
0
    def forward(self, data_batch, msg_tau=1.0):
        input_var = data_batch['input']
        input_mask = data_batch['input_mask']
        target_var = data_batch['target']
        target_mask = data_batch['target_mask']
        target_max_len = data_batch['target_max_len']

        message, msg_logits, msg_mask = self.speaker(input_var,
                                                     input_mask,
                                                     tau=msg_tau)

        spk_entropy = (F.softmax(msg_logits, dim=2) *
                       msg_logits).sum(dim=2).sum(dim=0)
        if self.training:
            log_msg_prob = torch.sum(msg_logits * message, dim=2).sum(dim=0)
        else:
            log_msg_prob = 0.

        seq_logits = self.listener(message, msg_mask, target_max_len)
        if self.training:
            target_one_hot = F.one_hot(
                target_var, num_classes=self.voc_size).to(seq_logits.dtype)
            log_seq_prob = torch.sum(target_one_hot * seq_logits,
                                     dim=2).sum(dim=0)
        else:
            log_seq_prob = 0.

        loss_max_len = min(seq_logits.shape[0], target_var.shape[0])

        loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\
            = seq_cross_entropy_loss(seq_logits, target_var, target_mask, loss_max_len)

        if self.training and self.msg_mode == 'SCST':
            self.speaker.eval()
            self.listener.eval()
            msg, _, msg_mask = self.speaker(input_var, input_mask)
            s_logits = self.listener(msg, msg_mask)
            loss_max_len = min(s_logits.shape[0], target_var.shape[0])
            baseline = seq_cross_entropy_loss(s_logits, target_var,
                                              target_mask, loss_max_len)[3]
            self.speaker.train()
            self.listener.train()
        else:
            baseline = 0.

        return loss, log_msg_prob, log_seq_prob, baseline, print_losses, \
                seq_correct, tok_acc, seq_acc, seq_logits, spk_entropy
예제 #3
0
def _speaker_learn_(model, data_batch, target, tgt_mask):
    input_var = data_batch['correct']['imgs']

    message, msg_logits, _ = model.speaker(input_var)

    loss_max_len = min(message.shape[0], target.shape[0])
    loss, _, _, _, tok_acc, seq_acc\
        = seq_cross_entropy_loss(msg_logits, target, tgt_mask, loss_max_len)

    return loss, tok_acc, seq_acc
예제 #4
0
    def forward(self, data_batch, msg_tau=1.0):
        correct_data = data_batch['correct']
        input_var = correct_data['imgs']
        target_var = correct_data['message']
        target_mask = correct_data['msg_mask']

        message, msg_logits, _ = self.speaker(input_var, tau=msg_tau)

        log_msg_prob = torch.sum(msg_logits, dim=1)

        loss_max_len = min(message.shape[0], target_var.shape[0])
        loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\
            = seq_cross_entropy_loss(msg_logits, target_var, target_mask, loss_max_len)

        return loss, print_losses, seq_correct, tok_acc, seq_acc, message, log_msg_prob
예제 #5
0
    def forward(self, data_batch):
        input_var = data_batch['input']
        msg_mask = data_batch['input_mask']
        target_var = data_batch['target']
        target_mask = data_batch['target_mask']
        target_max_len = data_batch['target_max_len']

        msg = F.one_hot(input_var, num_classes=args.msg_vocsize).to(torch.float32)
        msg_mask = msg_mask.to(torch.float32).unsqueeze(1)

        listener_outputs = self.listener(msg, msg_mask, target_max_len)

        loss_max_len = min(listener_outputs.shape[0], target_var.shape[0])
        loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\
            = seq_cross_entropy_loss(listener_outputs, target_var, target_mask, loss_max_len)
        
        return loss, 0., print_losses, tok_correct, seq_correct, tok_acc, seq_acc
예제 #6
0
    def forward(self, data_batch):
        input_var = data_batch['input']
        input_mask = data_batch['input_mask']
        target_var = data_batch['target']
        target_mask = data_batch['target_mask']
        target_max_len = data_batch['target_max_len']

        batch_size = input_var.shape[1]

        message, msg_logits, msg_mask = self.speaker(input_var, input_mask)

        log_msg_prob = torch.sum(msg_logits, dim=1)

        loss_max_len = min(message.shape[0], target_var.shape[0])
        loss, print_losses, tok_correct, seq_correct, tok_acc, seq_acc\
            = seq_cross_entropy_loss(msg_logits, target_var, target_mask, loss_max_len)

        return loss, print_losses, seq_correct, tok_acc, seq_acc, message, log_msg_prob