Example #1
0
def train(args):
    """Trains model for args.nepochs (default = 30)"""

    t_start = time.time()
    train_data = coco_loader(args.coco_root,
                             split='train',
                             ncap_per_img=args.ncap_per_img)
    print('[DEBUG] Loading train data ... %f secs' % (time.time() - t_start))

    train_data_loader = DataLoader(dataset=train_data, num_workers=args.nthreads,\
      batch_size=args.batchsize, shuffle=True, drop_last=True)

    #Load pre-trained imgcnn
    model_imgcnn = Vgg16Feats()
    model_imgcnn.cuda()
    model_imgcnn.train(True)

    #Convcap model
    model_convcap = convcap(train_data.numwords,
                            args.num_layers,
                            is_attention=args.attention)
    model_convcap.cuda()
    model_convcap.train(True)

    optimizer = optim.RMSprop(model_convcap.parameters(),
                              lr=args.learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.lr_step_size,
                                    gamma=.1)
    img_optimizer = None

    batchsize = args.batchsize
    ncap_per_img = args.ncap_per_img
    batchsize_cap = batchsize * ncap_per_img
    max_tokens = train_data.max_tokens
    nbatches = np.int_(np.floor((len(train_data.ids) * 1.) / batchsize))
    bestscore = .0

    for epoch in range(args.epochs):
        loss_train = 0.

        if (epoch == args.finetune_after):
            img_optimizer = optim.RMSprop(model_imgcnn.parameters(), lr=1e-5)
            img_scheduler = lr_scheduler.StepLR(img_optimizer,
                                                step_size=args.lr_step_size,
                                                gamma=.1)

        scheduler.step()
        if (img_optimizer):
            img_scheduler.step()

        #One epoch of train
        batch_idx = 0
        # img - the img
        # captions - all sentences of this image. if not train - single class
        # wordclass - matrix: each row represent caption num  of this image,
        #                     column represents word idx in the sentence,
        #                     value contain the index of this word from the dict
        # sentence_mask - vector which contain 1 until the max words we need to relate to it
        # img_id - img id uniqe for image
        for batch_idx, (imgs, captions, wordclass, mask, _) in \
          tqdm(enumerate(train_data_loader), total=nbatches):

            imgs = imgs.view(batchsize, 3, 224, 224)
            wordclass = wordclass.view(batchsize_cap, max_tokens)
            mask = mask.view(batchsize_cap, max_tokens)

            imgs_v = Variable(imgs).cuda()
            wordclass_v = Variable(wordclass).cuda()

            optimizer.zero_grad()
            if (img_optimizer):
                img_optimizer.zero_grad()

            imgsfeats, imgsfc7 = model_imgcnn(
                imgs_v
            )  #imgsfeats-features from VGG16, imgsfc7-classifications from VGG16
            imgsfeats, imgsfc7 = repeat_img_per_cap(imgsfeats, imgsfc7,
                                                    ncap_per_img)
            _, _, feat_h, feat_w = imgsfeats.size()

            if (args.attention == True):
                wordact, attn = model_convcap(imgsfeats, imgsfc7, wordclass_v)
                attn = attn.view(batchsize_cap, max_tokens, feat_h, feat_w)
            else:
                wordact, _ = model_convcap(imgsfeats, imgsfc7, wordclass_v)

            wordact = wordact[:, :, :-1]
            wordclass_v = wordclass_v[:, 1:]
            mask = mask[:, 1:].contiguous()

            # todo: continue from here!!!
            wordact_t = wordact.permute(0, 2, 1).contiguous().view(\
              batchsize_cap*(max_tokens-1), -1)
            wordclass_t = wordclass_v.contiguous().view(\
              batchsize_cap*(max_tokens-1), 1)

            maskids = torch.nonzero(mask.view(-1)).numpy().reshape(-1)

            if (args.attention == True):
                #Cross-entropy loss and attention loss of Show, Attend and Tell
                loss = F.cross_entropy(wordact_t[maskids, ...], \
                  wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])) \
                  + (torch.sum(torch.pow(1. - torch.sum(attn, 1), 2)))\
                  /(batchsize_cap*feat_h*feat_w)
            else:
                loss = F.cross_entropy(wordact_t[maskids, ...], \
                  wordclass_t[maskids, ...].contiguous().view(maskids.shape[0]))

            import pdb
            pdb.set_trace()
            loss_train = loss_train + loss.item()

            loss.backward()

            optimizer.step()
            if (img_optimizer):
                img_optimizer.step()

        loss_train = (loss_train * 1.) / (batch_idx)
        print('[DEBUG] Training epoch %d has loss %f' % (epoch, loss_train))

        modelfn = osp.join(args.model_dir, 'model.pth')

        if (img_optimizer):
            img_optimizer_dict = img_optimizer.state_dict()
        else:
            img_optimizer_dict = None

        torch.save(
            {
                'epoch': epoch,
                'state_dict': model_convcap.state_dict(),
                'img_state_dict': model_imgcnn.state_dict(),
                'optimizer': optimizer.state_dict(),
                'img_optimizer': img_optimizer_dict,
            }, modelfn)

        #Run on validation and obtain score
        scores = test(args,
                      'val',
                      model_convcap=model_convcap,
                      model_imgcnn=model_imgcnn)
        score = scores[0][args.score_select]

        if (score > bestscore):
            bestscore = score
            print('[DEBUG] Saving model at epoch %d with %s score of %f'\
              % (epoch, args.score_select, score))
            bestmodelfn = osp.join(args.model_dir, 'bestmodel.pth')
            os.system('cp %s %s' % (modelfn, bestmodelfn))
