def traditional_cider(self, fc_feats, att_feats, att_masks, data, loss,
                          gen_result, greedy_res, sample_logprobs, gen_masks):

        # Use the differenced rewards
        if self.use_gen_cider_scores == 0:
            reward, cider_greedy = rewards.get_self_critical_reward(
                data, gen_result, greedy_res)
        else:  # use the original rewards
            reward, _, cider_greedy = \
                rewards.get_self_critical_reward(
                    data, gen_result, greedy_res,
                    return_gen_scores=True)
        self._loss['avg_reward'] = reward.mean()
        self._loss['cider_greedy'] = cider_greedy

        loss_cider = sample_logprobs * utils.var_wrapper(
            -reward.astype('float32'),
            cuda=torch.cuda.is_available()).unsqueeze(1) * (
                gen_masks[:, 1:].detach())

        loss_cider = loss_cider.sum() / \
                     gen_masks[:, 1:].data.float().sum()
        loss += self.cider_optimization * loss_cider
        self._loss['loss_cider'] = loss_cider.data[0]

        return loss
 def forward(self, fc_feats, att_feats, obj_label, rela_label, rela_sub, rela_obj, rela_n2r, geometry,
             adj1, adj2, labels, masks, att_masks, rela_masks,
             gts, gt_indices, sc_flag):
     out = {}
     if not sc_flag:
         loss = self.crit(self.model(fc_feats, att_feats, obj_label, rela_label, rela_sub, rela_obj, rela_n2r, geometry,
                                     adj1, adj2, rela_masks, labels, att_masks), labels[:, 1:], masks[:, 1:])
     else:
         self.model.eval()
         with torch.no_grad():
             greedy_res, _ = self.model(fc_feats, att_feats, obj_label, rela_label, rela_sub, rela_obj, rela_n2r, geometry,
                                        adj1, adj2, rela_masks, labels, att_masks,
                                        mode='sample')
         self.model.train()
         gen_result, sample_logprobs = self.model(fc_feats, att_feats, obj_label, rela_label, rela_sub, rela_obj, rela_n2r, geometry,
                                                  adj1, adj2, rela_masks, labels, att_masks,
                                                  opt={'sample_method': 'sample'}, mode='sample')
         gts = [gts[_] for _ in gt_indices.tolist()]
         reward = get_self_critical_reward(
             greedy_res, gts, gen_result, self.opt)
         reward = torch.from_numpy(reward).float().to(gen_result.device)
         loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
         out['reward'] = reward[:, 0].mean()
     out['loss'] = loss
     return out
示例#3
0
    def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
                sc_flag):
        out = {}
        if not sc_flag:
            loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
        else:
            self.model.eval()
            with torch.no_grad():
                greedy_res, _ = self.model(fc_feats, att_feats, att_masks, mode='sample')
            self.model.train()
            gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_method':'sample'}, mode='sample')

            gts = [gts[_] for _ in gt_indices.tolist()]
            reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt)
            reward = torch.from_numpy(reward).float().to(gen_result.device)
            loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
            out['reward'] = reward[:,0].mean()

        if self.opt.caption_model == 'aat':
            all_aat_loss = torch.stack(self.model.all_att_cost).t()
            if not sc_flag:
                mask_ = masks[:,:all_aat_loss.size()[1]]
            else:
                mask_ = (torch.cat((gen_result.new_ones(gen_result.size(0),1), gen_result), dim=1)>0)[:,:all_aat_loss.size()[1]]
            aat_loss = (all_aat_loss * mask_.float()).sum(1).mean()
            out['aat_loss'] = aat_loss
            out['att_step'] = self.model.all_att_step
            out['avg_att_time'] = (np.array(self.model.all_att_step).transpose() * mask_.cpu().numpy()).sum()/mask_.cpu().numpy().sum()
            out['loss_'] = loss.clone()
            loss += self.opt.aat_lambda * aat_loss

        out['loss'] = loss
        return out
示例#4
0
def train(loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None):
    model.train()
    model = nn.DataParallel(model)
    for epoch in range(opt["epochs"]):
        lr_scheduler.step()

        iteration = 0
        # If start self crit training
        if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        for data in loader:
            torch.cuda.synchronize()
            fc_feats = Variable(data['fc_feats']).cuda()
            labels = Variable(data['labels']).long().cuda()
            masks = Variable(data['masks']).cuda()

            optimizer.zero_grad()
            if not sc_flag:
                seq_probs, _ = model(fc_feats, labels, 'train')
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(
                    seq_probs, seq_preds,
                    Variable(torch.from_numpy(reward).float().cuda()))

            loss.backward()
            utils.clip_gradient(optimizer, opt["grad_clip"])
            optimizer.step()
            train_loss = loss.data[0]
            torch.cuda.synchronize()
            iteration += 1

            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.6f" %
                      (iteration, epoch, np.mean(reward[:, 0])))

        if epoch != 0 and epoch % opt["save_checkpoint_every"] == 0:
            model_path = os.path.join(opt["checkpoint_path"],
                                      'model_%d.pth' % (epoch))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'model_score.txt')
            torch.save(model.state_dict(), model_path)
            print("model saved to %s" % (model_path))
            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))
    def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, sc_flag, ppo_flag, clipped_lambda, sc_lambda): ## Added ppo_flag and old_model for ppo 9/sep/2019
        out = {}
        ################ ADDED THIS SECTION for ppo 8/sep/2019
        if ppo_flag:
            gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max': 0},mode='sample')
            #print("######### SAMPLE LOGPROB#######",sample_logprobs.shape,sample_logprobs) ## REMOVE LATER
            #if self.old_sample_logprobs == None: ## Added this to control the intial null problem of the old policy
            #    self.old_sample_logprobs = sample_logprobs.clone() ## Added this on 11/Sep/2019

            #print('gen_result length:\n',gen_result)
            gts = [gts[_] for _ in gt_indices.tolist()]
            #reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt)
            reward = get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, self.opt)
            #print("Reward given:", reward, len(reward))
            reward = torch.from_numpy(reward).float().to(gen_result.device)
            # loss = self.ppo_crit(sample_logprobs, self.old_sample_logprobs, gen_result.data,reward)  # The loss is the main part, the core of reinforce, I guess, coming from utils.RewardCriterion()
            ###### Added in 24/sep/2019 as a way of combining PPO-clip and scst#######
            loss_ppo = self.ppo_crit(sample_logprobs, self.old_sample_logprobs, gen_result.data,reward)  # The loss is the main part, the core of reinforce, I guess, coming from utils.RewardCriterion()
            loss_sc = self.rl_crit(sample_logprobs, gen_result.data, reward)
            self.old_sample_logprobs = sample_logprobs.clone()
            print("Using sc_lambda: {}\tclipped_lambda: {}".format(sc_lambda,clipped_lambda))
            loss = sc_lambda * loss_sc + clipped_lambda * loss_ppo
            #loss = sc_lambda * loss_sc + clipped_lambda * 1 #********* Replacing with a dummy value c = 1 - 13/oct/2019
            #loss = loss_ppo ## Activate for only Clipped-SC loss
            #########################################################################
            out['reward'] = reward[:, 0].mean()
        else: ##############################################
            if not sc_flag:
                loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
            else:
                gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                gts = [gts[_] for _ in gt_indices.tolist()]
                reward = get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, self.opt)
                reward = torch.from_numpy(reward).float().to(gen_result.device)
                loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
                out['reward'] = reward[:,0].mean()
        out['loss'] = loss
        return out
示例#6
0
 def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
             sc_flag, struc_flag):
     opt = self.opt
     
     out = {}
     if struc_flag:
         if opt.structure_loss_weight < 1:
             lm_loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[..., 1:], masks[..., 1:])
         else:
             lm_loss = torch.tensor(0).type_as(fc_feats)
         if opt.structure_loss_weight > 0:
             gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
                 opt={'sample_method':opt.train_sample_method,
                     'beam_size':opt.train_beam_size,
                     'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
                         or not 'margin' in opt.structure_loss_type,
                     'sample_n': opt.train_sample_n},
                 mode='sample')
             gts = [gts[_] for _ in gt_indices.tolist()]
             struc_loss = self.struc_crit(sample_logprobs, gen_result, gts)
         else:
             struc_loss = {'loss': torch.tensor(0).type_as(fc_feats),
                           'reward': torch.tensor(0).type_as(fc_feats)}
         loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss']
         out['lm_loss'] = lm_loss
         out['struc_loss'] = struc_loss['loss']
         out['reward'] = struc_loss['reward']
     elif not sc_flag:
         loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[..., 1:], masks[..., 1:])
     else:
         self.model.eval()
         with torch.no_grad():
             greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
                 mode='sample',
                 opt={'sample_method': opt.sc_sample_method,
                      'beam_size': opt.sc_beam_size})
         self.model.train()
         gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
                 opt={'sample_method':opt.train_sample_method,
                     'beam_size':opt.train_beam_size,
                     'sample_n': opt.train_sample_n},
                 mode='sample')
         gts = [gts[_] for _ in gt_indices.tolist()]
         reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt)
         reward = torch.from_numpy(reward).float().to(gen_result.device)
         loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
         out['reward'] = reward[:,0].mean()
     out['loss'] = loss
     return out
 def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts,
             gt_indices, sc_flag):
     out = {}
     if not sc_flag:
         loss = self.crit(
             self.model(fc_feats, att_feats, labels, att_masks),
             labels[:, 1:], masks[:, 1:])
     else:
         gen_result, sample_logprobs = self.model(fc_feats,
                                                  att_feats,
                                                  att_masks,
                                                  opt={'sample_max': 0},
                                                  mode='sample')
         gts = [gts[_] for _ in gt_indices.tolist()]
         reward = get_self_critical_reward(self.model, fc_feats, att_feats,
                                           att_masks, gts, gen_result,
                                           self.opt)
         reward = torch.from_numpy(reward).float().to(gen_result.device)
         loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
         out['reward'] = reward[:, 0].mean()
     out['loss'] = loss
     return out
示例#8
0
 def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts,
             gt_indices, sc_flag, itera):
     out = {}
     if not sc_flag:
         c = 'LabelSmoothing' if self.ls > 0 else 'CrossEntropy'
         # print(f'----------sc_flag:{sc_flag}, crit:{c}------------')
         loss = self.crit(
             self.model(fc_feats, att_feats, labels, att_masks),
             labels[:, 1:], masks[:, 1:])
     else:
         # print(f'----------sc_flag:{sc_flag}, crit:Reward------------')
         self.model.eval()
         with torch.no_grad():  # greedy search?
             greedy_res, _ = self.model(fc_feats,
                                        att_feats,
                                        att_masks,
                                        mode='sample')
         # print(f'===greedy_res:{greedy_res.shape}===')
         self.model.train()
         gen_result, sample_logprobs = self.model(
             fc_feats,
             att_feats,
             att_masks,
             opt={'sample_method': 'sample'},
             mode='sample')
         # print(f'===gen_result:{gen_result.shape}===')
         # print(f'===sample_logprobs:{sample_logprobs.shape}===')
         # print(f'===gts:{len(gts), gts}===')
         gts = [gts[_] for _ in gt_indices.tolist()
                ]  # ground truth samples each has 5 captions
         # print(f'===gts:{len(gts), gts}===') # gts : list
         reward = get_self_critical_reward(greedy_res, gts, gen_result,
                                           self.opt, itera)
         reward = torch.from_numpy(reward).float().to(gen_result.device)
         loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
         out['reward'] = reward[:, 0].mean()
     out['loss'] = loss
     return out
