コード例 #1
0
def init(opt):
    infos = {}
    histories = {}

    if opt.seed > 0: torch.manual_seed(opt.seed)
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5  # box means [prob, x,y,w,h]
    loader = DataLoader(opt)  # Load image captioning dataset
    coco_loader = DataLoader_COCO(opt) if opt.coco_eval_flag else None
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    if opt.start_from is not None and len(opt.start_from) > 0:
        print('Start from: {}, Infos: {}'.format(
            opt.start_from, os.path.join(opt.start_from, 'infos-best.pkl')))
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos-best.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories-best.pkl')):
            with open(os.path.join(opt.start_from, 'histories-best.pkl')) as f:
                histories = cPickle.load(f)
    return opt, loader, coco_loader, infos, histories
コード例 #2
0
def main():
    import opts    
    import misc.utils as utils        
    opt = opts.parse_opt()
    opt.caption_model ='topdown'
    opt.batch_size=10
    opt.id ='topdown'
    opt.learning_rate= 5e-4 
    opt.learning_rate_decay_start= 0 
    opt.scheduled_sampling_start=0 
    opt.save_checkpoint_every=25#11500
    opt.val_images_use=5000
    opt.max_epochs=40
    opt.start_from=None
    opt.input_json='data/meta_coco_en.json'
    opt.input_label_h5='data/label_coco_en.h5'
    opt.input_image_h5 = 'data/coco_image_512.h5'    
    opt.use_att = utils.if_use_att(opt.caption_model)
    opt.ccg = False
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length    
    data = loader.get_batch('train')
    
    data = loader.get_batch('val')
コード例 #3
0
def show(opt):

    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt, is_show=True)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "! Command line argument and saved model disagree on '%s' " % checkme

    model = models.setup(opt)
    model.cuda()
    model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model-best.pth')))

    crit = utils.LanguageModelCriterion()

    # eval model
    eval_kwargs = {}
    eval_kwargs.update(vars(opt))
    eval_kwargs.update({'split': 'show',
                        'dataset': opt.input_json,
                        'language_eval': 0,
                        'beam_size': 5,
                        'print_all_beam': True})
    val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)
コード例 #4
0
def show(opt):

    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt, is_show=True)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "! Command line argument and saved model disagree on '%s' " % checkme

    model = models.setup(opt)
    model.cuda()
    model.load_state_dict(
        torch.load(os.path.join(opt.start_from, 'model-best.pth')))

    crit = utils.LanguageModelCriterion()

    # eval model
    eval_kwargs = {}
    eval_kwargs.update(vars(opt))
    eval_kwargs.update({
        'split': 'show',
        'dataset': opt.input_json,
        'language_eval': 0,
        'beam_size': 5,
        'print_all_beam': True
    })
    val_loss, predictions, lang_stats = eval_utils.eval_split(
        model, crit, loader, eval_kwargs)
コード例 #5
0
def main(args):
  opt = vars(args)
  
  # initialize
  opt['dataset_splitBy'] = opt['dataset'] + '_' + opt['splitBy']
  checkpoint_dir = osp.join(opt['checkpoint_path'], opt['dataset_splitBy'])
  if not osp.isdir(checkpoint_dir): os.makedirs(checkpoint_dir)

  # set random seed
  torch.manual_seed(opt['seed'])
  random.seed(opt['seed'])
  
  # set up loader
  data_json = osp.join('cache/prepro', opt['dataset_splitBy'], 'data.json')
  data_h5 = osp.join('cache/prepro', opt['dataset_splitBy'], 'data.h5')
  
  loader = CycleLoader(data_json, data_h5) ####
  
  # set up model
  opt['vocab_size']= loader.vocab_size
  #opt['fc7_dim']   = loader.fc7_dim
  #opt['pool5_dim'] = loader.pool5_dim
  #opt['num_atts']  = loader.num_atts
  #model = JointMatching(opt)
  opt['C4_feat_dim'] = 1024
  opt['use_att'] = utils.if_use_att(opt['caption_model'])
  opt['seq_length'] = loader.label_length
  
  
  #### can change to restore opt from info.pkl
  #infos = {}
  #histories = {}
  if opt['start_from']is not None:
    # open old infos and check if models are compatible
    with open(os.path.join(opt['dataset_splitBy'], opt['start_from'], 'infos-best.pkl')) as f:
      infos = cPickle.load(f)
      saved_model_opt = infos['opt']
      need_be_same = ['caption_model', 'rnn_type', 'rnn_size', 'num_layers']
      for checkme in need_be_same:
        assert vars(saved_model_opt)[checkme] == opt[checkme], "Command line argument and saved model disagree on '%s'" % checkme

    #if os.path.isfile(os.path.join(opt['dataset_splitBy'], opt['start_from'], 'histories.pkl')):
    #  with open(os.path.join(opt['dataset_splitBy'], opt['start_from'], 'histories.pkl')) as f:
    #    histories = cPickle.load(f)
  
  
  
  net = resnetv1(opt, batch_size=opt['batch_size'], num_layers=101) #### determine batch size in opt.py
  
  # output directory where the models are saved
  output_dir = osp.join(opt['dataset_splitBy'], 'output_{}'.format(opt['output_postfix']))
  print('Output will be saved to `{:s}`'.format(output_dir))

  # tensorboard directory where the summaries are saved during training
  tb_dir = osp.join(opt['dataset_splitBy'], 'tb_{}'.format(opt['output_postfix']))
  print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))

  # also add the validation set, but with no flipping images
  orgflip = cfg.TRAIN.USE_FLIPPED
  cfg.TRAIN.USE_FLIPPED = False
  #_, valroidb = combined_roidb('coco_2014_minival')
  #_, valroidb = combined_roidb('refcoco_test')
  #print('{:d} validation roidb entries'.format(len(valroidb)))
  cfg.TRAIN.USE_FLIPPED = orgflip
  
  if args.cfg_file is not None:
    cfg_from_file(args.cfg_file)
  if args.set_cfgs is not None:
    cfg_from_list(args.set_cfgs)
  
  
  #train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
  train_net(net, loader, output_dir, tb_dir,
            pretrained_model='pyutils/mask-faster-rcnn/output/res101/coco_2014_train_minus_refer_valtest+coco_2014_valminusminival/notime/res101_mask_rcnn_iter_1250000.pth',
            max_iters=args.max_iters)