Example #2
0
def captionme(args, modelfn):
  """Caption images in args.image_dir using checkpoint modelfn"""

  imgs, imgs_fn = load_images(args.image_dir)

  #For trained model released with the code
  batchsize = 1
  max_tokens = 15
  num_layers = 3 
  is_attention = True 
  worddict_tmp = pickle.load(open('data/wordlist.p', 'rb'))
  wordlist = [l for l in iter(worddict_tmp.keys()) if l != '</S>']
  wordlist = ['EOS'] + sorted(wordlist)
  numwords = len(wordlist)

  model_imgcnn = Vgg16Feats()
  model_imgcnn.cuda() 

  model_convcap = convcap(numwords, num_layers, is_attention = is_attention)
  model_convcap.cuda()

  print('[DEBUG] Loading checkpoint %s' % modelfn)
  checkpoint = torch.load(modelfn)
  model_convcap.load_state_dict(checkpoint['state_dict'])
  model_imgcnn.load_state_dict(checkpoint['img_state_dict'])

  model_imgcnn.train(False) 
  model_convcap.train(False)

  pred_captions = []
  for batch_idx, (img_fn) in \
    tqdm(enumerate(imgs_fn), total=len(imgs_fn)):
    
    img = imgs[batch_idx, ...].view(batchsize, 3, 224, 224)

    img_v = Variable(img.cuda())
    imgfeats, imgfc7 = model_imgcnn(img_v)

    b, f_dim, f_h, f_w = imgfeats.size()
    imgfeats = imgfeats.unsqueeze(1).expand(\
      b, args.beam_size, f_dim, f_h, f_w)
    imgfeats = imgfeats.contiguous().view(\
      b*args.beam_size, f_dim, f_h, f_w)

    b, f_dim = imgfc7.size()
    imgfc7 = imgfc7.unsqueeze(1).expand(\
      b, args.beam_size, f_dim)
    imgfc7 = imgfc7.contiguous().view(\
      b*args.beam_size, f_dim)

    beam_searcher = beamsearch(args.beam_size, batchsize, max_tokens)
  
    wordclass_feed = np.zeros((args.beam_size*batchsize, max_tokens), dtype='int64')
    wordclass_feed[:,0] = wordlist.index('<S>') 
    outcaps = np.empty((batchsize, 0)).tolist()

    for j in range(max_tokens-1):
      wordclass = Variable(torch.from_numpy(wordclass_feed)).cuda()

      wordact, attn = model_convcap(imgfeats, imgfc7, wordclass)
      wordact = wordact[:,:,:-1]
      wordact_j = wordact[..., j]

      beam_indices, wordclass_indices = beam_searcher.expand_beam(wordact_j)  

      if len(beam_indices) == 0 or j == (max_tokens-2): # Beam search is over.
        generated_captions = beam_searcher.get_results()
        for k in range(batchsize):
            g = generated_captions[:, k]
            outcaps[k] = [wordlist[x] for x in g]
      else:
        wordclass_feed = wordclass_feed[beam_indices]
        imgfc7 = imgfc7.index_select(0, Variable(torch.cuda.LongTensor(beam_indices)))
        imgfeats = imgfeats.index_select(0, Variable(torch.cuda.LongTensor(beam_indices)))
        for i, wordclass_idx in enumerate(wordclass_indices):
          wordclass_feed[i, j+1] = wordclass_idx

    for j in range(batchsize):
      num_words = len(outcaps[j]) 
      if 'EOS' in outcaps[j]:
        num_words = outcaps[j].index('EOS')
      outcap = ' '.join(outcaps[j][:num_words])
      pred_captions.append({'img_fn': img_fn, 'caption': outcap})

  return pred_captions
