Example #1
0
def train(epoch, encoder, decoder, enc_optim, dec_optim, cross_entropy_loss,
          train_loader, word_dict, lambda_kld, log_interval):
    # import pdb; pdb.set_trace()
    encoder.train()
    decoder.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_idx, (img, cap, mod) in enumerate(train_loader):
        img, cap, mod = Variable(img).cuda(), Variable(cap).cuda(), Variable(
            mod).cuda()

        enc_optim.zero_grad()
        dec_optim.zero_grad()

        enc_features = encoder(img, mod)
        preds, alphas, h = decoder(enc_features, cap)

        targets = cap[:, 1:]

        targets = pack_padded_sequence(targets,
                                       [len(tar) - 1 for tar in targets],
                                       batch_first=True)[0]
        preds = pack_padded_sequence(preds, [len(pred) - 1 for pred in preds],
                                     batch_first=True)[0]

        # Captioning Cross Entripy loss
        captioning_loss = cross_entropy_loss(preds, targets)

        # Attention regularization loss
        att_regularization = ((1 - alphas.sum(1))**2).mean()

        # Total loss
        loss = captioning_loss + lambda_kld * att_regularization

        loss.backward()
        enc_optim.step()
        dec_optim.step()

        total_caption_length = calculate_caption_lengths(word_dict, cap)
        acc1 = accuracy(preds, targets, 1)
        acc5 = accuracy(preds, targets, 5)
        losses.update(loss.item(), total_caption_length)
        top1.update(acc1, total_caption_length)
        top5.update(acc5, total_caption_length)

        if batch_idx % log_interval == 0:
            print('Train Batch: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                      batch_idx,
                      len(train_loader),
                      loss=losses,
                      top1=top1,
                      top5=top5))
def train(epoch, encoder, decoder, optimizer, cross_entropy_loss, data_loader,
          word_dict, alpha_c, log_interval, writer):
    encoder.eval()
    decoder.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    for batch_idx, (imgs, captions) in enumerate(data_loader):
        imgs, captions = Variable(imgs).cuda(), Variable(captions).cuda()
        img_features = encoder(imgs)
        optimizer.zero_grad()
        preds, alphas = decoder(img_features, captions)
        targets = captions[:, 1:]

        targets = pack_padded_sequence(targets,
                                       [len(tar) - 1 for tar in targets],
                                       batch_first=True)[0]
        preds = pack_padded_sequence(preds, [len(pred) - 1 for pred in preds],
                                     batch_first=True)[0]

        att_regularization = alpha_c * ((1 - alphas.sum(1))**2).mean()

        loss = cross_entropy_loss(preds, targets)
        loss += att_regularization
        loss.backward()
        optimizer.step()

        total_caption_length = calculate_caption_lengths(word_dict, captions)
        acc1 = accuracy(preds, targets, 1)
        acc5 = accuracy(preds, targets, 5)
        losses.update(loss.item(), total_caption_length)
        top1.update(acc1, total_caption_length)
        top5.update(acc5, total_caption_length)

        if batch_idx % log_interval == 0:
            print('Train Batch: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                      batch_idx,
                      len(data_loader),
                      loss=losses,
                      top1=top1,
                      top5=top5))
        torch.cuda.empty_cache()
    writer.add_scalar('train_loss', losses.avg, epoch)
    writer.add_scalar('train_top1_acc', top1.avg, epoch)
    writer.add_scalar('train_top5_acc', top5.avg, epoch)