def train(dataset,
          loader,
          model,
          rem,
          crit,
          optimizer,
          lr_scheduler,
          opt,
          rl_crit=None):
    writer = SummaryWriter('./runs/video_caption22')
    model.load_state_dict(
        torch.load(
            '/home/diml/video-caption.pytorch/save/RECON222_model_200.pth'))
    rem.load_state_dict(
        torch.load(
            '/home/diml/video-caption.pytorch/save/RECON222_module_200.pth'))
    #model.load_state_dict(torch.load('/home/diml/video-caption.pytorch/save/new_model_200.pth'))
    #model = nn.DataParallel(model)
    model.train()
    rem.train()

    vocab = dataset.get_vocab()

    for epoch in trange(opt["epochs"]):
        t_loss = [0, 0, 0]
        # =============================================================================
        #         model.eval()
        #         ev.demov(model,crit, dataset, dataset.get_vocab(),opt)
        # =============================================================================

        lr_scheduler.step()
        iteration = 0

        # If start self crit training
        if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        for idx, data in enumerate(loader):
            torch.cuda.synchronize()
            fc_feats = data['fc_feats'].cuda()
            labels = data['labels'].cuda()
            labels2 = data['labels2'].cuda()
            masks2 = data['masks2'].cuda()
            masks = data['masks'].cuda()
            optimizer.zero_grad()
            if not sc_flag:
                seq_probs, seq_preds, hn, de_hn = model(
                    fc_feats, labels, 'train')
                loss_C = crit(seq_probs, labels[:, 1:], masks[:, 1:])
                fake_en_hn = rem(de_hn, seq_probs)
                f_seq_probs, f_seq_preds, hn, de_hn = model(fc_feats,
                                                            labels2,
                                                            'train',
                                                            h=fake_en_hn)
                loss_R = crit(f_seq_probs, labels2[:, 1:], masks2[:, 1:])
                loss = loss_R + loss_C
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(seq_probs, seq_preds,
                               torch.from_numpy(reward).float().cuda())

            t_loss[0] += loss.item()
            t_loss[1] += loss_C.item()
            t_loss[2] += loss_R.item()
            loss.backward()
            clip_grad_value_(model.parameters(), opt['grad_clip'])
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            iteration += 1
            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.6f" %
                      (iteration, epoch, np.mean(reward[:, 0])))
        writer.add_scalar('training total loss', t_loss[0] / 140, epoch + 200)
        writer.add_scalar('training Caption loss', t_loss[1] / 140,
                          epoch + 200)
        writer.add_scalar('training Reconstruction loss', t_loss[2] / 140,
                          epoch + 200)
        if epoch % opt["save_checkpoint_every"] == 0:

            model_path = os.path.join(opt["checkpoint_path"],
                                      'RECON222_model_%d.pth' % (epoch + 200))
            rem_path = os.path.join(opt["checkpoint_path"],
                                    'RECON222_module_%d.pth' % (epoch + 200))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'RECON222_model_score.txt')
            torch.save(model.state_dict(), model_path)
            torch.save(rem.state_dict(), rem_path)
            print("model saved to %s" % (model_path))

            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))

        with torch.no_grad():
            _, seq_preds, __, ___ = model(fc_feats, mode='inference', opt=opt)
            _, f_seq_preds, __, ___ = model(fc_feats,
                                            mode='inference',
                                            h=fake_en_hn,
                                            opt=opt)
            origin = utils.decode_sequence(vocab, seq_preds)[0]
            revision = utils.decode_sequence(vocab, f_seq_preds)[0]
            with open('./results/training_versus.txt', 'a') as f:
                f.write("epoch is %d \n" % epoch)
                origin = "origin caption: " + origin + "\n"
                revision = "revision caption: " + revision + "\n"
                f.write(origin)
                f.write(revision)
示例#10
0
    def forward(self, fc_feats, att_feats, att_masks, seq, masks, data):
        if self.caption_loss_weight > 0 and not self.cider_optimization:
            loss_cap = self.caption_generator(fc_feats, att_feats, att_masks,
                                              seq, masks)
        else:
            loss_cap = Variable(torch.cuda.FloatTensor([0]))
        if self.vse_loss_weight > 0:
            loss_vse = self.vse(fc_feats,
                                att_feats,
                                seq,
                                masks,
                                only_one_retrieval=self.only_one_retrieval)
        else:
            loss_vse = Variable(torch.cuda.FloatTensor([0]))

        loss = self.caption_loss_weight * loss_cap + self.vse_loss_weight * loss_vse

        if self.retrieval_reward_weight > 0:
            if True:
                _seqs, _sampleLogProbs = self.caption_generator.sample(
                    fc_feats, att_feats, att_masks, {
                        'sample_max': 0,
                        'temperature': 1
                    })
                gen_result, sample_logprobs = _seqs, _sampleLogProbs
                _masks = torch.cat([
                    Variable(
                        _seqs.data.new(_seqs.size(0), 2).fill_(1).float()),
                    (_seqs > 0).float()[:, :-1]
                ], 1)

                gen_masks = _masks

                _seqs = torch.cat([
                    Variable(
                        _seqs.data.new(
                            _seqs.size(0),
                            1).fill_(self.caption_generator.vocab_size + 1)),
                    _seqs
                ], 1)

                if True:
                    retrieval_loss = self.vse(
                        fc_feats,
                        att_feats,
                        _seqs,
                        _masks,
                        True,
                        only_one_retrieval=self.only_one_retrieval)
                    if self.reinforce_baseline_type == 'greedy':
                        _seqs_greedy, _sampleLogProbs_greedy = self.caption_generator.sample(
                            *utils.var_wrapper(
                                [fc_feats, att_feats, att_masks],
                                volatile=True),
                            opt={
                                'sample_max': 1,
                                'temperature': 1
                            })
                        greedy_res = _seqs_greedy
                        # Do we need weights here???
                        if True:  #not self.use_word_weights:
                            _masks_greedy = torch.cat([
                                Variable(
                                    _seqs_greedy.data.new(_seqs.size(0),
                                                          2).fill_(1).float()),
                                (_seqs_greedy > 0).float()[:, :-1]
                            ], 1)
                        else:
                            _masks_greedy = self.get_word_weights_mask(
                                _seqs_greedy)

                        _seqs_greedy = torch.cat([
                            Variable(
                                _seqs_greedy.data.new(_seqs_greedy.size(0), 1).
                                fill_(self.caption_generator.vocab_size + 1)),
                            _seqs_greedy
                        ], 1)

                        baseline = self.vse(
                            fc_feats,
                            att_feats,
                            _seqs_greedy,
                            _masks_greedy,
                            True,
                            only_one_retrieval=self.only_one_retrieval)
                    elif self.reinforce_baseline_type == 'gt':
                        baseline = self.vse(
                            fc_feats,
                            att_feats,
                            seq,
                            masks,
                            True,
                            only_one_retrieval=self.only_one_retrieval)
                    else:
                        baseline = 0

                sc_loss = _sampleLogProbs * (
                    utils.var_wrapper(retrieval_loss) -
                    utils.var_wrapper(baseline)).detach().unsqueeze(1) * (
                        _masks[:, 1:].detach().float())
                sc_loss = sc_loss.sum() / _masks[:, 1:].data.float().sum()

                loss += self.retrieval_reward_weight * sc_loss

                self._loss['retrieval_sc_loss'] = sc_loss.data[0]

                self._loss['retrieval_loss'] = retrieval_loss.sum().data[0]
                self._loss['retrieval_loss_greedy'] = baseline.sum(
                ).data[0] if isinstance(baseline, Variable) else baseline

        if self.cider_optimization:
            if 'gen_result' not in locals():
                gen_result, sample_logprobs = self.caption_generator.sample(
                    fc_feats, att_feats, att_masks, opt={'sample_max': 0})
                gen_masks = torch.cat([
                    Variable(
                        gen_result.data.new(gen_result.size(0),
                                            2).fill_(1).float()),
                    (gen_result > 0).float()[:, :-1]
                ], 1)
            if 'greedy_res' not in locals():
                greedy_res, _ = self.caption_generator.sample(
                    *utils.var_wrapper([fc_feats, att_feats, att_masks],
                                       volatile=True),
                    opt={'sample_max': 1})
            reward, cider_greedy = rewards.get_self_critical_reward(
                data, gen_result, greedy_res)
            self._loss['avg_reward'] = reward.mean()
            self._loss['cider_greedy'] = cider_greedy
            loss_cap = sample_logprobs * utils.var_wrapper(-reward.astype(
                'float32')).unsqueeze(1) * (gen_masks[:, 1:].detach())
            loss_cap = loss_cap.sum() / gen_masks[:, 1:].data.float().sum()

            loss += self.caption_loss_weight * loss_cap

        self._loss['loss_cap'] = loss_cap.item()
        self._loss['loss_vse'] = loss_vse.item()
        self._loss['loss'] = loss.item()

        return loss
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    epoch_done = True
    # Assure in training mode
    dp_model.train()

    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    total_loss = 0
    times = 0
    while True:
        if epoch_done:
            if not opt.noamopt and not opt.reduce_on_plateau:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            epoch_done = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        times += 1

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        total_loss = total_loss + train_loss
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            # epoch += 1
            epoch_done = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            if opt.noamopt:
                opt.current_lr = optimizer.rate()
            elif opt.reduce_on_plateau:
                opt.current_lr = optimizer.current_lr
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration % opt.save_checkpoint_every == 0):
        if data['bounds']['wrapped']:
            epoch += 1
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': False
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            if opt.reduce_on_plateau:
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats
                f = open('train_log_%s.txt' % opt.id, 'a')
                f.write(
                    'Epoch {}: | Date: {} | TrainLoss: {} | ValLoss: {} | Score: {}'
                    .format(epoch, str(datetime.now()),
                            str(total_loss / times), str(val_loss),
                            str(current_score)))
                f.write('\n')
                f.close()
                print('-------------------wrote to log file')
                total_loss = 0
                times = 0
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                # print(str(infos['best_val_score']))
                print("model saved to {}".format(checkpoint_path))
                if opt.save_history_ckpt:
                    checkpoint_path = os.path.join(
                        opt.checkpoint_path, 'model-%d.pth' % (iteration))
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    utils.pickle_dump(infos, f)
                if opt.save_history_ckpt:
                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'infos_' + opt.id + '-%d.pkl' % (iteration)),
                            'wb') as f:
                        cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    utils.pickle_dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        utils.pickle_dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