Example #3
0
def test(args, split, modelfn=None, model_convcap=None, model_imgcnn=None):
    """Runs test on split=val/test with checkpoint file modelfn or loaded model_*"""

    t_start = time.time()
    data = coco_loader(args.coco_root, split=split, ncap_per_img=1)
    print('[DEBUG] Loading %s data ... %f secs' %
          (split, time.time() - t_start))

    data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\
      batch_size=args.batchsize, shuffle=False, drop_last=True)

    batchsize = args.batchsize
    max_tokens = data.max_tokens
    num_batches = np.int_(np.floor((len(data.ids) * 1.) / batchsize))
    print('[DEBUG] Running inference on %s with %d batches' %
          (split, num_batches))

    if (modelfn is not None):
        model_imgcnn = Vgg16Feats()
        model_imgcnn.cuda()

        model_convcap = convcap(data.numwords,
                                args.num_layers,
                                is_attention=args.attention)
        model_convcap.cuda()

        print('[DEBUG] Loading checkpoint %s' % modelfn)
        checkpoint = torch.load(modelfn)
        model_convcap.load_state_dict(checkpoint['state_dict'])
        model_imgcnn.load_state_dict(checkpoint['img_state_dict'])
    else:
        model_imgcnn = model_imgcnn
        model_convcap = model_convcap

    model_imgcnn.train(False)
    model_convcap.train(False)

    pred_captions = []
    #Test epoch
    for batch_idx, (imgs, _, _, _, img_ids) in \
      tqdm(enumerate(data_loader), total=num_batches):

        imgs = imgs.view(batchsize, 3, 224, 224)

        imgs_v = Variable(imgs.cuda())
        imgsfeats, imgsfc7 = model_imgcnn(imgs_v)
        _, featdim, feat_h, feat_w = imgsfeats.size()

        wordclass_feed = np.zeros((batchsize, max_tokens), dtype='int64')
        wordclass_feed[:, 0] = data.wordlist.index('<S>')

        outcaps = np.empty((batchsize, 0)).tolist()

        for j in range(max_tokens - 1):
            wordclass = Variable(torch.from_numpy(wordclass_feed)).cuda()

            wordact, _ = model_convcap(imgsfeats, imgsfc7, wordclass)

            wordact = wordact[:, :, :-1]
            wordact_t = wordact.permute(0, 2, 1).contiguous().view(
                batchsize * (max_tokens - 1), -1)

            wordprobs = F.softmax(wordact_t).cpu().data.numpy()
            wordids = np.argmax(wordprobs, axis=1)

            for k in range(batchsize):
                word = data.wordlist[wordids[j + k * (max_tokens - 1)]]
                outcaps[k].append(word)
                if (j < max_tokens - 1):
                    wordclass_feed[k,
                                   j + 1] = wordids[j + k * (max_tokens - 1)]

        for j in range(batchsize):
            num_words = len(outcaps[j])
            if 'EOS' in outcaps[j]:
                num_words = outcaps[j].index('EOS')
            outcap = ' '.join(outcaps[j][:num_words])
            pred_captions.append({'image_id': img_ids[j], 'caption': outcap})

    scores = language_eval(pred_captions, args.model_dir, split)

    model_imgcnn.train(True)
    model_convcap.train(True)

    return scores