コード例 #6
0
def train(opt):

    """
    :param   caption decoder
    :param   VSE model : image encoder + caption encoder
    """

    """   loading VSE model    """
    # Construct the model
    vse = VSE(opt)
    opt.best = os.path.join('./vse/model_best.pth.tar')
    print("=> loading best checkpoint '{}'".format(opt.best))
    checkpoint = torch.load(opt.best)
    vse.load_state_dict(checkpoint['model'])
    vse.val_start()

    """   loading caption model    """
    opt.use_att = utils.if_use_att(opt.caption_model)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    save_path = os.path.join(opt.checkpoint_path,'CSGD')

    if not os.path.exists(save_path):
        os.makedirs(save_path, 0777)

    infos = {}
    histories = {}

    RL_trainmodel = os.path.join('RL_%s' % opt.caption_model)
    
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        start_from_path = os.path.join(opt.start_from,'CSGD')
        with open(os.path.join(start_from_path,'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(start_from_path, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(start_from_path, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    with open(os.path.join(RL_trainmodel,'MLE','infos_'+opt.id+'-best.pkl')) as f:
        infos_XE = cPickle.load(f)
        opt.learning_rate = infos_XE['opt'].current_lr
    
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    print(loader.iterators)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup_pro(opt)

    if vars(opt).get('start_from', None) is not None:

        start_from_path = os.path.join(opt.start_from,'CSGD')
        # check if all necessary files exist 
        assert os.path.isdir(opt.start_from)," %s must be a path" % opt.start_from
        assert os.path.isfile(os.path.join(start_from_path,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from
        assert os.path.isfile(os.path.join(start_from_path,"optimizer.pth")) ,"optimizer.pth.file does not exist in path %s"%opt.start_from

        model_path = os.path.join(start_from_path,'model.pth')
        optimizer_path = os.path.join(start_from_path,'optimizer.pth')

    else:
        model_path = os.path.join(RL_trainmodel,'MLE', 'model-best.pth')
        optimizer_path = os.path.join(RL_trainmodel,'MLE','optimizer-best.pth')

    model.load_state_dict(torch.load(model_path))
    print("model load from {}".format(model_path))  
  
    model.cuda()
    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
  
    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)
    optimizer.load_state_dict(torch.load(optimizer_path))
    print("optimizer load from {}".format(optimizer_path))   

    all_cider = 0 # for computing the average CIDEr score
    all_dis = 0 # for computing the discriminability percentage

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            model.ss_prob = 0.25
            print('learning_rate: %s' %str(opt.current_lr))
            update_lr_flag = False

            # start self critical training
            sc_flag = True

        data = loader.get_batch('train')

        torch.cuda.synchronize()
        start = time.time()

        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample
        tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
               data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
               data['knn_fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
               data['knn_att_feats'][np.arange(loader.batch_size) * loader.seq_per_img]]

        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, knn_fc_feats, knn_att_feats = tmp

        optimizer.zero_grad()

        gen_result, sample_logprobs = model.sample_score(fc_feats, att_feats, loader, {'sample_max': 0})
        gen_result_baseline, sample_b_logprobs = model.sample_score(fc_feats, att_feats, loader, {'sample_max': 0})
        bd_reward, sample_loss = get_bd_reward(vse, model, fc_feats, att_feats, data, gen_result,gen_result_baseline, loader)
        hd_reward = get_hd_reward(vse, model, fc_feats, knn_fc_feats, data, gen_result,gen_result_baseline, loader)

        cs_reward, m_cider = get_cs_reward(model, fc_feats, att_feats, data, gen_result, gen_result_baseline, loader)
        reward = cs_reward - opt.hdr_w * hd_reward - opt.bdr_w * bd_reward
        loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False))
        dis_number = (sample_loss < 0.4).float()
        dis_number = dis_number.data.cpu().numpy().sum()
        all_dis += dis_number
        all_cider += m_cider

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]

        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), hdr = {:.3f}, bdr = {:.3f}, csr = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, np.mean(hd_reward[:,0]), np.mean(bd_reward[:,0]), np.mean(cs_reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):

            loss_history[iteration] = np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = evalpro_utils.eval_split(model, crit, loader, eval_kwargs)
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                save_path1 = os.path.join(save_path, 'model.pth')
                if not os.path.exists(os.path.dirname(save_path1)):
                    os.makedirs(os.path.dirname(save_path1))
                torch.save(model.state_dict(), save_path1)
                print("model saved to {}".format(save_path1))

                optimizer_path1 = os.path.join(save_path, 'optimizer.pth')
                if not os.path.exists(os.path.dirname(optimizer_path1)):
                    os.makedirs(os.path.dirname(optimizer_path1))
                torch.save(optimizer.state_dict(), optimizer_path1)
                print("optimizer saved to {}".format(optimizer_path1))

                all_dis = all_dis / opt.save_checkpoint_every
                print("all_dis:%f" %all_dis)
                infos['all_dis'] = all_dis

                all_cider = all_cider / opt.save_checkpoint_every
                print("all_cider:%f" %all_cider)
                infos['all_cider'] = all_cider

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(save_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(save_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    save_path2 = os.path.join(save_path, 'model-best.pth')
                    torch.save(model.state_dict(), save_path2)
                    optimizer_path2 = os.path.join(save_path, 'optimizer-best.pth')
                    torch.save(optimizer.state_dict(), optimizer_path2)
                    print("model saved to {}".format(save_path2))
                    with open(os.path.join(save_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)
                    with open(os.path.join(save_path,'histories_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(histories, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #7
0
def train(opt):
    if vars(opt).get('start_from', None) is not None:
        opt.checkpoint_path = opt.start_from
        opt.id = opt.checkpoint_path.split('/')[-1]
        print('Point to folder: {}'.format(opt.checkpoint_path))
    else:
        opt.id = datetime.datetime.now().strftime(
            '%Y%m%d_%H%M%S') + '_' + opt.caption_model
        opt.checkpoint_path = os.path.join(opt.checkpoint_path, opt.id)

        if not os.path.exists(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        print('Create folder: {}'.format(opt.checkpoint_path))

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    # opt.use_att = False
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader_UP(opt)
    opt.vocab_size = loader.vocab_size
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.compat.v1.summary.FileWriter(
            opt.checkpoint_path)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.checkpoint_path, 'histories.pkl')):
            with open(os.path.join(opt.checkpoint_path, 'histories.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    # dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    print('### Model summary below###\n {}\n'.format(str(model)))
    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('model parameter:{}'.format(model_params))

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    parameters = model.named_children()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), opt)

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        fc_feats = None
        att_feats = None
        att_masks = None
        ssg_data = None
        rela_data = None

        if getattr(opt, 'use_ssg', 0) == 1:
            if getattr(opt, 'use_isg', 0) == 1:
                tmp = [
                    data['fc_feats'], data['labels'], data['masks'],
                    data['att_feats'], data['att_masks'],
                    data['isg_rela_matrix'], data['isg_rela_masks'],
                    data['isg_obj'], data['isg_obj_masks'], data['isg_attr'],
                    data['isg_attr_masks'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks']
                ]

                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, labels, masks, att_feats, att_masks, \
                isg_rela_matrix, isg_rela_masks, isg_obj, isg_obj_masks, isg_attr, isg_attr_masks, \
                ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp

                # image graph domain
                isg_data = {}
                isg_data['att_feats'] = att_feats
                isg_data['att_masks'] = att_masks

                isg_data['isg_rela_matrix'] = isg_rela_matrix
                isg_data['isg_rela_masks'] = isg_rela_masks
                isg_data['isg_obj'] = isg_obj
                isg_data['isg_obj_masks'] = isg_obj_masks
                isg_data['isg_attr'] = isg_attr
                isg_data['isg_attr_masks'] = isg_attr_masks
                # text graph domain
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
            else:
                tmp = [
                    data['fc_feats'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks'], data['labels'], data['masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks, labels, masks = tmp
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr

                isg_data = None
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

        if not sc_flag:
            # loss = crit(dp_model(model_zh,model_en,itow_zh,itow, fc_feats, labels, isg_data, ssg_data), labels[:, 1:], masks[:, 1:])
            loss = crit(dp_model(fc_feats, labels, isg_data, ssg_data),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   isg_data,
                                                   ssg_data,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, isg_data,
                                              ssg_data, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()

            if not sc_flag:
                print("{}/{}/{}|train_loss={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, train_loss, end - start))
            else:
                print("{}/{}/{}|avg_reward={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, np.mean(reward[:, 0]), end - start))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration %10 == 0) and (iteration != 0):
        if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0):
            # eval model
            if use_rela:
                eval_kwargs = {
                    'split': 'val',
                    'dataset': opt.input_json,
                    'use_real': 1
                }
            else:
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(model_zh,model_en,itow_zh,itow, dp_model, crit, loader, eval_kwargs)
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos.pkl'),
                          'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories.pkl'),
                          'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #8
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    train_loss_list = []
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                    masks[:, 1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))
        train_loss_list.append(train_loss)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        min_val_loss = 100

        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats, temp_min = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)
            if (temp_min < min_val_loss):
                min_val_loss = temp_min
            # Write validation result into summary
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }
            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats, temp_min = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)
            if (temp_min < min_val_loss):
                min_val_loss = temp_min
            # Write validation result into summary
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }
            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

            plt.title('Training loss')
            plt.xlabel('iteration')
            plt.ylabel('loss')
            plt.plot(np.arange(1, len(train_loss_list) + 1), train_loss_list)
            savefilename = 'lab3-1.jpg'
            plt.savefig(savefilename)
            plt.close()

            print("min loss is :", min_val_loss)
            break
コード例 #9
0
import time
import os
from six.moves import cPickle

import opts
import models
from dataloader import *
import eval_utils
import misc.utils as utils
from misc.rewards import init_scorer, get_self_critical_reward

opt = opts.parse_opt()

# Setting up model
opt.use_att = utils.if_use_att(opt.caption_model)
if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

loader = DataLoader(opt)
opt.vocab_size = loader.vocab_size
opt.seq_length = loader.seq_length

print(opt.checkpoint_path)

infos = {}
histories = {}
if opt.start_from is not None:
    # open old infos and check if models are compatible
    print(os.getcwd())
    with open(
            os.path.join(os.getcwd(), opt.start_from,
コード例 #10
0
def train(opt):
    import random
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(0)
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    from dataloader_pair import DataLoader

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    if opt.log_to_file:
        if os.path.exists(os.path.join(opt.checkpoint_path, 'log')):
            suffix = time.strftime("%Y-%m-%d %X", time.localtime())
            print('Warning !!! %s already exists ! use suffix ! ' %
                  os.path.join(opt.checkpoint_path, 'log'))
            sys.stdout = open(
                os.path.join(opt.checkpoint_path, 'log' + suffix), "w")
        else:
            print('logging to file %s' %
                  os.path.join(opt.checkpoint_path, 'log'))
            sys.stdout = open(os.path.join(opt.checkpoint_path, 'log'), "w")

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        if os.path.isfile(opt.start_from):
            with open(os.path.join(opt.infos)) as f:
                infos = cPickle.load(f)
                saved_model_opt = infos['opt']
                need_be_same = [
                    "caption_model", "rnn_type", "rnn_size", "num_layers"
                ]
                for checkme in need_be_same:
                    assert vars(saved_model_opt)[checkme] == vars(
                        opt
                    )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        else:
            if opt.load_best != 0:
                print('loading best info')
                with open(
                        os.path.join(opt.start_from,
                                     'infos_' + opt.id + '-best.pkl')) as f:
                    infos = cPickle.load(f)
                    saved_model_opt = infos['opt']
                    need_be_same = [
                        "caption_model", "rnn_type", "rnn_size", "num_layers"
                    ]
                    for checkme in need_be_same:
                        assert vars(saved_model_opt)[checkme] == vars(
                            opt
                        )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
            else:
                with open(
                        os.path.join(opt.start_from,
                                     'infos_' + opt.id + '.pkl')) as f:
                    infos = cPickle.load(f)
                    saved_model_opt = infos['opt']
                    need_be_same = [
                        "caption_model", "rnn_type", "rnn_size", "num_layers"
                    ]
                    for checkme in need_be_same:
                        assert vars(saved_model_opt)[checkme] == vars(
                            opt
                        )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                try:
                    histories = cPickle.load(f)
                except:
                    print('load history error!')
                    histories = {}

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    start_epoch = epoch

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    #Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    if opt.caption_model == 'att2in2p':
        optimized = [
            'logit2', 'ctx2att2', 'core2', 'prev_sent_emb', 'prev_sent_wrap'
        ]
        optimized_param = []
        optimized_param1 = []

        for name, param in model.named_parameters():
            second = False
            for n in optimized:
                if n in name:
                    print('second', name)
                    optimized_param.append(param)
                    second = True
            if 'embed' in name:
                print('all', name)
                optimized_param1.append(param)
                optimized_param.append(param)
            elif not second:
                print('first', name)
                optimized_param1.append(param)

    while True:
        if opt.val_only:
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            print('start evaluating')
            val_loss, predictions, lang_stats = eval_utils_pair.eval_split(
                dp_model, crit, loader, eval_kwargs)
            exit(0)
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['pair_fc_feats'], data['pair_att_feats'], data['pair_labels'],
            data['pair_masks'], data['pair_att_masks']
        ]

        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        masks = masks.float()

        optimizer.zero_grad()

        if not sc_flag:
            if opt.onlysecond:
                # only using the second sentence from a visual paraphrase pair. opt.caption_model should be a one-stage decoding model
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, 1, :], att_masks),
                    labels[:, 1, 1:], masks[:, 1, 1:])
                loss1 = loss2 = loss / 2
            elif opt.first:
                # using the first sentence
                tmp = [
                    data['first_fc_feats'], data['first_att_feats'],
                    data['first_labels'], data['first_masks'],
                    data['first_att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, labels, masks, att_masks = tmp
                masks = masks.float()
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, :], att_masks),
                    labels[:, 1:], masks[:, 1:])
                loss1 = loss2 = loss / 2
            elif opt.onlyfirst:
                # only using the second sentence from a visual paraphrase pair
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, 0, :], att_masks),
                    labels[:, 0, 1:], masks[:, 0, 1:])
                loss1 = loss2 = loss / 2
            else:
                # proposed DCVP model, opt.caption_model should be att2inp
                output1, output2 = dp_model(fc_feats, att_feats, labels,
                                            att_masks, masks[:, 0, 1:])
                loss1 = crit(output1, labels[:, 0, 1:], masks[:, 0, 1:])
                loss2 = crit(output2, labels[:, 1, 1:], masks[:, 1, 1:])
                loss = loss1 + loss2

        else:
            raise NotImplementedError
            # Our DCVP model does not support self-critical sequence training
            # We found that RL(SCST) with CIDEr reward will improve conventional metrics (BLEU, CIDEr, etc.)
            # but harm diversity and descriptiveness
            # Please refer to the paper for the details

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, loss1 = {:.3f}, loss2 = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, loss.item(), loss1.item(), loss2.item(), end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        sys.stdout.flush()
        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_pair.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)
                checkpoint_path = os.path.join(
                    opt.checkpoint_path, 'model' + str(iteration) + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '_' + str(iteration) + '.pkl'),
                        'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #11
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    def ids_to_sents(ids):
        return utils.decode_sequence(loader, ids)

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    def load_infos(dir=opt.start_from, suffix=''):
        # open old infos and check if models are compatible
        with open(os.path.join(dir, 'infos_{}{}.pkl'.format(opt.id,
                                                            suffix))) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s'" % checkme
        return infos

    def load_histories(dir=opt.start_from, suffix=''):
        path = os.path.join(dir, 'histories_{}{}.pkl'.format(opt.id, suffix))
        if os.path.isfile(path):
            with open(path) as f:
                histories = cPickle.load(f)
        return histories

    infos = {}
    histories = {}
    if opt.start_from is not None:
        infos = load_infos()
        histories = load_histories()

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit_ce = utils.LanguageModelCriterion()
    crit_mb = mBLEU(4)

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    def eval_model():
        model.eval()

        eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
        eval_kwargs.update(vars(opt))
        val_loss, predictions, lang_stats = eval_utils.eval_split(
            model, crit_ce, loader, eval_kwargs)

        # Write validation result into summary
        if tf is not None:
            add_summary_value(tf_summary_writer, 'validation loss', val_loss,
                              iteration)
            for k, v in lang_stats.items():
                add_summary_value(tf_summary_writer, k, v, iteration)
            tf_summary_writer.flush()

        model.train()

        return val_loss, predictions, lang_stats

    eval_model()

    opt.current_teach_mask_prefix_length = opt.teach_mask_prefix_length

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            # Assign the teach mask prefix length
            if epoch > opt.teach_mask_prefix_length_increase_start:
                frac = (epoch - opt.teach_mask_prefix_length_increase_start
                        ) // opt.teach_mask_prefix_length_increase_every
                opt.current_teach_mask_prefix_length = opt.teach_mask_prefix_length + frac * opt.teach_mask_prefix_length_increase_steps
            update_lr_flag = False

        verbose = (iteration % opt.verbose_iters == 0)

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        if iteration % opt.print_iters == 0:
            print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()
        teach_mask = utils.make_teach_mask(labels.size(1), opt)
        enable_ce = (opt.bleu_w != 1)
        enable_mb = (opt.bleu_w != 0)
        if enable_ce:
            enable_xe = (opt.xe_w != 0)
            enable_pg = (opt.pg_w != 0)
            if enable_xe:
                logits = model(
                    fc_feats,
                    att_feats,
                    labels,
                    teach_mask=(teach_mask if opt.teach_ce
                                and not opt.teach_all_input else None))
                if opt.teach_ce:
                    decode_length = logits.shape[1] + 1
                    teach_mask = teach_mask[:decode_length]
                    onehot = utils.to_onehot(labels[:, :decode_length],
                                             logits.shape[-1],
                                             dtype=torch.float)
                    probs = torch.exp(logits)
                    probs = torch.cat([onehot[:, :1], probs], 1)
                    probs = utils.mask_probs(probs, onehot, teach_mask)
                    if verbose:
                        verbose_probs = probs
                        verbose_probs.retain_grad()
                    logits = torch.log(1. - (1. - 1e-6) * (1. - probs))[:, 1:]
                loss_xe = crit_ce(logits, labels[:, 1:], masks[:, 1:])
            else:
                loss_xe = 0.
            if enable_pg:
                ids_sample, logprobs_sample = model.sample(
                    fc_feats, att_feats, opt={'sample_max': 0})
                ids_greedy, logprobs_greedy = model.sample(
                    fc_feats, att_feats, opt={'sample_max': 1})
                seq_sample = utils.tolist(ids_sample)
                seq_greedy = utils.tolist(ids_greedy)
                seq_target = utils.tolist(labels[:, 1:])
                rewards = [
                    sentence_bleu([t], s, smooth=True) -
                    sentence_bleu([t], g, smooth=True)
                    for s, g, t in zip(seq_sample, seq_greedy, seq_target)
                ]
                rewards = torch.tensor(rewards, device='cuda')
                mask_sample = torch.ne(ids_sample,
                                       torch.tensor(0, device='cuda')).float()
                loss_pg = (rewards *
                           (logprobs_sample * mask_sample).sum(1)).mean()
            else:
                loss_pg = 0.
            loss_ce = opt.xe_w * loss_xe + opt.pg_w * loss_pg
        else:
            loss_ce = 0.
        if enable_mb:
            logits = model(
                fc_feats,
                att_feats,
                labels,
                teach_mask=(teach_mask if not opt.teach_all_input else None))
            decode_length = logits.shape[1] + 1
            teach_mask = teach_mask[:decode_length]
            onehot = utils.to_onehot(labels[:, :decode_length],
                                     logits.shape[-1],
                                     dtype=torch.float)
            probs = torch.exp(logits)
            probs = torch.cat([onehot[:, :1], probs], 1)  # pad bos
            probs = utils.mask_probs(probs, onehot, teach_mask)
            if verbose:
                verbose_probs = probs
                verbose_probs.retain_grad()
            mask = masks[:, :decode_length]
            mask = torch.cat([mask[:, :1], mask], 1)
            loss_mb = crit_mb(probs,
                              labels[:, :decode_length],
                              mask,
                              min_fn=opt.min_fn,
                              min_c=opt.min_c,
                              verbose=verbose)
        else:
            loss_mb = 0.
        loss = loss_ce * (1 - opt.bleu_w) + loss_mb * opt.bleu_w
        loss.backward()
        utils.clip_gradient(
            optimizer,
            opt.grad_clip)  #TODO: examine clip method and record grad

        if verbose and 'verbose_probs' in locals():
            max_grads, max_ids = verbose_probs.grad.topk(opt.verbose_topk,
                                                         -1,
                                                         largest=False)
            max_probs = torch.gather(verbose_probs, -1, max_ids)
            max_sents = ids_to_sents(max_ids[:, :, 0])
            for sample_i in range(min(opt.samples, verbose_probs.shape[0])):
                l = len(max_sents[sample_i]) + 1
                print('max:\n{}'.format(max_sents[sample_i]))
                print('max probs:\n{}'.format(max_probs[sample_i][:l]))
                print('max grads:\n{}'.format(max_grads[sample_i][:l]))

        optimizer.step()
        train_loss = float(loss)
        torch.cuda.synchronize()
        end = time.time()
        if iteration % opt.print_iters == 0:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            val_loss, predictions, lang_stats = eval_model()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            if True:  # if true
                best_flag = False
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                def save_model(suffix=''):
                    model_path = os.path.join(opt.checkpoint_path,
                                              'model{}.pth'.format(suffix))
                    torch.save(model.state_dict(), model_path)
                    print("model saved to {}".format(model_path))
                    optimizer_path = os.path.join(
                        opt.checkpoint_path, 'optimizer{}.pth'.format(suffix))
                    torch.save(optimizer.state_dict(), optimizer_path)

                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'infos_{}{}.pkl'.format(opt.id, suffix)),
                            'wb') as f:
                        cPickle.dump(infos, f)
                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'histories_{}{}.pkl'.format(opt.id, suffix)),
                            'wb') as f:
                        cPickle.dump(histories, f)

                save_model()
                save_model(".iter{}".format(iteration))

                if best_flag:
                    save_model(".best")

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #12
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)

    from dataloader import DataLoader
    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.vocab_ccg_size = loader.vocab_ccg_size
    opt.seq_length = loader.seq_length

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()

    model = models.setup(opt)
    model.cuda()
    # model = DataParallel(model)

    if vars(opt).get('start_from', None) is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    multilabel_crit = nn.MultiLabelSoftMarginLoss().cuda()
    #    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)
    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
        print('finetune mode')
        cnn_optimizer = optim.Adam([\
            {'params': module.parameters()} for module in cnn_model._modules.values()[5:]\
            ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay)

    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            if os.path.isfile(os.path.join(opt.start_from,
                                           'optimizer-cnn.pth')):
                cnn_optimizer.load_state_dict(
                    torch.load(
                        os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    eval_kwargs = {'split': 'val', 'dataset': opt.input_json, 'verbose': True}
    eval_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_utils.eval_split(
        cnn_model, model, crit, loader, eval_kwargs, True)
    epoch_start = time.time()
    while True:
        if update_lr_flag:
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
                #model.module.ss_prob = opt.ss_prob
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
            else:
                sc_flag = False

            # Update the training stage of cnn
            for p in cnn_model.parameters():
                p.requires_grad = True
            # Fix the first few layers:
            for module in cnn_model._modules.values()[:5]:
                for p in module.parameters():
                    p.requires_grad = False
            cnn_model.train()
            update_lr_flag = False

        cnn_model.apply(utils.set_bn_fix)
        cnn_model.apply(utils.set_bn_eval)

        start = time.time()
        torch.cuda.synchronize()
        data = loader.get_batch('train')
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:

            multilabels = [
                data['detection_infos'][i]['label']
                for i in range(len(data['detection_infos']))
            ]

            tmp = [
                data['labels'], data['masks'],
                np.array(multilabels, dtype=np.int16)
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            labels, masks, multilabels = tmp
            images = data[
                'images']  # it cannot be turned into tensor since different sizes.
            _fc_feats_2048 = []
            _fc_feats_81 = []
            _att_feats = []
            for i in range(loader.batch_size):
                x = Variable(torch.from_numpy(images[i]),
                             requires_grad=False).cuda()
                x = x.unsqueeze(0)
                att_feats, fc_feats_81 = cnn_model(x)
                fc_feats_2048 = att_feats.mean(3).mean(2).squeeze()
                att_feats = F.adaptive_avg_pool2d(att_feats,
                                                  [14, 14]).squeeze().permute(
                                                      1, 2, 0)  #(0, 2, 3, 1)
                _fc_feats_2048.append(fc_feats_2048)
                _fc_feats_81.append(fc_feats_81)
                _att_feats.append(att_feats)
            _fc_feats_2048 = torch.stack(_fc_feats_2048)
            _fc_feats_81 = torch.stack(_fc_feats_81)
            _att_feats = torch.stack(_att_feats)
            att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \
                                                           _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \
                                                           _att_feats.size()[1:]))
            fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:]))
            fc_feats_81 = _fc_feats_81
            #
            cnn_optimizer.zero_grad()
        else:

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks']
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()

        if not sc_flag:
            loss1 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss = 0.8 * loss1 + 0.2 * loss2.float()
        else:
            gen_result, sample_logprobs = model.sample(fc_feats_2048,
                                                       att_feats,
                                                       {'sample_max': 0})
            reward = get_self_critical_reward(model, fc_feats_2048, att_feats,
                                              data, gen_result)
            loss1 = rl_crit(
                sample_logprobs, gen_result,
                Variable(torch.from_numpy(reward).float().cuda(),
                         requires_grad=False))
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss3 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss = 0.995 * loss1 + 0.005 * (loss2.float() + loss3)
        loss.backward()

        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.data[0]
        mle_loss = loss1.data[0]
        multilabel_loss = loss2.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, mle_loss, multilabel_loss, train_loss, end - start))

        if sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), avg_reward = {:.3f}, mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), mle_loss, multilabel_loss, train_loss, end - start))
        iteration += 1
        if (iteration % opt.losses_log_every == 0):
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        if (iteration % opt.save_checkpoint_every == 0):
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': True
            }
            eval_kwargs.update(vars(opt))

            if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, True)
            else:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, False)

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))

                cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-cnn.pth')
                torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                print("cnn model saved to {}".format(cnn_checkpoint_path))

                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                    cnn_optimizer_path = os.path.join(opt.checkpoint_path,
                                                      'optimizer-cnn.pth')
                    torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))

                    cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                       'model-cnn-best.pth')
                    torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                    print("cnn model saved to {}".format(cnn_checkpoint_path))

                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True
            print("epoch: " + str(epoch) + " during: " +
                  str(time.time() - epoch_start))
            epoch_start = time.time()

        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #13
