def loss(self, preds, lengths):
     loss = 0
     for t, t_preds in preds.items():
         mask = length_to_mask(lengths - t)
         out = torch.stack(t_preds, dim=-1)
         out = F.log_softmax(out, dim=-1)
         out = out[..., 0] * mask
         loss += -out.mean()
     return loss
def validate(valid_data, model, epoch, device, logger, summary_writer):
    loading_time_meter = AverageMeter()
    batch_time_meter = AverageMeter()
    ce_loss_meter = AverageMeter()
    accuracy_meter = AverageMeter()
    entropy_meter = AverageMeter()
    n_entropy_meter = AverageMeter()

    model.eval()
    start = time.time()
    with torch.no_grad():
        for batch in valid_data:
            tokens, length = batch.text
            labels = batch.label
            mask = length_to_mask(length)
            loading_time_meter.update(time.time() - start)

            pred_labels, ce_loss, rewards, actions, actions_log_prob, entropy, normalized_entropy = \
                model(tokens, mask, labels)
            entropy = entropy.mean()
            normalized_entropy = normalized_entropy.mean()

            accuracy = (labels == pred_labels).to(dtype=torch.float32).mean()
            n = mask.shape[0]
            accuracy_meter.update(accuracy.item(), n)
            ce_loss_meter.update(ce_loss.item(), n)
            entropy_meter.update(entropy.item(), n)
            n_entropy_meter.update(normalized_entropy.item(), n)
            batch_time_meter.update(time.time() - start)
            start = time.time()

    logger.info(
        f"Valid: epoch: {epoch} ce_loss: {ce_loss_meter.avg:.4f} accuracy: {accuracy_meter.avg:.4f} "
        f"entropy: {entropy_meter.avg:.4f} n_entropy: {n_entropy_meter.avg:.4f} "
        f"loading_time: {loading_time_meter.avg:.4f} batch_time: {batch_time_meter.avg:.4f}"
    )

    summary_writer["valid"].add_scalar(tag="ce",
                                       scalar_value=ce_loss_meter.avg,
                                       global_step=global_step)
    summary_writer["valid"].add_scalar(tag="accuracy",
                                       scalar_value=accuracy_meter.avg,
                                       global_step=global_step)
    summary_writer["valid"].add_scalar(tag="n_entropy",
                                       scalar_value=n_entropy_meter.avg,
                                       global_step=global_step)

    model.train()
    return accuracy_meter.avg
Exemple #3
0
    def forward(self, ngram_ids, ngram_lengths):
        """
        :param ngram_ids:       shape is [batch_size x max_seq_length]
        :param ngram_lengths:   shape is [batch_size]
        """
        # shape is [batch_size x max_seq_length x embedding_size]
        ngrams_embedded = self.ngram_embeddings(ngram_ids)

        # shape is [batch_size x max_seq_length]
        mask = length_to_mask(ngram_lengths,
                              max_len=ngram_ids.shape[1],
                              dtype=torch.float)
        ngrams_embedded = ngrams_embedded * mask.unsqueeze(-1)

        bag_of_ngrams = torch.sum(ngrams_embedded,
                                  dim=1) / ngram_lengths.float().unsqueeze(-1)
        return bag_of_ngrams
Exemple #4
0
 def forward(self, decoder_outputs, encoder_outputs, source_lengths):
     """
     Return attention scores.
     args:
     decoder_outputs: TxBx*
     encoder_outputs: TxBx*
     returns:
     attention scores: Bx1xT
     """
     #projected_encoder_outputs = self.score(encoder_outputs) \
     #                                .permute(1, 2, 0) # batch first
     projected_encoder_outputs = encoder_outputs.permute(1, 2, 0)
     decoder_outputs = decoder_outputs.transpose(0, 1)
     scores = decoder_outputs.bmm(projected_encoder_outputs)
     scores = scores.squeeze(1)
     mask = length_to_mask(source_lengths, source_lengths[0])
     if scores.is_cuda: mask = mask.cuda()
     scores.data.masked_fill_(1 - mask, float('-inf'))
     scores = F.softmax(scores, dim=1)
     return scores.unsqueeze(1)