示例#12
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    #model_D = Discriminator(opt)
    #model_D.load_state_dict(torch.load('save/model_D.pth'))
    #model_D.cuda()
    #criterion_D = nn.CrossEntropyLoss(size_average=True)

    model_E = Distance(opt)
    model_E.load_state_dict(
        torch.load('save/model_E_NCE/model_E_10epoch.pthsfdasdfadf'))
    model_E.cuda()
    criterion_E = nn.CosineEmbeddingLoss(margin=0, size_average=True)
    #criterion_E = nn.CosineSimilarity()

    logger = Logger(opt)

    update_lr_flag = True
    # Assure in training mode
    model.train()
    #model_D.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer_G = optim.Adam(model.parameters(),
                             lr=opt.learning_rate,
                             weight_decay=opt.weight_decay)
    #optimizer_D = optim.Adam(model_D.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer_G.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            opt, sc_flag, update_lr_flag, model, optimizer_G = update_lr(
                opt, epoch, model, optimizer_G)

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        #print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        #tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [data['fc_feats'], data['labels'], data['masks']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        #fc_feats, att_feats, labels, masks = tmp
        fc_feats, labels, masks = tmp

        ############################################################################################################
        ############################################ REINFORCE TRAINING ############################################
        ############################################################################################################
        if 1:  #iteration % opt.D_scheduling != 0:
            optimizer_G.zero_grad()
            if not sc_flag:
                loss = crit(model(fc_feats, labels), labels[:, 1:], masks[:,
                                                                          1:])
            else:
                gen_result, sample_logprobs = model.sample(
                    fc_feats, {'sample_max': 0})
                #reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result)
                sc_reward = get_self_critical_reward(model, fc_feats, data,
                                                     gen_result, logger)
                #gan_reward = get_gan_reward(model, model_D, criterion_D, fc_feats, data, logger)                 # Criterion_D = nn.XEloss()
                distance_loss_reward1 = get_distance_reward(
                    model,
                    model_E,
                    criterion_E,
                    fc_feats,
                    data,
                    logger,
                    is_mismatched=False)  # criterion_E = nn.CosEmbedLoss()
                distance_loss_reward2 = get_distance_reward(
                    model,
                    model_E,
                    criterion_E,
                    fc_feats,
                    data,
                    logger,
                    is_mismatched=True)  # criterion_E = nn.CosEmbedLoss()
                #cosine_reward = get_distance_reward(model, model_E, criterion_E, fc_feats, data, logger)         # criterion_E = nn.CosSim()
                reward = distance_loss_reward1 + distance_loss_reward2
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(reward).float().cuda(),
                             requires_grad=False))
                loss.backward()

            utils.clip_gradient(optimizer_G, opt.grad_clip)
            optimizer_G.step()
            train_loss = loss.data[0]
            torch.cuda.synchronize()
            end = time.time()

            if not sc_flag:
                log = "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start)
                logger.write(log)
            else:
                log = "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration,  epoch, np.mean(reward[:,0]), end - start)
                logger.write(log)

        ######################################################################################################
        ############################################ GAN TRAINING ############################################
        ######################################################################################################
        else:  #elif iteration % opt.D_scheduling == 0: # gan training
            model_D.zero_grad()
            optimizer_D.zero_grad()

            fc_feats_temp = Variable(fc_feats.data.cpu(), volatile=True).cuda()
            labels = Variable(labels.data.cpu()).cuda()

            sample_res, sample_logprobs = model.sample(
                fc_feats_temp, {'sample_max': 0})  #640, 16
            greedy_res, greedy_logprobs = model.sample(
                fc_feats_temp, {'sample_max': 1})  #640, 16
            gt_res = labels  # 640, 18

            sample_res_embed = model.embed(Variable(sample_res))
            greedy_res_embed = model.embed(Variable(greedy_res))
            gt_res_embed = model.embed(gt_res)

            f_label = Variable(
                torch.FloatTensor(data['fc_feats'].shape[0]).cuda())
            r_label = Variable(
                torch.FloatTensor(data['fc_feats'].shape[0]).cuda())
            f_label.data.fill_(0)
            r_label.data.fill_(1)

            f_D_output = model_D(sample_res_embed.detach(), fc_feats.detach())
            f_loss = criterion_D(f_D_output, f_label.long())
            f_loss.backward()

            r_D_output = model_D(gt_res_embed.detach(), fc_feats.detach())
            r_loss = criterion_D(r_D_output, r_label.long())
            r_loss.backward()

            D_loss = f_loss + r_loss
            optimizer_D.step()
            torch.cuda.synchronize()

            log = 'iter {} (epoch {}),  Discriminator loss : {}'.format(
                iteration, epoch,
                D_loss.data.cpu().numpy()[0])
            logger.write(log)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))

            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, logger, eval_kwargs)
            logger.write_dict(lang_stats)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer_G.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
    def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts,
                gt_indices, sc_flag, struc_flag, drop_worst_flag):
        opt = self.opt

        out = {}

        reduction = 'none' if drop_worst_flag else 'mean'
        if struc_flag:
            if opt.structure_loss_weight < 1:
                lm_loss = self.crit(self.model(fc_feats, att_feats, labels,
                                               att_masks),
                                    labels[:, 1:],
                                    masks[:, 1:],
                                    reduction=reduction)
            else:
                lm_loss = torch.tensor(0).type_as(fc_feats)
            gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
                opt={'sample_max':0,
                    'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
                        or not 'margin' in opt.structure_loss_type,
                    'sample_n': opt.structure_sample_n},
                mode='sample')
            gts = [gts[_] for _ in gt_indices.tolist()]
            struc_loss = self.struc_crit(sample_logprobs,
                                         gen_result,
                                         gts,
                                         reduction=reduction)
            loss = (1 - opt.structure_loss_weight
                    ) * lm_loss + opt.structure_loss_weight * struc_loss
            out['lm_loss'] = lm_loss
            out['struc_loss'] = struc_loss
        elif not sc_flag:
            loss = self.crit(self.model(fc_feats, att_feats, labels,
                                        att_masks),
                             labels[:, 1:],
                             masks[:, 1:],
                             reduction=reduction)
        else:
            self.model.eval()
            with torch.no_grad():
                greedy_res, _ = self.model(fc_feats,
                                           att_feats,
                                           att_masks,
                                           mode='sample')
                if self.retrieval_reward_weight > 0:
                    _seqs_greedy, _sampleLogProbs_greedy = greedy_res, _
                    _masks_greedy = torch.cat([
                        _seqs_greedy.data.new(_seqs_greedy.size(0),
                                              2).fill_(1).float(),
                        (_seqs_greedy > 0).float()[:, :-1]
                    ], 1)
                    _seqs_greedy = torch.cat([
                        _seqs_greedy.data.new(
                            _seqs_greedy.size(0),
                            1).fill_(self.model.vocab_size + 1), _seqs_greedy
                    ], 1)

                    baseline = self.vse(fc_feats,
                                        att_feats,
                                        att_masks,
                                        _seqs_greedy,
                                        _masks_greedy,
                                        True,
                                        only_one_retrieval='off')
            self.model.train()
            gen_result, sample_logprobs = self.model(fc_feats,
                                                     att_feats,
                                                     att_masks,
                                                     opt={'sample_max': 0},
                                                     mode='sample')
            gts = [gts[_] for _ in gt_indices.tolist()]
            reward = get_self_critical_reward(greedy_res, gts, gen_result,
                                              self.opt)
            reward = torch.from_numpy(reward).float().to(gen_result.device)
            out['reward'] = reward[:, 0].mean()

            if self.retrieval_reward_weight > 0:
                _seqs, _sampleLogProbs = gen_result, sample_logprobs
                _masks = torch.cat([
                    _seqs.data.new(_seqs.size(0), 2).fill_(1).float(),
                    (_seqs > 0).float()[:, :-1]
                ], 1)

                gen_masks = _masks

                _seqs = torch.cat([
                    _seqs.data.new(_seqs.size(0),
                                   1).fill_(self.model.vocab_size + 1), _seqs
                ], 1)

                retrieval_loss = self.vse(fc_feats,
                                          att_feats,
                                          att_masks,
                                          _seqs,
                                          _masks,
                                          True,
                                          only_one_retrieval='off')

                reward -= self.retrieval_reward_weight * (
                    retrieval_loss - baseline).unsqueeze(1)

                out['retrieval_loss'] = retrieval_loss.sum()
                out['retrieval_loss_greedy'] = baseline.sum()

                print(out['retrieval_loss'].item(),
                      out['retrieval_loss_greedy'].item())

            loss = self.rl_crit(sample_logprobs,
                                gen_result.data,
                                reward,
                                reduction=reduction)

        out['loss'] = loss
        return out