Example #4
0
def test_beam(args, split, modelfn=None):
    """Sample generation with beam-search"""

    t_start = time.time()
    data = coco_loader(args.coco_root, split=split, ncap_per_img=1)
    print('[DEBUG] Loading %s data ... %f secs' %
          (split, time.time() - t_start))

    data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\
      batch_size=args.batchsize, shuffle=False, drop_last=True)

    batchsize = args.batchsize
    max_tokens = data.max_tokens
    num_batches = np.int_(np.floor((len(data.ids) * 1.) / batchsize))
    print('[DEBUG] Running test (w/ beam search) on %d batches' % num_batches)

    model_imgcnn = Vgg16Feats()
    model_imgcnn.cuda()

    model_convcap = convcap(data.numwords,
                            args.num_layers,
                            is_attention=args.attention)
    model_convcap.cuda()

    print('[DEBUG] Loading checkpoint %s' % modelfn)
    checkpoint = torch.load(modelfn)
    model_convcap.load_state_dict(checkpoint['state_dict'])
    model_imgcnn.load_state_dict(checkpoint['img_state_dict'])

    model_imgcnn.train(False)
    model_convcap.train(False)

    pred_captions = []
    for batch_idx, (imgs, _, _, _, img_ids) in \
      tqdm(enumerate(data_loader), total=num_batches):

        imgs = imgs.view(batchsize, 3, 224, 224)

        imgs_v = Variable(imgs.cuda())
        imgsfeats, imgsfc7 = model_imgcnn(imgs_v)

        b, f_dim, f_h, f_w = imgsfeats.size()
        imgsfeats = imgsfeats.unsqueeze(1).expand(\
          b, args.beam_size, f_dim, f_h, f_w)
        imgsfeats = imgsfeats.contiguous().view(\
          b*args.beam_size, f_dim, f_h, f_w)

        beam_searcher = beamsearch(args.beam_size, batchsize, max_tokens)

        wordclass_feed = np.zeros((args.beam_size * batchsize, max_tokens),
                                  dtype='int64')
        wordclass_feed[:, 0] = data.wordlist.index('<S>')
        imgsfc7 = repeat_img(args, imgsfc7)
        outcaps = np.empty((batchsize, 0)).tolist()

        for j in range(max_tokens - 1):
            wordclass = Variable(torch.from_numpy(wordclass_feed)).cuda()

            wordact, _ = model_convcap(imgsfeats, imgsfc7, wordclass)
            wordact = wordact[:, :, :-1]
            wordact_j = wordact[..., j]

            beam_indices, wordclass_indices = beam_searcher.expand_beam(
                wordact_j)

            if len(beam_indices) == 0 or j == (max_tokens -
                                               2):  # Beam search is over.
                generated_captions = beam_searcher.get_results()
                for k in range(batchsize):
                    g = generated_captions[:, k]
                    outcaps[k] = [data.wordlist[x] for x in g]
            else:
                wordclass_feed = wordclass_feed[beam_indices]
                imgsfc7 = imgsfc7.index_select(
                    0, Variable(torch.cuda.LongTensor(beam_indices)))
                imgsfeats = imgsfeats.index_select(
                    0, Variable(torch.cuda.LongTensor(beam_indices)))
                for i, wordclass_idx in enumerate(wordclass_indices):
                    wordclass_feed[i, j + 1] = wordclass_idx

        for j in range(batchsize):
            num_words = len(outcaps[j])
            if 'EOS' in outcaps[j]:
                num_words = outcaps[j].index('EOS')
            outcap = ' '.join(outcaps[j][:num_words])
            pred_captions.append({'image_id': img_ids[j], 'caption': outcap})

    scores = language_eval(pred_captions, args.model_dir, split)

    model_imgcnn.train(True)
    model_convcap.train(True)

    return scores
