Esempio n. 1
0
def train(opt):

    """
    :param   caption decoder
    :param   VSE model : image encoder + caption encoder
    """

    """   loading VSE model    """
    # Construct the model
    vse = VSE(opt)
    opt.best = os.path.join('./vse/model_best.pth.tar')
    print("=> loading best checkpoint '{}'".format(opt.best))
    checkpoint = torch.load(opt.best)
    vse.load_state_dict(checkpoint['model'])
    vse.val_start()

    """   loading caption model    """
    opt.use_att = utils.if_use_att(opt.caption_model)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    save_path = os.path.join(opt.checkpoint_path,'CSGD')

    if not os.path.exists(save_path):
        os.makedirs(save_path, 0777)

    infos = {}
    histories = {}

    RL_trainmodel = os.path.join('RL_%s' % opt.caption_model)
    
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        start_from_path = os.path.join(opt.start_from,'CSGD')
        with open(os.path.join(start_from_path,'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(start_from_path, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(start_from_path, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    with open(os.path.join(RL_trainmodel,'MLE','infos_'+opt.id+'-best.pkl')) as f:
        infos_XE = cPickle.load(f)
        opt.learning_rate = infos_XE['opt'].current_lr
    
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    print(loader.iterators)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup_pro(opt)

    if vars(opt).get('start_from', None) is not None:

        start_from_path = os.path.join(opt.start_from,'CSGD')
        # check if all necessary files exist 
        assert os.path.isdir(opt.start_from)," %s must be a path" % opt.start_from
        assert os.path.isfile(os.path.join(start_from_path,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from
        assert os.path.isfile(os.path.join(start_from_path,"optimizer.pth")) ,"optimizer.pth.file does not exist in path %s"%opt.start_from

        model_path = os.path.join(start_from_path,'model.pth')
        optimizer_path = os.path.join(start_from_path,'optimizer.pth')

    else:
        model_path = os.path.join(RL_trainmodel,'MLE', 'model-best.pth')
        optimizer_path = os.path.join(RL_trainmodel,'MLE','optimizer-best.pth')

    model.load_state_dict(torch.load(model_path))
    print("model load from {}".format(model_path))  
  
    model.cuda()
    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
  
    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)
    optimizer.load_state_dict(torch.load(optimizer_path))
    print("optimizer load from {}".format(optimizer_path))   

    all_cider = 0 # for computing the average CIDEr score
    all_dis = 0 # for computing the discriminability percentage

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            model.ss_prob = 0.25
            print('learning_rate: %s' %str(opt.current_lr))
            update_lr_flag = False

            # start self critical training
            sc_flag = True

        data = loader.get_batch('train')

        torch.cuda.synchronize()
        start = time.time()

        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample
        tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
               data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
               data['knn_fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
               data['knn_att_feats'][np.arange(loader.batch_size) * loader.seq_per_img]]

        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, knn_fc_feats, knn_att_feats = tmp

        optimizer.zero_grad()

        gen_result, sample_logprobs = model.sample_score(fc_feats, att_feats, loader, {'sample_max': 0})
        gen_result_baseline, sample_b_logprobs = model.sample_score(fc_feats, att_feats, loader, {'sample_max': 0})
        bd_reward, sample_loss = get_bd_reward(vse, model, fc_feats, att_feats, data, gen_result,gen_result_baseline, loader)
        hd_reward = get_hd_reward(vse, model, fc_feats, knn_fc_feats, data, gen_result,gen_result_baseline, loader)

        cs_reward, m_cider = get_cs_reward(model, fc_feats, att_feats, data, gen_result, gen_result_baseline, loader)
        reward = cs_reward - opt.hdr_w * hd_reward - opt.bdr_w * bd_reward
        loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False))
        dis_number = (sample_loss < 0.4).float()
        dis_number = dis_number.data.cpu().numpy().sum()
        all_dis += dis_number
        all_cider += m_cider

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]

        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), hdr = {:.3f}, bdr = {:.3f}, csr = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, np.mean(hd_reward[:,0]), np.mean(bd_reward[:,0]), np.mean(cs_reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):

            loss_history[iteration] = np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = evalpro_utils.eval_split(model, crit, loader, eval_kwargs)
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                save_path1 = os.path.join(save_path, 'model.pth')
                if not os.path.exists(os.path.dirname(save_path1)):
                    os.makedirs(os.path.dirname(save_path1))
                torch.save(model.state_dict(), save_path1)
                print("model saved to {}".format(save_path1))

                optimizer_path1 = os.path.join(save_path, 'optimizer.pth')
                if not os.path.exists(os.path.dirname(optimizer_path1)):
                    os.makedirs(os.path.dirname(optimizer_path1))
                torch.save(optimizer.state_dict(), optimizer_path1)
                print("optimizer saved to {}".format(optimizer_path1))

                all_dis = all_dis / opt.save_checkpoint_every
                print("all_dis:%f" %all_dis)
                infos['all_dis'] = all_dis

                all_cider = all_cider / opt.save_checkpoint_every
                print("all_cider:%f" %all_cider)
                infos['all_cider'] = all_cider

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(save_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(save_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    save_path2 = os.path.join(save_path, 'model-best.pth')
                    torch.save(model.state_dict(), save_path2)
                    optimizer_path2 = os.path.join(save_path, 'optimizer-best.pth')
                    torch.save(optimizer.state_dict(), optimizer_path2)
                    print("model saved to {}".format(save_path2))
                    with open(os.path.join(save_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)
                    with open(os.path.join(save_path,'histories_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(histories, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
        "id", "batch_size", "input_fc_dir", "input_att_dir", "beam_size",
        "start_from", "language_eval"
    ]
    for k in vars(infos['opt']).keys():
        if k not in ignore:
            if k in vars(opt):
                assert vars(opt)[k] == vars(
                    infos['opt'])[k], k + ' option not consistent'
            else:
                vars(opt).update({k: vars(infos['opt'])[k]
                                  })  # copy over options from model

    vocab = infos['vocab']  # ix -> word mapping

    opt.model = os.path.join(path, 'model-best.pth')
    model = models.setup_pro(opt)
    model.load_state_dict(torch.load(opt.model))
    model.cuda()
    model.eval()
    crit = utils.LanguageModelCriterion()

    # Create the Data Loader instance
    loader = DataLoader(opt)
    loader.ix_to_word = infos['vocab']

    # Set sample options
    loss, split_predictions, lang_stats = evalpro_utils.eval_split(
        model, crit, loader, vars(opt))
    if opt.dump_json == 1:
        # dump the json
        json.dump(split_predictions, open(Sentencespath, 'w'))
Esempio n. 3
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    save_path = os.path.join(opt.checkpoint_path,'MLE')  #checkpoint_path = RL_
    if not os.path.exists(save_path):
        os.makedirs(save_path,0777)
      
    tf_summary_writer = tf and tf.summary.FileWriter(save_path)

    infos = {}
    histories = {}

    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,'MLE', 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from,'MLE', 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from,'MLE', 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup_pro(opt)

    # check compatibility if training is continued from previously saved model
    if vars(opt).get('start_from', None) is not None:
        # check if all necessary files exist 
        assert os.path.isdir(opt.start_from)," %s must be a path" % opt.start_from
        assert os.path.isfile(os.path.join(opt.start_from,'MLE',"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from
        model.load_state_dict(torch.load(os.path.join(opt.start_from,'MLE', 'model.pth')))


    model.cuda()
    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()
  
    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,'MLE',"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from,'MLE', 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            print('learning_rate: %s' %str(opt.current_lr))
            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')  

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp
        optimizer.zero_grad()

        loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]

        torch.cuda.synchronize()
        end = time.time()

        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)

                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = evalpro_utils.eval_split(model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                for k,v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                save_path1 = os.path.join(save_path, 'model.pth')
                if not os.path.exists(os.path.dirname(save_path1)):
                    os.makedirs(os.path.dirname(save_path1))
                torch.save(model.state_dict(), save_path1)
                print("model saved to {}".format(save_path1))

                optimizer_path1 = os.path.join(save_path, 'optimizer.pth')
                if not os.path.exists(os.path.dirname(optimizer_path1)):
                    os.makedirs(os.path.dirname(optimizer_path1))
                torch.save(optimizer.state_dict(), optimizer_path1)
                print("optimizer saved to {}".format(optimizer_path1))

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(save_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(save_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    save_path2 = os.path.join(save_path, 'model-best.pth')
                    torch.save(model.state_dict(), save_path2)
                    optimizer_path2 = os.path.join(save_path, 'optimizer-best.pth')
                    torch.save(optimizer.state_dict(), optimizer_path2)
                    print("model saved to {}".format(save_path2))
                    with open(os.path.join(save_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)
                    with open(os.path.join(save_path,'histories_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(histories, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break