示例#14
0
def train(opt):
    assert opt.annfile is not None and len(opt.annfile) > 0

    print('Checkpoint path is ' + opt.checkpoint_path)
    print('This program is using GPU ' +
          str(os.environ['CUDA_VISIBLE_DEVICES']))
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        if opt.load_best:
            info_path = os.path.join(opt.start_from,
                                     'infos_' + opt.id + '-best.pkl')
        else:
            info_path = os.path.join(opt.start_from,
                                     'infos_' + opt.id + '.pkl')
        with open(info_path) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    if opt.learning_rate_decay_start is None:
        opt.learning_rate_decay_start = infos.get(
            'opt', None).learning_rate_decay_start
    # if opt.load_best:
    #     opt.self_critical_after = epoch
    elif opt.learning_rate_decay_start == -1 and opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
        opt.learning_rate_decay_start = epoch

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    # loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_val_score_ave_model = infos.get('best_val_score_ave_model', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion(opt.XE_eps)
    rl_crit = utils.RewardCriterion()

    # build_optimizer
    optimizer = build_optimizer(model, opt)

    # Load the optimizer
    if opt.load_opti and vars(opt).get(
            'start_from',
            None) is not None and opt.load_best == 0 and os.path.isfile(
                os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # initialize the running average of parameters
    avg_param = deepcopy(list(p.data for p in model.parameters()))

    # make evaluation using original model
    best_val_score, histories, infos = eva_original_model(
        best_val_score, crit, epoch, histories, infos, iteration, loader,
        loss_history, lr_history, model, opt, optimizer, ss_prob_history,
        tb_summary_writer, val_result_history)

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                if opt.lr_decay == 'exp':
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                elif opt.lr_decay == 'cosine':
                    lr_epoch = min((epoch - opt.learning_rate_decay_start),
                                   opt.lr_max_epoch)
                    cosine_decay = 0.5 * (
                        1 + math.cos(math.pi * lr_epoch / opt.lr_max_epoch))
                    decay_factor = (1 - opt.lr_cosine_decay_base
                                    ) * cosine_decay + opt.lr_cosine_decay_base
                    opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate

            lr = [opt.current_lr]
            if opt.att_normalize_method is not None and '6' in opt.att_normalize_method:
                lr = [opt.current_lr, opt.lr_ratio * opt.current_lr]

            utils.set_lr(optimizer, lr)
            print('learning rate is: ' + str(lr))

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Update the iteration
        iteration += 1

        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            output = dp_model(fc_feats, att_feats, labels, att_masks)
            # calculate loss
            loss = crit(output[0], labels[:, 1:], masks[:, 1:])

            # add some middle variable histogram
            if iteration % (4 * opt.losses_log_every) == 0:
                outputs = [
                    _.data.cpu().numpy() if _ is not None else None
                    for _ in output
                ]
                variables_histogram(data, iteration, outputs,
                                    tb_summary_writer, opt)

        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        # grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_max_norm)
        # add_summary_value(tb_summary_writer, 'grad_L2_norm', grad_norm, iteration)

        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()

        # compute the running average of parameters
        for p, avg_p in zip(model.parameters(), avg_param):
            avg_p.mul_(opt.beta).add_((1.0 - opt.beta), p.data)

        if iteration % 10 == 0:
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        if opt.tensorboard_weights_grads and (iteration %
                                              (8 * opt.losses_log_every) == 0):
            # add weights histogram to tensorboard summary
            for name, param in model.named_parameters():
                if (opt.tensorboard_parameters_name is None or sum([
                        p_name in name
                        for p_name in opt.tensorboard_parameters_name
                ]) > 0) and param.grad is not None:
                    tb_summary_writer.add_histogram(
                        'Weights_' + name.replace('.', '/'), param, iteration)
                    tb_summary_writer.add_histogram(
                        'Grads_' + name.replace('.', '/'), param.grad,
                        iteration)

        if opt.tensorboard_buffers and (iteration %
                                        (opt.losses_log_every) == 0):
            for name, buffer in model.named_buffers():
                if (opt.tensorboard_buffers_name is None or sum([
                        p_name in name
                        for p_name in opt.tensorboard_buffers_name
                ]) > 0) and buffer is not None:
                    add_summary_value(tb_summary_writer,
                                      name.replace('.',
                                                   '/'), buffer, iteration)

        if opt.distance_sensitive_coefficient and iteration % (
                4 * opt.losses_log_every) == 0:
            print('The coefficient in intra_att_att_lstm is as follows:')
            print(
                model.core.intra_att_att_lstm.coefficient.data.cpu().tolist())
            print('The coefficient in intra_att_lang_lstm is as follows:')
            print(
                model.core.intra_att_lang_lstm.coefficient.data.cpu().tolist())
        if opt.distance_sensitive_bias and iteration % (
                4 * opt.losses_log_every) == 0:
            print('The bias in intra_att_att_lstm is as follows:')
            print(model.core.intra_att_att_lstm.bias.data.cpu().tolist())
            print('The bias in intra_att_lang_lstm is as follows:')
            print(model.core.intra_att_lang_lstm.bias.data.cpu().tolist())

        # make evaluation using original model
        if (iteration % opt.save_checkpoint_every == 0):
            best_val_score, histories, infos = eva_original_model(
                best_val_score, crit, epoch, histories, infos, iteration,
                loader, loss_history, lr_history, model, opt, optimizer,
                ss_prob_history, tb_summary_writer, val_result_history)

        # make evaluation with the averaged parameters model
        if iteration > opt.ave_threshold and (iteration %
                                              opt.save_checkpoint_every == 0):
            best_val_score_ave_model, infos = eva_ave_model(
                avg_param, best_val_score_ave_model, crit, infos, iteration,
                loader, model, opt, tb_summary_writer)

        # # Stop if reaching max epochs
        # if epoch >= opt.max_epochs and opt.max_epochs != -1:
        #     break

        if iteration >= opt.max_iter:
            break
def train(train_loader,
          val_loader,
          model,
          crit,
          optimizer,
          lr_scheduler,
          opt,
          rl_crit=None):
    model.train()
    model = nn.DataParallel(model)
    # lowest val loss
    best_loss = None
    for epoch in range(opt.epochs):
        lr_scheduler.step()

        iteration = 0
        # If start self crit training
        if opt.self_crit_after != -1 and epoch >= opt.self_crit_after:
            sc_flag = True
            init_cider_scorer(opt.cached_tokens)
        else:
            sc_flag = False

        for data in train_loader:
            torch.cuda.synchronize()
            fc_feats = Variable(data['fc_feats']).cuda()
            labels = Variable(data['labels']).long().cuda()
            masks = Variable(data['masks']).cuda()
            if not sc_flag:
                seq_probs, predicts = model(fc_feats, labels)
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            else:
                gen_result, sample_logprobs = model.sample(fc_feats, vars(opt))
                # print(gen_result)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  gen_result)
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(reward).float().cuda()))

            optimizer.zero_grad()
            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.data[0]
            torch.cuda.synchronize()
            iteration += 1

            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.3f" %
                      (iteration, epoch, np.mean(reward[:, 0])))

        # lowest val loss

        if epoch % opt.save_checkpoint_every == 0:
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           'model_%d.pth' % (epoch))
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to %s" % (checkpoint_path))
            val_loss = val(val_loader, model, crit)
            print("Val loss is: %.6f" % (val_loss))
            model.train()
            if best_loss is None or val_loss < best_loss:
                print("(epoch %d), now lowest val loss is %.6f" %
                      (epoch, val_loss))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model_best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("best model saved to %s" % (checkpoint_path))
                best_loss = val_loss
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_cider_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp
        
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs = model.sample(fc_feats, att_feats, {'sample_max':0})
            reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result)
            loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False))

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
示例#17
0
def train(opt):
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model

    target_actor = models.setup(opt).cuda()

    ####################### Critic pretrain #####################################################################
    ##### Critic with state as input
    # if opt.critic_model == 'state_critic':
    #     critic_model = CriticModel(opt)
    # else:
    critic_model = AttCriticModel(opt)
    target_critic = AttCriticModel(opt)
    if vars(opt).get('start_from_critic', None) is not None and True:
        # check if all necessary files exist
        assert os.path.isdir(opt.start_from_critic
                             ), " %s must be a a path" % opt.start_from_critic
        print(
            os.path.join(opt.start_from_critic,
                         opt.critic_model + '_model.pth'))
        critic_model.load_state_dict(
            torch.load(
                os.path.join(opt.start_from_critic,
                             opt.critic_model + '_model.pth')))
        target_critic.load_state_dict(
            torch.load(
                os.path.join(opt.start_from_critic,
                             opt.critic_model + '_model.pth')))
    critic_model = critic_model.cuda()
    target_critic = target_critic.cuda()
    critic_optimizer = utils.build_optimizer(critic_model.parameters(), opt)
    dp_model.eval()
    critic_iter = 0
    init_scorer(opt.cached_tokens)
    critic_model.train()
    error_sum = 0
    loss_vector_sum = 0
    while opt.pretrain_critic == 1:
        if critic_iter > opt.pretrain_critic_steps:
            print('****************Finished critic training!')
            break
        data = loader.get_batch('train')
        torch.cuda.synchronize()
        start = time.time()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        critic_model.train()
        critic_optimizer.zero_grad()
        assert opt.critic_model == 'att_critic_vocab'
        # crit_loss, reward, std = critic_loss_fun(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data)
        crit_loss, reward, std = target_critic_loss_fun_mask(
            fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data,
            target_critic, target_actor)
        crit_loss.backward()
        critic_optimizer.step()
        #TODO update target.
        for cp, tp in zip(critic_model.parameters(),
                          target_critic.parameters()):
            tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
        crit_train_loss = crit_loss.item()
        torch.cuda.synchronize
        end = time.time()
        error_sum += crit_train_loss**0.5 - std
        if (critic_iter % opt.losses_log_every == 0):
            print("iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f}, time/batch = {:.3f}" \
                .format(critic_iter, crit_train_loss**0.5, crit_train_loss**0.5-std, error_sum, end - start))
            print(opt.checkpoint_path)
            opt.importance_sampling = 1
            critic_model.eval()
            _, _, _, _ = get_rf_loss(dp_model,
                                     fc_feats,
                                     att_feats,
                                     att_masks,
                                     data,
                                     opt,
                                     loader,
                                     critic_model,
                                     test_critic=True)

        critic_iter += 1

        # make evaluation on validation set, and save model
        if (critic_iter % opt.save_checkpoint_every == 0):
            if not os.path.isdir(opt.checkpoint_path):
                os.mkdir(opt.checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           opt.critic_model + '_model.pth')
            torch.save(critic_model.state_dict(), checkpoint_path)

    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    first_order = 0
    second_order = 0
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        if data['bounds']['it_pos_now'] > 5000:
            loader.reset_iterator('train')
            continue
        dp_model.train()
        critic_model.eval()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())
            elif opt.rl_type == 'arsm':
                loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks,
                                    data, opt, loader)
                #print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'rf4':
                loss, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats,
                                            att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'importance_sampling':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1],
                                   1)
                std = np.std(reward)
            elif opt.rl_type == 'importance_sampling_critic':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(
                    target_actor, fc_feats, att_feats, att_masks, data, opt,
                    loader, target_critic)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1],
                                   1)
                std = np.std(reward)
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks,
                                   data, opt, loader)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(
                    sample_logprobs, gen_result.data,
                    torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward_cuda * arm_baseline)
            elif opt.rl_type == 'arsm_baseline_critic':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader, critic_model)
                reward, std = get_reward(data, gen_result, opt, critic=True)
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(
                    sample_logprobs, gen_result.data,
                    torch.from_numpy(reward).float().cuda() - arm_baseline)
            elif opt.rl_type == 'arsm_critic':
                #print(opt.critic_model)
                tic = time.time()
                loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks,
                                    data, opt, loader, critic_model)
                #print('arm_loss time', str(time.time()-tic))
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'critic_vocab_sum':
                assert opt.critic_model == 'att_critic_vocab'
                tic = time.time()
                gen_result, sample_logprobs_total = dp_model(
                    fc_feats,
                    att_feats,
                    att_masks,
                    opt={'sample_max': 0},
                    total_probs=True,
                    mode='sample')  #batch, seq, vocab
                #print('generation time', time.time()-tic)
                gen_result_pad = torch.cat([
                    gen_result.new_zeros(
                        gen_result.size(0), 1, dtype=torch.long), gen_result
                ], 1)
                tic = time.time()
                critic_value = critic_model(gen_result_pad, fc_feats,
                                            att_feats, True, opt,
                                            att_masks)  #batch, seq, vocab
                #print('critic time', time.time() - tic)
                probs = torch.sum(
                    F.softmax(sample_logprobs_total, 2) *
                    critic_value.detach(), 2)
                mask = (gen_result > 0).float()
                mask = torch.cat(
                    [mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)
                loss = -torch.sum(probs * mask) / torch.sum(mask)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'reinforce_critic':
                #TODO change the critic to attention
                if opt.critic_model == 'state_critic':
                    critic_value, gen_result, sample_logprobs = critic_model(
                        dp_model, fc_feats, att_feats, opt, att_masks)
                    reward, std = get_reward(data,
                                             gen_result,
                                             opt,
                                             critic=True)
                    loss = rl_crit(
                        sample_logprobs, gen_result.data,
                        torch.from_numpy(reward).float().cuda() -
                        critic_value[:, :-1].data)
                elif opt.critic_model == 'att_critic':
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    gen_result_pad = torch.cat([
                        gen_result.new_zeros(gen_result.size(0),
                                             1,
                                             dtype=torch.long), gen_result
                    ], 1)
                    critic_value = critic_model(gen_result_pad, fc_feats,
                                                att_feats, True, opt,
                                                att_masks).squeeze(2)

                    reward, std = get_reward(data,
                                             gen_result,
                                             opt,
                                             critic=True)
                    loss = rl_crit(
                        sample_logprobs, gen_result.data,
                        torch.from_numpy(reward).float().cuda() -
                        critic_value.data)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(
                dp_model(fc_feats, att_feats, labels, att_masks),
                labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'best_embed.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        # with open(os.path.join(opt.checkpoint_path, 'best_logit.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.logit.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.9999 * first_order + 0.0001 * gradient
        second_order = 0.9999 * second_order + 0.0001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order -
                                        first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic
        if 'critic' in opt.rl_type:
            dp_model.eval()
            critic_model.train()
            utils.set_lr(critic_optimizer, opt.critic_learning_rate)
            critic_optimizer.zero_grad()
            assert opt.critic_model == 'att_critic_vocab'
            crit_loss, reward, std = target_critic_loss_fun_mask(
                fc_feats,
                att_feats,
                att_masks,
                dp_model,
                critic_model,
                opt,
                data,
                target_critic,
                target_actor,
                gen_result=gen_result,
                sample_logprobs_total=sample_logprobs_total,
                reward=reward)
            crit_loss.backward()
            critic_optimizer.step()
            for cp, tp in zip(critic_model.parameters(),
                              target_critic.parameters()):
                tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
            for cp, tp in zip(dp_model.parameters(),
                              target_actor.parameters()):
                tp.data = tp.data + opt.gamma_actor * (cp.data - tp.data)
            crit_train_loss = crit_loss.item()
            error_sum += crit_train_loss**0.5 - std
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            elif 'critic' in opt.rl_type:
                print(
                    "iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f},variance = {:g}, time/batch = {:.3f}" \
                    .format(iteration, crit_train_loss ** 0.5, crit_train_loss ** 0.5 - std, error_sum, variance, end - start))
                print(opt.checkpoint_path)
                critic_model.eval()
                _, _, _, _ = get_rf_loss(dp_model,
                                         fc_feats,
                                         att_feats,
                                         att_masks,
                                         data,
                                         opt,
                                         loader,
                                         critic_model,
                                         test_critic=True)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance,
                                  iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward)
            critic_loss_history[
                iteration] = crit_train_loss if 'critic' in opt.rl_type else 0
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            time_history[iteration] = end - start

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               opt.critic_model + '_model.pth')
                torch.save(critic_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['time'] = time_history
                # histories['variance'] = 0
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
示例#18
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        data_time = time.time() - start

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if iteration % opt.print_freq == 0:
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start, data_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), end - start, data_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path) # MODIFIED (ADDED)

            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best-i{}-score{}.pth'.format(iteration, best_val_score))
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
示例#19
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    #ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = models.setup(opt).cuda()
    #pretrained_dict = torch.load(opt.model)
    #model.load_state_dict(pretrained_dict, strict=False)

    num_params = get_n_params(model)
    print('number of parameteres:', num_params)

    dp_model = torch.nn.DataParallel(model)
    dp_model.train()

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag
    optimizer = utils.build_optimizer(model.parameters(), opt)
    update_lr_flag = True

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                #opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                #model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['dist'],
            data['masks'], data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, att_masks = tmp
        batchsize = fc_feats.size(0)
        # Forward pass and loss
        optimizer.zero_grad()
        if not sc_flag:
            wordact, reconstruct = dp_model(fc_feats, att_feats, labels)
            #loss_dist = F.binary_cross_entropy(dist, dist_label.cpu().float())
            fc_feats_max, _ = att_feats.max(1)
            loss_rec = F.mse_loss(reconstruct.cpu(), fc_feats_max.cpu())
            mask = masks[:, 1:].contiguous()
            wordact = wordact[:, :, :-1]
            wordact_t = wordact.permute(0, 2, 1).contiguous()
            wordact_t = wordact_t.view(
                wordact_t.size(0) * wordact_t.size(1), -1)
            labels = labels.contiguous().view(-1, 6 * 30).cpu()
            wordclass_v = labels[:, 1:]
            wordclass_t = wordclass_v.contiguous().view(\
               wordclass_v.size(0) * wordclass_v.size(1), 1)
            maskids = torch.nonzero(mask.view(-1).cpu()).numpy().reshape(-1)
            loss_xe = F.cross_entropy(wordact_t[maskids, ...], \
               wordclass_t[maskids, ...].contiguous().view(maskids.shape[0]))
            loss = 5 * loss_xe + loss_rec
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # Backward pass
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 1:
            print('Read data:', time.time() - start)
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, data_time, total_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            #ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if (iteration >= 60000 and iteration % opt.save_checkpoint_every == 0):
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)
            # Evaluate model
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            #histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best.pth'
                infos_fname = 'model-best.pkl'
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
def eval_split(model, crit, loader, eval_kwargs={}):
    eval_att = eval_kwargs.get('eval_att',False)
    gt_grd_eval = eval_kwargs.get('gt_grd_eval',False)
    eval_scan = eval_kwargs.get('eval_scan',False)
    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 1)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)
    remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
    os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration

    # Make sure in the evaluation mode
    model.eval()

    loader.reset_iterator(split)

    n = 0
    loss = 0
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []
    grd_output = defaultdict(list)

    while True:
        data = loader.get_batch(split)
        n = n + loader.batch_size

        if data.get('labels', None) is not None and verbose_loss:
            # forward the model to get loss
            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'],data['box_feats']]
            tmp = [_.cuda() if _ is not None else _ for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks, box_feats = tmp

            with torch.no_grad():
                loss = crit(model(fc_feats, att_feats, labels, att_masks)[0], labels[:,1:], masks[:,1:]).item()
            loss_sum = loss_sum + loss
            loss_evals = loss_evals + 1

        if not gt_grd_eval:
            # forward the model to also get generated samples for each image
            # Only leave one feature for each image, in case duplicate sample
            tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], 
                data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
                data['box_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
                data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img] if data['att_masks'] is not None else None]
            tmp = [_.cuda() if _ is not None else _ for _ in tmp]
            fc_feats, att_feats, box_feats, att_masks = tmp
        else:
            tmp = [data['fc_feats'], 
                data['att_feats'],
                data['box_feats'],
                data['att_masks'] if data['att_masks'] is not None else None]
            tmp = [_.cuda() if _ is not None else _ for _ in tmp]
            fc_feats, att_feats, box_feats, att_masks = tmp

        # forward the model to also get generated samples for each image
        with torch.no_grad():
            if eval_att:
                if not gt_grd_eval:
                    assert eval_kwargs['beam_size']==1,  'only support beam_size is 1'
                    seq, _, att_weights = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')
                    seq=seq.detach()
                    att_weights=att_weights.detach()
                    att_ind = torch.max(att_weights, dim=2)[1]
                else:
                    if not eval_scan:
                        #==This snippet used for evaluating grounding accuracy of caption model on gt sentence.=====#
                        _, att_weights=model(fc_feats, att_feats, labels, att_masks)
                        seq = labels[:,1:]
                        att_weights=att_weights.detach()
                        att_ind = torch.max(att_weights, dim=2)[1]
                        data['infos'] = data['infos']*5
                    else:
                        # pdb.set_trace()
                        #====This snippet used for evaluating grounding accuracy of SCAN model on gt sentence.======#
                        gts = data['gts']
                        reward, att_weights, noun_mask= get_self_critical_reward(model, fc_feats, att_feats, att_masks, gts, labels[:,1:], eval_kwargs)
                        seq =  labels[:,1:]
                        att_weights=att_weights.detach()
                        att_ind = torch.max(att_weights, dim=2)[1]
                        data['infos'] = data['infos']*5

                for i in range(seq.size(0)):
                    tmp_result = {'clss':[], 'idx_in_sent':[], 'bbox':[]}
                    num_sent = 0 # does not really matter which reference to use
                    for j in range(seq.size(1)):
                        if seq[i,j].item() != 0:
                            lemma = loader.wtol[loader.ix_to_word[str(seq[i,j].item())]]
                            if lemma in loader.lemma_det_dict:
                                tmp_result['bbox'].append(box_feats[i, att_ind[i, j], :4].tolist())
                                tmp_result['clss'].append(loader.itod[loader.lemma_det_dict[lemma]])
                                tmp_result['idx_in_sent'].append(j) # redundant, for the sake of output format
                        else:
                            break
                    grd_output[str(data['infos'][i]['id'])].append(tmp_result)
            else:
                seq = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')[0].data

        
        # Print beam search
        if beam_size > 1 and verbose_beam:
            for i in range(loader.batch_size):
                print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
                print('--' * 10)
        sents = utils.decode_sequence(loader.get_vocab(), seq)

        for k, sent in enumerate(sents):
            entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
                print(cmd)
                os.system(cmd)

            if verbose:
                print('image %s: %s' %(entry['image_id'], entry['caption']))

        # if we wrapped around the split or used up val imgs budget then bail
        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if num_images != -1:
            ix1 = min(ix1, num_images)
        for i in range(n - ix1):
            predictions.pop()

        if verbose:
            print('evaluating validation preformance... %d/%d (%f)' %(ix0 - 1, ix1, loss))

        if data['bounds']['wrapped']:
            break
        if num_images >= 0 and n >= num_images:
            break

    lang_stats = None
    if lang_eval == 1:
        if not gt_grd_eval:
            lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split)

    if eval_att:
        # write attention results to file
        attn_file = 'att_results/attn-gen-sent-results-'+split+'-'+eval_kwargs['id']+'.json'
        with open(attn_file, 'w') as f:
            json.dump({'results':grd_output, 'eval_mode':'gen', 'external_data':{'used':True, 'details':'Object detector pre-trained on Visual Genome on object detection task.'}}, f)

        # offline eval
        evaluator = FlickrGrdEval(reference_file=eval_kwargs['reference'], submission_file=attn_file,
                              split_file=eval_kwargs['split_file'], val_split=[split],
                              iou_thresh=0.5)

        print('\nResults Summary (generated sent):')
        print('Printing attention accuracy on generated sentences...')
        if not gt_grd_eval:
            prec_all, recall_all, f1_all = evaluator.grd_eval(mode='all')
            prec_loc, recall_loc, f1_loc = evaluator.grd_eval(mode='loc')
        else:
            grd_accu = evaluator.gt_grd_eval()
        print('\n')


    # Switch back to training mode
    model.train()
    return loss_sum/loss_evals, predictions, lang_stats
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size1", "rnn_size2",
                "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    # loader.iterators = infos.get('iterators', loader.iterators)
    # loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()
    # model.set_mode('train')

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        model.train()
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_cider_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train+val')
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['num_bbox'],
            data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_).float(), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, num_bbox, labels, masks = tmp
        labels = labels.long()

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(model(fc_feats, att_feats, num_bbox, labels),
                        labels[:, 1:], masks[:, 1:])
            # loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs = model.sample(fc_feats, att_feats,
                                                       num_bbox,
                                                       {'sample_max': 0})
            reward = get_self_critical_reward(model, fc_feats, att_feats,
                                              num_bbox, data, gen_result)
            loss = rl_crit(
                sample_logprobs, gen_result,
                Variable(torch.from_numpy(reward).float().cuda(),
                         requires_grad=False))

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            if (iteration % 100 == 0):
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f} lr={}" \
                 .format(iteration, epoch, train_loss, end - start, opt.current_lr ))
        else:
            if (iteration % 100 == 0):
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f} lr={}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start, opt.current_lr ))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'val_ref_path': opt.val_ref_path,
                'raw_val_anno_path': opt.raw_val_anno_path
            }
            eval_kwargs.update(vars(opt))
            # predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            best_flag = False
            if True:  # if true
                # if best_val_score is None or current_score > best_val_score:
                # 	best_val_score = current_score
                # 	best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