0
def main():
    import opts
    import misc.utils as utils
    opt = opts.parse_opt()
    opt.caption_model = 'topdown'
    opt.batch_size = 10  #512#32*4*4
    opt.id = 'topdown'
    opt.learning_rate = 5e-4
    opt.learning_rate_decay_start = 0
    opt.scheduled_sampling_start = 0
    opt.save_checkpoint_every = 5000  #450#5000#11500
    opt.val_images_use = 5000
    opt.max_epochs = 50  #30
    opt.start_from = 'save/rt'  #"save" #None
    opt.language_eval = 1
    opt.input_json = 'data/meta_coco_en.json'
    opt.input_label_h5 = 'data/label_coco_en.h5'
    #    opt.input_json='data/coco_ccg.json' #'data/meta_coco_en.json'
    #    opt.input_label_h5='data/coco_ccg_label.h5' #'data/label_coco_en.h5'
    #    opt.input_fc_dir='/nlp/andyweizhao/self-critical.pytorch-master/data/cocotalk_fc'
    #    opt.input_att_dir='/nlp/andyweizhao/self-critical.pytorch-master/data/cocotalk_att'
    opt.finetune_cnn_after = 0
    opt.ccg = False
    opt.input_image_h5 = 'data/coco_image_512.h5'

    opt.use_att = utils.if_use_att(opt.caption_model)

    from dataloader import DataLoader  # just-in-time generated features
    loader = DataLoader(opt)

    #    from dataloader_fixcnn import DataLoader # load pre-processed features
    #    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.vocab_ccg_size = loader.vocab_ccg_size
    opt.seq_length = loader.seq_length

    import models
    model = models.setup(opt)
    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()
    model.cuda()

    data = loader.get_batch('train')
    images = data['images']

    #    _fc_feats_2048 = []
    #    _fc_feats_81 = []
    #    _att_feats = []
    #    for i in range(loader.batch_size):
    #        x = Variable(torch.from_numpy(images[i]), volatile=True).cuda()
    #        x = x.unsqueeze(0)
    #        att_feats, fc_feats_81 = cnn_model(x)
    #        fc_feats_2048 = att_feats.mean(3).mean(2).squeeze()
    #        att_feats = F.adaptive_avg_pool2d(att_feats,[14,14]).squeeze().permute(1, 2, 0)#(0, 2, 3, 1)
    #        _fc_feats_2048.append(fc_feats_2048)
    #        _fc_feats_81.append(fc_feats_81)
    #        _att_feats.append(att_feats)
    #    _fc_feats_2048 = torch.stack(_fc_feats_2048)
    #    _fc_feats_81 = torch.stack(_fc_feats_81)
    #    _att_feats = torch.stack(_att_feats)
    #    att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \
    #                                                   _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \
    #                                                   _att_feats.size()[1:]))
    #    fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \
    #                                                  _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \
    #                                                  _fc_feats_2048.size()[1:]))
    #    fc_feats_81 = _fc_feats_81
    #
    #    att_feats = Variable(att_feats, requires_grad=False).cuda()
    #    Variable(fc_feats_81)

    crit = utils.LanguageModelCriterion()
    eval_kwargs = {'split': 'val', 'dataset': opt.input_json, 'verbose': True}
    eval_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_split(cnn_model, model, crit,
                                                   loader, eval_kwargs, True)