def test(test_data, model, device, logger):
    loading_time_meter = AverageMeter()
    batch_time_meter = AverageMeter()
    ce_loss_meter = AverageMeter()
    accuracy_meter = AverageMeter()
    entropy_meter = AverageMeter()
    n_entropy_meter = AverageMeter()

    model.eval()
    start = time.time()
    with torch.no_grad():
        for batch in test_data:
            tokens, length = batch.text
            labels = batch.label
            mask = length_to_mask(length)
            loading_time_meter.update(time.time() - start)

            pred_labels, ce_loss, rewards, actions, actions_log_prob, entropy, normalized_entropy = \
                model(tokens, mask, labels)
            entropy = entropy.mean()
            normalized_entropy = normalized_entropy.mean()

            accuracy = (labels == pred_labels).to(dtype=torch.float32).mean()
            n = mask.shape[0]
            accuracy_meter.update(accuracy.item(), n)
            ce_loss_meter.update(ce_loss.item(), n)
            entropy_meter.update(entropy.item(), n)
            n_entropy_meter.update(normalized_entropy.item(), n)
            batch_time_meter.update(time.time() - start)
            start = time.time()

    logger.info(
        f"Test: ce_loss: {ce_loss_meter.avg:.4f} accuracy: {accuracy_meter.avg:.4f} "
        f"entropy: {entropy_meter.avg:.4f} n_entropy: {n_entropy_meter.avg:.4f} "
        f"loading_time: {loading_time_meter.avg:.4f} batch_time: {batch_time_meter.avg:.4f}"
    )
    logger.info("done")

    return accuracy_meter.avg
Exemple #6
0
def run(args,
        source_vocab,
        target_vocab,
        encoder,
        decoder,
        encoder_optim,
        decoder_optim,
        batch,
        mode,
        sample_prob=1):
    batch_id, (source, source_lens, target, target_lens) = batch
    if mode == "train":
        encoder.train()
        decoder.train()
        encoder_optim.zero_grad()
        decoder_optim.zero_grad()
    elif mode == "validate" or mode == "greedy":
        encoder.eval()
        decoder.eval()
    if args.cuda: source, target = source.cuda(), target.cuda()
    source, target = Variable(source), Variable(target)
    batch_size = source.size()[1]
    encoder_outputs, encoder_last_hidden = encoder(source, source_lens, None)
    max_target_len = max(target_lens)
    decoder_hidden = encoder_last_hidden
    target_slice = Variable(
        torch.zeros(batch_size).fill_(target_vocab.SOS).long())
    decoder_outputs = Variable(
        torch.zeros(args.global_max_target_len, batch_size,
                    target_vocab.vocab_size))  # preallocate
    pred_seq = torch.zeros_like(target.data)
    if args.cuda:
        source, target = source.cuda(), target.cuda()
        target_slice = target_slice.cuda()
        decoder_outputs = decoder_outputs.cuda()
        pred_seq = pred_seq.cuda()
    for l in range(max_target_len):
        predictions, decoder_hidden, atten_scores = decoder(
            l, target_slice, encoder_outputs, source_lens, decoder_hidden)
        decoder_outputs[l] = predictions
        pred_words = predictions.data.max(1)[1]

        pred_seq[l] = pred_words
        if mode == "train" or mode == "validate":
            coin = random.random()
            if coin > sample_prob:
                target_slice = Variable(pred_words).long()
            else:
                target_slice = target[l]  # use teacher forcing
        elif mode == "greedy":
            target_slice = Variable(pred_words)  # use teacher forcing
        if args.cuda: target_slice = target_slice.cuda()
        # detach hidden states
        for h in decoder_hidden:
            h.detach_()
    mask = Variable(length_to_mask(target_lens)).transpose(0, 1).float()
    if args.cuda: mask = mask.cuda()

    loss = masked_cross_entropy_loss(decoder_outputs[:max_target_len], target,
                                     mask)
    if mode == "train": loss.backward()

    correct = torch.eq(target.data.float(),
                       pred_seq.float()) * mask.data.byte()
    correct = correct.float().sum()
    total = mask.data.float().sum()
    accuracy = correct / total

    current_lr = encoder_optim.param_groups[0]['lr']
    if mode == "validate" or mode == "greedy":
        if batch_id == 0 and mode == "greedy":
            i = random.randint(0, batch_size - 1)
            print("Given source sequence:\n {}".format(
                source_vocab.to_text(source.data[:source_lens[i], i])))
            print("target sequence is:\n {}".format(
                target_vocab.to_text(target.data[:target_lens[i], i])))
            print("greedily decoded sequence is:\n {}".format(
                target_vocab.to_text(pred_seq[:, i])))
        return correct, total, loss.item()
    elif mode == "train":
        if (batch_id + 1) % args.log_interval == 0:
            i = random.randint(0, batch_size - 1)
            print("Given source sequence:\n {}".format(
                source_vocab.to_text(source.data[:source_lens[i], i])))
            print("target sequence is:\n {}".format(
                target_vocab.to_text(target.data[:target_lens[i], i])))
            print("teacher forcing generated sequence is:\n {}".format(
                target_vocab.to_text(pred_seq[:target_lens[i], i])))

        nn.utils.clip_grad_norm(encoder.parameters(), args.clip_thresh)
        nn.utils.clip_grad_norm(decoder.parameters(), args.clip_thresh)
        encoder_optim.step()
        decoder_optim.step()
        return correct, total, loss.item()