示例#22
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)

    from dataloader import DataLoader
    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.vocab_ccg_size = loader.vocab_ccg_size
    opt.seq_length = loader.seq_length

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()

    model = models.setup(opt)
    model.cuda()
    # model = DataParallel(model)

    if vars(opt).get('start_from', None) is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    multilabel_crit = nn.MultiLabelSoftMarginLoss().cuda()
    #    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)
    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
        print('finetune mode')
        cnn_optimizer = optim.Adam([\
            {'params': module.parameters()} for module in cnn_model._modules.values()[5:]\
            ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay)

    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            if os.path.isfile(os.path.join(opt.start_from,
                                           'optimizer-cnn.pth')):
                cnn_optimizer.load_state_dict(
                    torch.load(
                        os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    eval_kwargs = {'split': 'val', 'dataset': opt.input_json, 'verbose': True}
    eval_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_utils.eval_split(
        cnn_model, model, crit, loader, eval_kwargs, True)
    epoch_start = time.time()
    while True:
        if update_lr_flag:
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
                #model.module.ss_prob = opt.ss_prob
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
            else:
                sc_flag = False

            # Update the training stage of cnn
            for p in cnn_model.parameters():
                p.requires_grad = True
            # Fix the first few layers:
            for module in cnn_model._modules.values()[:5]:
                for p in module.parameters():
                    p.requires_grad = False
            cnn_model.train()
            update_lr_flag = False

        cnn_model.apply(utils.set_bn_fix)
        cnn_model.apply(utils.set_bn_eval)

        start = time.time()
        torch.cuda.synchronize()
        data = loader.get_batch('train')
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:

            multilabels = [
                data['detection_infos'][i]['label']
                for i in range(len(data['detection_infos']))
            ]

            tmp = [
                data['labels'], data['masks'],
                np.array(multilabels, dtype=np.int16)
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            labels, masks, multilabels = tmp
            images = data[
                'images']  # it cannot be turned into tensor since different sizes.
            _fc_feats_2048 = []
            _fc_feats_81 = []
            _att_feats = []
            for i in range(loader.batch_size):
                x = Variable(torch.from_numpy(images[i]),
                             requires_grad=False).cuda()
                x = x.unsqueeze(0)
                att_feats, fc_feats_81 = cnn_model(x)
                fc_feats_2048 = att_feats.mean(3).mean(2).squeeze()
                att_feats = F.adaptive_avg_pool2d(att_feats,
                                                  [14, 14]).squeeze().permute(
                                                      1, 2, 0)  #(0, 2, 3, 1)
                _fc_feats_2048.append(fc_feats_2048)
                _fc_feats_81.append(fc_feats_81)
                _att_feats.append(att_feats)
            _fc_feats_2048 = torch.stack(_fc_feats_2048)
            _fc_feats_81 = torch.stack(_fc_feats_81)
            _att_feats = torch.stack(_att_feats)
            att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \
                                                           _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \
                                                           _att_feats.size()[1:]))
            fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:]))
            fc_feats_81 = _fc_feats_81
            #
            cnn_optimizer.zero_grad()
        else:

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks']
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()

        if not sc_flag:
            loss1 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss = 0.8 * loss1 + 0.2 * loss2.float()
        else:
            gen_result, sample_logprobs = model.sample(fc_feats_2048,
                                                       att_feats,
                                                       {'sample_max': 0})
            reward = get_self_critical_reward(model, fc_feats_2048, att_feats,
                                              data, gen_result)
            loss1 = rl_crit(
                sample_logprobs, gen_result,
                Variable(torch.from_numpy(reward).float().cuda(),
                         requires_grad=False))
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss3 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss = 0.995 * loss1 + 0.005 * (loss2.float() + loss3)
        loss.backward()

        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.data[0]
        mle_loss = loss1.data[0]
        multilabel_loss = loss2.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, mle_loss, multilabel_loss, train_loss, end - start))

        if sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), avg_reward = {:.3f}, mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), mle_loss, multilabel_loss, train_loss, end - start))
        iteration += 1
        if (iteration % opt.losses_log_every == 0):
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        if (iteration % opt.save_checkpoint_every == 0):
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': True
            }
            eval_kwargs.update(vars(opt))

            if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, True)
            else:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, False)

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))

                cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-cnn.pth')
                torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                print("cnn model saved to {}".format(cnn_checkpoint_path))

                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                    cnn_optimizer_path = os.path.join(opt.checkpoint_path,
                                                      'optimizer-cnn.pth')
                    torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))

                    cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                       'model-cnn-best.pth')
                    torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                    print("cnn model saved to {}".format(cnn_checkpoint_path))

                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True
            print("epoch: " + str(epoch) + " during: " +
                  str(time.time() - epoch_start))
            epoch_start = time.time()

        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
	def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
				sc_flag,box_inds):
		out = {}
		if not sc_flag:
			if self.opt.att_supervise:
				outputs, attn_weights=self.model(fc_feats, att_feats, labels, att_masks)
				loss1 = self.crit(outputs, labels[:,1:], masks[:,1:])

				if self.opt.use_gt_box:
					box_inds = box_inds[:,1:]
					if self.opt.att_sup_crit == 'KL' or self.opt.att_sup_crit == 'ExtendNLL':
						sup_mask = (box_inds != 1e-8* torch.ones(box_inds.size(-1)).type_as(box_inds)).any(dim=-1).view(-1)
					else:
						sup_mask =  (box_inds>=0).view(-1)
				else:
					_, grd_weights,noun_mask= get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, labels[:,1:].detach(), vars(self.opt))
					sup_mask =  (noun_mask==1).cuda().view(-1)

				attn_weights = torch.log(torch.clamp(attn_weights,min=self.min_value)).view(-1,attn_weights.size(-1))[sup_mask]

				if self.opt.use_gt_box:
					if self.opt.att_sup_crit == 'KL':
						# Todo
						grd_target = F.softmax(box_inds/0.5,dim=-1).view(-1, box_inds.size(-1))[sup_mask]
						loss2 = self.kl_crit(attn_weights, grd_target)
					elif self.opt.att_sup_crit == 'NLL':
						grd_target = box_inds.reshape(-1)[sup_mask].long()
						loss2 = self.nll(attn_weights,grd_target)
					elif self.opt.att_sup_crit == 'ExtendNLL':
						grd_target = box_inds.reshape(-1, box_inds.size(-1))[sup_mask]
						loss2 = self.extendnll(attn_weights, grd_target)
				else:
					if self.opt.att_sup_crit == 'KL':
						grd_target = torch.clamp(grd_weights[:,:17,:],min=self.min_value).view(-1,grd_weights.size(-1))[sup_mask]
						loss2 = self.kl_crit(attn_weights, grd_target)
					elif self.opt.att_sup_crit == 'NLL':
						grd_target = torch.max(grd_weights[:,:17,:],dim=2)[1].view(-1)[sup_mask]
						loss2 = self.nll(attn_weights,grd_target)
					elif self.opt.att_sup_crit == 'ExtendNLL':
						# grd_target = torch.clamp(grd_weights[:,:17,:],min=self.min_value).view(-1,grd_weights.size(-1))[sup_mask]
						# loss2 = self.extendnll(attn_weights, grd_target)
						raise NotImplementedError
				
				loss=loss1+self.opt.att_supervise_weight*loss2
			else:
				outputs=self.model(fc_feats, att_feats, labels, att_masks)[0]
				loss = self.crit(outputs, labels[:,1:], masks[:,1:])
		else:
			if self.opt.att_supervise:
				gen_result, sample_logprobs, attn_weights = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
			else:
				gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
			gts = [gts[_] for _ in gt_indices.tolist()]

			if self.opt.att_supervise:
				reward, grd_weights, noun_mask= get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, vars(self.opt))
			else:
				reward = get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, vars(self.opt))
			reward = torch.from_numpy(reward).float().to(gen_result.device)

			if self.opt.att_supervise:
				loss1=self.rl_crit(sample_logprobs, gen_result.data, reward)
				sup_mask =  (noun_mask==1).cuda().view(-1)
				attn_weights = torch.log(torch.clamp(attn_weights,min=self.min_value)).view(-1,attn_weights.size(-1))[sup_mask]
				if self.opt.att_sup_crit == 'KL':
					grd_target = torch.clamp(grd_weights,min=self.min_value).view(-1,grd_weights.size(-1))[sup_mask]
					loss2 = self.kl_crit(attn_weights, grd_target)
				elif self.opt.att_sup_crit == 'NLL':
					grd_target = torch.max(grd_weights,dim=2)[1].view(-1)[sup_mask]
					loss2 = self.nll(attn_weights,grd_target)
				elif self.opt.att_sup_crit == 'ExtendNLL':
					# grd_target = torch.clamp(grd_weights,min=self.min_value).view(-1,grd_weights.size(-1))[sup_mask]
					# loss2 = self.extendnll(attn_weights, grd_target)
					raise NotImplementedError

				loss=loss1+self.opt.att_supervise_weight*loss2
			else:
				loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
			out['reward'] = reward[:,0].mean()
		out['loss'] = loss
		return out