コード例 #14
0
ファイル: train.py プロジェクト: vplab-github/sacic
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from_path is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from_path,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from_path,
                             'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from_path,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    #print(val_result_history.get(3000))
    #exit(0)
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    no = sum(p.numel() for p in model.parameters())
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print("Trainable Params:" + str(pytorch_total_params))
    print("Total Params:" + str(no))
    #exit(0)
    dp_model = torch.nn.DataParallel(model)

    epoch_done = True
    # Assure in training mode
    dp_model.train()
    if (opt.use_obj_mcl_loss == 1):
        mcl_crit = utils.MultiLabelClassification()
    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from_path', None) is not None and os.path.isfile(
            os.path.join(opt.start_from_path, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from_path, 'optimizer.pth')))

    time_epoch_start = time.time()
    data_time_sum = 0
    batch_time_sum = 0
    while True:
        if epoch_done:
            torch.cuda.synchronize()
            time_epoch_end = time.time()
            time_elapsed = (time_epoch_end - time_epoch_start)
            print('[DEBUG] Epoch Time: ' + str(time_elapsed))
            print('[DEBUG] Sum Data Time: ' + str(data_time_sum))
            print('[DEBUG] Sum Batch Time: ' + str(batch_time_sum))
            #if epoch==1:
            #    exit(0)
            if not opt.noamopt and not opt.reduce_on_plateau:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            epoch_done = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)
        data_time_sum += time.time() - start
        torch.cuda.synchronize()
        start = time.time()

        if (opt.use_obj_mcl_loss == 0):
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp
        else:
            if opt.use_obj_att and opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['obj_att_feats'],
                    data['seg_feat_feats'], data['labels'], data['masks'],
                    data['obj_labels'], data['att_masks'],
                    data['obj_att_masks'], data['seg_feat_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, obj_att_feats, seg_feat_feats, labels, masks, obj_labels, att_masks, obj_att_masks, seg_feat_masks = tmp
            elif not opt.use_obj_att and opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'],
                    data['seg_feat_feats'], data['labels'], data['masks'],
                    data['obj_labels'], data['att_masks'],
                    data['seg_feat_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, seg_feat_feats, labels, masks, obj_labels, att_masks, seg_feat_masks = tmp
            elif not opt.use_obj_att and not opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['obj_labels'], data['att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, labels, masks, obj_labels, att_masks = tmp
            elif opt.use_obj_att and not opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['obj_att_feats'],
                    data['labels'], data['masks'], data['obj_labels'],
                    data['att_masks'], data['obj_att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, obj_att_feats, labels, masks, obj_labels, att_masks, obj_att_masks = tmp

        optimizer.zero_grad()
        if (opt.use_obj_mcl_loss == 0):
            if not sc_flag:
                loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                            labels[:, 1:], masks[:, 1:])
            else:
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())
        else:
            if opt.use_obj_att and opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(
                        fc_feats, [att_feats, obj_att_feats, seg_feat_feats],
                        labels, [att_masks, obj_att_masks, seg_feat_masks])
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss
                    #loss = 0.1*caption_loss + obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
            elif not opt.use_obj_att and opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(fc_feats,
                                           [att_feats, seg_feat_feats], labels,
                                           [att_masks, seg_feat_masks])
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
            if not opt.use_obj_att and not opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(fc_feats, att_feats, labels,
                                           att_masks)
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
            elif opt.use_obj_att and not opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(fc_feats,
                                           [att_feats, obj_att_feats], labels,
                                           [att_masks, obj_att_masks])
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = 0.1 * caption_loss + obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        batch_time_sum += end - start
        if not sc_flag:
            if (opt.use_obj_mcl_loss == 1):
                print("iter {} (epoch {}), train_loss = {:.3f}, caption_loss = {:.3f}, object_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, caption_loss.item(), obj_loss.item(), end - start))
            else:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            epoch_done = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            if (opt.use_obj_mcl_loss == 1):
                add_summary_value(tb_summary_writer, 'obj_loss',
                                  obj_loss.item(), iteration)
                add_summary_value(tb_summary_writer, 'caption_loss',
                                  caption_loss.item(), iteration)
            if opt.noamopt:
                opt.current_lr = optimizer.rate()
            elif opt.reduce_on_plateau:
                opt.current_lr = optimizer.current_lr
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            orig_batch_size = opt.batch_size
            opt.batch_size = 1
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            loader.batch_size = eval_kwargs.get('batch_size', 1)
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)
            opt.batch_size = orig_batch_size
            loader.batch_size = orig_batch_size

            if opt.reduce_on_plateau:
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            for k, v in lang_stats.items():
                add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #15
0
def eval_cap(opt):
    opt.start_from = None
    opt.start_from_gan = 'data/save_gan_rcsls_s_g_naacl/20201121_185728_sep_self_att_sep_gan_only'
    opt.checkpoint_path_gan = opt.start_from_gan
    opt.init_path_zh = 'data/save_for_finetune/20201116_203753_gtssg_sep_self_att_sep_vv1_RCSLS_submap_v2_add_wordmap_add_global_naacl_self_gate_p2/model-best.pth'
    opt.input_isg_dir = "data/coco_graph_extract_ft_isg_joint_rcsls_submap_global_naacl_self_gate_finetune"
    opt.input_ssg_dir = "data/coco_graph_extract_ft_ssg_joint_rcsls_submap_global_naacl_self_gate_finetune"
    opt.caption_model_to_replace = 'up_gtssg_sep_self_att_sep_vv1_RCSLS_submap_v2_add_wordmap_add_global_naacl_self_gate_p2'

    opt.gpu = 0

    opt.batch_size = 50
    opt.beam_size = 5
    opt.dump_path = 1
    opt.caption_model = 'sep_self_att_sep_gan_only'
    opt.input_json = 'data/coco_cn/cocobu_gan_ssg.json'
    opt.input_json_isg = 'data/coco_cn/cocobu_gan_isg.json'
    opt.input_label_h5 = 'data/coco_cn/cocobu_gan_isg_label.h5'
    opt.ssg_dict_path = 'data/aic_process/ALL_11683_v3_COCOCN_spice_sg_dict_t5.npz_revise.npz'
    opt.rela_dict_dir = 'data/rela_dict.npy'
    opt.input_fc_dir = 'data/cocobu_fc'
    opt.input_att_dir = 'data/cocobu_att'
    opt.input_box_dir = 'data/cocotalk_box'
    opt.input_label_h5 = 'data/cocobu_label.h5'

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5
    opt.split = 'test'
    loader = DataLoader_GAN(opt)
    loader_i2t = DataLoader_UP(opt)

    opt.vocab_size = loader.vocab_size
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        try:
            with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f:
                infos = cPickle.load(f)
                saved_model_opt = infos['opt']
                need_be_same = [
                    "caption_model", "rnn_type", "rnn_size", "num_layers"
                ]
                for checkme in need_be_same:
                    assert vars(saved_model_opt)[checkme] == vars(
                        opt
                    )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

            if os.path.isfile(
                    os.path.join(opt.checkpoint_path, 'histories.pkl')):
                with open(os.path.join(opt.checkpoint_path,
                                       'histories.pkl')) as f:
                    histories = cPickle.load(f)
        except:
            print("Can not load infos.pkl")

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # opt.caption_model = 'up_gtssg_sep_self_att_sep'
    opt.caption_model = opt.caption_model_to_replace
    model = models.setup(opt).cuda()
    model.eval()

    opt.start_from = opt.start_from_gan
    opt.checkpoint_path = opt.start_from_gan
    opt.id = opt.checkpoint_path.split('/')[-1]
    start_from = opt.start_from_gan

    with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f:
        infos = cPickle.load(f)
        saved_model_opt = infos['opt']
        saved_model_opt.start_from_gan = start_from
        saved_model_opt.checkpoint_path_gan = start_from

    netG_A_obj = GAN_init_G(saved_model_opt,
                            Generator(saved_model_opt),
                            type='netG_A_obj').cuda().eval()
    netG_A_rel = GAN_init_G(saved_model_opt,
                            Generator(saved_model_opt),
                            type='netG_A_rel').cuda().eval()
    netG_A_atr = GAN_init_G(saved_model_opt,
                            Generator(saved_model_opt),
                            type='netG_A_atr').cuda().eval()

    netG_B_obj = GAN_init_G(saved_model_opt,
                            Generator(saved_model_opt),
                            type='netG_B_obj').cuda().eval()
    netG_B_rel = GAN_init_G(saved_model_opt,
                            Generator(saved_model_opt),
                            type='netG_B_rel').cuda().eval()
    netG_B_atr = GAN_init_G(saved_model_opt,
                            Generator(saved_model_opt),
                            type='netG_B_atr').cuda().eval()

    val_loss, cache_path = eval_utils_gan.eval_split_gan_v2(
        opt, model, netG_A_obj, netG_A_rel, netG_A_atr, netG_B_obj, netG_B_rel,
        netG_B_atr, loader, loader_i2t, opt.split)
コード例 #16
0
def train(opt):
    assert opt.annfile is not None and len(opt.annfile) > 0

    print('Checkpoint path is ' + opt.checkpoint_path)
    print('This program is using GPU ' +
          str(os.environ['CUDA_VISIBLE_DEVICES']))
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        if opt.load_best:
            info_path = os.path.join(opt.start_from,
                                     'infos_' + opt.id + '-best.pkl')
        else:
            info_path = os.path.join(opt.start_from,
                                     'infos_' + opt.id + '.pkl')
        with open(info_path) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    if opt.learning_rate_decay_start is None:
        opt.learning_rate_decay_start = infos.get(
            'opt', None).learning_rate_decay_start
    # if opt.load_best:
    #     opt.self_critical_after = epoch
    elif opt.learning_rate_decay_start == -1 and opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
        opt.learning_rate_decay_start = epoch

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    # loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_val_score_ave_model = infos.get('best_val_score_ave_model', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion(opt.XE_eps)
    rl_crit = utils.RewardCriterion()

    # build_optimizer
    optimizer = build_optimizer(model, opt)

    # Load the optimizer
    if opt.load_opti and vars(opt).get(
            'start_from',
            None) is not None and opt.load_best == 0 and os.path.isfile(
                os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # initialize the running average of parameters
    avg_param = deepcopy(list(p.data for p in model.parameters()))

    # make evaluation using original model
    best_val_score, histories, infos = eva_original_model(
        best_val_score, crit, epoch, histories, infos, iteration, loader,
        loss_history, lr_history, model, opt, optimizer, ss_prob_history,
        tb_summary_writer, val_result_history)

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                if opt.lr_decay == 'exp':
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                elif opt.lr_decay == 'cosine':
                    lr_epoch = min((epoch - opt.learning_rate_decay_start),
                                   opt.lr_max_epoch)
                    cosine_decay = 0.5 * (
                        1 + math.cos(math.pi * lr_epoch / opt.lr_max_epoch))
                    decay_factor = (1 - opt.lr_cosine_decay_base
                                    ) * cosine_decay + opt.lr_cosine_decay_base
                    opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate

            lr = [opt.current_lr]
            if opt.att_normalize_method is not None and '6' in opt.att_normalize_method:
                lr = [opt.current_lr, opt.lr_ratio * opt.current_lr]

            utils.set_lr(optimizer, lr)
            print('learning rate is: ' + str(lr))

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Update the iteration
        iteration += 1

        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            output = dp_model(fc_feats, att_feats, labels, att_masks)
            # calculate loss
            loss = crit(output[0], labels[:, 1:], masks[:, 1:])

            # add some middle variable histogram
            if iteration % (4 * opt.losses_log_every) == 0:
                outputs = [
                    _.data.cpu().numpy() if _ is not None else None
                    for _ in output
                ]
                variables_histogram(data, iteration, outputs,
                                    tb_summary_writer, opt)

        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        # grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_max_norm)
        # add_summary_value(tb_summary_writer, 'grad_L2_norm', grad_norm, iteration)

        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()

        # compute the running average of parameters
        for p, avg_p in zip(model.parameters(), avg_param):
            avg_p.mul_(opt.beta).add_((1.0 - opt.beta), p.data)

        if iteration % 10 == 0:
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        if opt.tensorboard_weights_grads and (iteration %
                                              (8 * opt.losses_log_every) == 0):
            # add weights histogram to tensorboard summary
            for name, param in model.named_parameters():
                if (opt.tensorboard_parameters_name is None or sum([
                        p_name in name
                        for p_name in opt.tensorboard_parameters_name
                ]) > 0) and param.grad is not None:
                    tb_summary_writer.add_histogram(
                        'Weights_' + name.replace('.', '/'), param, iteration)
                    tb_summary_writer.add_histogram(
                        'Grads_' + name.replace('.', '/'), param.grad,
                        iteration)

        if opt.tensorboard_buffers and (iteration %
                                        (opt.losses_log_every) == 0):
            for name, buffer in model.named_buffers():
                if (opt.tensorboard_buffers_name is None or sum([
                        p_name in name
                        for p_name in opt.tensorboard_buffers_name
                ]) > 0) and buffer is not None:
                    add_summary_value(tb_summary_writer,
                                      name.replace('.',
                                                   '/'), buffer, iteration)

        if opt.distance_sensitive_coefficient and iteration % (
                4 * opt.losses_log_every) == 0:
            print('The coefficient in intra_att_att_lstm is as follows:')
            print(
                model.core.intra_att_att_lstm.coefficient.data.cpu().tolist())
            print('The coefficient in intra_att_lang_lstm is as follows:')
            print(
                model.core.intra_att_lang_lstm.coefficient.data.cpu().tolist())
        if opt.distance_sensitive_bias and iteration % (
                4 * opt.losses_log_every) == 0:
            print('The bias in intra_att_att_lstm is as follows:')
            print(model.core.intra_att_att_lstm.bias.data.cpu().tolist())
            print('The bias in intra_att_lang_lstm is as follows:')
            print(model.core.intra_att_lang_lstm.bias.data.cpu().tolist())

        # make evaluation using original model
        if (iteration % opt.save_checkpoint_every == 0):
            best_val_score, histories, infos = eva_original_model(
                best_val_score, crit, epoch, histories, infos, iteration,
                loader, loss_history, lr_history, model, opt, optimizer,
                ss_prob_history, tb_summary_writer, val_result_history)

        # make evaluation with the averaged parameters model
        if iteration > opt.ave_threshold and (iteration %
                                              opt.save_checkpoint_every == 0):
            best_val_score_ave_model, infos = eva_ave_model(
                avg_param, best_val_score_ave_model, crit, infos, iteration,
                loader, model, opt, tb_summary_writer)

        # # Stop if reaching max epochs
        # if epoch >= opt.max_epochs and opt.max_epochs != -1:
        #     break

        if iteration >= opt.max_iter:
            break
コード例 #17
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from_S is not None:
        if os.path.isfile(
                os.path.join(opt.start_from_S,
                             'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from_S,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Set CommNetModel
    model1 = ShowTellModel(opt)
    model2 = ShowTellModel(opt)
    model1.load_state_dict(
        torch.load(os.path.join(opt.start_from_T, 'model.pth')))
    model2.load_state_dict(
        torch.load(os.path.join(opt.start_from_S, 'model.pth')))
    model1.cuda()
    model2.cuda()

    model = CommNetModel(opt, model1, model2)
    model.cuda()
    logger = Logger(opt)

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    # if vars(opt).get('start_from_S', None) is not None and os.path.isfile(os.path.join(opt.start_from_S,"optimizer.pth")):
    #     temp = torch.load(os.path.join(opt.start_from_S, 'optimizer.pth'))
    #     optimizer.load_state_dict(temp)

    while True:
        if update_lr_flag:
            opt, sc_flag, update_lr_flag, model, optimizer = update_lr(
                opt, epoch, model, optimizer)

        # Load data from train split (0)
        data = loader.get_batch('train', seq_per_img=opt.seq_per_img)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()

        if not sc_flag:
            loss_1 = crit(
                model(fc_feats, att_feats, labels)[0], labels[:, 1:],
                masks[:, 1:])
            loss_2 = crit(
                model(fc_feats, att_feats, labels)[1], labels[:, 1:],
                masks[:, 1:])
            loss_1.backward()
            loss_2.backward()
            loss = loss_1 + loss_2
        else:
            gen_result_1, sample_logprobs_1 = model.model1.sample(
                fc_feats, att_feats, {'sample_max': 0})
            gen_result_2, sample_logprobs_2 = model.model2.sample(
                fc_feats, att_feats, {'sample_max': 0})

            reward_1 = get_self_critical_reward_forCommNet(
                model, fc_feats, att_feats, data, gen_result_1, logger)
            reward_2 = get_self_critical_reward_forCommNet(
                model, fc_feats, att_feats, data, gen_result_2, logger)

            loss_1 = rl_crit(
                sample_logprobs_1, gen_result_1,
                Variable(torch.from_numpy(reward_1).float().cuda(),
                         requires_grad=False))
            loss_2 = rl_crit(
                sample_logprobs_2, gen_result_2,
                Variable(torch.from_numpy(reward_2).float().cuda(),
                         requires_grad=False))

            loss_1.backward()
            loss_2.backward()
            loss = loss_1 + loss_2

        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()

        if not sc_flag:
            log = "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start)
            logger.write(log)
        else:
            log = "iter {} (epoch {}), S_avg_reward = {:.3f}, T_avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration,  epoch, np.mean(reward_1[:,0]), np.mean(reward_2[:,0]), end - start)
            logger.write(log)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward_S',
                                      np.mean(reward_1[:, 0]), iteration)
                    add_summary_value(tf_summary_writer, 'avg_reward_T',
                                      np.mean(reward_2[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward_1[:, 0] + reward_2[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))

            #val_loss, predictions, lang_stats = eval_utils.eval_split(model.model1, crit, loader, logger, eval_kwargs)
            val_loss, predictions, lang_stats = eval_utils_forCommNet.eval_split(
                model, crit, loader, logger, eval_kwargs)
            logger.write_dict(lang_stats)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model1.pth')
                torch.save(model.model1.state_dict(), checkpoint_path)
                print("model1 saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model2.pth')
                torch.save(model.model2.state_dict(), checkpoint_path)
                print("model2 saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model1-best.pth')
                    torch.save(model.model1.state_dict(), checkpoint_path)
                    print("model1 saved to {}".format(checkpoint_path))
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model2-best.pth')
                    torch.save(model.model2.state_dict(), checkpoint_path)
                    print("model2 saved to {}".format(checkpoint_path))
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #18
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        data_time = time.time() - start

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if iteration % opt.print_freq == 0:
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start, data_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), end - start, data_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path) # MODIFIED (ADDED)

            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best-i{}-score{}.pth'.format(iteration, best_val_score))
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #19
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'),'rb') as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'),'rb') as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()
    if opt.multi_gpu:
        model=nn.DataParallel(model)
    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp
        
        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                for k,v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #20
0
def train(opt):
    start_from = vars(opt).get('start_from', None)
    start_from_p = vars(opt).get('start_from_en', None)
    opt.checkpoint_path, opt.id = model_start(start_from, 0)
    opt.checkpoint_path_p, opt.id_p = model_start(start_from_p, 1)

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    # opt.use_att = False
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader_UP(opt)
    vocab_size = loader.vocab_size
    vocab_size_p = loader.vocab_size_p

    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.compat.v1.summary.FileWriter(
            opt.checkpoint_path)
        tb_summary_writer_p = tf and tf.compat.v1.summary.FileWriter(
            opt.checkpoint_path_p)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    opt.p_flag = 0  # whether paired model
    opt.vocab_size = vocab_size
    loader,iteration,epoch,val_result_history,loss_history,lr_history,ss_prob_history,best_val_score,\
    infos,histories,update_lr_flag,model,dp_model,parameters,crit,rl_crit,optimizer,accumulate_iter,train_loss,reward,train_loss_kl,train_loss_all=load_info(loader,start_from,opt.checkpoint_path,opt.p_flag)

    opt.p_flag = 1
    opt.vocab_size = vocab_size_p
    loader,iteration_p,epoch_p,val_result_history_p,loss_history_p,lr_history_p,ss_prob_history_p,best_val_score_p, \
    infos_p, histories_p, update_lr_flag_p, model_p,dp_model_p, parameters_p, crit_p, rl_crit_p, optimizer_p, accumulate_iter_p, train_loss_p, reward_p,train_loss_kl,train_loss_all = load_info(
        loader, start_from_p,opt.checkpoint_path_p,opt.p_flag)

    # #  global variables
    update_lr_flag = update_lr_flag
    accumulate_iter = accumulate_iter
    train_loss = train_loss
    train_loss_kl = train_loss_kl
    train_loss_all = train_loss_all
    reward = reward

    update_lr_flag_p = update_lr_flag_p
    accumulate_iter_p = accumulate_iter_p
    train_loss_p = train_loss_p
    reward_p = reward_p

    while True:

        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        opt.p_flag = 0  # whether paired model
        loss, att_obj, att_rela, att_attr, update_lr_flag, sc_flag = pre_model(
            update_lr_flag, epoch, optimizer, model, data, dp_model, crit,
            rl_crit, opt.p_flag)

        opt.p_flag = 1
        loss_p, att_obj_p, att_rela_p, att_attr_p, update_lr_flag_p, sc_flag_p = pre_model(
            update_lr_flag_p, epoch_p, optimizer_p, model_p, data, dp_model_p,
            crit_p, rl_crit_p, opt.p_flag)

        att_obj = F.softmax(att_obj, dim=1)
        att_obj_p = F.softmax(att_obj_p, dim=1)
        att_rela = F.softmax(att_rela, dim=1)
        att_rela_p = F.softmax(att_rela_p, dim=1)
        att_attr = F.softmax(att_attr, dim=1)
        att_attr_p = F.softmax(att_attr_p, dim=1)

        loss_kl = torch.exp(F.kl_div(att_obj.log(), att_obj_p, reduction='sum'),out=None)\
                  +torch.exp(F.kl_div(att_rela.log(), att_rela_p, reduction='sum'),out=None)\
                  +torch.exp(F.kl_div(att_attr.log(), att_attr_p, reduction='sum'),out=None)
        # print(loss_kl)

        accumulate_iter = accumulate_iter + 1
        accumulate_iter_p = accumulate_iter_p + 1
        loss_all = loss + loss_p + loss_kl
        loss = loss / opt.accumulate_number
        loss_p = loss_p / opt.accumulate_number
        loss_kl = loss_kl / opt.accumulate_number
        loss_all = loss_all / opt.accumulate_number
        loss_all.backward()

        opt.p_flag = 0
        # print ('iteration of model 1 is {}'.format(iteration))
        update_lr_flag, epoch, optimizer, model, dp_model, accumulate_iter, iteration, loss, sc_flag, start, reward, \
        tb_summary_writer, loss_history, lr_history, ss_prob_history, use_rela, val_result_history, best_val_score, loader,\
        infos, histories, train_loss, train_loss_kl, train_loss_all, loss_all, loss_kl=\
        save_model(opt.input_json,accumulate_iter, optimizer, iteration, loss, sc_flag, epoch, start, reward, data, tb_summary_writer,model,
                   loss_history, lr_history, ss_prob_history, use_rela, dp_model, val_result_history, best_val_score,
                   crit, loader, infos, histories,train_loss, train_loss_kl ,train_loss_all,opt.id,opt.checkpoint_path,opt.p_flag,loss_all,loss_kl,update_lr_flag)

        opt.p_flag = 1
        # print('iteration of model 2 is {}'.format(iteration_p))
        update_lr_flag_p, epoch_p, optimizer_p, model_p, dp_model_p, accumulate_iter_p, iteration_p, loss_p, sc_flag_p, start, reward_p, \
        tb_summary_writer_p, loss_history_p, lr_history_p, ss_prob_history_p, use_rela, val_result_history_p, best_val_score_p, loader,\
        infos_p, histories_p, train_loss_p, train_loss_kl, train_loss_all, loss_all, loss_kl=\
        save_model(opt.input_json_en,accumulate_iter_p, optimizer_p, iteration_p, loss_p, sc_flag_p, epoch_p, start, reward_p, data, tb_summary_writer_p, model_p,
                   loss_history_p, lr_history_p, ss_prob_history_p, use_rela, dp_model_p, val_result_history_p, best_val_score_p,
                   crit_p, loader, infos_p, histories_p,train_loss_p,train_loss_kl ,train_loss_all,opt.id_p,opt.checkpoint_path_p,opt.p_flag,loss_all,loss_kl,update_lr_flag_p)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #21
0
def train(opt):
    if vars(opt).get('start_from', None) is not None:
        opt.checkpoint_path = opt.start_from
        opt.id = opt.checkpoint_path.split('/')[-1]
        print('Point to folder: {}'.format(opt.checkpoint_path))
    else:
        opt.id = datetime.datetime.now().strftime(
            '%Y%m%d_%H%M%S') + '_' + opt.caption_model
        opt.checkpoint_path = os.path.join(opt.checkpoint_path, opt.id)

        if not os.path.exists(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        print('Create folder: {}'.format(opt.checkpoint_path))

    # Write YAML file
    with io.open(opt.checkpoint_path + '/opts.yaml', 'w',
                 encoding='utf8') as outfile:
        yaml.dump(opt, outfile, default_flow_style=False, allow_unicode=True)

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader_GAN(opt)
    loader_i2t = DataLoader_UP(opt)

    opt.vocab_size = loader.vocab_size
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        try:
            with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f:
                infos = cPickle.load(f)
                saved_model_opt = infos['opt']
                need_be_same = [
                    "caption_model", "rnn_type", "rnn_size", "num_layers"
                ]
                for checkme in need_be_same:
                    assert vars(saved_model_opt)[checkme] == vars(
                        opt
                    )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

            if os.path.isfile(
                    os.path.join(opt.checkpoint_path, 'histories.pkl')):
                with open(os.path.join(opt.checkpoint_path,
                                       'histories.pkl')) as f:
                    histories = cPickle.load(f)
        except:
            print("Can not load infos.pkl")

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # opt.caption_model = 'up_gtssg_sep_self_att_sep'
    opt.caption_model = opt.caption_model_to_replace
    model = models.setup(opt).cuda()
    print('### Model summary below###\n {}\n'.format(str(model)))
    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('model parameter:{}'.format(model_params))

    model.eval()

    train_loss = 0
    update_lr_flag = True
    fake_A_pool_obj = utils.ImagePool(opt.pool_size)
    fake_A_pool_rel = utils.ImagePool(opt.pool_size)
    fake_A_pool_atr = utils.ImagePool(opt.pool_size)

    fake_B_pool_obj = utils.ImagePool(opt.pool_size)
    fake_B_pool_rel = utils.ImagePool(opt.pool_size)
    fake_B_pool_atr = utils.ImagePool(opt.pool_size)

    netD_A_obj = GAN_init_D(opt, Discriminator(opt),
                            type='netD_A_obj').cuda().train()
    netD_A_rel = GAN_init_D(opt, Discriminator(opt),
                            type='netD_A_rel').cuda().train()
    netD_A_atr = GAN_init_D(opt, Discriminator(opt),
                            type='netD_A_atr').cuda().train()

    netD_B_obj = GAN_init_D(opt, Discriminator(opt),
                            type='netD_B_obj').cuda().train()
    netD_B_rel = GAN_init_D(opt, Discriminator(opt),
                            type='netD_B_rel').cuda().train()
    netD_B_atr = GAN_init_D(opt, Discriminator(opt),
                            type='netD_B_atr').cuda().train()

    netG_A_obj = GAN_init_G(opt, Generator(opt),
                            type='netG_A_obj').cuda().train()
    netG_A_rel = GAN_init_G(opt, Generator(opt),
                            type='netG_A_rel').cuda().train()
    netG_A_atr = GAN_init_G(opt, Generator(opt),
                            type='netG_A_atr').cuda().train()

    netG_B_obj = GAN_init_G(opt, Generator(opt),
                            type='netG_B_obj').cuda().train()
    netG_B_rel = GAN_init_G(opt, Generator(opt),
                            type='netG_B_rel').cuda().train()
    netG_B_atr = GAN_init_G(opt, Generator(opt),
                            type='netG_B_atr').cuda().train()

    optimizer_G = utils.build_optimizer(
        itertools.chain(netG_A_obj.parameters(), netG_B_obj.parameters(),
                        netG_A_rel.parameters(), netG_B_rel.parameters(),
                        netG_A_atr.parameters(), netG_B_atr.parameters()), opt)
    optimizer_D = utils.build_optimizer(
        itertools.chain(netD_A_obj.parameters(), netD_B_obj.parameters(),
                        netD_A_rel.parameters(), netD_B_rel.parameters(),
                        netD_A_atr.parameters(), netD_B_atr.parameters()), opt)

    criterionGAN = GANLoss(opt.gan_mode).cuda()  # define GAN loss.
    criterionCycle = torch.nn.L1Loss()
    criterionIdt = torch.nn.L1Loss()

    optimizers = []
    optimizers.append(optimizer_G)
    optimizers.append(optimizer_D)
    schedulers = [get_scheduler(opt, optimizer) for optimizer in optimizers]
    current_lr = optimizers[0].param_groups[0]['lr']
    train_num = 0
    update_lr_flag = True

    while True:
        if update_lr_flag and opt.current_lr >= 1e-4:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            if opt.current_lr >= 1e-4:
                utils.set_lr(optimizer, opt.current_lr)
            else:
                utils.set_lr(optimizer, 1e-4)
            update_lr_flag = False
        """
        Show the percentage of data loader
        """
        if train_num > loader.max_index:
            train_num = 0
        train_num = train_num + 1
        train_precentage = float(train_num) * 100 / float(loader.max_index)
        """
        Start training
        """
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['isg_feats'][:, 0, :], data['isg_feats'][:, 1, :],
            data['isg_feats'][:, 2, :], data['ssg_feats'][:, 0, :],
            data['ssg_feats'][:, 1, :], data['ssg_feats'][:, 2, :]
        ]

        tmp = [
            _ if _ is None else torch.from_numpy(_).float().cuda() for _ in tmp
        ]

        real_A_obj, real_A_rel, real_A_atr, real_B_obj, real_B_rel, real_B_atr = tmp

        iteration += 1

        fake_B_rel = netG_A_rel(real_A_rel)
        rec_A_rel = netG_B_rel(fake_B_rel)
        idt_B_rel = netG_B_rel(real_A_rel)

        fake_A_rel = netG_B_rel(real_B_rel)
        rec_B_rel = netG_A_rel(fake_A_rel)
        idt_A_rel = netG_A_rel(real_B_rel)

        # Obj
        fake_B_obj = netG_A_obj(real_A_obj)
        rec_A_obj = netG_B_obj(fake_B_obj)
        idt_B_obj = netG_B_obj(real_A_obj)

        fake_A_obj = netG_B_obj(real_B_obj)
        rec_B_obj = netG_A_obj(fake_A_obj)
        idt_A_obj = netG_A_obj(real_B_obj)

        # Atr
        fake_B_atr = netG_A_atr(real_A_atr)
        rec_A_atr = netG_B_atr(fake_B_atr)
        idt_B_atr = netG_B_atr(real_A_atr)

        fake_A_atr = netG_B_atr(real_B_atr)
        rec_B_atr = netG_A_atr(fake_A_atr)
        idt_A_atr = netG_A_atr(real_B_atr)

        domain_A = [
            real_A_obj, real_A_rel, real_A_atr, fake_A_obj, fake_A_rel,
            fake_A_atr, rec_A_obj, rec_A_rel, rec_A_atr, idt_A_obj, idt_A_rel,
            idt_A_atr
        ]
        domain_B = [
            real_B_obj, real_B_rel, real_B_atr, fake_B_obj, fake_B_rel,
            fake_B_atr, rec_B_obj, rec_B_rel, rec_B_atr, idt_B_obj, idt_B_rel,
            idt_B_atr
        ]
        # G_A and G_B
        utils.set_requires_grad([
            netD_A_obj, netD_A_rel, netD_A_atr, netD_B_obj, netD_B_rel,
            netD_B_atr
        ], False)  # Ds require no gradients when optimizing Gs
        optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero

        loss_G = cycle_GAN_backward_G(opt, criterionGAN, criterionCycle,
                                      criterionIdt, netG_A_obj, netG_A_rel,
                                      netG_A_atr, netG_B_obj, netG_B_rel,
                                      netG_B_atr, netD_A_obj, netD_A_rel,
                                      netD_A_atr, netD_B_obj, netD_B_rel,
                                      netD_B_atr, domain_A, domain_B)
        loss_G.backward()
        optimizer_G.step()

        # D_A and D_B
        utils.set_requires_grad([
            netD_A_obj, netD_A_rel, netD_A_atr, netD_B_obj, netD_B_rel,
            netD_B_atr
        ], True)
        optimizer_D.zero_grad()  # set D_A and D_B's gradients to zero
        loss_D_A = cycle_GAN_backward_D(opt, fake_B_pool_obj, fake_B_pool_rel,
                                        fake_B_pool_atr, netD_A_obj,
                                        netD_A_rel, netD_A_atr, criterionGAN,
                                        real_B_obj, real_B_rel, real_B_atr,
                                        fake_B_obj, fake_B_rel, fake_B_atr)
        loss_D_A.backward()
        loss_D_B = cycle_GAN_backward_D(opt, fake_A_pool_obj, fake_A_pool_rel,
                                        fake_A_pool_atr, netD_B_obj,
                                        netD_B_rel, netD_B_atr, criterionGAN,
                                        real_A_obj, real_A_rel, real_A_atr,
                                        fake_A_obj, fake_A_rel, fake_A_atr)
        loss_D_B.backward()
        optimizer_D.step()  # update D_A and D_B's weights

        end = time.time()
        train_loss_G = loss_G.item()
        train_loss_D_A = loss_D_A.item()
        train_loss_D_B = loss_D_B.item()
        print(
            "{}/{:.1f}/{}/{}|train_loss={:.3f}|train_loss_G={:.3f}|train_loss_D_A={:.3f}|train_loss_D_B={:.3f}|time/batch = {:.3f}"
            .format(opt.id, train_precentage, iteration, epoch, train_loss,
                    train_loss_G, train_loss_D_A, train_loss_D_B, end - start))
        torch.cuda.synchronize()

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_G', train_loss_G,
                              iteration)
            add_summary_value(tb_summary_writer, 'train_loss_D_A',
                              train_loss_D_A, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_D_B',
                              train_loss_D_B, iteration)

            # add hype parameters
            add_summary_value(tb_summary_writer, 'beam_size', opt.beam_size,
                              iteration)
            add_summary_value(tb_summary_writer, 'lambdaA', opt.lambda_A,
                              iteration)
            add_summary_value(tb_summary_writer, 'lambdaB', opt.lambda_B,
                              iteration)
            add_summary_value(tb_summary_writer, 'pool_size', opt.pool_size,
                              iteration)
            add_summary_value(tb_summary_writer, 'gan_type', opt.gan_type,
                              iteration)
            add_summary_value(tb_summary_writer, 'gan_d_type', opt.gan_d_type,
                              iteration)
            add_summary_value(tb_summary_writer, 'gan_g_type', opt.gan_g_type,
                              iteration)

        if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0):
            val_loss = eval_utils_gan.eval_split_gan(opt, model, netG_A_obj,
                                                     netG_A_rel, netG_A_atr,
                                                     loader, loader_i2t)
            val_loss = val_loss.item()
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            current_score = -val_loss
            best_flag = False

            save_id = iteration / opt.save_checkpoint_every
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True

            checkpoint_path = os.path.join(opt.checkpoint_path, 'model_D.pth')
            torch.save(
                {
                    'epoch': epoch,
                    'netD_A_atr': netD_A_atr.state_dict(),
                    'netD_A_obj': netD_A_obj.state_dict(),
                    'netD_A_rel': netD_A_rel.state_dict(),
                    'netD_B_atr': netD_B_atr.state_dict(),
                    'netD_B_obj': netD_B_obj.state_dict(),
                    'netD_B_rel': netD_B_rel.state_dict()
                }, checkpoint_path)

            checkpoint_path = os.path.join(opt.checkpoint_path, 'model_G.pth')
            torch.save(
                {
                    'epoch': epoch,
                    'netG_A_atr': netG_A_atr.state_dict(),
                    'netG_A_obj': netG_A_obj.state_dict(),
                    'netG_A_rel': netG_A_rel.state_dict(),
                    'netG_B_atr': netG_B_atr.state_dict(),
                    'netG_B_obj': netG_B_obj.state_dict(),
                    'netG_B_rel': netG_B_rel.state_dict()
                }, checkpoint_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()

            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            histories['ss_prob_history'] = ss_prob_history
            with open(os.path.join(opt.checkpoint_path, 'infos.pkl'),
                      'wb') as f:
                cPickle.dump(infos, f)
            with open(os.path.join(opt.checkpoint_path, 'histories.pkl'),
                      'wb') as f:
                cPickle.dump(histories, f)

            if best_flag:
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model_D-best.pth')
                torch.save(
                    {
                        'epoch': epoch,
                        'netD_A_atr': netD_A_atr.state_dict(),
                        'netD_A_obj': netD_A_obj.state_dict(),
                        'netD_A_rel': netD_A_rel.state_dict(),
                        'netD_B_atr': netD_B_atr.state_dict(),
                        'netD_B_obj': netD_B_obj.state_dict(),
                        'netD_B_rel': netD_B_rel.state_dict()
                    }, checkpoint_path)

                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model_G-best.pth')
                torch.save(
                    {
                        'epoch': epoch,
                        'netG_A_atr': netG_A_atr.state_dict(),
                        'netG_A_obj': netG_A_obj.state_dict(),
                        'netG_A_rel': netG_A_rel.state_dict(),
                        'netG_B_atr': netG_B_atr.state_dict(),
                        'netG_B_obj': netG_B_obj.state_dict(),
                        'netG_B_rel': netG_B_rel.state_dict()
                    }, checkpoint_path)

                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, 'infos-best.pkl'),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            # current_lr = update_learning_rate(schedulers, optimizers)
            epoch += 1
            update_lr_flag = True
            # make evaluation on validation set, and save model
            # lang_stats_isg = eval_utils_gan.eval_split_i2t(opt, model, netG_A_obj, netG_A_rel, netG_A_atr, loader, loader_i2t)
            lang_stats_isg = eval_utils_gan.eval_split_g2t(
                opt, model, netG_A_obj, netG_A_rel, netG_A_atr, loader,
                loader_i2t)

            if lang_stats_isg is not None:
                for k, v in lang_stats_isg.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #22
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                    masks[:, 1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # Stop if reaching max epochs
        if epoch >= 8:
            break
コード例 #23
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()
    back_model = models.setup(opt, reverse=True)
    back_model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()
    back_model.train()

    crit = utils.LanguageModelCriterion()
    all_param = chain(model.parameters(), back_model.parameters())
    optimizer = optim.Adam(all_param,
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        reverse_labels = np.flip(data['labels'], 1).copy()
        reverse_masks = np.flip(data['masks'], 1).copy()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'],
            reverse_labels, data['masks'], reverse_masks
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, reverse_labels, masks, reverse_masks = tmp
        optimizer.zero_grad()
        out, states = model(fc_feats, att_feats, labels)
        back_out, back_states = back_model(fc_feats, att_feats, reverse_labels)
        idx = [i for i in range(back_states.size()[1] - 1, -1, -1)]
        idx = torch.LongTensor(idx)
        idx = Variable(idx).cuda()
        invert_backstates = back_states.index_select(1, idx)

        loss = crit(out, labels[:, 1:], masks[:, 1:])

        back_loss = crit(back_out, reverse_labels[:, :-1],
                         reverse_masks[:, :-1])

        invert_backstates = invert_backstates.detach()
        l2_loss = ((states - invert_backstates)**2).mean()

        all_loss = loss + 1.5 * l2_loss + back_loss

        all_loss.backward()
        #back_loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_l2_loss = l2_loss.data[0]
        train_loss = loss.data[0]
        train_all_loss = all_loss.data[0]
        train_back_loss = back_loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, l2_loss = {:.3f}, back_loss = {:.3f}, all_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, train_l2_loss, train_back_loss, train_all_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'l2_loss', train_l2_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'all_loss',
                                  train_all_loss, iteration)
                add_summary_value(tf_summary_writer, 'back_loss',
                                  train_back_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #24
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.maxlen_sen
    opt.inc_seg = loader.inc_seg
    opt.seg_ix = loader.seg_ix
    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    score_list = []
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    best_val_score = None
    best_val_score = {}
    score_splits = ['val', 'test']
    score_type = ['Bleu_4', 'METEOR', 'CIDEr']
    for split_i in score_splits:
        for score_item in score_type:
            if split_i not in best_val_score.keys():
                best_val_score[split_i] = {}
            best_val_score[split_i][score_item] = 0.0
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', best_val_score)

    model = models.setup(opt)
    device_ids = [0, 1]

    torch.cuda.set_device(device_ids[0])
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    update_lr_flag = True
    # Assure in training mode
    model.module.train()
    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.module.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)
    #optimizer = nn.DataParallel(optimizer, device_ids=device_ids)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.module.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['labels'], data['x_phrase_mask_0'], data['x_phrase_mask_1'], \
               data['label_masks'], data['salicy_seg'], data['seg_mask']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, seq, phrase_mask_0, phrase_mask_1, masks, salicy_seg, seg_mask = tmp

        optimizer.zero_grad()
        remove_len = 2
        outputs, alphas = model.module(fc_feats, seq, phrase_mask_0,
                                       phrase_mask_1, masks, seg_mask,
                                       remove_len)
        loss = crit(outputs, seq[remove_len:, :].permute(1, 0),
                    masks[remove_len:, :].permute(1, 0))
        alphas = alphas.permute(1, 0, 2)
        salicy_seg = salicy_seg[:, :, :]
        seg_mask = seg_mask[:, :]
        if opt.salicy_hard == False:
            if opt.salicy_loss_type == 'l2':
                salicy_loss = (((((salicy_seg * seg_mask[:, :, None] -
                                   alphas * seg_mask[:, :, None])**2).sum(0)
                                 ).sum(-1))**(0.5)).mean()
            if opt.salicy_loss_type == 'kl':
                #alphas: len_sen, batch_size, num_frame
                salicy_loss = kullback_leibler2(
                    alphas * seg_mask[:, :, None],
                    salicy_seg * seg_mask[:, :, None])
                salicy_loss = (((salicy_loss *
                                 seg_mask[:, :, None]).sum(-1)).sum(0)).mean()
        elif opt.salicy_hard == True:
            #salicy len_sen, batch_size, num_frame
            salicy_loss = -torch.log((alphas * salicy_seg).sum(-1) + 1e-8)
            #salicy_loss len_sen, batch_size
            salicy_loss = ((salicy_loss * seg_mask).sum(0)).mean()
        loss = loss + opt.salicy_alpha * salicy_loss
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.module.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.module.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.dataset,
                'remove_len': remove_len
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats, score_list_i = eval_utils.eval_split(
                model.module, crit, loader, eval_kwargs)
            score_list.append(score_list_i)
            np.savetxt('./save/train_valid_test.txt', score_list, fmt='%.3f')
            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k in lang_stats.keys():
                    for v in lang_stats[k].keys():
                        add_summary_value(tf_summary_writer, k + v,
                                          lang_stats[k][v], iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['val']['CIDEr']
            else:
                current_score = -val_loss
            best_flag = {}
            for split_i in score_splits:
                for score_item in score_type:
                    if split_i not in best_flag.keys():
                        best_flag[split_i] = {}
                    best_flag[split_i][score_item] = False
            if True:  # if true
                for split_i in score_splits:
                    for score_item in score_type:
                        if best_val_score is None or lang_stats[split_i][
                                score_item] > best_val_score[split_i][
                                    score_item]:
                            best_val_score[split_i][score_item] = lang_stats[
                                split_i][score_item]
                            best_flag[split_i][score_item] = True

                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.module.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                for split_i in score_splits:
                    for score_item in score_type:
                        if best_flag[split_i][score_item]:
                            checkpoint_path = os.path.join(
                                opt.checkpoint_path, 'model-best_' + split_i +
                                '_' + score_item + '.pth')
                            torch.save(model.module.state_dict(),
                                       checkpoint_path)
                            print("model saved to {}".format(checkpoint_path))
                            with open(
                                    os.path.join(
                                        opt.checkpoint_path,
                                        'infos_' + split_i + '_' + score_item +
                                        '_' + opt.id + '-best.pkl'),
                                    'wb') as f:
                                cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #25
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_fc = utils.if_use_fc(opt.caption_model)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = load_info(opt)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Define and load model, optimizer, critics
    decoder = setup(opt).train().cuda()
    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing).cuda()
    else:
        crit = utils.LanguageModelCriterion().cuda()
    # crit = utils.LanguageModelCriterion().cuda()
    rl_crit = utils.RewardCriterion().cuda()
    if opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(decoder.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(decoder.parameters(), opt)
    # optimizer = utils.build_optimizer(decoder.parameters(), opt)
    models = {'decoder': decoder}
    optimizers = {'decoder': optimizer}
    save_nets_structure(models, opt)
    load_checkpoint(models, optimizers, opt)
    print('opt', opt)

    epoch_done = True
    sc_flag = False
    while True:
        if epoch_done:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                decoder.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False
            epoch_done = False

        # 1. fetch a batch of data from train split
        data = loader.get_batch('train')
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['tags'],
            data['masks'], data['att_masks'], data['verbs']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, tags, masks, att_masks, weak_relas = tmp
        vrg_data = {key: data['vrg_data'][key] if data['vrg_data'][key] is None \
            else torch.from_numpy(data['vrg_data'][key]).cuda() for key in data['vrg_data']}

        # 2. Forward model and compute loss
        torch.cuda.synchronize()
        optimizer.zero_grad()
        if not sc_flag:
            out = decoder(vrg_data, fc_feats, att_feats, labels, weak_relas,
                          att_masks)
            loss_words = crit(out[0], labels[:, 1:], masks[:, 1:])
            loss_tags = crit(out[1], tags[:, 1:], masks[:, 1:])
            loss = loss_words + loss_tags * 0.15
        else:
            gen_result, sample_logprobs, core_args = decoder(
                vrg_data,
                fc_feats,
                att_feats,
                weak_relas,
                att_masks,
                opt={
                    'sample_max': 0,
                    'return_core_args': True
                },
                mode='sample')
            reward = get_self_critical_reward(decoder, core_args, vrg_data,
                                              fc_feats, att_feats, weak_relas,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # 3. Update model
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Update the iteration and epoch
        iteration += 1
        # Write the training loss summary
        if (iteration % opt.log_loss_every == 0):
            # logging log
            logger.info("{} ({}), loss: {:.3f}".format(iteration, epoch,
                                                       train_loss))
            tb.add_values('loss', {'train': train_loss}, iteration)

        if data['bounds']['wrapped']:
            epoch += 1
            epoch_done = True

        # Make evaluation and save checkpoint
        if (opt.save_checkpoint_every > 0
                and iteration % opt.save_checkpoint_every
                == 0) or (opt.save_checkpoint_every == -1 and epoch_done):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'expand_features': False
            }
            eval_kwargs.update(vars(opt))
            predictions, lang_stats = eval_utils.eval_split(
                decoder, loader, eval_kwargs)

            if opt.reduce_on_plateau:
                assert 'CIDEr' in lang_stats, 'Error: cider should be in eval list'
                optimizer.scheduler_step(-lang_stats['CIDEr'])

            # log val results
            if not lang_stats is None:
                logger.info("Scores: {}".format(lang_stats))
                tb.add_values('scores', lang_stats, epoch)
            val_result_history[epoch] = {
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            current_score = 0 if lang_stats is None else lang_stats['CIDEr']
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            infos['val_result_history'] = val_result_history

            save_checkpoint(models, optimizers, infos, best_flag, opt)

        # Stop if reaching max epochs
        if epoch > opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #26
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size1", "rnn_size2",
                "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    # loader.iterators = infos.get('iterators', loader.iterators)
    # loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()
    # model.set_mode('train')

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        model.train()
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_cider_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train+val')
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['num_bbox'],
            data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_).float(), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, num_bbox, labels, masks = tmp
        labels = labels.long()

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(model(fc_feats, att_feats, num_bbox, labels),
                        labels[:, 1:], masks[:, 1:])
            # loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs = model.sample(fc_feats, att_feats,
                                                       num_bbox,
                                                       {'sample_max': 0})
            reward = get_self_critical_reward(model, fc_feats, att_feats,
                                              num_bbox, data, gen_result)
            loss = rl_crit(
                sample_logprobs, gen_result,
                Variable(torch.from_numpy(reward).float().cuda(),
                         requires_grad=False))

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            if (iteration % 100 == 0):
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f} lr={}" \
                 .format(iteration, epoch, train_loss, end - start, opt.current_lr ))
        else:
            if (iteration % 100 == 0):
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f} lr={}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start, opt.current_lr ))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'val_ref_path': opt.val_ref_path,
                'raw_val_anno_path': opt.raw_val_anno_path
            }
            eval_kwargs.update(vars(opt))
            # predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            best_flag = False
            if True:  # if true
                # if best_val_score is None or current_score > best_val_score:
                # 	best_val_score = current_score
                # 	best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #27