def validate(epoch, encoder, decoder, cross_entropy_loss, data_loader,
             word_dict, alpha_c, log_interval, writer):
    encoder.eval()
    decoder.eval()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # used for calculating bleu scores
    references = []
    hypotheses = []
    with torch.no_grad():
        for batch_idx, (imgs, captions,
                        all_captions) in enumerate(data_loader):
            imgs, captions = Variable(imgs).cuda(), Variable(captions).cuda()
            img_features = encoder(imgs)
            preds, alphas = decoder(img_features, captions)
            targets = captions[:, 1:]

            targets = pack_padded_sequence(targets,
                                           [len(tar) - 1 for tar in targets],
                                           batch_first=True)[0]
            packed_preds = pack_padded_sequence(
                preds, [len(pred) - 1 for pred in preds], batch_first=True)[0]

            att_regularization = alpha_c * ((1 - alphas.sum(1))**2).mean()

            loss = cross_entropy_loss(packed_preds, targets)
            loss += att_regularization

            total_caption_length = calculate_caption_lengths(
                word_dict, captions)
            acc1 = accuracy(packed_preds, targets, 1)
            acc5 = accuracy(packed_preds, targets, 5)
            losses.update(loss.item(), total_caption_length)
            top1.update(acc1, total_caption_length)
            top5.update(acc5, total_caption_length)

            for cap_set in all_captions.tolist():
                caps = []
                for caption in cap_set:
                    cap = [
                        word_idx for word_idx in caption
                        if word_idx != word_dict['<start>']
                        and word_idx != word_dict['<pad>']
                    ]
                    caps.append(cap)
                references.append(caps)

            word_idxs = torch.max(preds, dim=2)[1]
            for idxs in word_idxs.tolist():
                hypotheses.append([
                    idx for idx in idxs if idx != word_dict['<start>']
                    and idx != word_dict['<pad>']
                ])

            if batch_idx % log_interval == 0:
                print('Validation Batch: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                          batch_idx,
                          len(data_loader),
                          loss=losses,
                          top1=top1,
                          top5=top5))
        writer.add_scalar('val_loss', losses.avg, epoch)
        writer.add_scalar('val_top1_acc', top1.avg, epoch)
        writer.add_scalar('val_top5_acc', top5.avg, epoch)

        bleu_1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
        bleu_2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
        bleu_3 = corpus_bleu(references,
                             hypotheses,
                             weights=(0.33, 0.33, 0.33, 0))
        bleu_4 = corpus_bleu(references, hypotheses)

        writer.add_scalar('val_bleu1', bleu_1, epoch)
        writer.add_scalar('val_bleu2', bleu_2, epoch)
        writer.add_scalar('val_bleu3', bleu_3, epoch)
        writer.add_scalar('val_bleu4', bleu_4, epoch)
        print('Validation Epoch: {}\t'
              'BLEU-1 ({})\t'
              'BLEU-2 ({})\t'
              'BLEU-3 ({})\t'
              'BLEU-4 ({})\t'.format(epoch, bleu_1, bleu_2, bleu_3, bleu_4))
def test(epoch, encoder, decoder, cross_entropy_loss, data_loader, word_dict,
         alpha_c, log_interval, writer, saver):
    encoder.eval()
    decoder.eval()

    # used for calculating bleu scores
    references = []
    hypotheses = []
    with torch.no_grad():
        for batch_idx, (imgs, captions,
                        all_captions) in enumerate(data_loader):
            imgs, captions = Variable(imgs).cuda(), Variable(captions).cuda()
            img_features = encoder(imgs)
            preds, alphas = decoder(img_features, captions)
            targets = captions[:, 1:]

            targets = pack_padded_sequence(targets,
                                           [len(tar) - 1 for tar in targets],
                                           batch_first=True)[0]
            packed_preds = pack_padded_sequence(
                preds, [len(pred) - 1 for pred in preds], batch_first=True)[0]

            total_caption_length = calculate_caption_lengths(
                word_dict, captions)

            for cap_set in all_captions.tolist():
                caps = []
                for caption in cap_set:
                    cap = [
                        word_idx for word_idx in caption
                        if word_idx != word_dict['<start>']
                        and word_idx != word_dict['<pad>']
                    ]
                    caps.append(cap)
                references.append(caps)

            word_idxs = torch.max(preds, dim=2)[1]
            for idxs in word_idxs.tolist():
                hypotheses.append([
                    idx for idx in idxs if idx != word_dict['<start>']
                    and idx != word_dict['<pad>']
                ])

            if batch_idx % log_interval == 0:
                print('[%d/%d]' % (batch_idx, len(data_loader)))

        bleu_1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
        bleu_2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
        bleu_3 = corpus_bleu(references,
                             hypotheses,
                             weights=(0.33, 0.33, 0.33, 0))
        bleu_4 = corpus_bleu(references, hypotheses)

        writer.add_scalar('test_bleu1', bleu_1, epoch)
        writer.add_scalar('test_bleu2', bleu_2, epoch)
        writer.add_scalar('test_bleu3', bleu_3, epoch)
        writer.add_scalar('test_bleu4', bleu_4, epoch)
        saver.save_print_msg('Test Epoch: {}\t'
                             'BLEU-1 ({})\t'
                             'BLEU-2 ({})\t'
                             'BLEU-3 ({})\t'
                             'BLEU-4 ({})\t'.format(epoch, bleu_1, bleu_2,
                                                    bleu_3, bleu_4))