示例#24
0
	def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
				sc_flag,box_inds, epoch, sents_mask):
		out = {}
		# pdb.set_trace()
		if not sc_flag:
			if self.opt.cexe and epoch >= self.opt.cexe_after:
				if self.opt.sup_nde:
					outputs, outputs_adjust, outputs_nde=self.model(fc_feats, att_feats, labels, att_masks, sents_mask[:,1:])
				else:
					outputs, outputs_adjust=self.model(fc_feats, att_feats, labels, att_masks, sents_mask[:,1:])

				#At now, we only consider visual words. 
				adjust_mask = sents_mask[:,1:] == 1
				adjust_mask_expand = adjust_mask.unsqueeze(dim=2).expand(outputs.shape)
				masked_outputs = torch.masked_select(outputs,adjust_mask_expand).view(-1,outputs.shape[2])
				masked_outputs_adjust = torch.masked_select(outputs_adjust,adjust_mask_expand).view(-1,outputs.shape[2])
				# masked_outputs_nde = torch.masked_select(outputs_nde,adjust_mask_expand).view(-1,outputs.shape[2])
				loss1 = self.crit(outputs, labels[:,1:], masks[:,1:])
				if self.opt.sup_tie  and self.opt.sup_nde:
					loss2 = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='batchmean')
					loss3 = self.nll(masked_outputs_adjust, torch.masked_select(labels[:,1:], adjust_mask))
					loss4 = self.crit(outputs_nde, labels[:,1:], masks[:,1:])
					loss = loss1 + self.opt.cexe_weight * loss2 + self.opt.tie_weight * loss3 + self.opt.nde_weight * loss4
				elif self.opt.sup_tie:
					loss2 = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='batchmean')
					loss3 = self.nll(masked_outputs_adjust, torch.masked_select(labels[:,1:], adjust_mask))
					loss = loss1 + self.opt.cexe_weight * loss2 + self.opt.tie_weight * loss3
				elif self.opt.sup_nde:
					loss2 = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='batchmean')
					loss4 = self.crit(outputs_nde, labels[:,1:], masks[:,1:])
					loss = loss1 + self.opt.cexe_weight * loss2  + self.opt.nde_weight * loss4
				else:
					loss2 = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='batchmean')
					loss = loss1 + self.opt.cexe_weight * loss2
			else:
				outputs=self.model(fc_feats, att_feats, labels, att_masks)[0]
				loss = self.crit(outputs, labels[:,1:], masks[:,1:])
		else:
			if self.opt.cec:
				gen_result, sample_logprobs, outputs, outputs_tie = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
			else:
				gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
			gts = [gts[_] for _ in gt_indices.tolist()]

			reward = get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, vars(self.opt))
			reward = torch.from_numpy(reward).float().to(gen_result.device)

			if self.opt.cec:
				loss1 = self.rl_crit(sample_logprobs, gen_result.data, reward)
				sents_mask = make_sents_mask(gen_result, self.opt.vocab)
				adjust_mask = sents_mask == 1
				adjust_mask_expand = adjust_mask.unsqueeze(dim=2).expand(outputs.shape)
				masked_outputs = torch.masked_select(outputs,adjust_mask_expand).view(-1,outputs.shape[2])
				masked_outputs_adjust = torch.masked_select(outputs_tie,adjust_mask_expand).view(-1,outputs.shape[2])
				batch_div = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='none').sum(dim=1)
				masked_reward = torch.masked_select(reward, adjust_mask)
				masked_reward_positive = (masked_reward>0).float()
				loss2 = (batch_div * masked_reward * masked_reward_positive).mean()
				loss = loss1 + self.opt.cec_weight * loss2
			else:
				loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
			out['reward'] = reward[:,0].mean()
		out['loss'] = loss
		return out