0
ファイル: train.py プロジェクト: HakuSean/caption-transfer
def train(opt):
    opt['source'].use_att = utils.if_use_att(opt['source'].caption_model)
    loader = {}
    loader['source'] = DataLoader(opt['source'])
    opt['source'].vocab_size = loader[
        'source'].vocab_size  # vocab and seq_length all follows the source dataset
    opt['source'].seq_length = loader['source'].seq_length
    loader['target'] = DataLoader(opt['target'],
                                  target=True,
                                  opt_extra=opt['source'])

    tf_summary_writer = tf and tf.summary.FileWriter(
        opt['source'].checkpoint_path)

    infos = {}
    histories = {}
    if opt['source'].start_from is not None:  # the state of previous training session
        # open old infos and check if models are compatible
        with open(
                os.path.join(opt['source'].start_from,
                             'infos_' + opt['source'].id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']['source']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:  # the key arguments of RNN model should keep the same
                assert vars(saved_model_opt)[checkme] == vars(
                    opt['source']
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt['source'].start_from,
                             'histories_' + opt['source'].id + '.pkl')):
            with open(
                    os.path.join(opt['source'].start_from, 'histories_' +
                                 opt['source'].id + '.pkl')) as f:
                histories = cPickle.load(
                    f
                )  # it is the histories file that be used in the following training

    iteration = infos.get(
        'iter',
        0)  # obtain the iteration number, if not defined, then return 0
    epoch = infos.get(
        'epoch', 0)  # obtain the epoch number, if not defined, then return 0

    val_result_history = histories.get('val_result_history',
                                       {})  # if not defined, return {}
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader['source'].iterator = infos.get(
        'iterator',
        loader['source'].iterator)  # if not defined, return loader.iterator
    loader['source'].split_ix = infos.get('split_ix',
                                          loader['source'].split_ix)
    if opt['source'].load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(
        opt['source']
    )  # opt contains the model to use, here is the caption training model, the definitions are in __init__.py
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()  # for later learning

    optimizer = optim.Adam(model.parameters(),
                           lr=opt['source'].learning_rate,
                           weight_decay=opt['source'].weight_decay)

    # Load the optimizer, if training from another status instead of scratch
    if vars(opt['source']).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt['source'].start_from,
                                    'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt['source'].learning_rate_decay_start and opt[
                    'source'].learning_rate_decay_start >= 0:
                frac = (epoch - opt['source'].learning_rate_decay_start
                        ) // opt['source'].learning_rate_decay_every
                decay_factor = opt['source'].learning_rate_decay_rate**frac
                opt['source'].current_lr = opt[
                    'source'].learning_rate * decay_factor
                utils.set_lr(optimizer,
                             opt['source'].current_lr)  # set the decayed rate
            else:
                opt['source'].current_lr = opt['source'].learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt['source'].scheduled_sampling_start and opt[
                    'source'].scheduled_sampling_start >= 0:
                frac = (epoch - opt['source'].scheduled_sampling_start
                        ) // opt['source'].scheduled_sampling_increase_every
                opt['source'].ss_prob = min(
                    opt['source'].scheduled_sampling_increase_prob * frac,
                    opt['source'].scheduled_sampling_max_prob)
                model.ss_prob = opt['source'].ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader['source'].get_batch()
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                    masks[:, 1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt['source'].grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt['source'].losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt['source'].current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt['source'].current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt['source'].save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'dataset': opt['target'].input_target}
            eval_kwargs.update(
                vars(opt['target'])
            )  # the final version of eval_kwargs is a dict same as opt added with the previous key
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader['target'], eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt['target'].language_eval == 1:
                current_score = lang_stats['CIDEr']  # only based on CIDEr.
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:  # choose the best model based on the validation score
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(
                    opt['source'].checkpoint_path,
                    'model' + str(iteration) + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(
                    opt['source'].checkpoint_path,
                    'optimizer' + str(iteration) + '.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterator'] = loader['source'].iterator
                infos['split_ix'] = loader['source'].split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader['source'].get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(
                            opt['source'].checkpoint_path, 'infos_' +
                            opt['source'].id + str(iteration) + '.pkl'),
                        'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt['source'].checkpoint_path, 'histories_' +
                            opt['source'].id + str(iteration) + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(
                        opt['source'].checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(
                                opt['source'].checkpoint_path,
                                'infos_' + opt['source'].id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt['source'].max_epochs and opt['source'].max_epochs != -1:
            break
コード例 #28
0
ファイル: train.py プロジェクト: cxqj/ECHR
def train(opt):
    exclude_opt = [
        'training_mode', 'tap_epochs', 'cg_epochs', 'tapcg_epochs', 'lr',
        'learning_rate_decay_start', 'learning_rate_decay_every',
        'learning_rate_decay_rate', 'self_critical_after',
        'save_checkpoint_every', 'id', "pretrain", "pretrain_path", "debug",
        "save_all_checkpoint", "min_epoch_when_save"
    ]

    save_folder, logger, tf_writer = build_floder_and_create_logger(opt)
    saved_info = {'best': {}, 'last': {}, 'history': {}}
    is_continue = opt.start_from != None

    if is_continue:
        infos_path = os.path.join(save_folder, 'info.pkl')
        with open(infos_path) as f:
            logger.info('load info from {}'.format(infos_path))
            saved_info = cPickle.load(f)
            pre_opt = saved_info[opt.start_from_mode]['opt']
            if vars(opt).get("no_exclude_opt", False):
                exclude_opt = []
            for opt_name in vars(pre_opt).keys():
                if (not opt_name in exclude_opt):
                    vars(opt).update({opt_name: vars(pre_opt).get(opt_name)})
                if vars(pre_opt).get(opt_name) != vars(opt).get(opt_name):
                    print('change opt: {} from {} to {}'.format(
                        opt_name,
                        vars(pre_opt).get(opt_name),
                        vars(opt).get(opt_name)))

    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.CG_vocab_size = loader.vocab_size
    opt.CG_seq_length = loader.seq_length

    # init training option
    epoch = saved_info[opt.start_from_mode].get('epoch', 0)
    iteration = saved_info[opt.start_from_mode].get('iter', 0)
    best_val_score = saved_info[opt.start_from_mode].get('best_val_score', 0)
    val_result_history = saved_info['history'].get('val_result_history', {})
    loss_history = saved_info['history'].get('loss_history', {})
    lr_history = saved_info['history'].get('lr_history', {})
    loader.iterators = saved_info[opt.start_from_mode].get(
        'iterators', loader.iterators)
    loader.split_ix = saved_info[opt.start_from_mode].get(
        'split_ix', loader.split_ix)
    opt.current_lr = vars(opt).get('current_lr', opt.lr)
    opt.m_batch = vars(opt).get('m_batch', 1)

    # create a tap_model,fusion_model,cg_model

    tap_model = models.setup_tap(opt)
    lm_model = CaptionGenerator(opt)
    cg_model = lm_model

    if is_continue:
        if opt.start_from_mode == 'best':
            model_pth = torch.load(os.path.join(save_folder, 'model-best.pth'))
        elif opt.start_from_mode == 'last':
            model_pth = torch.load(
                os.path.join(save_folder,
                             'model_iter_{}.pth'.format(iteration)))
        assert model_pth['iteration'] == iteration
        logger.info('Loading pth from {}, iteration:{}'.format(
            save_folder, iteration))
        tap_model.load_state_dict(model_pth['tap_model'])
        cg_model.load_state_dict(model_pth['cg_model'])

    elif opt.pretrain:
        print('pretrain {} from {}'.format(opt.pretrain, opt.pretrain_path))
        model_pth = torch.load(opt.pretrain_path)
        if opt.pretrain == 'tap':
            tap_model.load_state_dict(model_pth['tap_model'])
        elif opt.pretrain == 'cg':
            cg_model.load_state_dict(model_pth['cg_model'])
        elif opt.pretrain == 'tap_cg':
            tap_model.load_state_dict(model_pth['tap_model'])
            cg_model.load_state_dict(model_pth['cg_model'])
        else:
            assert 1 == 0, 'opt.pretrain error'

    tap_model.cuda()
    tap_model.train()  # Assure in training mode

    tap_crit = utils.TAPModelCriterion()

    tap_optimizer = optim.Adam(tap_model.parameters(),
                               lr=opt.lr,
                               weight_decay=opt.weight_decay)

    cg_model.cuda()
    cg_model.train()
    cg_optimizer = optim.Adam(cg_model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)
    cg_crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    cg_optimizer = optim.Adam(cg_model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)

    allmodels = [tap_model, cg_model]
    optimizers = [tap_optimizer, cg_optimizer]

    if is_continue:
        tap_optimizer.load_state_dict(model_pth['tap_optimizer'])
        cg_optimizer.load_state_dict(model_pth['cg_optimizer'])

    update_lr_flag = True
    loss_sum = np.zeros(5)
    bad_video_num = 0
    best_epoch = epoch
    start = time.time()

    print_opt(opt, allmodels, logger)
    logger.info('\nStart training')

    # set a var to indicate what to train in current iteration: "tap", "cg" or "tap_cg"
    flag_training_whats = get_training_list(opt, logger)

    # Iteration begin
    while True:
        if update_lr_flag:
            if (epoch > opt.learning_rate_decay_start
                    and opt.learning_rate_decay_start >= 0):
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.lr * decay_factor
            else:
                opt.current_lr = opt.lr
            for optimizer in optimizers:
                utils.set_lr(optimizer, opt.current_lr)
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(None)
            else:
                sc_flag = False
            update_lr_flag = False

        flag_training_what = flag_training_whats[epoch]
        if opt.training_mode == "alter2":
            flag_training_what = flag_training_whats[iteration]

        # get data
        data = loader.get_batch('train')

        if opt.debug:
            print('vid:', data['vid'])
            print('info:', data['infos'])

        torch.cuda.synchronize()

        if (data["proposal_num"] <= 0) or (data['fc_feats'].shape[0] <= 1):
            bad_video_num += 1  # print('vid:{} has no good proposal.'.format(data['vid']))
            continue

        ind_select_list, soi_select_list, cg_select_list, sampled_ids, = data[
            'ind_select_list'], data['soi_select_list'], data[
                'cg_select_list'], data['sampled_ids']

        if flag_training_what == 'cg' or flag_training_what == 'gt_tap_cg':
            ind_select_list = data['gts_ind_select_list']
            soi_select_list = data['gts_soi_select_list']
            cg_select_list = data['gts_cg_select_list']

        tmp = [
            data['fc_feats'], data['att_feats'], data['lda_feats'],
            data['tap_labels'], data['tap_masks_for_loss'],
            data['cg_labels'][cg_select_list],
            data['cg_masks'][cg_select_list], data['w1']
        ]

        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]

        c3d_feats, att_feats, lda_feats, tap_labels, tap_masks_for_loss, cg_labels, cg_masks, w1 = tmp

        if (iteration - 1) % opt.m_batch == 0:
            tap_optimizer.zero_grad()
            cg_optimizer.zero_grad()

        tap_feats, pred_proposals = tap_model(c3d_feats)
        tap_loss = tap_crit(pred_proposals, tap_masks_for_loss, tap_labels, w1)

        loss_sum[0] = loss_sum[0] + tap_loss.item()

        # Backward Propagation
        if flag_training_what == 'tap':
            tap_loss.backward()
            utils.clip_gradient(tap_optimizer, opt.grad_clip)
            if iteration % opt.m_batch == 0:
                tap_optimizer.step()
        else:
            if not sc_flag:
                pred_captions = cg_model(tap_feats,
                                         c3d_feats,
                                         lda_feats,
                                         cg_labels,
                                         ind_select_list,
                                         soi_select_list,
                                         mode='train')
                cg_loss = cg_crit(pred_captions, cg_labels[:, 1:],
                                  cg_masks[:, 1:])

            else:
                gen_result, sample_logprobs, greedy_res = cg_model(
                    tap_feats,
                    c3d_feats,
                    lda_feats,
                    cg_labels,
                    ind_select_list,
                    soi_select_list,
                    mode='train_rl')
                sentence_info = data['sentences_batch'] if (
                    flag_training_what != 'cg'
                    and flag_training_what != 'gt_tap_cg'
                ) else data['gts_sentences_batch']

                reward = get_self_critical_reward2(
                    greedy_res, (data['vid'], sentence_info),
                    gen_result,
                    vocab=loader.get_vocab(),
                    opt=opt)
                cg_loss = rl_crit(sample_logprobs, gen_result,
                                  torch.from_numpy(reward).float().cuda())

            loss_sum[1] = loss_sum[1] + cg_loss.item()

            if flag_training_what == 'cg' or flag_training_what == 'gt_tap_cg' or flag_training_what == 'LP_cg':
                cg_loss.backward()

                utils.clip_gradient(cg_optimizer, opt.grad_clip)
                if iteration % opt.m_batch == 0:
                    cg_optimizer.step()
                if flag_training_what == 'gt_tap_cg':
                    utils.clip_gradient(tap_optimizer, opt.grad_clip)
                    if iteration % opt.m_batch == 0:
                        tap_optimizer.step()
            elif flag_training_what == 'tap_cg':
                total_loss = opt.lambda1 * tap_loss + opt.lambda2 * cg_loss
                total_loss.backward()
                utils.clip_gradient(tap_optimizer, opt.grad_clip)
                utils.clip_gradient(cg_optimizer, opt.grad_clip)
                if iteration % opt.m_batch == 0:
                    tap_optimizer.step()
                    cg_optimizer.step()

                loss_sum[2] = loss_sum[2] + total_loss.item()

        torch.cuda.synchronize()

        # Updating epoch num
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Print losses, Add to summary
        if iteration % opt.losses_log_every == 0:
            end = time.time()
            losses = np.round(loss_sum / opt.losses_log_every, 3)
            logger.info(
                "iter {} (epoch {}, lr {}), avg_iter_loss({}) = {}, time/batch = {:.3f}, bad_vid = {:.3f}" \
                    .format(iteration, epoch, opt.current_lr, flag_training_what, losses,
                            (end - start) / opt.losses_log_every,
                            bad_video_num))

            tf_writer.add_scalar('lr', opt.current_lr, iteration)
            tf_writer.add_scalar('train_tap_loss', losses[0], iteration)
            tf_writer.add_scalar('train_tap_prop_loss', losses[3], iteration)
            tf_writer.add_scalar('train_tap_bound_loss', losses[4], iteration)
            tf_writer.add_scalar('train_cg_loss', losses[1], iteration)
            tf_writer.add_scalar('train_total_loss', losses[2], iteration)
            if sc_flag and (not flag_training_what == 'tap'):
                tf_writer.add_scalar('avg_reward', np.mean(reward[:, 0]),
                                     iteration)
            loss_history[iteration] = losses
            lr_history[iteration] = opt.current_lr
            loss_sum = np.zeros(5)
            start = time.time()
            bad_video_num = 0

        # Evaluation, and save model
        if (iteration % opt.save_checkpoint_every
                == 0) and (epoch >= opt.min_epoch_when_save):
            eval_kwargs = {
                'split': 'val',
                'val_all_metrics': 0,
                'topN': 100,
            }

            eval_kwargs.update(vars(opt))

            # eval_kwargs['num_vids_eval'] = int(491)
            eval_kwargs['topN'] = 100

            eval_kwargs2 = {
                'split': 'val',
                'val_all_metrics': 1,
                'num_vids_eval': 4917,
            }
            eval_kwargs2.update(vars(opt))

            if not opt.num_vids_eval:
                eval_kwargs['num_vids_eval'] = int(4917.)
                eval_kwargs2['num_vids_eval'] = 4917

            crits = [tap_crit, cg_crit]
            pred_json_path_T = os.path.join(save_folder, 'pred_sent',
                                            'pred_num{}_iter{}.json')

            # if 'alter' in opt.training_mode:
            if flag_training_what == 'tap':
                eval_kwargs['topN'] = 1000
                predictions, eval_score, val_loss = eval_utils.eval_split(
                    allmodels,
                    crits,
                    loader,
                    pred_json_path_T.format(eval_kwargs['num_vids_eval'],
                                            iteration),
                    eval_kwargs,
                    flag_eval_what='tap')
            else:
                if vars(opt).get('fast_eval_cg', False) == False:
                    predictions, eval_score, val_loss = eval_utils.eval_split(
                        allmodels,
                        crits,
                        loader,
                        pred_json_path_T.format(eval_kwargs['num_vids_eval'],
                                                iteration),
                        eval_kwargs,
                        flag_eval_what='tap_cg')

                predictions2, eval_score2, val_loss2 = eval_utils.eval_split(
                    allmodels,
                    crits,
                    loader,
                    pred_json_path_T.format(eval_kwargs2['num_vids_eval'],
                                            iteration),
                    eval_kwargs2,
                    flag_eval_what='cg')

                if (not vars(opt).get('fast_eval_cg', False)
                        == False) or (not vars(opt).get(
                            'fast_eval_cg_top10', False) == False):
                    eval_score = eval_score2
                    val_loss = val_loss2
                    predictions = predictions2

            # else:
            #    predictions, eval_score, val_loss = eval_utils.eval_split(allmodels, crits, loader, pred_json_path,
            #                                                              eval_kwargs,
            #                                                              flag_eval_what=flag_training_what)

            f_f1 = lambda x, y: 2 * x * y / (x + y)
            f1 = f_f1(eval_score['Recall'], eval_score['Precision']).mean()
            if flag_training_what != 'tap':  # if only train tap, use the mean of precision and recall as final score
                current_score = np.array(eval_score['METEOR']).mean() * 100
            else:  # if train tap_cg, use avg_meteor as final score
                current_score = f1

            for model in allmodels:
                for name, param in model.named_parameters():
                    tf_writer.add_histogram(name,
                                            param.clone().cpu().data.numpy(),
                                            iteration,
                                            bins=10)
                    if param.grad is not None:
                        tf_writer.add_histogram(
                            name + '_grad',
                            param.grad.clone().cpu().data.numpy(),
                            iteration,
                            bins=10)

            tf_writer.add_scalar('val_tap_loss', val_loss[0], iteration)
            tf_writer.add_scalar('val_cg_loss', val_loss[1], iteration)
            tf_writer.add_scalar('val_tap_prop_loss', val_loss[3], iteration)
            tf_writer.add_scalar('val_tap_bound_loss', val_loss[4], iteration)
            tf_writer.add_scalar('val_total_loss', val_loss[2], iteration)
            tf_writer.add_scalar('val_score', current_score, iteration)
            if flag_training_what != 'tap':
                tf_writer.add_scalar('val_score_gt_METEOR',
                                     np.array(eval_score2['METEOR']).mean(),
                                     iteration)
                tf_writer.add_scalar('val_score_gt_Bleu_4',
                                     np.array(eval_score2['Bleu_4']).mean(),
                                     iteration)
                tf_writer.add_scalar('val_score_gt_CIDEr',
                                     np.array(eval_score2['CIDEr']).mean(),
                                     iteration)
            tf_writer.add_scalar('val_recall', eval_score['Recall'].mean(),
                                 iteration)
            tf_writer.add_scalar('val_precision',
                                 eval_score['Precision'].mean(), iteration)
            tf_writer.add_scalar('f1', f1, iteration)

            val_result_history[iteration] = {
                'val_loss': val_loss,
                'eval_score': eval_score
            }

            if flag_training_what == 'tap':
                logger.info(
                    'Validation the result of iter {}, score(f1/meteor):{},\n all:{}'
                    .format(iteration, current_score, eval_score))
            else:
                mean_score = {
                    k: np.array(v).mean()
                    for k, v in eval_score.items()
                }
                gt_mean_score = {
                    k: np.array(v).mean()
                    for k, v in eval_score2.items()
                }

                metrics = ['Bleu_4', 'CIDEr', 'METEOR', 'ROUGE_L']
                gt_avg_score = np.array([
                    v for metric, v in gt_mean_score.items()
                    if metric in metrics
                ]).sum()
                logger.info(
                    'Validation the result of iter {}, score(f1/meteor):{},\n all:{}\n mean:{} \n\n gt:{} \n mean:{}\n avg_score: {}'
                    .format(iteration, current_score, eval_score, mean_score,
                            eval_score2, gt_mean_score, gt_avg_score))

            # Save model .pth
            saved_pth = {
                'iteration': iteration,
                'cg_model': cg_model.state_dict(),
                'tap_model': tap_model.state_dict(),
                'cg_optimizer': cg_optimizer.state_dict(),
                'tap_optimizer': tap_optimizer.state_dict(),
            }

            if opt.save_all_checkpoint:
                checkpoint_path = os.path.join(
                    save_folder, 'model_iter_{}.pth'.format(iteration))
            else:
                checkpoint_path = os.path.join(save_folder, 'model.pth')
            torch.save(saved_pth, checkpoint_path)
            logger.info('Save model at iter {} to checkpoint file {}.'.format(
                iteration, checkpoint_path))

            # save info.pkl
            if current_score > best_val_score:
                best_val_score = current_score
                best_epoch = epoch
                saved_info['best'] = {
                    'opt': opt,
                    'iter': iteration,
                    'epoch': epoch,
                    'iterators': loader.iterators,
                    'flag_training_what': flag_training_what,
                    'split_ix': loader.split_ix,
                    'best_val_score': best_val_score,
                    'vocab': loader.get_vocab(),
                }

                best_checkpoint_path = os.path.join(save_folder,
                                                    'model-best.pth')
                torch.save(saved_pth, best_checkpoint_path)
                logger.info(
                    'Save Best-model at iter {} to checkpoint file.'.format(
                        iteration))

            saved_info['last'] = {
                'opt': opt,
                'iter': iteration,
                'epoch': epoch,
                'iterators': loader.iterators,
                'flag_training_what': flag_training_what,
                'split_ix': loader.split_ix,
                'best_val_score': best_val_score,
                'vocab': loader.get_vocab(),
            }
            saved_info['history'] = {
                'val_result_history': val_result_history,
                'loss_history': loss_history,
                'lr_history': lr_history,
            }
            with open(os.path.join(save_folder, 'info.pkl'), 'w') as f:
                cPickle.dump(saved_info, f)
                logger.info('Save info to info.pkl')

            # Stop criterion
            if epoch >= len(flag_training_whats):
                tf_writer.close()
                break
コード例 #29
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    ac = 0

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + format(int(opt.start_from),'04') + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + format(int(opt.start_from),'04') + '.pkl')):
            with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + format(int(opt.start_from),'04') + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()

    #dp_model = torch.nn.DataParallel(model)
    #dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    for name, param in model.named_parameters():
        print(name)

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    CE_ac = utils.CE_ac()

    optim_para = model.parameters()
    optimizer = utils.build_optimizer(optim_para, opt)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.checkpoint_path, 'optimizer' + opt.id + format(int(opt.start_from),'04')+'.pth')):
        optimizer.load_state_dict(torch.load(os.path.join(
            opt.checkpoint_path, 'optimizer' + opt.id + format(int(opt.start_from),'04')+'.pth')))

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1,1])
    sim_lambda = opt.sim_lambda
    reset_optimzer_index = 1
    while True:
        if opt.self_critical_after != -1 and epoch >= opt.self_critical_after and reset_optimzer_index :
            opt.learning_rate_decay_start = opt.self_critical_after
            opt.learning_rate_decay_rate = opt.learning_rate_decay_rate_rl
            opt.learning_rate = opt.learning_rate_rl
            reset_optimzer_index = 0


        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['labels'], data['masks'], data['mods']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        labels, masks, mods = tmp

        tmp = [data['att_feats'], data['att_masks'], data['attr_feats'], data['attr_masks'],data['rela_feats'], data['rela_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        att_feats, att_masks, attr_feats, attr_masks, rela_feats, rela_masks = tmp

        rs_data = {}
        rs_data['att_feats'] = att_feats
        rs_data['att_masks'] = att_masks
        rs_data['attr_feats'] = attr_feats
        rs_data['attr_masks'] = attr_masks
        rs_data['rela_feats'] = rela_feats
        rs_data['rela_masks'] = rela_masks

        if not sc_flag:
            logits, cw_logits = dp_model(rs_data, labels)
            ac = CE_ac(logits,labels[:,1:], masks[:,1:])
            print('ac :{0}'.format(ac))
            loss_lan = crit(logits,labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs, cw_logits = dp_model(rs_data,
                                                   opt={'sample_max':0}, mode='sample')
            reward = get_self_critical_reward(dp_model, rs_data, data, gen_result, opt)
            loss_lan = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())

        loss_cw = crit(cw_logits, mods[:, 1:], masks[:, 1:])
        ac2 = CE_ac(cw_logits, mods[:, 1:], masks[:, 1:])
        print('ac :{0}'.format(ac2))
        if epoch < opt.step2_train_after:
            loss = loss_lan + sim_lambda*loss_cw
        else:
            loss = loss_lan

        accumulate_iter =  accumulate_iter + 1
        loss = loss/opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item()*opt.accumulate_number
            train_loss_lan = loss_lan.item()
            train_loss_cw = loss_cw.item()
            end = time.time()

            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            print('lr:{0}'.format(opt.current_lr))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (accumulate_iter % opt.accumulate_number == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_lan', train_loss_lan, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_cw', train_loss_cw, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            add_summary_value(tb_summary_writer, 'ac', ac, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0) and (accumulate_iter % opt.accumulate_number == 0):
            # eval model
            eval_kwargs = {'split': 'test',
                               'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            #val_loss, predictions, lang_stats = eval_utils_rs3.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            # add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            # if lang_stats is not None:
            #     for k,v in lang_stats.items():
            #         add_summary_value(tb_summary_writer, k, v, iteration)
            # val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            # if opt.language_eval == 1:
            #     current_score = lang_stats['CIDEr']
            # else:
            #     current_score = - val_loss
            current_score=0

            best_flag = False
            if True: # if true
                save_id = iteration/opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model'+opt.id+format(int(save_id),'04')+'.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer'+opt.id+format(int(save_id),'04')+'.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+format(int(save_id),'04')+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+format(int(save_id),'04')+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #30
0
ファイル: train.py プロジェクト: vgilad/DiscCaptioning-1
def train(opt):
    opt.use_att = utils.if_use_att(opt)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_val_score_vse = infos.get('best_val_score_vse', None)

    model = models.JointModel(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, 'optimizer.pth')):
        state_dict = torch.load(os.path.join(opt.start_from, 'optimizer.pth'))
        if len(state_dict['state']) == len(optimizer.state_dict()['state']):
            optimizer.load_state_dict(state_dict)
        else:
            print(
                'Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.'
            )

    init_scorer(opt.cached_tokens)
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.caption_generator.ss_prob = opt.ss_prob
            # Assign retrieval loss weight
            if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0:
                frac = (epoch - opt.retrieval_reward_weight_decay_start
                        ) // opt.retrieval_reward_weight_decay_every
                model.retrieval_reward_weight = opt.retrieval_reward_weight * (
                    opt.retrieval_reward_weight_decay_rate**frac)
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['att_masks'],
            data['labels'], data['masks']
        ]
        tmp = utils.var_wrapper(tmp)
        fc_feats, att_feats, att_masks, labels, masks = tmp

        optimizer.zero_grad()

        loss = model(fc_feats, att_feats, att_masks, labels, masks, data)
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))
        prt_str = ""
        for k, v in model.loss().items():
            prt_str += "{} = {:.3f} ".format(k, v)
        print(prt_str)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                tf_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                for k, v in model.loss().items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tf_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.caption_generator.ss_prob,
                                             iteration)
                tf_summary_writer.add_scalar('retrieval_reward_weight',
                                             model.retrieval_reward_weight,
                                             iteration)
                tf_summary_writer.file_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.caption_generator.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # Load the retrieval model for evaluation
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                for k, v in val_loss.items():
                    tf_summary_writer.add_scalar('validation ' + k, v,
                                                 iteration)
                for k, v in lang_stats.items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_text(
                    'Captions',
                    '.\n\n'.join([_['caption'] for _ in predictions[:100]]),
                    iteration)
                #tf_summary_writer.add_image('images', utils.make_summary_image(), iteration)
                #utils.make_html(opt.id, iteration)
                tf_summary_writer.file_writer.flush()

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['SPICE'] * 100
            else:
                current_score = -val_loss['loss_cap']
            current_score_vse = val_loss.get(opt.vse_eval_criterion, 0) * 100

            best_flag = False
            best_flag_vse = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if best_val_score_vse is None or current_score_vse > best_val_score_vse:
                    best_val_score_vse = current_score_vse
                    best_flag_vse = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model-%d.pth' % (iteration))
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['best_val_score_vse'] = best_val_score_vse
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '-%d.pkl' % (iteration)),
                        'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)
                if best_flag_vse:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model_vse-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_vse_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #31
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    logger = Logger(opt)

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer_G = optim.Adam(model.parameters(),
                             lr=opt.learning_rate,
                             weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer_G.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            opt, sc_flag, update_lr_flag, model, optimizer_G = update_lr(
                opt, epoch, model, optimizer_G)

        # Load data from train split (0)
        data = loader.get_batch('train', seq_per_img=opt.seq_per_img)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer_G.zero_grad()
        if not sc_flag:
            loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                        masks[:, 1:])
            loss.backward()
        else:
            pass

        utils.clip_gradient(optimizer_G, opt.grad_clip)
        optimizer_G.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()

        log = "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, loss.data.cpu().numpy()[0], end - start)
        logger.write(log)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))

            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, logger, eval_kwargs)
            logger.write_dict(lang_stats)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer_G.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