Exemple #7
0
 def _get_mask(length):
     return length_to_mask(length, dtype=torch.float)
def train(train_data, valid_data, model, optimizer, lr_scheduler, es, epoch,
          args, logger, summary_writer):
    loading_time_meter = AverageMeter()
    batch_time_meter = AverageMeter()
    ce_loss_meter = AverageMeter()
    accuracy_meter = AverageMeter()
    entropy_meter = AverageMeter()
    n_entropy_meter = AverageMeter()
    prob_ratio_meter = AverageMeter()

    device = args.gpu_id
    model.train()
    start = time.time()
    for batch_idx, batch in enumerate(train_data):
        tokens, length = batch.text
        labels = batch.label
        mask = length_to_mask(length)
        loading_time_meter.update(time.time() - start)

        pred_labels, ce_loss, rewards, actions, actions_log_prob, entropy, normalized_entropy = \
            model(tokens, mask, labels)

        ce_loss.backward()
        perform_env_optimizer_step(optimizer, model, args)
        for k in range(args.ppo_updates):
            if k == 0:
                new_normalized_entropy, new_actions_log_prob = normalized_entropy, actions_log_prob
            else:
                new_normalized_entropy, new_actions_log_prob = model.evaluate_actions(
                    tokens, mask, actions)
            prob_ratio = (new_actions_log_prob -
                          actions_log_prob.detach()).exp()
            clamped_prob_ratio = prob_ratio.clamp(1.0 - args.epsilon,
                                                  1.0 + args.epsilon)
            ppo_loss = torch.max(prob_ratio * rewards,
                                 clamped_prob_ratio * rewards).mean()
            loss = ppo_loss - args.entropy_weight * new_normalized_entropy.mean(
            )
            loss.backward()
            perform_policy_optimizer_step(optimizer, model, args)

        entropy = entropy.mean()
        normalized_entropy = normalized_entropy.mean()
        n = mask.shape[0]
        accuracy = (labels == pred_labels).to(dtype=torch.float32).mean()
        accuracy_meter.update(accuracy.item(), n)
        ce_loss_meter.update(ce_loss.item(), n)
        entropy_meter.update(entropy.item(), n)
        n_entropy_meter.update(normalized_entropy.item(), n)
        prob_ratio_meter.update(
            (1.0 - prob_ratio.detach()).abs().mean().item(), n)
        batch_time_meter.update(time.time() - start)

        global global_step
        summary_writer["train"].add_scalar(tag="ce",
                                           scalar_value=ce_loss.item(),
                                           global_step=global_step)
        summary_writer["train"].add_scalar(tag="accuracy",
                                           scalar_value=accuracy.item(),
                                           global_step=global_step)
        summary_writer["train"].add_scalar(
            tag="n_entropy",
            scalar_value=normalized_entropy.item(),
            global_step=global_step)
        summary_writer["train"].add_scalar(tag="prob_ratio",
                                           scalar_value=prob_ratio_meter.value,
                                           global_step=global_step)
        global_step += 1

        if (batch_idx + 1) % (len(train_data) // 3) == 0:
            logger.info(
                f"Train: epoch: {epoch} batch_idx: {batch_idx + 1} ce_loss: {ce_loss_meter.avg:.4f} "
                f"accuracy: {accuracy_meter.avg:.4f} entropy: {entropy_meter.avg:.4f} "
                f"n_entropy: {n_entropy_meter.avg:.4f} loading_time: {loading_time_meter.avg:.4f} "
                f"batch_time: {batch_time_meter.avg:.4f}")
            val_accuracy = validate(valid_data, model, epoch, device, logger,
                                    summary_writer)
            lr_scheduler["env"].step(val_accuracy)
            lr_scheduler["policy"].step(val_accuracy)
            es.step(val_accuracy)
            global best_model_path
            if es.is_converged:
                return
            if es.is_improved():
                logger.info("saving model...")
                best_model_path = f"{args.model_dir}/{epoch}-{batch_idx}.mdl"
                torch.save(
                    {
                        "epoch": epoch,
                        "batch_idx": batch_idx,
                        "state_dict": model.state_dict()
                    }, best_model_path)
            model.train()
        start = time.time()