示例#25
0
def train(loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None):
    model.train()
    if opt['visdom']:
        viz = visdom.Visdom(env='train')
        loss_win = viz.line(np.arange(1), opts={'title': 'loss'})

    for epoch in range(opt["epochs"]):
        lr_scheduler.step()

        iteration = 0
        # If start self crit training
        # print(opt["self_crit_after"])
        if opt["self_crit_after"] != -1 and epoch >= opt[
                "self_crit_after"]:  #每多少次保存一下
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        # print(model)

        for data in loader:
            # print(data)
            torch.cuda.synchronize()
            fc_feats = data['fc_feats'].cuda()
            # voice_feats = data['voice_feats'].cuda()
            if opt['with_hand'] == 1:
                hand_feats = data['hand_feats'].cuda()
                hand_pro = data['hand_pro'].cuda()
            labels = data['labels'].cuda()
            masks = data['masks'].cuda()
            #print(sc_flag)
            optimizer.zero_grad()
            if not sc_flag:
                # seq_probs, _ = model(fc_feats, voice_feats, hand_feats, labels, 'train')
                if opt['with_hand'] == 1:
                    seq_probs, _ = model(fc_feats, hand_feats, hand_pro,
                                         labels, 'train')
                else:
                    seq_probs, _ = model.forward2(fc_feats, labels, 'train')
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            # todo 下面else部分没有修改声音和手语的内容
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(seq_probs, seq_preds,
                               torch.from_numpy(reward).float().cuda())
            loss.backward()
            clip_grad_value_(model.parameters(), opt['grad_clip'])
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            iteration += 1

            if not sc_flag:
                print("?iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
                if opt['visdom']:
                    viz.line(Y=np.array([train_loss]),
                             X=np.array([epoch]),
                             win=loss_win,
                             update='append')
            else:
                print("??iter %d (epoch %d), avg_reward = %.6f" %
                      (iteration, epoch, np.mean(reward[:, 0])))

        if epoch % opt["save_checkpoint_every"] == 0:
            model_path = os.path.join(opt["checkpoint_path"],
                                      'model_%d.pth' % (epoch))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'model_score.txt')
            torch.save(model.state_dict(), model_path)
            # print("model saved to %s" % (model_path))
            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))
示例#26
0
    def __init__(self, opt):
        super(AttModel, self).__init__()
        self.image_crop_size = opt.image_crop_size
        self.vocab_size = opt.vocab_size
        self.detect_size = opt.detect_size
        self.input_encoding_size = opt.input_encoding_size
        #self.rnn_type = opt.rnn_type
        self.rnn_size = opt.rnn_size
        self.num_layers = opt.num_layers
        self.drop_prob_lm = opt.drop_prob_lm
        self.seq_length = opt.seq_length
        self.fc_feat_size = opt.fc_feat_size
        self.att_feat_size = opt.att_feat_size
        self.att_hid_size = opt.att_hid_size
        self.finetune_cnn = opt.finetune_cnn
        self.cbs = opt.cbs
        self.cbs_mode = opt.cbs_mode
        self.seq_per_img = 5
        if opt.cnn_backend == 'vgg16':
            self.stride = 16
        else:
            self.stride = 32

        self.att_size = int(opt.image_crop_size / self.stride)
        self.tiny_value = 1e-8

        self.pool_feat_size = self.att_feat_size + 300 * 2
        self.ss_prob = 0.0  # Schedule sampling probability
        self.min_value = -1e8
        opt.beta = 1
        self.beta = opt.beta
        if opt.cnn_backend == 'res101':
            self.cnn = resnet(opt,
                              _num_layers=101,
                              _fixed_block=opt.fixed_block,
                              pretrained=True)
        elif opt.cnn_backend == 'res152':
            self.cnn = resnet(opt,
                              _num_layers=152,
                              _fixed_block=opt.fixed_block,
                              pretrained=True)
        elif opt.cnn_backend == 'vgg16':
            self.cnn = vgg16(opt, pretrained=True)

        self.det_fc = nn.Sequential(nn.Embedding(self.detect_size + 1, 300),
                                    nn.ReLU(), nn.Dropout())

        self.loc_fc = nn.Sequential(nn.Linear(5, 300), nn.ReLU(), nn.Dropout())

        self.embed = nn.Sequential(
            nn.Embedding(self.vocab_size + self.detect_size + 1,
                         self.input_encoding_size), nn.ReLU(),
            nn.Dropout(self.drop_prob_lm))

        self.fc_embed = nn.Sequential(
            nn.Linear(self.fc_feat_size, self.rnn_size), nn.ReLU(),
            nn.Dropout(self.drop_prob_lm))

        self.att_embed = nn.Sequential(
            nn.Linear(self.att_feat_size, self.rnn_size), nn.ReLU(),
            nn.Dropout(self.drop_prob_lm))

        self.pool_embed = nn.Sequential(
            nn.Linear(self.pool_feat_size, self.rnn_size), nn.ReLU(),
            nn.Dropout(self.drop_prob_lm))

        self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)
        self.ctx2pool = nn.Linear(self.rnn_size, self.att_hid_size)

        self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
        self.roi_align = RoIAlignAvg(1, 1, 1.0 / self.stride)

        #self.grid_size = 1
        #self.roi_crop = _RoICrop()
        self.critLM = utils.LMCriterion(opt)
        self.critBN = utils.BNCriterion(opt)
        self.critFG = utils.FGCriterion(opt)

        if opt.self_critical:
            print("load reward function...")
            self.get_self_critical_reward = get_self_critical_reward(opt)
            self.critRL = utils.RewardCriterion(opt)

        # initialize the glove weight for the labels.
        self.det_fc[0].weight.data.copy_(opt.glove_clss)
        for p in self.det_fc[0].parameters():
            p.requires_grad = False
def train(dataset,
          loader,
          model,
          crit,
          optimizer,
          lr_scheduler,
          opt,
          rl_crit=None):
    writer = SummaryWriter('./runs/video_caption_basic')
    model.load_state_dict(
        torch.load('/home/diml/video-caption.pytorch/save/new_model_200.pth'))
    #model = nn.DataParallel(model)
    model.train()
    vocab = dataset.get_vocab()

    for epoch in trange(300):
        t_loss = 0
        # =============================================================================
        #         model.eval()
        #         ev.demov(model,crit, dataset, dataset.get_vocab(),opt)
        # =============================================================================

        lr_scheduler.step()
        iteration = 0

        # If start self crit training
        if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        for idx, data in enumerate(loader):
            torch.cuda.synchronize()
            fc_feats = data['fc_feats'].cuda()
            labels = data['labels'].cuda()
            masks = data['masks'].cuda()
            optimizer.zero_grad()
            if not sc_flag:
                seq_probs, seq_preds, hn, de_hn = model(
                    fc_feats, labels, 'train')
                loss_C = crit(seq_probs, labels[:, 1:], masks[:, 1:])

                loss = loss_C
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(seq_probs, seq_preds,
                               torch.from_numpy(reward).float().cuda())

            t_loss += loss.item()
            loss.backward()
            clip_grad_value_(model.parameters(), opt['grad_clip'])
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            iteration += 1
            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.6f" %
                      (iteration, epoch + 201, np.mean(reward[:, 0])))
        writer.add_scalar('training total loss', t_loss / 140, epoch + 200)
        if epoch % opt["save_checkpoint_every"] == 0:

            model_path = os.path.join(opt["checkpoint_path"],
                                      'new_model_%d.pth' % (epoch + 200))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'Rnew_model_score.txt')
            torch.save(model.state_dict(), model_path)
            print("model saved to %s" % (model_path))

            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))

        with torch.no_grad():
            _, seq_preds, __, ___ = model(fc_feats, mode='inference', opt=opt)
            print(utils.decode_sequence(vocab, seq_preds)[0])