コード例 #32
0
ファイル: train.py プロジェクト: cfh3c/caption.selfcritic.gan
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    #model_D = Discriminator(opt)
    #model_D.load_state_dict(torch.load('save/model_D.pth'))
    #model_D.cuda()
    #criterion_D = nn.CrossEntropyLoss(size_average=True)

    model_E = Distance(opt)
    model_E.load_state_dict(
        torch.load('save/model_E_NCE/model_E_10epoch.pthsfdasdfadf'))
    model_E.cuda()
    criterion_E = nn.CosineEmbeddingLoss(margin=0, size_average=True)
    #criterion_E = nn.CosineSimilarity()

    logger = Logger(opt)

    update_lr_flag = True
    # Assure in training mode
    model.train()
    #model_D.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer_G = optim.Adam(model.parameters(),
                             lr=opt.learning_rate,
                             weight_decay=opt.weight_decay)
    #optimizer_D = optim.Adam(model_D.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer_G.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            opt, sc_flag, update_lr_flag, model, optimizer_G = update_lr(
                opt, epoch, model, optimizer_G)

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        #print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        #tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [data['fc_feats'], data['labels'], data['masks']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        #fc_feats, att_feats, labels, masks = tmp
        fc_feats, labels, masks = tmp

        ############################################################################################################
        ############################################ REINFORCE TRAINING ############################################
        ############################################################################################################
        if 1:  #iteration % opt.D_scheduling != 0:
            optimizer_G.zero_grad()
            if not sc_flag:
                loss = crit(model(fc_feats, labels), labels[:, 1:], masks[:,
                                                                          1:])
            else:
                gen_result, sample_logprobs = model.sample(
                    fc_feats, {'sample_max': 0})
                #reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result)
                sc_reward = get_self_critical_reward(model, fc_feats, data,
                                                     gen_result, logger)
                #gan_reward = get_gan_reward(model, model_D, criterion_D, fc_feats, data, logger)                 # Criterion_D = nn.XEloss()
                distance_loss_reward1 = get_distance_reward(
                    model,
                    model_E,
                    criterion_E,
                    fc_feats,
                    data,
                    logger,
                    is_mismatched=False)  # criterion_E = nn.CosEmbedLoss()
                distance_loss_reward2 = get_distance_reward(
                    model,
                    model_E,
                    criterion_E,
                    fc_feats,
                    data,
                    logger,
                    is_mismatched=True)  # criterion_E = nn.CosEmbedLoss()
                #cosine_reward = get_distance_reward(model, model_E, criterion_E, fc_feats, data, logger)         # criterion_E = nn.CosSim()
                reward = distance_loss_reward1 + distance_loss_reward2
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(reward).float().cuda(),
                             requires_grad=False))
                loss.backward()

            utils.clip_gradient(optimizer_G, opt.grad_clip)
            optimizer_G.step()
            train_loss = loss.data[0]
            torch.cuda.synchronize()
            end = time.time()

            if not sc_flag:
                log = "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start)
                logger.write(log)
            else:
                log = "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration,  epoch, np.mean(reward[:,0]), end - start)
                logger.write(log)

        ######################################################################################################
        ############################################ GAN TRAINING ############################################
        ######################################################################################################
        else:  #elif iteration % opt.D_scheduling == 0: # gan training
            model_D.zero_grad()
            optimizer_D.zero_grad()

            fc_feats_temp = Variable(fc_feats.data.cpu(), volatile=True).cuda()
            labels = Variable(labels.data.cpu()).cuda()

            sample_res, sample_logprobs = model.sample(
                fc_feats_temp, {'sample_max': 0})  #640, 16
            greedy_res, greedy_logprobs = model.sample(
                fc_feats_temp, {'sample_max': 1})  #640, 16
            gt_res = labels  # 640, 18

            sample_res_embed = model.embed(Variable(sample_res))
            greedy_res_embed = model.embed(Variable(greedy_res))
            gt_res_embed = model.embed(gt_res)

            f_label = Variable(
                torch.FloatTensor(data['fc_feats'].shape[0]).cuda())
            r_label = Variable(
                torch.FloatTensor(data['fc_feats'].shape[0]).cuda())
            f_label.data.fill_(0)
            r_label.data.fill_(1)

            f_D_output = model_D(sample_res_embed.detach(), fc_feats.detach())
            f_loss = criterion_D(f_D_output, f_label.long())
            f_loss.backward()

            r_D_output = model_D(gt_res_embed.detach(), fc_feats.detach())
            r_loss = criterion_D(r_D_output, r_label.long())
            r_loss.backward()

            D_loss = f_loss + r_loss
            optimizer_D.step()
            torch.cuda.synchronize()

            log = 'iter {} (epoch {}),  Discriminator loss : {}'.format(
                iteration, epoch,
                D_loss.data.cpu().numpy()[0])
            logger.write(log)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))

            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, logger, eval_kwargs)
            logger.write_dict(lang_stats)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer_G.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break