Example #5
0
def eval_split(model, crit, loader, eval_kwargs={}):

    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 1)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images',
                                 eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'test')
    lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)

    # Make sure in the evaluation mode
    model.eval()
    with torch.no_grad():
        model_imgcnn = Vgg16Feats()
        model_imgcnn.cuda()

    loader.reset_iterator(split)

    n = 0
    loss = 0
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []
    while True:

        data = loader.get_batch(split)
        n = loader.batch_size

        if data.get('labels', None) is not None and verbose_loss:
            # forward the model to get loss
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [
                torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp
            ]
            fc_feats, att_feats, labels, masks, att_masks = tmp

        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample

        tmp = [
            data['fc_feats'][np.arange(loader.batch_size) * 1],
            data['att_feats'][np.arange(loader.batch_size) * 1],
            data['att_masks'][np.arange(loader.batch_size) *
                              1] if data['att_masks'] is not None else None
        ]
        tmp = [torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp]
        fc_feats, att_feats, att_masks = tmp
        # forward the model to also get generated samples for each image
        num = 6
        length = 30
        max_tokens = num * length
        with torch.no_grad():
            wordclass_feed = np.zeros((fc_feats.size(0), max_tokens),
                                      dtype='int64')
            outcaps = np.empty((fc_feats.size(0), 0)).tolist()
            wordclass_feed[:, 0] = 8667
            #sent = 0
            for j in range(max_tokens - 1):
                if j > 0 and j % length == 0:
                    wordclass_feed[:, j] = 8667
                wordclass_feed_used = np.reshape(wordclass_feed,
                                                 (fc_feats.size(0), 6, 30))
                wordclass = Variable(
                    torch.from_numpy(wordclass_feed_used)).cuda()
                wordact, _ = model(fc_feats, att_feats, wordclass, 30, 6)
                wordact = wordact[:, :, :-1]
                wordact = wordact.transpose(2, 1)
                wordact_t = wordact.contiguous().view(
                    fc_feats.size(0) * (max_tokens - 1), -1)
                wordprobs = F.softmax(wordact_t).cpu().data.numpy()
                wordids = np.argmax(wordprobs, axis=1)

                for k in range(fc_feats.size(0)):
                    word = wordids[j + k * (max_tokens - 1)]
                    outcaps[k].append(word)
                    if (j < max_tokens - 1):
                        wordclass_feed[k, j + 1] = wordids[j + k *
                                                           (max_tokens - 1)]
            seq = torch.tensor(outcaps)
        # Print beam search
        if beam_size > 1 and verbose_beam:
            for i in range(loader.batch_size):
                print('\n'.join([
                    utils.decode_sequence(loader.get_vocab(),
                                          _['seq'].unsqueeze(0))[0]
                    for _ in model.done_beams[i]
                ]))
                print('--' * 10)
        #print (prob)
        prob = 1
        sents = utils.decode_sequence(loader.get_vocab(), seq, prob)

        for k, sent in enumerate(sents):
            sent = sent.replace('<start>', '').replace(' .', '.').replace(
                'UNK', '').replace('<pause>',
                                   '').replace('  ', ' ').replace('..', '')
            entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(
                    eval_kwargs['image_root'],
                    data['infos'][k]['file_path']) + '" vis/imgs/img' + str(
                        len(predictions)) + '.jpg'  # bit gross
                print(cmd)
                os.system(cmd)

            if verbose:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

        # if we wrapped around the split or used up val imgs budget then bail
        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if num_images != -1:
            ix1 = min(ix1, num_images)
        for i in range(n - ix1):
            predictions.pop()

        if verbose:
            print('evaluating validation preformance... %d/%d (%f)' %
                  (ix0 - 1, ix1, loss))

        if data['bounds']['wrapped']:
            break
        if num_images >= 0 and n >= num_images:
            break

    lang_stats = None

    if lang_eval == 1:
        lang_stats = language_eval(dataset, predictions, eval_kwargs['id'],
                                   split)

    # Switch back to training mode
    model.train()
    return loss_sum / loss_evals, predictions, lang_stats