示例#28
0
    def train(self, data, loader, iteration, epoch, nmt_epoch):
        nmt_dec_state = None
        nmt_dec_state_zh = None
        torch.cuda.synchronize()
        self.optim.zero_grad()

        tmp = [
            data['fc_feats'], data['attri_feats'], data['att_feats'],
            data['labels'], data['masks'], data['att_masks'],
            data['nmt'] if self.nmt_train_flag else None
        ]
        tmp = [
            _ if _ is None else
            (Variable(torch.from_numpy(_), requires_grad=False).cuda()
             if utils.under_0_4() else torch.from_numpy(_).cuda()) for _ in tmp
        ]
        fc_feats, attri_feats, att_feats, labels, masks, att_masks, nmt_batch = tmp

        if self.i2t_train_flag:
            if self.update_i2t_lr_flag:
                self.optim.update_LearningRate(
                    'i2t', epoch)  # Assign the learning rate
                self.optim.update_ScheduledSampling_prob(
                    self.opt, epoch,
                    self.dp_i2t_model)  # Assign the scheduled sampling prob
                if self.opt.self_critical_after != -1 and epoch >= self.opt.self_critical_after:
                    # If start self critical training
                    self.sc_flag = True
                    init_scorer(self.opt.cached_tokens)
                else:
                    self.sc_flag = False
                self.update_i2t_lr_flag = False

            if not self.sc_flag:
                i2t_outputs = self.dp_i2t_model(fc_feats, attri_feats,
                                                att_feats, labels, att_masks)
                i2t_loss = self.i2t_crit(i2t_outputs, labels[:, 1:], masks[:,
                                                                           1:])
            else:
                gen_result, sample_logprobs = self.dp_i2t_model(
                    fc_feats,
                    attri_feats,
                    att_feats,
                    att_masks,
                    opt={'sample_max': 0},
                    mode='sample')
                reward = get_self_critical_reward(self.dp_i2t_model, fc_feats,
                                                  attri_feats, att_feats,
                                                  att_masks, data, gen_result,
                                                  self.opt)
                i2t_loss = self.i2t_rl_crit(
                    sample_logprobs, gen_result.data,
                    Variable(torch.from_numpy(reward).float().cuda(),
                             requires_grad=False))

                self.i2t_avg_reward = np.mean(reward[:, 0])
            self.i2t_train_loss = i2t_loss.data[0] if utils.under_0_4(
            ) else i2t_loss.item()
            i2t_loss.backward(retain_graph=True)

        if self.nmt_train_flag:
            if self.update_nmt_lr_flag:
                self.optim.update_LearningRate(
                    'nmt', nmt_epoch)  # Assign the learning rate
            outputs, attn, dec_state, upper_bounds = self.dp_nmt_model(
                nmt_batch.src, nmt_batch.tgt, nmt_batch.lengths, nmt_dec_state)
            nmt_loss = self.nmt_crit(loader, nmt_batch, outputs, attn)

            if nmt_dec_state is not None: nmt_dec_state.detach()
            if nmt_dec_state_zh is not None: nmt_dec_state_zh.detach()

            self.nmt_crit.report_stats.n_src_words += nmt_batch.lengths.data.sum(
            )
            self.nmt_train_ppl = self.nmt_crit.report_stats.ppl()
            self.nmt_train_acc = self.nmt_crit.report_stats.accuracy()
            # Minimize the word embedding weights
            # wemb_weight_loss = self.weight_trans(self.i2t_model.embed, self.nmt_encoder.embeddings.word_lut)
            # self.wemb_loss = wemb_weight_loss.data[0]

            nmt_loss.backward(retain_graph=True)
        # if self.nmt_train_flag: wemb_weight_loss.backward(retain_graph=True)
        self.optim.step()
示例#29
0
def train(opt):
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model




    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    #TODO: change this to a flag
    crit = utils.LanguageModelCriterion_binary()
    rl_crit = utils.RewardCriterion_binary()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    first_order = 0
    second_order = 0
    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        if data['bounds']['it_pos_now'] > 10000:
            loader.reset_iterator('train')
            continue
        dp_model.train()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:], dp_model.depth,
                        dp_model.vocab2code, dp_model.phi_list, dp_model.cluster_size)
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth)
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth)
            elif opt.rl_type == 'arm':
                loss = dp_model.get_arm_loss_binary_fast(fc_feats, att_feats, att_masks, opt, data, loader)
                #print(loss)
                reward = np.zeros([2,2])
            elif opt.rl_type == 'rf4':
                loss,_,_,_ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = np.zeros([2,2])
            elif opt.rl_type =='mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data,
                                                                         opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'embeddings.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.999 * first_order + 0.001 * gradient
        second_order = 0.999 * second_order + 0.001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic

        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance, iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward)
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            time_history[iteration] = end - start


        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_binary.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            if lang_stats is not None:
                for k,v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth')
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['time'] = time_history
                # histories['variance'] = 0
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
示例#30
0
def train(opt):
    logger = initialize_logger(os.path.join(opt.checkpoint_path, 'train.log'))
    print = logger.info

    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Print out the option variables
    print("*" * 20)
    for k, v in opt.__dict__.items():
        print("%r: %r" % (k, v))
    print("*" * 20)

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos.json'), 'r') as f:
            infos = json.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    else:
        best_val_score = None

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    start_time = time.time()
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if 0 <= opt.learning_rate_decay_start < epoch:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if 0 <= opt.scheduled_sampling_start < epoch:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer()
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        batch_data = loader.get_batch('train')
        torch.cuda.synchronize()

        tmp = [
            batch_data['fc_feats'], batch_data['att_feats'],
            batch_data['labels'], batch_data['masks'], batch_data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            outputs = dp_model(fc_feats, att_feats, labels, att_masks)
            loss = crit(outputs, labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, batch_data,
                                              gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data
        torch.cuda.synchronize()

        # Update the iteration and epoch
        iteration += 1
        if batch_data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Print train loss or avg reward
        if iteration % opt.losses_print_every == 0:
            if not sc_flag:
                print(
                    "iter {} (epoch {}), loss = {:.3f}, time = {:.3f}".format(
                        iteration, epoch, loss.item(),
                        time.time() - start_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time = {:.3f}".
                      format(iteration, epoch, np.mean(reward[:, 0]),
                             time.time() - start_time))
            start_time = time.time()

        # make evaluation on validation set, and save model
        if (opt.save_checkpoint_every > 0 and iteration % opt.save_checkpoint_every == 0)\
                or (opt.save_checkpoint_every <= 0 and update_lr_flag):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.simple_eval_split(
                dp_model, loader, eval_kwargs)

            # Save model if is improving on validation result
            if not os.path.exists(opt.checkpoint_path):
                os.makedirs(opt.checkpoint_path)

            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False

            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscellaneous information
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = vars(opt)
            infos['vocab'] = loader.get_vocab()

            with open(os.path.join(opt.checkpoint_path, 'infos.json'),
                      'w') as f:
                json.dump(infos, f, sort_keys=True, indent=4)

            if best_flag:
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model-best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, 'infos-best.json'),
                          'w') as f:
                    json.dump(infos, f, sort_keys=True, indent=4)

            # Stop if reaching max epochs
            if opt.max_epochs != -1 and epoch >= opt.max_epochs:
                break
示例#31
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories.pkl')):
            with open(os.path.join(opt.start_from, 'histories.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_cider_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        fc_feats = fc_feats.unsqueeze(1).expand(*((
            fc_feats.size(0),
            opt.seq_per_img,
        ) + fc_feats.size()[1:])).contiguous().view(
            *((fc_feats.size(0) * opt.seq_per_img, ) + fc_feats.size()[1:]))
        att_feats = att_feats.unsqueeze(1).expand(*((
            att_feats.size(0),
            opt.seq_per_img,
        ) + att_feats.size()[1:])).contiguous().view(
            *((att_feats.size(0) * opt.seq_per_img, ) + att_feats.size()[1:]))

        optimizer.zero_grad()
        outputs = model(fc_feats, att_feats, labels)
        if opt.caption_model == 'stack_cap':
            loss_coarse = crit(outputs[0], labels[:, 1:], masks[:, 1:])
            loss_fine_0 = crit(outputs[1], labels[:, 1:], masks[:, 1:])
            loss_fine_1 = crit(outputs[-1], labels[:, 1:], masks[:, 1:])
            loss = loss_fine_1 + loss_coarse + loss_fine_0
        else:
            if not sc_flag:
                loss = crit(outputs, labels[:, 1:], masks[:, 1:])
            else:
                gen_result, sample_logprobs = model.sample(
                    fc_feats, att_feats, {'sample_max': 0})
                reward = get_self_critical_reward(model, fc_feats, att_feats,
                                                  data, gen_result)
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(reward).float().cuda(),
                             requires_grad=False))

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if opt.caption_model == 'stack_cap':
            print("{}|I:{}/E:{}|Tloss_0:{:.3f}/Tloss_1:{:.3f}/Tloss_2:{:.3f}|T/B={:.3f}" \
                .format(opt.caption_model, iteration, epoch, loss_coarse.data[0], loss_fine_0.data[0], loss_fine_1.data[0], end - start))
        else:
            if not sc_flag:
                print("{}|I:{}/E:{}|Train_loss:{:.3f}|T/B={:.3f}".format(
                    opt.caption_model, iteration, epoch, loss.data[0],
                    end - start))
            else:
                print("{}|I:{}/E:{}|Avg_reward:{:.3f}|T/B={:.3f}".format(
                    opt.caption_model, iteration, epoch, np.mean(reward[:, 0]),
                    end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                if opt.caption_model == 'stack_cap':
                    add_summary_value(tf_summary_writer, 'train_loss_coarse',
                                      loss_coarse.data[0], iteration)
                    add_summary_value(tf_summary_writer, 'train_loss_fine_0',
                                      loss_fine_0.data[0], iteration)
                    add_summary_value(tf_summary_writer, 'train_loss_fine_1',
                                      loss_fine_1.data[0], iteration)
                else:
                    add_summary_value(tf_summary_writer, 'train_loss',
                                      loss.data[0], iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                opt, model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break