def train_lrp(train_loader, model, criterion, optimizer, epoch, ss_prob,
              word_map, print_freq, grad_clip):
    model.train()
    losses = mutils.AverageMeter()  # loss (per decoded word)
    top5accs = mutils.AverageMeter()  # top5 accuracy
    rev_word_map = {v: k for k, v in word_map.items()}
    for i, (imgs, caps, all_caps, caplens) in enumerate(train_loader):
        imgs = imgs.cuda()
        caps = caps.cuda()
        predictions, weighted_predictions, max_length = model.forwardlrp_context(
            imgs, caps, caplens, rev_word_map)
        scores = predictions.contiguous().view(-1, predictions.size(2))
        targets = caps[:, 1:max_length + 1]
        targets = targets.contiguous().view(
            predictions.size(0) * predictions.size(1))
        loss_standard = criterion(scores, targets)

        # print(weighted_predictions.size(), max_length)
        weighted_scores = weighted_predictions.contiguous().view(
            -1, weighted_predictions.size(2))
        loss_lrp = criterion(weighted_scores, targets)
        loss = loss_lrp + loss_standard
        optimizer.zero_grad()
        loss.backward()
        if grad_clip:
            mutils.clip_gradient(optimizer, grad_clip=grad_clip)
        optimizer.step()
        top5 = mutils.accuracy(scores, targets, 1)
        losses.update(loss.item(), sum(caplens).float())
        top5accs.update(top5, sum(caplens).float())
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-1 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(
                      epoch, i, len(train_loader), loss=losses, top5=top5accs))
def trainciderlrp(train_loader, model, criterion, optimizer, epoch, ss_prob,
                  word_map, print_freq, grad_clip):
    model.train()
    losses = mutils.AverageMeter()  # loss (per decoded word)
    rewards = mutils.AverageMeter()
    rev_word_map = {v: k for k, v in word_map.items()}
    for i, (imgs, caps, all_caps, caplens) in enumerate(train_loader):
        imgs = imgs.cuda()
        model.eval()
        with torch.no_grad():
            greedy_res, _, _ = model.sample(imgs, word_map, caplens)
        model.train()
        gen_result, sample_logprobs, max_length = model.sample_lrp(
            imgs,
            rev_word_map,
            word_map,
            caplens,
            opt={'sample_method': 'sample'})
        reward = mutils.get_self_critical_reward(greedy_res,
                                                 all_caps,
                                                 gen_result,
                                                 word_map,
                                                 cider_reward_weight=1.,
                                                 bleu_reward_weight=0)
        reward = torch.from_numpy(reward).float().cuda()
        loss = criterion(sample_logprobs, gen_result.data, reward)
        optimizer.zero_grad()
        loss.backward()
        if grad_clip:
            mutils.clip_gradient(optimizer, grad_clip=grad_clip)
        optimizer.step()
        losses.update(loss.item(), sum(caplens - 2).float())
        rewards.update(reward[:, 0].mean().item(), float(len(reward)))
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Reward {rewards.val:.3f} ({rewards.avg:.3f})\t'.format(
                      epoch,
                      i,
                      len(train_loader),
                      loss=losses,
                      rewards=rewards))
            state = {
                'epoch': epoch,
                'loss': losses,
                'state_dict': model.state_dict(),
                'batch': i
            }
            filename = f'lrpcider_checkpoint_epoch{epoch}_batch_{i}.pth'
            torch.save(
                state,
                os.path.join(
                    '/home/sunjiamei/work/ImageCaptioning/ImgCaptioningPytorch/output/gridTD/vgg16/flickr30k/lrpciderfinetune/',
                    filename))
Beispiel #3
0
def trainciderlrp(train_loader, model, criterion, optimizer, epoch, ss_prob,
                  word_map, print_freq, grad_clip):
    model.train()
    losses = mutils.AverageMeter()  # loss (per decoded word)
    rewards = mutils.AverageMeter()
    rev_word_map = {v: k for k, v in word_map.items()}
    for i, (imgs, caps, all_caps, caplens) in enumerate(train_loader):
        imgs = imgs.cuda()
        model.eval()
        with torch.no_grad():
            greedy_res, _, _ = model.sample(imgs, word_map, caplens)
        model.train()
        gen_result, sample_logprobs, max_length = model.sample_lrp(
            imgs,
            rev_word_map,
            word_map,
            caplens,
            opt={'sample_method': 'sample'})
        reward = mutils.get_self_critical_reward(greedy_res,
                                                 all_caps,
                                                 gen_result,
                                                 word_map,
                                                 cider_reward_weight=1.,
                                                 bleu_reward_weight=0)
        reward = torch.from_numpy(reward).float().cuda()
        loss = criterion(sample_logprobs, gen_result.data, reward)
        optimizer.zero_grad()
        loss.backward()
        if grad_clip:
            mutils.clip_gradient(optimizer, grad_clip=grad_clip)
        optimizer.step()
        losses.update(loss.item(), sum(caplens - 2).float())
        rewards.update(reward[:, 0].mean().item(), float(len(reward)))
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Reward {rewards.val:.3f} ({rewards.avg:.3f})\t'.format(
                      epoch,
                      i,
                      len(train_loader),
                      loss=losses,
                      rewards=rewards))