Example #6
0
def train(args):
    """Trains model for args.nepochs (default = 30)"""

    t_start = time.time()
    train_data = coco_loader(args.coco_root,
                             split='train',
                             ncap_per_img=args.ncap_per_img)
    print('[DEBUG] Loading train data ... %f secs' % (time.time() - t_start))

    train_data_loader = DataLoader(dataset=train_data, num_workers=args.nthreads,\
      batch_size=args.batchsize, shuffle=True, drop_last=True)

    lang_model = Seq2Seq(train_data.numwords)
    lang_model = lang_model.cuda()
    lang_model.load_state_dict(
        torch.load('log_model/bestmodel.pth')['lang_state_dict'])
    lang_model.train()
    #Load pre-trained imgcnn
    model_imgcnn = Vgg16Feats()
    model_imgcnn.cuda()
    model_imgcnn.train(True)
    model_imgcnn.load_state_dict(
        torch.load('log_reg/bestmodel.pth')['img_state_dict'])
    #Convcap model
    model_convcap = convcap(train_data.numwords,
                            args.num_layers,
                            is_attention=args.attention)
    model_convcap.cuda()
    model_convcap.load_state_dict(
        torch.load('log_reg/bestmodel.pth')['state_dict'])
    model_convcap.train(True)

    optimizer = optim.RMSprop(model_convcap.parameters(),
                              lr=args.learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.lr_step_size,
                                    gamma=.1)
    img_optimizer = None

    batchsize = args.batchsize
    ncap_per_img = args.ncap_per_img
    batchsize_cap = batchsize * ncap_per_img
    max_tokens = train_data.max_tokens
    nbatches = np.int_(np.floor((len(train_data.ids) * 1.) / batchsize))
    bestscore = .0

    for epoch in range(args.epochs):
        loss_train = 0.

        if (epoch == args.finetune_after):
            img_optimizer = optim.RMSprop(model_imgcnn.parameters(), lr=1e-5)
            img_scheduler = lr_scheduler.StepLR(img_optimizer,
                                                step_size=args.lr_step_size,
                                                gamma=.1)

        scheduler.step()
        if (img_optimizer):
            img_scheduler.step()
        it = 0
        #One epoch of train
        for batch_idx, (imgs, captions, wordclass, mask, _) in \
          tqdm(enumerate(train_data_loader), total=nbatches):
            it = it + 1
            imgs = imgs.view(batchsize, 3, 224, 224)
            wordclass = wordclass.view(batchsize_cap, max_tokens).cuda()
            mask = mask.view(batchsize_cap, max_tokens)

            captions = utils.decode_sequence(train_data.wordlist, wordclass,
                                             None)
            captions_all = []
            for index, caption in enumerate(captions):
                captions_all.append(caption)

            imgs_v = Variable(imgs).cuda()
            wordclass_v = Variable(wordclass).cuda()

            optimizer.zero_grad()
            if (img_optimizer):
                img_optimizer.zero_grad()

            imgsfeats, imgsfc7 = model_imgcnn(imgs_v)
            imgsfeats, imgsfc7 = repeat_img_per_cap(imgsfeats, imgsfc7,
                                                    ncap_per_img)
            _, _, feat_h, feat_w = imgsfeats.size()

            if (args.attention == True):
                wordact, attn = model_convcap(imgsfeats, imgsfc7, wordclass_v)
                attn = attn.view(batchsize_cap, max_tokens, feat_h, feat_w)
            else:
                wordact, _ = model_convcap(imgsfeats, imgsfc7, wordclass_v)

            wordact = wordact[:, :, :-1]
            wordclass_v = wordclass_v[:, 1:]
            mask = mask[:, 1:].contiguous()

            wordact_t = wordact.permute(0, 2, 1).contiguous().view(\
              batchsize_cap*(max_tokens-1), -1)
            wordclass_t = wordclass_v.contiguous().view(\
              batchsize_cap*(max_tokens-1), 1)

            maskids = torch.nonzero(mask.view(-1)).numpy().reshape(-1)

            if (args.attention == True):
                #Cross-entropy loss and attention loss of Show, Attend and Tell
                loss_xe = F.cross_entropy(wordact_t[maskids, ...], \
                  wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])) \
                  + (torch.sum(torch.pow(1. - torch.sum(attn, 1), 2)))\
                  /(batchsize_cap*feat_h*feat_w)
            else:
                loss_xe = F.cross_entropy(wordact_t[maskids, ...], \
                  wordclass_t[maskids, ...].contiguous().view(maskids.shape[0]))

            wordact = lang_model(wordclass_v.transpose(1, 0),
                                 wordclass_v.transpose(1, 0), imgs)
            wordact = wordact.transpose(1, 0)[:, :-1, :]
            wordclass_v = wordclass_v[:, 1:]

            wordact_t = wordact.contiguous().view(\
              batchsize_cap*wordact.size(1), -1)

            wordclass_t = wordclass_v.contiguous().view(\
              batchsize_cap*wordclass_v.size(1), 1)

            loss_xe_lang = F.cross_entropy(wordact_t[...], \
                wordclass_t[...].contiguous().view(-1))

            with torch.no_grad():
                outcap, sampled_ids, sample_logprobs, x_all_langauge, outputs = lang_model.sample(
                    wordclass.transpose(1, 0), wordclass.transpose(1, 0),
                    imgsfeats.transpose(1, 0), train_data.wordlist)

            logprobs_input, _ = model_convcap(imgsfeats, imgsfc7,
                                              sampled_ids.long().cuda())
            log_probs = F.log_softmax(
                logprobs_input.transpose(2, 1)[:, :-1, :], -1)

            sample_logprobs_true = log_probs.gather(
                2, sampled_ids[:, 1:].cuda().long().unsqueeze(2))
            with torch.no_grad():
                reward = get_self_critical_reward(batchsize_cap, lang_model,
                                                  wordclass.transpose(1, 0),
                                                  imgsfeats.transpose(1, 0),
                                                  outcap, captions_all,
                                                  train_data.wordlist, 16)

            loss_rl1 = rl_crit(
                torch.exp(sample_logprobs_true.squeeze()) /
                torch.exp(sample_logprobs[:, 1:]).cuda().detach(),
                sampled_ids[:, 1:].cpu(),
                torch.from_numpy(reward).float().cuda())
            #loss_rl2 = rl_crit(sample_logprobs[:,1:].cuda(), sampled_ids[:, 1:].cpu(), torch.from_numpy(reward).float().cuda())

            loss = 0.0 * loss_xe + loss_rl1  # + loss_xe_lang + loss_rl2

            if it % 500 == 0:
                modelfn = osp.join(args.model_dir, 'model.pth')
                scores = test(args,
                              'val',
                              model_convcap=model_convcap,
                              model_imgcnn=model_imgcnn)
                score = scores[0][args.score_select]
                if (score > bestscore):
                    bestscore = score
                    print('[DEBUG] Saving model at epoch %d with %s score of %f'\
                       % (epoch, args.score_select, score))
                    bestmodelfn = osp.join(args.model_dir, 'bestmodel.pth')
                    os.system('cp %s %s' % (modelfn, bestmodelfn))

                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model_convcap.state_dict(),
                        'img_state_dict': model_imgcnn.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lang_state_dict': lang_model.state_dict()
                    }, modelfn)

            loss_train = loss_train + loss

            loss.backward()

            optimizer.step()
            if (img_optimizer):
                img_optimizer.step()

        loss_train = (loss_train * 1.) / (batch_idx)
        print('[DEBUG] Training epoch %d has loss %f' % (epoch, loss_train))

        modelfn = osp.join(args.model_dir, 'model.pth')

        if (img_optimizer):
            img_optimizer_dict = img_optimizer.state_dict()
        else:
            img_optimizer_dict = None

        torch.save(
            {
                'epoch': epoch,
                'state_dict': model_convcap.state_dict(),
                'img_state_dict': model_imgcnn.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lang_state_dict': lang_model.state_dict()
            }, modelfn)

        #Run on validation and obtain score
        scores = test(args,
                      'val',
                      model_convcap=model_convcap,
                      model_imgcnn=model_imgcnn)
        score = scores[0][args.score_select]

        if (score > bestscore):
            bestscore = score
            print('[DEBUG] Saving model at epoch %d with %s score of %f'\
              % (epoch, args.score_select, score))
            bestmodelfn = osp.join(args.model_dir, 'bestmodel.pth')
            os.system('cp %s %s' % (modelfn, bestmodelfn))