def train(epoch, encoder, decoder, optimizer, cross_entropy_loss, data_loader,
          word_dict, alpha_c, log_interval, writer, saver, val_loader, args):
    encoder.eval()
    decoder.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    for batch_idx, (imgs, captions) in enumerate(data_loader):
        imgs, captions = Variable(imgs).cuda(), Variable(captions).cuda()
        img_features = encoder(imgs)
        optimizer.zero_grad()
        preds, alphas = decoder(img_features, captions)

        targets = captions[:, 1:]

        targets = pack_padded_sequence(targets,
                                       [len(tar) - 1 for tar in targets],
                                       batch_first=True)[0]
        preds = pack_padded_sequence(preds, [len(pred) - 1 for pred in preds],
                                     batch_first=True)[0]

        att_regularization = alpha_c * ((1 - alphas.sum(1))**2).mean()
        entropy_reg = att_entropy(alphas)
        time_reg = -time_entropy(
            alphas)  # maximize this to hope each image can be treated equally

        loss = cross_entropy_loss(preds, targets)
        loss += att_regularization
        loss += args.lambda_ient * entropy_reg
        loss += args.lambda_tent * time_reg
        loss.backward()
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 10)
        optimizer.step()

        total_caption_length = calculate_caption_lengths(word_dict, captions)
        acc1 = accuracy(preds, targets, 1)
        acc5 = accuracy(preds, targets, 5)
        losses.update(loss.item(), total_caption_length)
        top1.update(acc1, total_caption_length)
        top5.update(acc5, total_caption_length)

        if batch_idx % log_interval == 0:
            #print('(Time: {}) Train Batch: [{0}/{1}]\t'
            #      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
            #      'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
            #      'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(saver.used_time,
            #          batch_idx, len(data_loader), loss=losses, top1=top1, top5=top5))
            print(
                '(Epoch: %03d, Iter: %06d, Time: %s), Loss: %.2f, I_ent: %.2f, T_ent: %.2f, top1: %.2f, top5: %.2f'
                % (
                    epoch,
                    batch_idx,
                    saver.used_time,
                    losses.avg,
                    entropy_reg,
                    time_reg,
                    top1.val,
                    top5.val,
                ))
    writer.add_scalar('train_loss', losses.avg, epoch)
    writer.add_scalar('train_top1_acc', top1.avg, epoch)
    writer.add_scalar('train_top5_acc', top5.avg, epoch)
Example #6
0
            if debug:
                print(f"preds = {preds.shape} alphas = {alphas.shape}")
        
            targets = captions[:, 1:] # removing the start token 

            targets = pack_padded_sequence(targets, [len(tar) - 1 for tar in targets], batch_first=True)[0]
            packed_preds = pack_padded_sequence(preds, [len(pred) - 1 for pred in preds], batch_first=True)[0]

            att_regularization = args.alpha_c * ((1 - alphas.sum(1))**2).mean()

            loss = cross_entropy_loss(packed_preds, targets)
            loss += att_regularization
            loss.backward()
            optimizer.step()

            total_caption_length = calculate_caption_lengths(word_dict, captions)
            losses.update(loss.item(), total_caption_length)
            if batch_idx % args.log_interval == 0:
                print('Train Batch: [{0}/{1}]\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                        batch_idx, len(train_loader), loss=losses))
            if debug:
                break
        
        # x = get_scores(model,train_loader)
        y = get_scores(model,val_loader,word_dict,idx_dict,device,debug)
        z = get_scores(model,test_loader,word_dict,idx_dict,device,debug)
        torch.save(model.state_dict(),Path(args.result_dir)/f"{epoch}.pth")
        print(f"epoch = {epoch} Val : {y} Test : {z}")