Ejemplo n.º 1
0
Archivo: train.py Proyecto: cxqj/ECHR
def train(opt):
    exclude_opt = [
        'training_mode', 'tap_epochs', 'cg_epochs', 'tapcg_epochs', 'lr',
        'learning_rate_decay_start', 'learning_rate_decay_every',
        'learning_rate_decay_rate', 'self_critical_after',
        'save_checkpoint_every', 'id', "pretrain", "pretrain_path", "debug",
        "save_all_checkpoint", "min_epoch_when_save"
    ]

    save_folder, logger, tf_writer = build_floder_and_create_logger(opt)
    saved_info = {'best': {}, 'last': {}, 'history': {}}
    is_continue = opt.start_from != None

    if is_continue:
        infos_path = os.path.join(save_folder, 'info.pkl')
        with open(infos_path) as f:
            logger.info('load info from {}'.format(infos_path))
            saved_info = cPickle.load(f)
            pre_opt = saved_info[opt.start_from_mode]['opt']
            if vars(opt).get("no_exclude_opt", False):
                exclude_opt = []
            for opt_name in vars(pre_opt).keys():
                if (not opt_name in exclude_opt):
                    vars(opt).update({opt_name: vars(pre_opt).get(opt_name)})
                if vars(pre_opt).get(opt_name) != vars(opt).get(opt_name):
                    print('change opt: {} from {} to {}'.format(
                        opt_name,
                        vars(pre_opt).get(opt_name),
                        vars(opt).get(opt_name)))

    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.CG_vocab_size = loader.vocab_size
    opt.CG_seq_length = loader.seq_length

    # init training option
    epoch = saved_info[opt.start_from_mode].get('epoch', 0)
    iteration = saved_info[opt.start_from_mode].get('iter', 0)
    best_val_score = saved_info[opt.start_from_mode].get('best_val_score', 0)
    val_result_history = saved_info['history'].get('val_result_history', {})
    loss_history = saved_info['history'].get('loss_history', {})
    lr_history = saved_info['history'].get('lr_history', {})
    loader.iterators = saved_info[opt.start_from_mode].get(
        'iterators', loader.iterators)
    loader.split_ix = saved_info[opt.start_from_mode].get(
        'split_ix', loader.split_ix)
    opt.current_lr = vars(opt).get('current_lr', opt.lr)
    opt.m_batch = vars(opt).get('m_batch', 1)

    # create a tap_model,fusion_model,cg_model

    tap_model = models.setup_tap(opt)
    lm_model = CaptionGenerator(opt)
    cg_model = lm_model

    if is_continue:
        if opt.start_from_mode == 'best':
            model_pth = torch.load(os.path.join(save_folder, 'model-best.pth'))
        elif opt.start_from_mode == 'last':
            model_pth = torch.load(
                os.path.join(save_folder,
                             'model_iter_{}.pth'.format(iteration)))
        assert model_pth['iteration'] == iteration
        logger.info('Loading pth from {}, iteration:{}'.format(
            save_folder, iteration))
        tap_model.load_state_dict(model_pth['tap_model'])
        cg_model.load_state_dict(model_pth['cg_model'])

    elif opt.pretrain:
        print('pretrain {} from {}'.format(opt.pretrain, opt.pretrain_path))
        model_pth = torch.load(opt.pretrain_path)
        if opt.pretrain == 'tap':
            tap_model.load_state_dict(model_pth['tap_model'])
        elif opt.pretrain == 'cg':
            cg_model.load_state_dict(model_pth['cg_model'])
        elif opt.pretrain == 'tap_cg':
            tap_model.load_state_dict(model_pth['tap_model'])
            cg_model.load_state_dict(model_pth['cg_model'])
        else:
            assert 1 == 0, 'opt.pretrain error'

    tap_model.cuda()
    tap_model.train()  # Assure in training mode

    tap_crit = utils.TAPModelCriterion()

    tap_optimizer = optim.Adam(tap_model.parameters(),
                               lr=opt.lr,
                               weight_decay=opt.weight_decay)

    cg_model.cuda()
    cg_model.train()
    cg_optimizer = optim.Adam(cg_model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)
    cg_crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    cg_optimizer = optim.Adam(cg_model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)

    allmodels = [tap_model, cg_model]
    optimizers = [tap_optimizer, cg_optimizer]

    if is_continue:
        tap_optimizer.load_state_dict(model_pth['tap_optimizer'])
        cg_optimizer.load_state_dict(model_pth['cg_optimizer'])

    update_lr_flag = True
    loss_sum = np.zeros(5)
    bad_video_num = 0
    best_epoch = epoch
    start = time.time()

    print_opt(opt, allmodels, logger)
    logger.info('\nStart training')

    # set a var to indicate what to train in current iteration: "tap", "cg" or "tap_cg"
    flag_training_whats = get_training_list(opt, logger)

    # Iteration begin
    while True:
        if update_lr_flag:
            if (epoch > opt.learning_rate_decay_start
                    and opt.learning_rate_decay_start >= 0):
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.lr * decay_factor
            else:
                opt.current_lr = opt.lr
            for optimizer in optimizers:
                utils.set_lr(optimizer, opt.current_lr)
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(None)
            else:
                sc_flag = False
            update_lr_flag = False

        flag_training_what = flag_training_whats[epoch]
        if opt.training_mode == "alter2":
            flag_training_what = flag_training_whats[iteration]

        # get data
        data = loader.get_batch('train')

        if opt.debug:
            print('vid:', data['vid'])
            print('info:', data['infos'])

        torch.cuda.synchronize()

        if (data["proposal_num"] <= 0) or (data['fc_feats'].shape[0] <= 1):
            bad_video_num += 1  # print('vid:{} has no good proposal.'.format(data['vid']))
            continue

        ind_select_list, soi_select_list, cg_select_list, sampled_ids, = data[
            'ind_select_list'], data['soi_select_list'], data[
                'cg_select_list'], data['sampled_ids']

        if flag_training_what == 'cg' or flag_training_what == 'gt_tap_cg':
            ind_select_list = data['gts_ind_select_list']
            soi_select_list = data['gts_soi_select_list']
            cg_select_list = data['gts_cg_select_list']

        tmp = [
            data['fc_feats'], data['att_feats'], data['lda_feats'],
            data['tap_labels'], data['tap_masks_for_loss'],
            data['cg_labels'][cg_select_list],
            data['cg_masks'][cg_select_list], data['w1']
        ]

        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]

        c3d_feats, att_feats, lda_feats, tap_labels, tap_masks_for_loss, cg_labels, cg_masks, w1 = tmp

        if (iteration - 1) % opt.m_batch == 0:
            tap_optimizer.zero_grad()
            cg_optimizer.zero_grad()

        tap_feats, pred_proposals = tap_model(c3d_feats)
        tap_loss = tap_crit(pred_proposals, tap_masks_for_loss, tap_labels, w1)

        loss_sum[0] = loss_sum[0] + tap_loss.item()

        # Backward Propagation
        if flag_training_what == 'tap':
            tap_loss.backward()
            utils.clip_gradient(tap_optimizer, opt.grad_clip)
            if iteration % opt.m_batch == 0:
                tap_optimizer.step()
        else:
            if not sc_flag:
                pred_captions = cg_model(tap_feats,
                                         c3d_feats,
                                         lda_feats,
                                         cg_labels,
                                         ind_select_list,
                                         soi_select_list,
                                         mode='train')
                cg_loss = cg_crit(pred_captions, cg_labels[:, 1:],
                                  cg_masks[:, 1:])

            else:
                gen_result, sample_logprobs, greedy_res = cg_model(
                    tap_feats,
                    c3d_feats,
                    lda_feats,
                    cg_labels,
                    ind_select_list,
                    soi_select_list,
                    mode='train_rl')
                sentence_info = data['sentences_batch'] if (
                    flag_training_what != 'cg'
                    and flag_training_what != 'gt_tap_cg'
                ) else data['gts_sentences_batch']

                reward = get_self_critical_reward2(
                    greedy_res, (data['vid'], sentence_info),
                    gen_result,
                    vocab=loader.get_vocab(),
                    opt=opt)
                cg_loss = rl_crit(sample_logprobs, gen_result,
                                  torch.from_numpy(reward).float().cuda())

            loss_sum[1] = loss_sum[1] + cg_loss.item()

            if flag_training_what == 'cg' or flag_training_what == 'gt_tap_cg' or flag_training_what == 'LP_cg':
                cg_loss.backward()

                utils.clip_gradient(cg_optimizer, opt.grad_clip)
                if iteration % opt.m_batch == 0:
                    cg_optimizer.step()
                if flag_training_what == 'gt_tap_cg':
                    utils.clip_gradient(tap_optimizer, opt.grad_clip)
                    if iteration % opt.m_batch == 0:
                        tap_optimizer.step()
            elif flag_training_what == 'tap_cg':
                total_loss = opt.lambda1 * tap_loss + opt.lambda2 * cg_loss
                total_loss.backward()
                utils.clip_gradient(tap_optimizer, opt.grad_clip)
                utils.clip_gradient(cg_optimizer, opt.grad_clip)
                if iteration % opt.m_batch == 0:
                    tap_optimizer.step()
                    cg_optimizer.step()

                loss_sum[2] = loss_sum[2] + total_loss.item()

        torch.cuda.synchronize()

        # Updating epoch num
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Print losses, Add to summary
        if iteration % opt.losses_log_every == 0:
            end = time.time()
            losses = np.round(loss_sum / opt.losses_log_every, 3)
            logger.info(
                "iter {} (epoch {}, lr {}), avg_iter_loss({}) = {}, time/batch = {:.3f}, bad_vid = {:.3f}" \
                    .format(iteration, epoch, opt.current_lr, flag_training_what, losses,
                            (end - start) / opt.losses_log_every,
                            bad_video_num))

            tf_writer.add_scalar('lr', opt.current_lr, iteration)
            tf_writer.add_scalar('train_tap_loss', losses[0], iteration)
            tf_writer.add_scalar('train_tap_prop_loss', losses[3], iteration)
            tf_writer.add_scalar('train_tap_bound_loss', losses[4], iteration)
            tf_writer.add_scalar('train_cg_loss', losses[1], iteration)
            tf_writer.add_scalar('train_total_loss', losses[2], iteration)
            if sc_flag and (not flag_training_what == 'tap'):
                tf_writer.add_scalar('avg_reward', np.mean(reward[:, 0]),
                                     iteration)
            loss_history[iteration] = losses
            lr_history[iteration] = opt.current_lr
            loss_sum = np.zeros(5)
            start = time.time()
            bad_video_num = 0

        # Evaluation, and save model
        if (iteration % opt.save_checkpoint_every
                == 0) and (epoch >= opt.min_epoch_when_save):
            eval_kwargs = {
                'split': 'val',
                'val_all_metrics': 0,
                'topN': 100,
            }

            eval_kwargs.update(vars(opt))

            # eval_kwargs['num_vids_eval'] = int(491)
            eval_kwargs['topN'] = 100

            eval_kwargs2 = {
                'split': 'val',
                'val_all_metrics': 1,
                'num_vids_eval': 4917,
            }
            eval_kwargs2.update(vars(opt))

            if not opt.num_vids_eval:
                eval_kwargs['num_vids_eval'] = int(4917.)
                eval_kwargs2['num_vids_eval'] = 4917

            crits = [tap_crit, cg_crit]
            pred_json_path_T = os.path.join(save_folder, 'pred_sent',
                                            'pred_num{}_iter{}.json')

            # if 'alter' in opt.training_mode:
            if flag_training_what == 'tap':
                eval_kwargs['topN'] = 1000
                predictions, eval_score, val_loss = eval_utils.eval_split(
                    allmodels,
                    crits,
                    loader,
                    pred_json_path_T.format(eval_kwargs['num_vids_eval'],
                                            iteration),
                    eval_kwargs,
                    flag_eval_what='tap')
            else:
                if vars(opt).get('fast_eval_cg', False) == False:
                    predictions, eval_score, val_loss = eval_utils.eval_split(
                        allmodels,
                        crits,
                        loader,
                        pred_json_path_T.format(eval_kwargs['num_vids_eval'],
                                                iteration),
                        eval_kwargs,
                        flag_eval_what='tap_cg')

                predictions2, eval_score2, val_loss2 = eval_utils.eval_split(
                    allmodels,
                    crits,
                    loader,
                    pred_json_path_T.format(eval_kwargs2['num_vids_eval'],
                                            iteration),
                    eval_kwargs2,
                    flag_eval_what='cg')

                if (not vars(opt).get('fast_eval_cg', False)
                        == False) or (not vars(opt).get(
                            'fast_eval_cg_top10', False) == False):
                    eval_score = eval_score2
                    val_loss = val_loss2
                    predictions = predictions2

            # else:
            #    predictions, eval_score, val_loss = eval_utils.eval_split(allmodels, crits, loader, pred_json_path,
            #                                                              eval_kwargs,
            #                                                              flag_eval_what=flag_training_what)

            f_f1 = lambda x, y: 2 * x * y / (x + y)
            f1 = f_f1(eval_score['Recall'], eval_score['Precision']).mean()
            if flag_training_what != 'tap':  # if only train tap, use the mean of precision and recall as final score
                current_score = np.array(eval_score['METEOR']).mean() * 100
            else:  # if train tap_cg, use avg_meteor as final score
                current_score = f1

            for model in allmodels:
                for name, param in model.named_parameters():
                    tf_writer.add_histogram(name,
                                            param.clone().cpu().data.numpy(),
                                            iteration,
                                            bins=10)
                    if param.grad is not None:
                        tf_writer.add_histogram(
                            name + '_grad',
                            param.grad.clone().cpu().data.numpy(),
                            iteration,
                            bins=10)

            tf_writer.add_scalar('val_tap_loss', val_loss[0], iteration)
            tf_writer.add_scalar('val_cg_loss', val_loss[1], iteration)
            tf_writer.add_scalar('val_tap_prop_loss', val_loss[3], iteration)
            tf_writer.add_scalar('val_tap_bound_loss', val_loss[4], iteration)
            tf_writer.add_scalar('val_total_loss', val_loss[2], iteration)
            tf_writer.add_scalar('val_score', current_score, iteration)
            if flag_training_what != 'tap':
                tf_writer.add_scalar('val_score_gt_METEOR',
                                     np.array(eval_score2['METEOR']).mean(),
                                     iteration)
                tf_writer.add_scalar('val_score_gt_Bleu_4',
                                     np.array(eval_score2['Bleu_4']).mean(),
                                     iteration)
                tf_writer.add_scalar('val_score_gt_CIDEr',
                                     np.array(eval_score2['CIDEr']).mean(),
                                     iteration)
            tf_writer.add_scalar('val_recall', eval_score['Recall'].mean(),
                                 iteration)
            tf_writer.add_scalar('val_precision',
                                 eval_score['Precision'].mean(), iteration)
            tf_writer.add_scalar('f1', f1, iteration)

            val_result_history[iteration] = {
                'val_loss': val_loss,
                'eval_score': eval_score
            }

            if flag_training_what == 'tap':
                logger.info(
                    'Validation the result of iter {}, score(f1/meteor):{},\n all:{}'
                    .format(iteration, current_score, eval_score))
            else:
                mean_score = {
                    k: np.array(v).mean()
                    for k, v in eval_score.items()
                }
                gt_mean_score = {
                    k: np.array(v).mean()
                    for k, v in eval_score2.items()
                }

                metrics = ['Bleu_4', 'CIDEr', 'METEOR', 'ROUGE_L']
                gt_avg_score = np.array([
                    v for metric, v in gt_mean_score.items()
                    if metric in metrics
                ]).sum()
                logger.info(
                    'Validation the result of iter {}, score(f1/meteor):{},\n all:{}\n mean:{} \n\n gt:{} \n mean:{}\n avg_score: {}'
                    .format(iteration, current_score, eval_score, mean_score,
                            eval_score2, gt_mean_score, gt_avg_score))

            # Save model .pth
            saved_pth = {
                'iteration': iteration,
                'cg_model': cg_model.state_dict(),
                'tap_model': tap_model.state_dict(),
                'cg_optimizer': cg_optimizer.state_dict(),
                'tap_optimizer': tap_optimizer.state_dict(),
            }

            if opt.save_all_checkpoint:
                checkpoint_path = os.path.join(
                    save_folder, 'model_iter_{}.pth'.format(iteration))
            else:
                checkpoint_path = os.path.join(save_folder, 'model.pth')
            torch.save(saved_pth, checkpoint_path)
            logger.info('Save model at iter {} to checkpoint file {}.'.format(
                iteration, checkpoint_path))

            # save info.pkl
            if current_score > best_val_score:
                best_val_score = current_score
                best_epoch = epoch
                saved_info['best'] = {
                    'opt': opt,
                    'iter': iteration,
                    'epoch': epoch,
                    'iterators': loader.iterators,
                    'flag_training_what': flag_training_what,
                    'split_ix': loader.split_ix,
                    'best_val_score': best_val_score,
                    'vocab': loader.get_vocab(),
                }

                best_checkpoint_path = os.path.join(save_folder,
                                                    'model-best.pth')
                torch.save(saved_pth, best_checkpoint_path)
                logger.info(
                    'Save Best-model at iter {} to checkpoint file.'.format(
                        iteration))

            saved_info['last'] = {
                'opt': opt,
                'iter': iteration,
                'epoch': epoch,
                'iterators': loader.iterators,
                'flag_training_what': flag_training_what,
                'split_ix': loader.split_ix,
                'best_val_score': best_val_score,
                'vocab': loader.get_vocab(),
            }
            saved_info['history'] = {
                'val_result_history': val_result_history,
                'loss_history': loss_history,
                'lr_history': lr_history,
            }
            with open(os.path.join(save_folder, 'info.pkl'), 'w') as f:
                cPickle.dump(saved_info, f)
                logger.info('Save info to info.pkl')

            # Stop criterion
            if epoch >= len(flag_training_whats):
                tf_writer.close()
                break
Ejemplo n.º 2
0
def train(opt):
    opt.use_att = utils.if_use_att(opt)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_val_score_vse = infos.get('best_val_score_vse', None)

    model = models.JointModel(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, 'optimizer.pth')):
        state_dict = torch.load(os.path.join(opt.start_from, 'optimizer.pth'))
        if len(state_dict['state']) == len(optimizer.state_dict()['state']):
            optimizer.load_state_dict(state_dict)
        else:
            print(
                'Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.'
            )

    init_scorer(opt.cached_tokens)
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.caption_generator.ss_prob = opt.ss_prob
            # Assign retrieval loss weight
            if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0:
                frac = (epoch - opt.retrieval_reward_weight_decay_start
                        ) // opt.retrieval_reward_weight_decay_every
                model.retrieval_reward_weight = opt.retrieval_reward_weight * (
                    opt.retrieval_reward_weight_decay_rate**frac)
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['att_masks'],
            data['labels'], data['masks']
        ]
        tmp = utils.var_wrapper(tmp)
        fc_feats, att_feats, att_masks, labels, masks = tmp

        optimizer.zero_grad()

        loss = model(fc_feats, att_feats, att_masks, labels, masks, data)
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))
        prt_str = ""
        for k, v in model.loss().items():
            prt_str += "{} = {:.3f} ".format(k, v)
        print(prt_str)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                tf_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                for k, v in model.loss().items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tf_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.caption_generator.ss_prob,
                                             iteration)
                tf_summary_writer.add_scalar('retrieval_reward_weight',
                                             model.retrieval_reward_weight,
                                             iteration)
                tf_summary_writer.file_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.caption_generator.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # Load the retrieval model for evaluation
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                for k, v in val_loss.items():
                    tf_summary_writer.add_scalar('validation ' + k, v,
                                                 iteration)
                for k, v in lang_stats.items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_text(
                    'Captions',
                    '.\n\n'.join([_['caption'] for _ in predictions[:100]]),
                    iteration)
                #tf_summary_writer.add_image('images', utils.make_summary_image(), iteration)
                #utils.make_html(opt.id, iteration)
                tf_summary_writer.file_writer.flush()

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['SPICE'] * 100
            else:
                current_score = -val_loss['loss_cap']
            current_score_vse = val_loss.get(opt.vse_eval_criterion, 0) * 100

            best_flag = False
            best_flag_vse = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if best_val_score_vse is None or current_score_vse > best_val_score_vse:
                    best_val_score_vse = current_score_vse
                    best_flag_vse = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model-%d.pth' % (iteration))
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['best_val_score_vse'] = best_val_score_vse
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '-%d.pkl' % (iteration)),
                        'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)
                if best_flag_vse:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model_vse-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_vse_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 3
0
    def train(self, data, loader, iteration, epoch, nmt_epoch):
        nmt_dec_state = None
        nmt_dec_state_zh = None
        torch.cuda.synchronize()
        self.optim.zero_grad()

        tmp = [
            data['fc_feats'], data['attri_feats'], data['att_feats'],
            data['labels'], data['masks'], data['att_masks'],
            data['nmt'] if self.nmt_train_flag else None
        ]
        tmp = [
            _ if _ is None else
            (Variable(torch.from_numpy(_), requires_grad=False).cuda()
             if utils.under_0_4() else torch.from_numpy(_).cuda()) for _ in tmp
        ]
        fc_feats, attri_feats, att_feats, labels, masks, att_masks, nmt_batch = tmp

        if self.i2t_train_flag:
            if self.update_i2t_lr_flag:
                self.optim.update_LearningRate(
                    'i2t', epoch)  # Assign the learning rate
                self.optim.update_ScheduledSampling_prob(
                    self.opt, epoch,
                    self.dp_i2t_model)  # Assign the scheduled sampling prob
                if self.opt.self_critical_after != -1 and epoch >= self.opt.self_critical_after:
                    # If start self critical training
                    self.sc_flag = True
                    init_scorer(self.opt.cached_tokens)
                else:
                    self.sc_flag = False
                self.update_i2t_lr_flag = False

            if not self.sc_flag:
                i2t_outputs = self.dp_i2t_model(fc_feats, attri_feats,
                                                att_feats, labels, att_masks)
                i2t_loss = self.i2t_crit(i2t_outputs, labels[:, 1:], masks[:,
                                                                           1:])
            else:
                gen_result, sample_logprobs = self.dp_i2t_model(
                    fc_feats,
                    attri_feats,
                    att_feats,
                    att_masks,
                    opt={'sample_max': 0},
                    mode='sample')
                reward = get_self_critical_reward(self.dp_i2t_model, fc_feats,
                                                  attri_feats, att_feats,
                                                  att_masks, data, gen_result,
                                                  self.opt)
                i2t_loss = self.i2t_rl_crit(
                    sample_logprobs, gen_result.data,
                    Variable(torch.from_numpy(reward).float().cuda(),
                             requires_grad=False))

                self.i2t_avg_reward = np.mean(reward[:, 0])
            self.i2t_train_loss = i2t_loss.data[0] if utils.under_0_4(
            ) else i2t_loss.item()
            i2t_loss.backward(retain_graph=True)

        if self.nmt_train_flag:
            if self.update_nmt_lr_flag:
                self.optim.update_LearningRate(
                    'nmt', nmt_epoch)  # Assign the learning rate
            outputs, attn, dec_state, upper_bounds = self.dp_nmt_model(
                nmt_batch.src, nmt_batch.tgt, nmt_batch.lengths, nmt_dec_state)
            nmt_loss = self.nmt_crit(loader, nmt_batch, outputs, attn)

            if nmt_dec_state is not None: nmt_dec_state.detach()
            if nmt_dec_state_zh is not None: nmt_dec_state_zh.detach()

            self.nmt_crit.report_stats.n_src_words += nmt_batch.lengths.data.sum(
            )
            self.nmt_train_ppl = self.nmt_crit.report_stats.ppl()
            self.nmt_train_acc = self.nmt_crit.report_stats.accuracy()
            # Minimize the word embedding weights
            # wemb_weight_loss = self.weight_trans(self.i2t_model.embed, self.nmt_encoder.embeddings.word_lut)
            # self.wemb_loss = wemb_weight_loss.data[0]

            nmt_loss.backward(retain_graph=True)
        # if self.nmt_train_flag: wemb_weight_loss.backward(retain_graph=True)
        self.optim.step()
Ejemplo n.º 4
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        print("opt.start_from: " + str(opt.start_from))
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # create model
    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    # load model
    if os.path.isfile("log_sc/model.pth"):
        model_path = "log_sc/model.pth"
        state_dict = torch.load(model_path)
        dp_model.load_state_dict(state_dict)

    dp_model.train()

    # create/load vector model
    vectorModel = models.setup_vectorModel().cuda()
    dp_vectorModel = torch.nn.DataParallel(vectorModel)

    # load vector model
    if os.path.isfile("log_sc/model_vec.pth"):
        model_vec_path = "log_sc/model_vec.pth"
        state_dict_vec = torch.load(model_vec_path)
        dp_vectorModel.load_state_dict(state_dict_vec)

    dp_vectorModel.train()

    optimizer = utils.build_optimizer(
        list(model.parameters()) + list(vectorModel.parameters()), opt)
    update_lr_flag = True

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    vec_crit = nn.L1Loss()

    # create idxs for doc2vec vectors
    with open('paragraphs_image_ids.txt', 'r') as file:
        paragraph_image_ids = file.readlines()

    paragraph_image_ids = [int(i) for i in paragraph_image_ids]

    # select corresponding vectors
    with open('paragraphs_vectors.txt', 'r') as the_file:
        vectors = the_file.readlines()

    vectors_list = []
    for string in vectors:
        vectors_list.append([float(s) for s in string.split(' ')])

    vectors_list_np = np.asarray(vectors_list)

    print("Starting training loop!")

    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # pad data['att_feats'] axis=1 to have length = 83
        def pad_along_axis(array, target_length, axis=0):

            pad_size = target_length - array.shape[axis]
            axis_nb = len(array.shape)

            if pad_size < 0:
                return a

            npad = [(0, 0) for x in range(axis_nb)]
            npad[axis] = (0, pad_size)

            b = np.pad(array,
                       pad_width=npad,
                       mode='constant',
                       constant_values=0)

            return b

        data['att_feats'] = pad_along_axis(data['att_feats'], 83, axis=1)

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        idx = []
        for element in data['infos']:
            idx.append(paragraph_image_ids.index(element['id']))

        batch_vectors = vectors_list_np[idx]

        # Forward pass and loss
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        att_feats_reshaped = att_feats.permute(0, 2, 1).cuda()
        semantic_features = dp_vectorModel(att_feats_reshaped.cuda(),
                                           fc_feats)  # (10, 2048)
        batch_vectors = torch.from_numpy(
            batch_vectors).float().cuda()  # (10, 512)
        vec_loss = vec_crit(semantic_features, batch_vectors)
        alpha_ = 1
        loss = loss + (alpha_ * vec_loss)

        # Backward pass
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 1:
            print('Read data:', time.time() - start)
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, data_time, total_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if True:

            # Evaluate model
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(dp_model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))

            # save vec model
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           'model_vec.pth')
            torch.save(dp_vectorModel.state_dict(), checkpoint_path)
            print("model_vec saved to {}".format(checkpoint_path))

            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(dp_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))

                # best vec
                model_fname_vec = 'model-best-vec-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname_vec)
                torch.save(dp_vectorModel.state_dict(), checkpoint_path)
                print("model_vec saved to {}".format(checkpoint_path))

                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 5
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    #ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = models.setup(opt).cuda()
    #pretrained_dict = torch.load(opt.model)
    #model.load_state_dict(pretrained_dict, strict=False)

    num_params = get_n_params(model)
    print('number of parameteres:', num_params)

    dp_model = torch.nn.DataParallel(model)
    dp_model.train()

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag
    optimizer = utils.build_optimizer(model.parameters(), opt)
    update_lr_flag = True

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                #opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                #model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['dist'],
            data['masks'], data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, att_masks = tmp
        batchsize = fc_feats.size(0)
        # Forward pass and loss
        optimizer.zero_grad()
        if not sc_flag:
            wordact, reconstruct = dp_model(fc_feats, att_feats, labels)
            #loss_dist = F.binary_cross_entropy(dist, dist_label.cpu().float())
            fc_feats_max, _ = att_feats.max(1)
            loss_rec = F.mse_loss(reconstruct.cpu(), fc_feats_max.cpu())
            mask = masks[:, 1:].contiguous()
            wordact = wordact[:, :, :-1]
            wordact_t = wordact.permute(0, 2, 1).contiguous()
            wordact_t = wordact_t.view(
                wordact_t.size(0) * wordact_t.size(1), -1)
            labels = labels.contiguous().view(-1, 6 * 30).cpu()
            wordclass_v = labels[:, 1:]
            wordclass_t = wordclass_v.contiguous().view(\
               wordclass_v.size(0) * wordclass_v.size(1), 1)
            maskids = torch.nonzero(mask.view(-1).cpu()).numpy().reshape(-1)
            loss_xe = F.cross_entropy(wordact_t[maskids, ...], \
               wordclass_t[maskids, ...].contiguous().view(maskids.shape[0]))
            loss = 5 * loss_xe + loss_rec
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # Backward pass
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 1:
            print('Read data:', time.time() - start)
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, data_time, total_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            #ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if (iteration >= 60000 and iteration % opt.save_checkpoint_every == 0):
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)
            # Evaluate model
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            #histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best.pth'
                infos_fname = 'model-best.pkl'
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 6
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.vocab = loader.get_vocab()

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        # with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
        with open(os.path.join(opt.start_from, 'infos_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
        #     with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))


    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
                utils.pickle_dump(histories, f)
    # pdb.set_trace()
    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate  ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                    init_scorer(opt.cached_tokens)

                epoch_done = False
                    
            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            # pdb.set_trace()
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['sents_mask']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks, sents_mask = tmp
            box_inds = None
                
            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, box_inds, epoch, sents_mask)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            
            # make evaluation on validation set, and save model
            # eval model
            # eval_kwargs = {'split': 'val',
            #                 'dataset': opt.input_json}
            # eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(
            #     dp_model, lw_model.crit, loader, eval_kwargs)

            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                                'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
def train(opt):
    import random
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(0)
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    from dataloader_pair import DataLoader

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    if opt.log_to_file:
        if os.path.exists(os.path.join(opt.checkpoint_path, 'log')):
            suffix = time.strftime("%Y-%m-%d %X", time.localtime())
            print('Warning !!! %s already exists ! use suffix ! ' %
                  os.path.join(opt.checkpoint_path, 'log'))
            sys.stdout = open(
                os.path.join(opt.checkpoint_path, 'log' + suffix), "w")
        else:
            print('logging to file %s' %
                  os.path.join(opt.checkpoint_path, 'log'))
            sys.stdout = open(os.path.join(opt.checkpoint_path, 'log'), "w")

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        if os.path.isfile(opt.start_from):
            with open(os.path.join(opt.infos)) as f:
                infos = cPickle.load(f)
                saved_model_opt = infos['opt']
                need_be_same = [
                    "caption_model", "rnn_type", "rnn_size", "num_layers"
                ]
                for checkme in need_be_same:
                    assert vars(saved_model_opt)[checkme] == vars(
                        opt
                    )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        else:
            if opt.load_best != 0:
                print('loading best info')
                with open(
                        os.path.join(opt.start_from,
                                     'infos_' + opt.id + '-best.pkl')) as f:
                    infos = cPickle.load(f)
                    saved_model_opt = infos['opt']
                    need_be_same = [
                        "caption_model", "rnn_type", "rnn_size", "num_layers"
                    ]
                    for checkme in need_be_same:
                        assert vars(saved_model_opt)[checkme] == vars(
                            opt
                        )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
            else:
                with open(
                        os.path.join(opt.start_from,
                                     'infos_' + opt.id + '.pkl')) as f:
                    infos = cPickle.load(f)
                    saved_model_opt = infos['opt']
                    need_be_same = [
                        "caption_model", "rnn_type", "rnn_size", "num_layers"
                    ]
                    for checkme in need_be_same:
                        assert vars(saved_model_opt)[checkme] == vars(
                            opt
                        )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                try:
                    histories = cPickle.load(f)
                except:
                    print('load history error!')
                    histories = {}

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    start_epoch = epoch

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    #Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    if opt.caption_model == 'att2in2p':
        optimized = [
            'logit2', 'ctx2att2', 'core2', 'prev_sent_emb', 'prev_sent_wrap'
        ]
        optimized_param = []
        optimized_param1 = []

        for name, param in model.named_parameters():
            second = False
            for n in optimized:
                if n in name:
                    print('second', name)
                    optimized_param.append(param)
                    second = True
            if 'embed' in name:
                print('all', name)
                optimized_param1.append(param)
                optimized_param.append(param)
            elif not second:
                print('first', name)
                optimized_param1.append(param)

    while True:
        if opt.val_only:
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            print('start evaluating')
            val_loss, predictions, lang_stats = eval_utils_pair.eval_split(
                dp_model, crit, loader, eval_kwargs)
            exit(0)
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['pair_fc_feats'], data['pair_att_feats'], data['pair_labels'],
            data['pair_masks'], data['pair_att_masks']
        ]

        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        masks = masks.float()

        optimizer.zero_grad()

        if not sc_flag:
            if opt.onlysecond:
                # only using the second sentence from a visual paraphrase pair. opt.caption_model should be a one-stage decoding model
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, 1, :], att_masks),
                    labels[:, 1, 1:], masks[:, 1, 1:])
                loss1 = loss2 = loss / 2
            elif opt.first:
                # using the first sentence
                tmp = [
                    data['first_fc_feats'], data['first_att_feats'],
                    data['first_labels'], data['first_masks'],
                    data['first_att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, labels, masks, att_masks = tmp
                masks = masks.float()
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, :], att_masks),
                    labels[:, 1:], masks[:, 1:])
                loss1 = loss2 = loss / 2
            elif opt.onlyfirst:
                # only using the second sentence from a visual paraphrase pair
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, 0, :], att_masks),
                    labels[:, 0, 1:], masks[:, 0, 1:])
                loss1 = loss2 = loss / 2
            else:
                # proposed DCVP model, opt.caption_model should be att2inp
                output1, output2 = dp_model(fc_feats, att_feats, labels,
                                            att_masks, masks[:, 0, 1:])
                loss1 = crit(output1, labels[:, 0, 1:], masks[:, 0, 1:])
                loss2 = crit(output2, labels[:, 1, 1:], masks[:, 1, 1:])
                loss = loss1 + loss2

        else:
            raise NotImplementedError
            # Our DCVP model does not support self-critical sequence training
            # We found that RL(SCST) with CIDEr reward will improve conventional metrics (BLEU, CIDEr, etc.)
            # but harm diversity and descriptiveness
            # Please refer to the paper for the details

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, loss1 = {:.3f}, loss2 = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, loss.item(), loss1.item(), loss2.item(), end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        sys.stdout.flush()
        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_pair.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)
                checkpoint_path = os.path.join(
                    opt.checkpoint_path, 'model' + str(iteration) + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '_' + str(iteration) + '.pkl'),
                        'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 8
0
def train(opt):
    print("=================Training Information==============")
    print("start from {}".format(opt.start_from))
    print("box from {}".format(opt.input_box_dir))
    print("attributes from {}".format(opt.input_att_dir))
    print("features from {}".format(opt.input_fc_dir))
    print("batch size ={}".format(opt.batch_size))
    print("#GPU={}".format(torch.cuda.device_count()))
    print("Caption model {}".format(opt.caption_model))
    print("refine aoa {}".format(opt.refine_aoa))
    print("Number of aoa module {}".format(opt.aoa_num))
    print("Self Critic After  {}".format(opt.self_critical_after))
    print("learning_rate_decay_every {}".format(opt.learning_rate_decay_every))

    # use more data to fine tune the model for better challeng results. We dont use it
    if opt.use_val or opt.use_test:
        print("+++++++++++It is a refining training+++++++++++++++")
        print("===========Val is {} used for training ===========".format(
            '' if opt.use_val else 'not'))
        print("===========Test is {} used for training ===========".format(
            '' if opt.use_test else 'not'))
    print("=====================================================")

    # set more detail name of checkpoint paths
    checkpoint_path_suffix = "_bs{}".format(opt.batch_size)
    if opt.use_warmup:
        checkpoint_path_suffix += "_warmup"
    if torch.cuda.device_count() > 1:
        checkpoint_path_suffix += "_gpu{}".format(torch.cuda.device_count())

    if opt.checkpoint_path.endswith('_rl'):
        opt.checkpoint_path = opt.checkpoint_path[:
                                                  -3] + checkpoint_path_suffix + '_rl'
    else:
        opt.checkpoint_path += checkpoint_path_suffix
    print("Save model to {}".format(opt.checkpoint_path))

    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5

    acc_steps = getattr(opt, 'acc_steps', 1)
    name_append = opt.name_append
    if len(name_append) > 0 and name_append[0] != '-':
        name_append = '_' + name_append

    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.losses_log_every = len(loader.split_ix['train']) // opt.batch_size
    print("Evaluate on each {} iterations".format(opt.losses_log_every))
    if opt.write_summary:
        print("write summary to {}".format(opt.checkpoint_path))
        tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}

    # load  checkpoint
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        infos_path = os.path.join(opt.start_from,
                                  'infos' + name_append + '.pkl')
        print("Load model information {}".format(infos_path))
        with open(infos_path, 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]

            # this sanity check may not work well, and comment it if necessary
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], \
                    "Command line argument and saved model disagree on '%s' " % checkme

        histories_path = os.path.join(opt.start_from,
                                      'histories' + name_append + '.pkl')
        if os.path.isfile(histories_path):
            with open(histories_path, 'rb') as f:
                histories = utils.pickle_load(f)
    else:  # start from scratch
        print("==============================================")
        print("Initialize training process from all begining")
        print("==============================================")
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()

    infos['opt'] = opt
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    print("==================start from {} iterations -- {} epoch".format(
        iteration, epoch))
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    start_Img_idx = loader.iterators['train']
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_epoch = infos.get('best_epoch', None)
        best_cider = infos.get('best_val_score', 0)
        print("========best history val cider score: {} in epoch {}=======".
              format(best_val_score, best_epoch))

    #  sanity check for the saved model name has a correct index
    if opt.name_append.isdigit() and int(opt.name_append) < 100:
        assert int(
            opt.name_append
        ) - epoch == 1, "dismatch in the model index and the real epoch number"
        epoch += 1
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab

    if torch.cuda.device_count() > 1:
        dp_model = torch.nn.DataParallel(model)
    else:
        dp_model = model
    lw_model = LossWrapper1(model, opt)  # wrap loss into model
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'aoa'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer_path = os.path.join(opt.start_from,
                                      'optimizer' + name_append + '.pth')
        if os.path.isfile(optimizer_path):
            print("Loading optimizer............")
            optimizer.load_state_dict(torch.load(optimizer_path))

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '_' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path,
                                       'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("Save model state to {}".format(checkpoint_path))

        optimizer_path = os.path.join(opt.checkpoint_path,
                                      'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        print("Save model optimizer to {}".format(optimizer_path))

        with open(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
            print("Save training information to {}".format(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append))))

        if histories:
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories' + '%s.pkl' % (append)),
                    'wb') as f:
                utils.pickle_dump(histories, f)
                print("Save training historyes to {}".format(
                    os.path.join(opt.checkpoint_path,
                                 'histories' + '%s.pkl' % (append))))

    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor * opt.refine_lr_decay
                    else:
                        opt.current_lr = opt.learning_rate
                    infos['current_lr'] = opt.current_lr
                    print("Current Learning Rate is: {}".format(
                        opt.current_lr))
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False
            print("{}th Epoch Training starts now!".format(epoch))
            with tqdm(total=len(loader.split_ix['train']),
                      initial=start_Img_idx) as pbar:
                for i in range(start_Img_idx, len(loader.split_ix['train']),
                               opt.batch_size):
                    start = time.time()
                    if (opt.use_warmup
                            == 1) and (iteration < opt.noamopt_warmup):
                        opt.current_lr = opt.learning_rate * (
                            iteration + 1) / opt.noamopt_warmup
                        utils.set_lr(optimizer, opt.current_lr)
                    # Load data from train split (0)
                    data = loader.get_batch('train')
                    # print('Read data:', time.time() - start)

                    if (iteration % acc_steps == 0):
                        optimizer.zero_grad()

                    torch.cuda.synchronize()
                    start = time.time()
                    tmp = [
                        data['fc_feats'], data['att_feats'],
                        data['flag_feats'], data['labels'], data['masks'],
                        data['att_masks']
                    ]
                    tmp = [_ if _ is None else _.cuda() for _ in tmp]
                    fc_feats, att_feats, flag_feats, labels, masks, att_masks = tmp

                    model_out = dp_lw_model(fc_feats, att_feats, flag_feats,
                                            labels, masks, att_masks,
                                            data['gts'],
                                            torch.arange(0, len(data['gts'])),
                                            sc_flag)

                    loss = model_out['loss'].mean()
                    loss_sp = loss / acc_steps

                    loss_sp.backward()
                    if (iteration + 1) % acc_steps == 0:
                        utils.clip_gradient(optimizer, opt.grad_clip)
                        optimizer.step()
                    torch.cuda.synchronize()
                    train_loss = loss.item()
                    end = time.time()
                    if not sc_flag:
                        pbar.set_description(
                            "iter {:8} (epoch {:2}), train_loss = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch, train_loss, end - start))
                    else:
                        pbar.set_description(
                            "iter {:8} (epoch {:2}), avg_reward = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch,
                                    model_out['reward'].mean(), end - start))

                    # Update the iteration and epoch
                    iteration += 1
                    pbar.update(opt.batch_size)
                    if data['bounds']['wrapped']:
                        epoch += 1
                        epoch_done = True
                        # save after each epoch
                        save_checkpoint(model, infos, optimizer)
                        if epoch > 15:  # To save memory, you can comment this part
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append=str(epoch))
                        print(
                            "====================================================="
                        )
                        print(
                            "======Best Cider = {} in epoch {}: iter {}!======"
                            .format(best_val_score, best_epoch,
                                    infos.get('best_itr', None)))
                        print(
                            "====================================================="
                        )

                    # Write training history into summary
                    if (iteration % opt.losses_log_every
                            == 0) and opt.write_summary:
                        # if (iteration % 10== 0) and opt.write_summary:
                        add_summary_value(tb_summary_writer, 'loss/train_loss',
                                          train_loss, iteration)
                        if opt.noamopt:
                            opt.current_lr = optimizer.rate()
                        elif opt.reduce_on_plateau:
                            opt.current_lr = optimizer.current_lr
                        add_summary_value(tb_summary_writer,
                                          'hyperparam/learning_rate',
                                          opt.current_lr, iteration)
                        add_summary_value(
                            tb_summary_writer,
                            'hyperparam/scheduled_sampling_prob',
                            model.ss_prob, iteration)
                        if sc_flag:
                            add_summary_value(tb_summary_writer, 'avg_reward',
                                              model_out['reward'].mean(),
                                              iteration)

                        loss_history[
                            iteration] = train_loss if not sc_flag else model_out[
                                'reward'].mean()
                        lr_history[iteration] = opt.current_lr
                        ss_prob_history[iteration] = model.ss_prob

                    # update infos
                    infos['iter'] = iteration
                    infos['epoch'] = epoch
                    infos['iterators'] = loader.iterators
                    infos['split_ix'] = loader.split_ix

                    # make evaluation on validation set, and save model
                    # unnecessary to eval from the beginning
                    if (iteration % opt.save_checkpoint_every
                            == 0) and eval_ and epoch > 3:
                        # eval model
                        model_path = os.path.join(
                            opt.checkpoint_path,
                            'model_itr%s.pth' % (iteration))
                        if opt.use_val and not opt.use_test:
                            val_split = 'test'
                        if not opt.use_val:
                            val_split = 'val'
                        # val_split = 'val'

                        eval_kwargs = {
                            'split': val_split,
                            'dataset': opt.input_json,
                            'model': model_path
                        }
                        eval_kwargs.update(vars(opt))
                        val_loss, predictions, lang_stats = eval_utils.eval_split(
                            dp_model, lw_model.crit, loader, eval_kwargs)

                        if opt.reduce_on_plateau:
                            if 'CIDEr' in lang_stats:
                                optimizer.scheduler_step(-lang_stats['CIDEr'])
                            else:
                                optimizer.scheduler_step(val_loss)

                        # Write validation result into summary
                        if opt.write_summary:
                            add_summary_value(tb_summary_writer,
                                              'loss/validation loss', val_loss,
                                              iteration)

                            if lang_stats is not None:
                                bleu_dict = {}
                                for k, v in lang_stats.items():
                                    if 'Bleu' in k:
                                        bleu_dict[k] = v
                                if len(bleu_dict) > 0:
                                    tb_summary_writer.add_scalars(
                                        'val/Bleu', bleu_dict, epoch)

                                for k, v in lang_stats.items():
                                    if 'Bleu' not in k:
                                        add_summary_value(
                                            tb_summary_writer, 'val/' + k, v,
                                            iteration)
                        val_result_history[iteration] = {
                            'loss': val_loss,
                            'lang_stats': lang_stats,
                            'predictions': predictions
                        }

                        # Save model if is improving on validation result
                        if opt.language_eval == 1:
                            current_score = lang_stats['CIDEr']
                        else:
                            current_score = -val_loss

                        best_flag = False

                        if best_val_score is None or current_score > best_val_score:
                            best_val_score = current_score
                            infos['best_epoch'] = epoch
                            infos['best_itr'] = iteration
                            best_flag = True

                        # Dump miscalleous informations
                        infos['best_val_score'] = best_val_score
                        histories['val_result_history'] = val_result_history
                        histories['loss_history'] = loss_history
                        histories['lr_history'] = lr_history
                        histories['ss_prob_history'] = ss_prob_history

                        save_checkpoint(model, infos, optimizer, histories)
                        if opt.save_history_ckpt:
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append=str(iteration))

                        if best_flag:
                            best_epoch = epoch
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append='best')
                            print(
                                "update best model at {} iteration--{} epoch".
                                format(iteration, epoch))
                    # reset
                    start_Img_idx = 0
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                print("epoch {} break all".format(epoch))
                save_checkpoint(model, infos, optimizer)
                # save_checkpoint(model, infos, optimizer, append=str(epoch))
                tb_summary_writer.close()
                print("============{} Training Done !==============".format(
                    'Refine' if opt.use_test or opt.use_val else ''))
                break
    except (RuntimeError, KeyboardInterrupt):  # KeyboardInterrupt
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer, append='interrupt')
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 9
0
def train(opt):
    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'vocab': loader.get_vocab(),
    }
    # Load old infos (if there is) and check if models are compatible
    if opt.checkpoint_path is not None and os.path.isfile(
            os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl')):
        with open(
                os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'),
                'rb') as f:
            infos = utils.pickle_load(f)
            print('infos load success')
    infos['opt'] = opt

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:
    if opt.checkpoint_path is not None and os.path.isfile(
            os.path.join(opt.checkpoint_path, 'model.pth')):
        model.load_state_dict(
            torch.load(os.path.join(opt.checkpoint_path, 'model.pth')))
        print('model load success')

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    optimizer = utils.ReduceLROnPlateau(optim.Adam(model.parameters(),
                                                   opt.learning_rate),
                                        factor=0.5,
                                        patience=3)
    # Load the optimizer
    if opt.checkpoint_path is not None and os.path.isfile(
            os.path.join(opt.checkpoint_path, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.checkpoint_path, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    best_val_score = infos.get('best_val_score', None)
    print('iter {}, epoch {}, best_val_score {}'.format(
        iteration, epoch, best_val_score))

    print(sorted(dict(set(vars(opt).items())).items(), key=lambda x: x[0]))
    # Start training
    if opt.self_critical:
        init_scorer(opt.cached_tokens)
    # Assure in training mode
    dp_lw_model.train()
    try:
        while True:
            # Stop if reaching max_epoch
            if epoch >= opt.max_epochs:
                break

            # Load data from train split (0)
            data = loader.get_batch('train')

            torch.cuda.synchronize()

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                    att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])))

            loss = model_out['loss'].mean()

            loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), 0.1)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1

            # Write the training loss summary
            if iteration % opt.losses_log_every == 0:
                tb_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                if opt.self_critical:
                    tb_summary_writer.add_scalar('avg_reward',
                                                 model_out['reward'].mean(),
                                                 iteration)

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch

            # make evaluation on validation set, and save model
            if iteration % opt.save_checkpoint_every == 0:
                tb_summary_writer.add_scalar('epoch', epoch, iteration)
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                _, _, lang_stats = eval_utils.eval_split(
                    dp_model, loader, eval_kwargs)

                optimizer.scheduler_step(-lang_stats['CIDEr'])
                # Write validation result into summary
                for k, v in lang_stats.items():
                    tb_summary_writer.add_scalar(k, v, iteration)

                # Save model if is improving on validation result
                current_score = lang_stats['CIDEr']

                best_flag = False
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscellaneous information
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer)
                if best_flag:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append='best')

    except (RuntimeError, KeyboardInterrupt):
        pass
Ejemplo n.º 10
0
def eval_split_output(model,
                      crit,
                      loader,
                      eval_kwargs={},
                      Discriminator=None,
                      Discriminator_learned=None):
    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 1)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images',
                                 eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)
    internal = eval_kwargs.get('internal', 1)
    sim_model = eval_kwargs.get('sim_model', None)
    bleu_option = eval_kwargs.get('bleu_option', 'closest')
    weight_deterministic_flg = eval_kwargs.get('weight_deterministic_flg', 0)
    cut_length = eval_kwargs.get('cut_length', -1)
    baseline_concat = eval_kwargs.get('baseline_concat', 0)

    # Make sure in the evaluation mode
    model.eval()
    # if internal is not None:
    #     internal.eval()

    loader.reset_iterator(split)
    init_scorer('coco-train-idxs', len(loader.ix_to_word))

    n = 0
    count = 0
    loss = 0
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []
    predictions_for_eval = []
    gts_for_wb = {}
    res_for_wb = {}
    while True:
        # get input data
        data = loader.get_batch(split)
        n = n + loader.batch_size

        # data['att_masks'] = None
        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample
        # try:
        # tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
        #        data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
        #        data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img] if data['att_masks'] is not None else None,
        #        data['bbox'][np.arange(loader.batch_size) * loader.seq_per_img]]
        tmp = [
            data['fc_feats'][np.arange(loader.batch_size) *
                             loader.seq_per_img],
            data['att_feats'][np.arange(loader.batch_size) *
                              loader.seq_per_img],
            data['att_masks'][np.arange(loader.batch_size) *
                              loader.seq_per_img]
            if data['att_masks'] is not None else None,
            data['bbox'][np.arange(loader.batch_size) * loader.seq_per_img]
            if data['att_masks'] is not None else None
        ]

        tmp = [torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp]
        fc_feats, att_feats, att_masks, bbox = tmp

        if weight_deterministic_flg > 0:
            weight_index = np.array(data['weight_index'])
            weight_index = weight_index[np.arange(loader.batch_size) *
                                        loader.seq_per_img]
        else:
            weight_index = None

        sub_att = None
        # forward the model to also get generated samples for each image
        with torch.no_grad():
            if baseline_concat == 0:
                seq = model(fc_feats,
                            att_feats,
                            att_masks,
                            internal,
                            opt=eval_kwargs,
                            sim_model=sim_model,
                            bbox=bbox,
                            sub_att=sub_att,
                            label_region=data['label_region'],
                            mode='sample',
                            weight_index=weight_index,
                            test_flg=True)[0].data
            else:
                seq, model = make_baseline_result(model, fc_feats, att_feats,
                                                  eval_kwargs, data,
                                                  weight_index)

            if Discriminator is not None:
                hokan = torch.zeros(
                    (len(seq), 1)).type(torch.LongTensor).cuda()
                # fake_data = torch.cat((hokan, seq, hokan), 1)
                seq = seq.type(hokan.type())
                fake_data = torch.cat((seq, hokan), 1)
                if cut_length > 0:
                    _, dis_score = Discriminator.fixed_length_forward(
                        fake_data, cut_length)
                    _, dis_leanrned_score = Discriminator_learned.fixed_length_forward(
                        fake_data, cut_length)
                else:
                    dis_score = Discriminator(fake_data)
                    dis_leanrned_score = Discriminator_learned(fake_data)
                dis_score = dis_score.data.cpu().numpy()
                dis_leanrned_score = dis_leanrned_score.data.cpu().numpy()
                dis_score = dis_score[:, 1]
                dis_leanrned_score = dis_leanrned_score[:, 1]
            else:
                dis_score = np.zeros(len(seq))
                dis_leanrned_score = np.zeros(len(seq))

        # Print beam search
        if beam_size > 1 and verbose_beam:
            for i in range(loader.batch_size):
                print('\n'.join([
                    utils.decode_sequence(loader.get_vocab(),
                                          _['seq'].unsqueeze(0))[0]
                    for _ in model.done_beams[i]
                ]))
                print('--' * 10)
        sents = utils.decode_sequence(loader.get_vocab(), seq)

        # informations cal for att
        gen_result_ = seq.data.cpu().numpy()
        word_exist = (gen_result_ > 0).astype(np.int).reshape(
            gen_result_.shape[0], gen_result_.shape[1], 1)
        weights_for_att = model.weights_p.data.cpu()
        # att_score = get_attnorm_reward(word_exist, weights_for_att).mean(axis=1)
        att_score = np.zeros(len(gen_result_))
        if bbox is not None:
            att_score_hard = get_hardatt_reward(bbox, model.weights.data.cpu(),
                                                seq.data.cpu())
        else:
            att_score_hard = None

        for k, sent in enumerate(sents):
            if len(model.attentions.shape) == 3:
                model.attentions = model.attentions.reshape(
                    1, model.attentions.shape[0], model.attentions.shape[1],
                    model.attentions.shape[2])
            if len(model.similarity.size()) == 1:
                model.similarity = model.similarity.view(
                    1, model.similarity.size(0))
                model.region_b1 = model.region_b1.view(1,
                                                       model.region_b1.size(0))
                model.region_b4 = model.region_b4.view(1,
                                                       model.region_b4.size(0))
                model.region_cider = model.region_cider.view(
                    1, model.region_cider.size(0))

            if att_score_hard is not None:
                entry = {
                    'image_id': data['infos'][k]['id'],
                    'caption': sent,
                    'attention': model.attentions[:, :, k, :],
                    'similarity': model.similarity[k].numpy(),
                    'att_score': att_score[k],
                    'att_score_hard': att_score_hard[k],
                    'dis_score': dis_score[k],
                    'dis_learned_score': dis_leanrned_score[k],
                    'region_b1': model.region_b1[k].numpy(),
                    'region_b4': model.region_b4[k].numpy(),
                    'region_cider': model.region_cider[k].numpy(),
                    'region_rouge': model.region_rouge[k].numpy(),
                    'region_meteor': model.region_meteor[k].numpy()
                }
            else:
                entry = {
                    'image_id': data['infos'][k]['id'],
                    'caption': sent,
                    'attention': model.attentions[:, :, k, :],
                    'similarity': model.similarity[k].numpy(),
                    'att_score': att_score[k],
                    'att_score_hard': None,
                    'dis_score': dis_score[k],
                    'dis_learned_score': dis_leanrned_score[k],
                    'region_b1': model.region_b1[k].numpy(),
                    'region_b4': model.region_b4[k].numpy(),
                    'region_cider': model.region_cider[k].numpy(),
                    'region_rouge': model.region_rouge[k].numpy(),
                    'region_meteor': model.region_meteor[k].numpy()
                }
            entry_for_leval = {
                'image_id': data['infos'][k]['id'],
                'caption': sent
            }
            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            predictions_for_eval.append(entry_for_leval)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(
                    eval_kwargs['image_root'], data['infos'][k]
                    ['file_path']) + '" vis/imgs/cifar10_output/img' + str(
                        len(predictions)) + '.jpg'  # bit gross
                print(cmd)
                exit()
                os.system(cmd)

            if verbose:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

        # if we wrapped around the split or used up val imgs budget then bail
        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if num_images != -1:
            ix1 = min(ix1, num_images)
        for i in range(n - ix1):
            predictions.pop()
            predictions_for_eval.pop()

        if verbose:
            print('evaluating validation preformance... %d/%d (%f)' %
                  (ix0 - 1, ix1, loss))

        if data['bounds']['wrapped']:
            break
        if num_images >= 0 and n >= num_images:
            break

        corrent_weights = model.weights_p.data.cpu().numpy()

        # if data['label_region'] is not None:
        #     gts_for_wb, res_for_wb, count, corrent_weights = preprocess_wbleu(seq.data.cpu().numpy(), data['gts'],
        #                                                         data['label_region'], corrent_weights, count, gts_for_wb, res_for_wb)
        if model.pre_weights_p is None or baseline_concat == 1:
            r_weights = corrent_weights
        else:
            r_weights = np.concatenate([r_weights, corrent_weights], axis=0)

    lang_stats = None
    if lang_eval == 1:

        lang_stats = language_eval(
            dataset,
            predictions_for_eval,
            eval_kwargs['id'],
            split,
            detail_flg=True,
            wbleu_set=[gts_for_wb, res_for_wb, r_weights],
            option=bleu_option)
        for j in range(len(predictions)):
            predictions[j].update(lang_stats[1][str(
                predictions[j]['image_id'])])
        similarity_calculator(predictions)
        att_score_calculator(predictions)
    elif lang_eval == 2:
        labels = data['labels']
        lang_stats = utils.language_eval_excoco(sents, labels, loader)

    # pdb.set_trace()
    # Switch back to training mode
    # model.train()

    return loss_sum / loss_evals, predictions, lang_stats
Ejemplo n.º 11
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}

    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    #ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = convcap(opt).cuda()
    #  pretrained_dict = torch.load('log_xe_final_before_review/all2model12000.pth')
    #   model.load_state_dict(pretrained_dict, strict=False)
    back_model = convcap(opt).cuda()
    back_model.train()
    #   d_pretrained_dict = torch.load('log_xe_final_before_review/all2d_model12000.pth')
    #   back_model.load_state_dict(d_pretrained_dict, strict=False)
    dp_model = model
    dp_model.train()
    dis_model = Discriminator(512, 512, 512, 0.2)
    dis_model = dis_model.cuda()
    dis_model.train()
    #    dis_pretrained_dict = torch.load('./log_xe_final_before_review/all2dis_model12000.pth')
    #    dis_model.load_state_dict(dis_pretrained_dict, strict=False)
    d_optimizer = utils.build_optimizer(dis_model.parameters(), opt)
    back_model.train()
    # Loss functio}
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag

    optimizer = utils.build_optimizer_adam(
        chain(model.parameters(), back_model.parameters()), opt)

    #back_optimizer = utils.build_optimizer(back_model.parameters(), opt)
    update_lr_flag = True

    #Load the optimizer

    #   if os.path.isfile(os.path.join('log_xe_final_before_review/',"optimizer.pth")):
    #      optimizer.load_state_dict(torch.load(os.path.join('log_xe_final_before_review/', 'optimizer.pth')))
    #      print ('optimiser loaded')
    #   print (optimizer)
    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                #opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                #model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['dist'],
            data['masks'], data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, attmasks = tmp
        labels = labels.long()
        labels[:, :, 0] = 8667
        nd_labels = labels
        batchsize = fc_feats.size(0)
        # Forward pass and loss
        optimizer.zero_grad()
        d_steps = 1
        g_steps = 1
        #print (torch.sum(labels!=0), torch.sum(masks!=0))
        if 1:
            if iteration >= 0:

                if 1:
                    dp_model.eval()
                    back_model.eval()
                    with torch.no_grad():
                        _, x_all_d = dp_model(fc_feats, att_feats,
                                              nd_labels.long(), 30, 6)

                        labels_nd = nd_labels.view(batchsize, -1)
                        idx = [
                            i for i in range(labels_nd.size()[1] - 1, -1, -1)
                        ]
                        labels_flip_nd = labels_nd[:, idx]
                        labels_flip_nd = labels_flip_nd.view(batchsize, 6, 30)
                        labels_flip_nd[:, :, 0] = 8667
                        _, x_all_flip_d = back_model(fc_feats, att_feats,
                                                     labels_flip_nd, 30, 6)

                        x_all_d = x_all_d[:, :, :-1]
                        x_all_flip_d = x_all_flip_d[:, :, :-1]

                        idx = [
                            i
                            for i in range(x_all_flip_d.size()[2] - 1, -1, -1)
                        ]
                        idx = torch.LongTensor(idx[1:])
                        idx = Variable(idx).cuda()
                        invert_backstates = x_all_flip_d.index_select(2, idx)

                        x_all_d.detach()
                        invert_backstates.detach()
                    x_all_d = x_all_d[:, :, :-1]

                    autoregressive_scores = dis_model(
                        x_all_d.transpose(2, 1).cuda())
                    teacher_forcing_scores = dis_model(
                        invert_backstates.transpose(2, 1).cuda())

                    tf_loss, ar_loss = _calcualte_discriminator_loss(
                        teacher_forcing_scores, autoregressive_scores)

                    tf_loss.backward(retain_graph=True)
                    ar_loss.backward()

                    d_optimizer.step()
                    for p in dis_model.parameters():
                        p.data.clamp_(-0.01, 0.01)

                    torch.cuda.synchronize()
                    total_time = time.time() - start

                if 1:
                    dp_model.train()
                    back_model.train()
                    wordact, x_all = dp_model(fc_feats, att_feats, labels, 30,
                                              6)
                    mask = masks.view(batchsize, -1)
                    mask = mask[:, 1:].contiguous()
                    wordact = wordact[:, :, :-1]
                    wordact_t = wordact.permute(0, 2, 1).contiguous()
                    wordact_t = wordact_t.view(
                        wordact_t.size(0) * wordact_t.size(1), -1)
                    labels_flat = labels.view(batchsize, -1)
                    wordclass_v = labels_flat[:, 1:]
                    wordclass_t = wordclass_v.contiguous().view(\
                     wordclass_v.size(0) * wordclass_v.size(1), 1)
                    maskids = torch.nonzero(
                        mask.view(-1).cpu()).numpy().reshape(-1)
                    loss_xe = F.cross_entropy(wordact_t[maskids, ...], \
                     wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])).cuda()

                    idx = [i for i in range(labels_flat.size()[1] - 1, -1, -1)]
                    labels_flip = labels_flat[:, idx]
                    labels_flip = labels_flip.view(batchsize, 6, 30)
                    labels_flip[:, :, 0] = 8667
                    wordact, x_all_flip = back_model(fc_feats, att_feats,
                                                     labels_flip, 30, 6)
                    mask = masks.view(batchsize, -1).flip((1, ))
                    reverse_mask = mask[:, 1:].contiguous()
                    wordact = wordact[:, :, :-1]
                    wordact_t = wordact.permute(0, 2, 1).contiguous()
                    wordact_t = wordact_t.view(
                        wordact_t.size(0) * wordact_t.size(1), -1)
                    labels_flip = labels_flip.contiguous().view(-1, 6 * 30)
                    wordclass_v = labels_flip[:, 1:]
                    wordclass_t = wordclass_v.contiguous().view(\
                     wordclass_v.size(0) * wordclass_v.size(1), 1)
                    maskids = torch.nonzero(
                        reverse_mask.view(-1).cpu()).numpy().reshape(-1)

                    loss_xe_flip = F.cross_entropy(wordact_t[maskids, ...], \
                     wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])).cuda()

                    train_loss = loss_xe

                    x_all_flip = x_all_flip[:, :, :-1].cuda()
                    x_all = x_all[:, :, :-1].cuda()

                    idx = [i for i in range(x_all_flip.size()[2] - 1, -1, -1)]
                    idx = torch.LongTensor(idx[1:])
                    idx = Variable(idx).cuda()

                    invert_backstates = x_all_flip.index_select(2, idx)
                    invert_backstates = invert_backstates.detach()
                    l2_loss = ((x_all[:, :, :-1] -
                                invert_backstates)**2).mean()

                    autoregressive_scores = dis_model(
                        x_all.transpose(2, 1).cuda())

                    ad_loss = _calculate_generator_loss(
                        autoregressive_scores).sum()

                    all_loss = loss_xe + loss_xe_flip + l2_loss
                    ad_loss.backward(retain_graph=True)
                    all_loss.backward()
                    #            utils.clip_gradient(optimizer, opt.grad_clip)
                    optimizer.step()

            if 1:
                if iteration % opt.print_freq == 1:
                    print('Read data:', time.time() - start)
                    if not sc_flag:
                        print("iter {} (epoch {}), train_loss = {:.3f},l2_loss= {:.3f}, flip_loss = {:.3f}, loss_ad = {:.3f}, fake = {:.3f}, real = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, loss_xe, l2_loss, loss_xe_flip, ad_loss, ar_loss, tf_loss, data_time, total_time))
                    else:
                        print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                            .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                update_lr_flag = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tb_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                loss_history[
                    iteration] = train_loss if not sc_flag else np.mean(
                        reward[:, 0])
                lr_history[iteration] = opt.current_lr
                #ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
            if (iteration % opt.save_checkpoint_every == 0):
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'all2model{:05d}.pth'.format(iteration))
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'all2d_model{:05d}.pth'.format(iteration))
                torch.save(back_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'all2dis_model{:05d}.pth'.format(iteration))
                torch.save(dis_model.state_dict(), checkpoint_path)
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                # Evaluate model
        if (iteration % 1000 == 0):
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path, 'd_model.pth')
            torch.save(back_model.state_dict(), checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           'dis_model.pth')
            torch.save(dis_model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            #histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'd_model-best.pth')
                torch.save(back_model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'dis_model-best.pth')
                torch.save(dis_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)
Ejemplo n.º 12
0
def eval_split(model, crit, loader, eval_kwargs={}):
    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 1)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    rank_eval = eval_kwargs.get('rank_eval', 0)
    dataset = eval_kwargs.get('dataset', 'person')
    beam_size = eval_kwargs.get('beam_size', 1)
    remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
    os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
    use_joint=eval_kwargs.get('use_joint', 0)
    init_scorer('person-'+split+'-words')
    # Make sure in the evaluation mode
    model.eval()

    loader.reset_iterator(split)

    n = 0
    loss = 0
    losses={}
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []
    visual={"image_id":[],"personality":[],"generation":[],"gd":[],"densecap":[],"Bleu1_gen/cap":[],"Bleu2_gen/cap":[],"Bleu3_gen/cap":[],"Bleu4_gen/cap":[],"Cider_gen/cap":[],"Bleu1_gen/gd":[],"Bleu2_gen/gd":[],"Bleu3_gen/gd":[],"Bleu4_gen/gd":[],"Cider_gen/gd":[],"Bleu1_cap/gd":[],"Bleu2_cap/gd":[],"Bleu3_cap/gd":[],"Bleu4_cap/gd":[],"Cider_cap/gd":[], "Bleu1_gd/gen":[],"Bleu2_gd/gen":[],"Bleu3_gd/gen":[],"Bleu4_gd/gen":[],"Cider_gd/gen":[]}
    if split=='change':
        visual['new_personality']=[]
    minopt=0
    verbose_loss = True
    while True:
        data = loader.get_batch(split)
        n = n + loader.batch_size
        if data.get('labels', None) is not  None and verbose_loss:
            # forward the model to get loss
            tmp = [data['fc_feats'], data['att_feats'],data['densecap'], data['labels'], data['masks'], data['att_masks'], data['personality']]
            tmp = [_.cuda() if _ is not None else _ for _ in tmp]
            fc_feats, att_feats,densecap, labels, masks, att_masks,personality = tmp
            with torch.no_grad():
               if eval_kwargs.get("use_dl",0)>0:
                    gen_result, sample_logprobs,alogprobs  = model(fc_feats, att_feats,densecap, att_masks,personality, opt={'sample_method':'sample'}, mode='sample')
                    loss = crit(model(fc_feats, att_feats,densecap, labels, att_masks,personality), alogprobs, labels[:,1:], masks[:,1:]).item()
               else:
                   loss = crit(model(fc_feats, att_feats,densecap, labels, att_masks,personality), labels[:,1:], masks[:,1:])
            
            loss_sum = loss_sum + loss
            loss_evals = loss_evals + 1
            if use_joint==1:
                for k,v in model.loss().items():
                    if k not in losses:
                        losses[k] = 0
                    losses[k] += v
        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample
        tmp = [data['fc_feats'][np.arange(loader.batch_size)], 
            data['att_feats'][np.arange(loader.batch_size)] if data['att_feats'] is not None else None,
            data['densecap'][np.arange(loader.batch_size)],
            data['att_masks'][np.arange(loader.batch_size)] if data['att_masks'] is not None else None,
            data['personality'][np.arange(loader.batch_size)]]
        tmp = [_.cuda() if _ is not None else _ for _ in tmp]
        fc_feats, att_feats,densecap, att_masks,personality = tmp
        if split =='change':
            for pindex,pid in personality.nonzero():
                personality[pindex][pid]=0
                newpid = random.choice(range(1,len(personality)-1))
                personality[pindex][newpid]=1
        ground_truth =  data['labels'][:][:,1:]
        # forward the model to also get generated samples for each image
        with torch.no_grad():
            seq = model(fc_feats, att_feats,densecap, att_masks,personality, opt=eval_kwargs, mode='sample')[0].data
        
        # Print beam search
#        if beam_size > 1 and verbose_beam:
#            for i in range(loader.batch_size):
#                print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
#                print('--' * 10)
        sents = utils.decode_sequence(loader.get_vocab(), seq)
        gd_display = utils.decode_sequence(loader.get_vocab(), ground_truth)
        for k, s in enumerate(sents):
            if beam_size > 1 and verbose_beam:
                beam_sents = [utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[k]] 
                maxcider=0
                mincider=1000
                sent =s
                for b,sq in enumerate(beam_sents):
                    current_cider=cal_cider(gd_display[k*loader.seq_per_img:(k+1)*loader.seq_per_img],sq)
                    if current_cider >= maxcider:
                        maxcider=current_cider
                        sentmax=sq
                    if current_cider <= mincider:
                        mincider=current_cider
                        sentmin=sq
                    if minopt==1:
                        sent=sentmin
                    elif minopt==-1:
                        sent=sentmax
                    else:
                        sent=s
                    
            else:
                sent = s
            #print("best sentence: ",sent) 
            newpidstr = str(personality[k].nonzero()[0].item())
            changed_personality =loader.get_personality()[newpidstr]
            entry = {'image_id': data['infos'][k]['id']+"_"+data['infos'][k]['personality'], 'caption':sent,'gd':gd_display[k*loader.seq_per_img:(k+1)*loader.seq_per_img]}
            if( entry not in predictions ):
                densecap_display = utils.decode_sequence(loader.get_vocab(), data['densecap'][k])
                allscore = get_scores_separate([densecap_display],[sent]) # gd is the densecap and test is generation, len(common)/len(generation)
                for bk in allscore:
                    visual[bk+"_gen/cap"].append(allscore[bk])
                allscore_gd = get_scores_separate([gd_display[k*loader.seq_per_img:(k+1)*loader.seq_per_img]],[sent])
                for bkgd in allscore_gd:
                    visual[bkgd+"_gen/gd"].append(allscore_gd[bkgd])
                allscore_capgd = get_scores_separate([gd_display[k*loader.seq_per_img:(k+1)*loader.seq_per_img]],densecap_display)
                for cap_bkgd in allscore_capgd:
                    visual[cap_bkgd+"_cap/gd"].append(allscore_capgd[cap_bkgd])
                
                allscore_gd_flip = get_scores_separate([[sent]],gd_display[k*loader.seq_per_img:(k+1)*loader.seq_per_img]) 
                for bkgd in allscore_gd_flip:
                    visual[bkgd+"_gd/gen"].append(allscore_gd_flip[bkgd])                
                
                visual["image_id"].append(data['infos'][k]['id'])
                visual["personality"].append(data['infos'][k]['personality'])
                if split=='change':
                    visual["new_personality"].append(changed_personality)
                visual['generation'].append(sent)
                visual["gd"].append(gd_display[k*loader.seq_per_img:(k+1)*loader.seq_per_img])
                visual["densecap"].append(densecap_display)
            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
                print(cmd)
                os.system(cmd)

            if verbose:
                print('--------------------------------------------------------------------')
                if split=='change':
                    print('image %s{%s--------->%s}: %s' %(entry['image_id'],changed_personality,entry['gd'], entry['caption']))
                else:
                    print('image %s{%s}: %s' %(entry['image_id'],entry['gd'], entry['caption']))
                print('--------------------------------------------------------------------')

        # if we wrapped around the split or used up val imgs budget then bail
        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if num_images != -1:
            ix1 = min(ix1, num_images)
        for i in range(n - ix1):
            predictions.pop()
        if verbose:
            print('evaluating validation preformance... %d/%d (%f)' %(ix0 - 1, ix1, loss))
        if data['bounds']['wrapped']:
            break
        if num_images >= 0 and n >= num_images:
            break
    allwords = " ".join(visual['generation'])
    allwords = allwords.split(" ")
    print("sets length of allwords:",len(set(allwords)))
    print("length of allwords:",len(allwords))
    print("rate of set/all:",len(set(allwords))/len(allwords))
    lang_stats = None
    if lang_eval == 1:
        lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split)
    
    df = pd.DataFrame.from_dict(visual)
    df.to_csv("visual_res/"+eval_kwargs['id']+"_"+str(split)+"_"+"visual.csv")
    if use_joint==1:
        ranks = evalrank(model, loader, eval_kwargs) if rank_eval else {}
    # Switch back to training mode

    model.train()
    if use_joint==1:
        losses = {k:v/loss_evals for k,v in losses.items()}
        losses.update(ranks)
        return losses, predictions, lang_stats
    return loss_sum/loss_evals, predictions, lang_stats
Ejemplo n.º 13
0
def train(opt, num_switching=None):
    global internal
    if opt.gpu2 is None:
        torch.cuda.set_device(opt.gpu)
    RL_count = 0
    pure_reward = None

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    # set dataloder
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.baseline_concat = 0

    # setting of record
    result_path = '/mnt/workspace2019/nakamura/selfsequential/log_python3/' + opt.checkpoint_path
    tb_summary_writer = tb and tb.SummaryWriter(result_path)

    infos = {}
    histories = {}


    # --- pretrained model loading --- #
    if opt.start_from is not None:
        opt.start_from = '/mnt/workspace2019/nakamura/selfsequential/log_python3/' + opt.start_from
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        infos = cPickle.load(open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), mode='rb'))
        saved_model_opt = infos['opt']
        # need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
        need_be_same = ["rnn_type", "rnn_size", "num_layers"]
        for checkme in need_be_same:
            assert vars(saved_model_opt)[checkme] == vars(opt)[
                checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            histories = cPickle.load(open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl') , mode='rb'))
    if opt.sf_epoch is not None and opt.sf_itr is not None:
        iteration = opt.sf_itr
        epoch = opt.sf_epoch
    else:
        iteration = infos.get('iter', 0)
        epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    #---------------------------------------#

    # I forget about these parameter, they maybe are not used.
    b_regressor = None
    opt.regressor = b_regressor

    # model setting
    if opt.gpu2 is not None:
        model = models.setup(opt).cuda()
        dp_model = torch.nn.DataParallel(model)
    else:
        model = models.setup(opt).cuda()
        dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    # set rl mode and internal critic and similairty model
    info_json = json.load(open(opt.input_json))
    sim_model = None
    new_internal = None
    if opt.internal_model == 'sim' or opt.internal_model == 'sim_newr'  or opt.internal_model == 'sim_dammy':

        # setting internal critic and similarity prediction network
        sim_model = sim.Sim_model(opt.input_encoding_size, opt.rnn_size, vocab_size=len(info_json['ix_to_word']))

        if opt.region_bleu_flg == 0:
            if opt.sim_pred_type == 0:
                # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim2/model_13_1700.pt'
                model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_bu/model_6_0.pt'
            elif opt.sim_pred_type == 1:
                model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_noshuffle04model_71_1300.pt'
            elif opt.sim_pred_type == 2:
                model_root = '/mnt/workspace2019/nakamura/selfsequential/sim_model/subset_similarity/model_0_3000.pt'
            else:
                print('select 0 or 1')
                exit()
            checkpoint = torch.load(model_root, map_location='cuda:0')
            sim_model.load_state_dict(checkpoint['model_state_dict'])
            sim_model.cuda()
            sim_model.eval()
            for param in sim_model.parameters():
                param.requires_grad = False
            sim_model_optimizer = None
        elif opt.region_bleu_flg == 1:
            sim_model.cuda()
            if opt.sf_internal_epoch is not None:
                sim_model.load_state_dict(
                    torch.load(os.path.join(opt.start_from, 'sim_model_' + str(opt.sf_internal_epoch) + '_' + str(
                        opt.sf_internal_itr) + '.pth')))
                # sim_model_optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'internal_optimizer_' + str(
                #     opt.sf_internal_epoch) + '_' + str(opt.sf_internal_itr) + '.pth')))
            sim_model_optimizer = utils.build_internal_optimizer(sim_model.parameters(), opt)
        else:
            print('not implimented')
            exit()


        if opt.only_critic_train == 1:
            random.seed(100)
        if opt.critic_encode==1:
            internal = models.CriticModel_with_encoder(opt)
        elif opt.bag_flg == 1:
            internal = models.CriticModel_bag(opt)
        elif opt.ppo == 1:
            # internal = models.CriticModel_sim(opt)
            internal = models.CriticModel_nodropout(opt)
            new_internal = models.CriticModel_nodropout(opt)
            internal.load_state_dict(new_internal.state_dict())
        elif opt.input_h_flg == 1:
            internal = models.CriticModel_sim(opt)
        else:
            internal = models.CriticModel_sim_h(opt)

        internal = internal.cuda()
        if new_internal is not None:
            new_internal = new_internal.cuda()

        if opt.ppo == 1:
            internal_optimizer = utils.build_internal_optimizer(new_internal.parameters(), opt)
        else:
            internal_optimizer = utils.build_internal_optimizer(internal.parameters(), opt)

        if opt.sf_internal_epoch is not None:
            internal.load_state_dict(torch.load(os.path.join(opt.start_from,'internal_' + str(opt.sf_internal_epoch) + '_' + str(
                                                                 opt.sf_internal_itr) + '.pth')))
            internal_optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'internal_optimizer_' + str(
                opt.sf_internal_epoch) + '_' + str(opt.sf_internal_itr) + '.pth')))
            # new_internal = models.CriticModel_nodropout(opt)
            new_internal.load_state_dict(torch.load(os.path.join(opt.start_from,'internal_' + str(opt.sf_internal_epoch) + '_' + str(
                                                                 opt.sf_internal_itr) + '.pth')))
        if opt.multi_learn_flg != 1:
            if opt.internal_rl_flg == 1:
                internal_rl_flg = True
                dp_model.eval()
            else:
                internal.eval()
                internal_rl_flg = False
        else:
            internal_rl_flg = True
    else:
        if opt.sim_reward_flg > 0:
            # setting internal critic and similarity prediction network
            sim_model = sim.Sim_model(opt.input_encoding_size, opt.rnn_size, vocab_size=len(info_json['ix_to_word']))
            if opt.sim_pred_type == 0:
                # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim2/model_13_1700.pt'
                # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_bu/model_6_0.pt'
                model_root = '/mnt/workspace2019/nakamura/selfsequential/sim_model/no_shuffle_simforcoco/model_37_34000.pt'
            elif opt.sim_pred_type == 1:
                model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_noshuffle04model_71_1300.pt'
            elif opt.sim_pred_type == 2:
                model_root = '/mnt/workspace2019/nakamura/selfsequential/sim_model/subset_similarity/model_0_3000.pt'
            else:
                print('select 0 or 1')
                exit()

            if opt.region_bleu_flg == 0:
                if opt.sim_pred_type == 0:
                    # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim2/model_13_1700.pt'
                    opt.sim_model_dir = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_bu/model_6_0.pt'
                elif opt.sim_pred_type == 1:
                    opt.sim_model_dir = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_noshuffle04model_71_1300.pt'
                elif opt.sim_pred_type == 2:
                    opt.sim_model_dir = '/mnt/workspace2019/nakamura/selfsequential/sim_model/subset_similarity/model_0_3000.pt'
                else:
                    opt.sim_model_dir = '/mnt/workspace2019/nakamura/selfsequential/log_python3/log_' + opt.id + '/sim_model' + opt.model[-13:-4] + '.pth'

                checkpoint = torch.load(opt.sim_model_dir, map_location='cuda:0')
                sim_model.load_state_dict(checkpoint['model_state_dict'])
                sim_model.cuda()
                sim_model.eval()
                for param in sim_model.parameters():
                    param.requires_grad = False
                sim_model_optimizer = None
            elif opt.region_bleu_flg == 1:
                sim_model_optimizer = utils.build_internal_optimizer(sim_model.parameters(), opt)
                sim_model.cuda()

        internal = None
        internal_optimizer = None
        internal_rl_flg = False
        opt.c_current_lr = 0
    # opt.internal = internal

    # set Discriminator
    if opt.discriminator_weight > 0:
        dis_opt = opt
        if opt.dis_type == 'coco':
            discrimiantor_model_dir = '/mnt/workspace2018/nakamura/selfsequential/discriminator_log/coco/discriminator_150.pth'
            dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_coco_for_discriminator_label.h5'
            dis_opt.input_json = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_coco_for_discriminator.json'
        elif opt.dis_type == 'iapr':
            discrimiantor_model_dir = '/mnt/workspace2018/nakamura/selfsequential/discriminator_log/iapr_dict/discriminator_125.pth'
            dis_opt.input_label_h5 = '/mnt/workspace2019/visual_genome_pretrain/iapr_talk_cocodict_label.h5'
            dis_opt.input_json = '/mnt/workspace2018/nakamura/IAPR/iapr_talk_cocodict.json'
        elif opt.dis_type == 'ss':
            discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/shuttorstock_dict/discriminator_900.pth'
            dis_opt.input_label_h5 = '/mnt/workspace2019/nakamura/shutterstock/shuttorstock_talk_cocodict_label.h5'
            dis_opt.input_json = '/mnt/workspace2019/nakamura/shutterstock/shuttorstock_talk_cocodict.json'
        elif opt.dis_type == 'sew':
            discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/sew/discriminator_900.pth'
            dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk_label.h5'
            dis_opt.input_json = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk.json'
        elif opt.dis_type == 'sew_cut5':
            discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/sew_cut5/discriminator_90.pth'
            dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk_label.h5'
            dis_opt.input_json = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk.json'
            opt.cut_length = 5
        elif opt.dis_type == 'vg_cut5':
            opt.cut_length = 5
            discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/vg_cut5/discriminator_200.pth'
            dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_subset_vg_larger_label.h5'
            dis_opt.input_json = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_subset_vg_larger_addregions.json'
        else:
            print('select existing discriminative model!')
            exit()

        discriminator_path_learned = os.path.join(result_path, 'discriminator_{}_{}.pth'.format(epoch, iteration))
        Discriminator = dis_utils.Discriminator(opt)
        if os.path.isfile(discriminator_path_learned):
            Discriminator.load_state_dict(torch.load(discriminator_path_learned, map_location='cuda:' + str(opt.gpu)))
        else:
            Discriminator.load_state_dict(torch.load(discrimiantor_model_dir, map_location='cuda:' + str(opt.gpu)))
        Discriminator = Discriminator.cuda()
        # change discriminator learning rate
        # opt.learning_rate = opt.learning_rate/10
        dis_optimizer = utils.build_optimizer(Discriminator.parameters(), opt)
        # for group in dis_optimizer.param_groups:
        #     group['lr'] = opt.learning_rate/100
        Discriminator.eval()
        dis_loss_func = nn.BCELoss().cuda()
        dis_loader = dis_dataloader.DataLoader(dis_opt)
    else:
        Discriminator = None
        dis_loader = None
        dis_optimizer = None

    # set Acter Critic network
    if opt.actor_critic_flg == 1:
        Q_net = models.Actor_Critic_Net_upper(opt)
        target_Q_net = models.Actor_Critic_Net_upper(opt)
        Q_net.load_state_dict(target_Q_net.state_dict())
        target_model = models.setup(opt).cuda()
        target_model.load_state_dict(model.state_dict())
        target_model.eval()
        Q_net.cuda()
        target_Q_net.cuda()
        Q_net_optimizer = utils.build_optimizer(Q_net.parameters(), opt)
    elif opt.actor_critic_flg == 2:
        Q_net = models.Actor_Critic_Net_seq(opt)
        target_Q_net = models.Actor_Critic_Net_seq(opt)
        Q_net.load_state_dict(target_Q_net.state_dict())
        target_model = models.setup(opt).cuda()
        target_model.load_state_dict(model.state_dict())
        target_model.eval()
        Q_net.cuda()
        target_Q_net.cuda()
        Q_net_optimizer = utils.build_optimizer(Q_net.parameters(), opt)

        seq_mask = torch.zeros((opt.batch_size * opt.seq_per_img, opt.seq_length, opt.seq_length)).cuda().type(torch.cuda.LongTensor)
        for i in range(opt.seq_length):
            seq_mask[:, i, :i] += 1
    elif opt.t_model_flg == 1:
        target_model = models.setup(opt).cuda()
        target_model.load_state_dict(model.state_dict())
        target_model.eval()
    else:
        target_model = None

    baseline = None
    new_model = None
    # set functions calculating loss
    if opt.caption_model == 'hcatt_hard' or opt.caption_model == 'basicxt_hard' or opt.caption_model == 'hcatt_hard_nregion' or opt.caption_model == 'basicxt_hard_nregion' :
        if opt.ppo == 1:
            new_model = models.setup(opt).cuda()
            new_model.load_state_dict(model.state_dict())
            # new_optimizer = utils.build_optimizer(new_model.parameters(), opt)
            # new_model.eval()

        # If you use hard attention, use this setting (but is is not implemented completely)
        crit = utils.LanguageModelCriterion_hard()
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    rl_crit_hard = utils.RewardCriterion_hard()
    rl_crit_conly = utils.RewardCriterion_conly()
    rl_crit_hard_base = utils.RewardCriterion_hard_baseline()
    att_crit = utils.AttentionCriterion()

    if opt.caption_model == 'hcatt_hard' and opt.ppo == 1:
        optimizer = utils.build_optimizer(new_model.parameters(), opt)
    else:
        # set optimizer
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        if opt.sf_epoch is None:
            optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        else:
            optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer_' + str(opt.sf_epoch) + '_' +str(opt.sf_itr) + '.pth')))

    critic_train_count = 0
    total_critic_reward = 0
    pre_para = None

    #------------------------------------------------------------------------------------------------------------#
    # training start
    while True:
        train_loss = 0
        if update_lr_flag:
            # cahnge lr
            opt, optimizer, model, internal_optimizer, dis_optimizer = utils.change_lr(opt, epoch, optimizer, model, internal_optimizer, dis_optimizer)

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                # internal_rl_flg == False
                init_scorer(opt.cached_tokens, len(info_json['ix_to_word']))
            else:
                sc_flag = False

            update_lr_flag = False

        # # !!!!!
        # internal_rl_flg = False
        # model.train()
        # internal.eval()
        # #!!!!!

        # Load data from train split (0)
        data = loader.get_batch('train')

        torch.cuda.synchronize()
        start = time.time()

        # get datch
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'],
               data['bbox'], data['sub_att'], data['fixed_region']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks, bbox, sub_att, fixed_region = tmp

        optimizer.zero_grad()
        # calculating loss...
        if not sc_flag:
            # use cross entropy
            if opt.weight_deterministic_flg > 0:
                weight_index = np.array(data['weight_index'])
                # fc_feats = fc_feats * 0.0
                output = dp_model(fc_feats, att_feats, labels, att_masks, internal, weight_index=weight_index)
                # output = dp_model(fc_feats, att_feats, labels, att_masks, internal, weight_index=None)
            else:
                output = dp_model(fc_feats, att_feats, labels, att_masks, internal)
            if opt.caption_model == 'hcatt_prob':
                print(torch.exp(output).mean(),  model.probs.mean())
                output = output + model.probs.view(output.size(0), output.size(1), 1)
                loss = crit(output, labels[:,1:], masks[:,1:])
            elif opt.caption_model != 'hcatt_hard' and opt.caption_model != 'hcatt_hard_nregion'and opt.caption_model != 'basicxt_hard_nregion' and opt.caption_model != 'basicxt_hard':
                loss = crit(output, labels[:,1:], masks[:,1:])
            else:
                if baseline is None:
                    baseline = torch.zeros((output.size()[0], output.size()[1]))/output.size()[1]
                    baseline = baseline.cuda()
                    # baseline = torch.log(baseline)
                # print('pre:', baseline.mean().item())
                loss, baseline = crit(output, labels[:,1:], masks[:,1:], baseline, dp_model.weights_p, dp_model.weights)
                # print('after:', baseline.mean().item())
        else:
            # use rl
            if opt.weight_deterministic_flg > 0:
                weight_index = np.array(data['weight_index'])
            else:
                weight_index = None

            if dp_model.training:
                sample_max_flg = 0
            else:
                sample_max_flg = 1

            # get predicted captions and logprops, similarity
            gen_result, sample_logprobs, word_exist_seq = dp_model(fc_feats, att_feats, att_masks,internal,
                                                   opt={'sample_max':sample_max_flg}, sim_model = sim_model, New_Critic=new_internal,
                                                   bbox=bbox, sub_att=sub_att, label_region = data['label_region'], weight_index=weight_index,mode='sample')
            train_similarity = dp_model.similarity

            # ---------- learning discriminator ----------------
            if Discriminator is not None and opt.dis_adv_flg == 1 and internal_rl_flg == False:
                correct = 0
                Discriminator.train()
                fake_data = gen_result.data.cpu()
                hokan = torch.zeros((len(fake_data), 1)).type(torch.LongTensor)
                fake_data = torch.cat((hokan, fake_data, hokan), 1).cuda()
                fake_data = fake_data[:, 1:]
                label = torch.ones((fake_data.size(0))).cuda()
                # pdb.set_trace()
                Discriminator, dis_optimizer, correct, neg_loss = \
                    dis_utils.learning_func(Discriminator, dis_optimizer, fake_data, label, correct, 0, opt.cut_length, opt.random_disc, opt.all_switch_end_dis, opt.all_switch_dis,
                                            loss_func=dis_loss_func, weight_index=weight_index, model_gate=model.gate.data.cpu().numpy())

                dis_data = dis_loader.get_batch('train', batch_size=fake_data.size(0))
                real_data = torch.from_numpy(dis_data['labels']).cuda()
                real_data = real_data[:, 1:]
                Discriminator, dis_optimizer, correct, pos_loss = \
                    dis_utils.learning_func(Discriminator, dis_optimizer, real_data, label, correct, 1, opt.cut_length, 0, 0, 0,
                                            loss_func=dis_loss_func, weight_index=weight_index)

                loss_mean = (pos_loss + neg_loss) / 2
                dis_accuracy = correct/(fake_data.size(0) * 2)
                print('Discriminator loss: {}, accuracy: {}'.format(loss_mean, dis_accuracy))
                Discriminator.eval()
            else:
                loss_mean = -1.0
                dis_accuracy = -1.0
            # --------------------------------------------------


            # ---------- calculate att loss -----------
            if opt.att_reward_flg == 1 and model.training:
            # if opt.att_reward_flg == 1 :
                att_loss = att_crit(model, gen_result.data.cpu().numpy())
                att_loss_num = att_loss.data.cpu().numpy()
            else:
                att_loss = 0.0
                att_loss_num = 0.0
            # ------------------------------------------

            # --- get states and actions xt and weights, ccs, seqs ---
            if opt.actor_critic_flg==1 and model.training:
                xts = model.all_xts
                weights_p = model.weights_p
                ccs = internal.output_action
            if opt.actor_critic_flg == 2 and model.training:
                all_logprops = model.all_logprops
                weight_state = model.state_weights
                # xts = model.all_xts
                gen_result_repeat = gen_result.repeat(1, opt.seq_length).view(all_logprops.size(0), opt.seq_length, opt.seq_length)
                # xts = seq_mask * gen_result_repeat
                xts = gen_result_repeat
                weights_p = model.weights_p
                # pdb.set_trace()
                if internal is not None:
                    ccs = internal.output_action
                else:
                    ccs = torch.zeros((len(xts), weights_p.size(1))).cuda()
            if opt.caption_model == 'hcatt_hard' and opt.ppo==1:
                xts = model.all_xts
                weights_p = model.weights_p
                weights = model.weights
            # ----------------------------------------------------------

            # ---------------- Calculate reward (CIDEr, Discriminator, Similarity...)---------------------
            if opt.actor_critic_flg == 2 and model.training:
                reward, pure_reward = get_self_critical_and_similarity_reward_for_actor_critic(dp_model,
                                                                                                   fc_feats,
                                                                                                   att_feats,
                                                                                                   att_masks, data,
                                                                                                   gen_result, opt,
                                                                                                   train_similarity,
                                                                                                   internal=internal,
                                                                                                   sim_model=sim_model,
                                                                                               label_region=data['label_region'],
                                                                                               D=Discriminator)
            else:
                reward, pure_reward, actor_critic_reward, target_update_flg = get_self_critical_and_similarity_reward(dp_model, fc_feats, att_feats,
                                                                          att_masks, data, gen_result, opt,
                                                                          train_similarity,
                                                                          internal=internal,
                                                                          sim_model=sim_model,
                                                                        label_region=data['label_region'],
                                                                          bbox=bbox,
                                                                        D=Discriminator,
                                                                        weight_index=weight_index,
                                                                        fixed_region=fixed_region,
                                                                        target_model=target_model)
                if target_update_flg and target_model is not None:
                    print('----- target model updated ! -----')
                    target_model.load_state_dict(model.state_dict())

                # print(train_similarity.mean(), model.similarity.mean())
            #----------------------------------------------------------


            #-------------------------------- calculate captioning model loss -----------------------------------------
            #------------ Calculate actor critic loss ----------------
            if opt.actor_critic_flg == 1 and model.training:
                # get q_value
                q_value = Q_net(fc_feats, att_feats, xts, weights_p, gen_result)
                # get target_sample
                with torch.no_grad():
                    gen_result_sample, __ = target_model(fc_feats, att_feats, att_masks,
                                                           seqs=gen_result, ccs=ccs, mode='sample')
                    target_q_value = target_Q_net(fc_feats, att_feats, target_model.all_xts, target_model.weights_p, gen_result)
                # calculate actor critic loss
                actor_critic_loss = Q_net.loss_func(actor_critic_reward, q_value, target_q_value)
                add_summary_value(tb_summary_writer, 'actor_critic_loss', actor_critic_loss.item(), iteration, opt.tag)
                Q_net_optimizer.zero_grad()
            elif opt.actor_critic_flg == 2 and model.training:
                # get q_value
                q_value = Q_net(fc_feats, att_feats, xts, weight_state.detach(), weights_p, all_logprops[:,:-1,:], gen_result)
                # get target_sample
                with torch.no_grad():
                    gen_result_sample, __ = target_model(fc_feats, att_feats, att_masks,
                                                         seqs=gen_result, ccs=ccs, mode='sample', state_weights=weight_state)
                    # pdb.set_trace()
                    target_q_value = target_Q_net(fc_feats, att_feats, xts, target_model.state_weights,
                                                  target_model.weights_p, target_model.all_logprops[:,:-1,:], gen_result)
                # calculate actor critic loss
                if reward is None:
                    pdb.set_trace()
                actor_critic_loss = Q_net.loss_func(reward, q_value, target_q_value, gen_result)
                print('actor_critic_loss', actor_critic_loss.item())
                add_summary_value(tb_summary_writer, 'actor_critic_loss', actor_critic_loss.item(), iteration,
                                  opt.tag)
                Q_net_optimizer.zero_grad()
            else:
                actor_critic_loss = 0

            model.att_score = att_loss_num

            # update ppo old policy
            if new_internal is not None and internal.iteration % 1 == 0:
                internal.load_state_dict(new_internal.state_dict())
            if opt.caption_model == 'hcatt_hard' and opt.ppo == 1:
                model.load_state_dict(new_model.state_dict())

            if not internal_rl_flg or opt.multi_learn_flg == 1:
                # if opt.ppo == 1 and opt.caption_model == 'hcatt_hard':
                # -------------- calculaete self critical loss ---------------
                if False:
                    # get coeffitient and calculate
                    new_gen_result, new_sample_logprobs = new_model(fc_feats, att_feats, att_masks,
                                                         seqs=gen_result,  mode='sample', decided_att=weights)
                    new_model.pre_weights_p = new_model.weights_p
                    new_model.pre_weights = new_model.weights
                    att_index = np.where(weights.data.cpu() > 0)
                    weights_p_ = weights_p[att_index].view(weights_p.size(0), weights_p.size(1))  # (batch, seq_length)
                    reward_coefficient = 1 / (torch.exp(sample_logprobs) * weights_p_).data.cpu()
                    # train caption network get reward and calculate loss
                    reward_loss, baseline = utils.calculate_loss(opt, rl_crit, rl_crit_hard, rl_crit_hard_base,
                                                                 new_sample_logprobs, gen_result, reward,
                                                                 baseline, new_model, reward_coefficient=reward_coefficient)
                elif (not internal_rl_flg or opt.multi_learn_flg == 1) and opt.actor_critic_flg == 0:
                    # train caption network get reward and calculate loss
                    if opt.weight_deterministic_flg == 7:
                        reward_loss, baseline = utils.calculate_loss(opt, rl_crit, rl_crit_hard, rl_crit_hard_base,
                                                                     sample_logprobs, word_exist_seq, reward,
                                                                     baseline, model)
                    else:
                        reward_loss, baseline = utils.calculate_loss(opt, rl_crit, rl_crit_hard, rl_crit_hard_base,
                                                                     sample_logprobs, gen_result, reward,
                                                                     baseline, model)
                else:
                    reward_loss = 0

                # -------------- calculaete self critical loss ---------------
                if (opt.caption_model == 'hcatt_simple' or  opt.caption_model == 'hcatt_simple_switch') and opt.xe_weight > 0.0:
                    output = dp_model(fc_feats, att_feats, labels, att_masks, internal, weight_index=weight_index)
                    xe_loss = crit(output, labels[:, 1:], masks[:, 1:])
                    print('r_loss: {}, xe_loss: {}'.format(reward_loss.item(), xe_loss.item()))
                    add_summary_value(tb_summary_writer, 'xe_loss', xe_loss.item(), iteration, opt.tag)
                    add_summary_value(tb_summary_writer, 'r_loss', reward_loss.item(), iteration, opt.tag)
                else:
                    xe_loss = 0.0

                loss = opt.rloss_weight * reward_loss + opt.att_lambda * att_loss + actor_critic_loss + opt.xe_weight * xe_loss
            # --------------------------------------------------------------------------------------------------------


        # ------------------------- calculate internal critic loss and update ---------------------------
        if internal_optimizer is not None and internal_rl_flg == True and sc_flag:

            internal_optimizer.zero_grad()
            if opt.region_bleu_flg == 1:
                sim_model_optimizer.zero_grad()
            if opt.only_critic_train == 0:
                internal_loss = rl_crit(internal.pre_output, gen_result.data, torch.from_numpy(reward).float().cuda(),
                                    reward_coefficient=internal.pre_reward_coefficient)
            else:
                internal_loss = rl_crit_conly(internal.pre_output, gen_result.data, torch.from_numpy(reward).float().cuda(),
                                        reward_coefficient=internal.pre_reward_coefficient, c_count=critic_train_count)
            q_value_prop = torch.exp(internal.pre_output)
            entropy = torch.mean(-1 * q_value_prop * torch.log2(q_value_prop + 1e-8) + -1 * (1 - q_value_prop) * torch.log2(
                    1 - q_value_prop + 1e-8))

            internal_loss = internal_loss
            internal_loss.backward()
            internal_optimizer.step()
            if opt.region_bleu_flg == 1:
                sim_model_optimizer.step()

            # ----- record loss and reward to tensorboard -----
            # q_value_prop = torch.exp(internal.pre_output)
            # entropy = torch.mean(-1 * q_value_prop * torch.log2(q_value_prop + 1e-8) + -1 * (1 - q_value_prop) * torch.log2(1 - q_value_prop + 1e-8))
            if opt.only_critic_train == 1:
                if internal is not None and sc_flag:
                    num_internal_switching = internal.same_action_flg.mean().item()
                else:
                    num_internal_switching = 0
                total_critic_reward += np.mean(pure_reward)
                total_critic_reward = utils.record_tb_about_critic(model, internal_loss.cpu().data, critic_train_count, opt.tag,
                                                                   tb_summary_writer, reward,
                                                                   pure_reward, entropy,
                                                                   opt.sim_sum_flg,num_internal_switching,
                                                                   total_critic_reward=total_critic_reward)
            else:
                if internal is not None and sc_flag:
                    num_internal_switching = internal.same_action_flg.mean().item()
                else:
                    num_internal_switching = 0
                total_critic_reward = utils.record_tb_about_critic(model, internal_loss.cpu().data, iteration, opt.tag,
                                         tb_summary_writer, reward, pure_reward, entropy, opt.sim_sum_flg, num_internal_switching)
            # -------------------------------------------------

            critic_train_count += 1

            internal.reset()
            internal.iteration+=1

            print('iter {} (epoch {}), internal_loss: {}, avg_reward: {}, entropy: {}'.format(iteration, epoch,internal_loss, reward.mean(), entropy))
        # --------------------------------------------------------------------------------------------------------
        else:
            #------------------------- updating captioning model ----------------------------
            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            if opt.actor_critic_flg > 0 and model.training:
                utils.clip_gradient(Q_net_optimizer, opt.grad_clip)
                Q_net_optimizer.step()
                utils.soft_update(target_model, model, 0.001)
                utils.soft_update(target_Q_net, Q_net, 0.001)
                # if iteration % 1000 == 0:
                #     utils.hard_update(target_model, model)
                #     utils.hard_update(target_Q_net, Q_net)
                # else:
                #     utils.soft_update(target_model, model, 0.001)
                #     utils.soft_update(target_Q_net, Q_net, 0.001)

            train_loss = loss.item()
            torch.cuda.synchronize()
            del loss
            end = time.time()
            if internal is not None and sc_flag:
                num_internal_switching = internal.same_action_flg.mean().item()
            else:
                num_internal_switching = 0
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                try:
                    print("iter {} (epoch {}), avg_reward = {:.3f}, att_loss = {:.3f}. time/batch = {:.3f}" \
                        .format(iteration, epoch, np.mean(reward[:,0]), model.att_score.item(), end - start))
                    utils.record_tb_about_model(model, pure_reward, tb_summary_writer, iteration, opt.tag,
                                                opt.sim_sum_flg, loss_mean, dis_accuracy, num_internal_switching)
                except AttributeError:
                    print("iter {} (epoch {}), avg_reward = {:.3f}, att_loss = {:.3f}. time/batch = {:.3f}" \
                          .format(iteration, epoch, np.mean(reward[:, 0]), model.att_score, end - start))
                    utils.record_tb_about_model(model, pure_reward, tb_summary_writer, iteration, opt.tag,
                                                opt.sim_sum_flg, loss_mean, dis_accuracy, num_internal_switching)
                RL_count += 1

            # --------------------------------------------------------------------------------



        # Update the iteration and epoch
        iteration += 1

        # -------------------- change train internal critic or caption network -----------------------------
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True
            if opt.cycle is None and internal is not None and opt.multi_learn_flg != 1:
                # and entropy < 1.0
                if internal_rl_flg == True and opt.only_critic_train == 0:
                    if opt.actor_critic_flg == 1:
                        utils.hard_update(target_model, model)
                        utils.hard_update(target_Q_net, Q_net)
                    internal_rl_flg = False
                    internal.eval()
                    dp_model.train()
                    if weight_index is not None and loader.weight_deterministic_flg == 4:
                        loader.weight_deterministic_flg = 5

                    if opt.region_bleu_flg == 1:
                        sim_model.eval()
                    train_loss = None
                # elif internal_optimizer is not None and internal_rl_flg == False:
                # elif internal_optimizer is not None and internal_rl_flg == False and (epoch + 1) % 3 == 0 and opt.internal_model != 'sim_dammy':
                # elif internal_optimizer is not None and internal_rl_flg == False and opt.internal_model != 'sim_dammy':
                else:
                    internal_rl_flg = True
                    # internal.load_state_dict(torch.load(result_path + '/internal_best.pth'))
                    if opt.ppo == 1:
                        internal_optimizer = optim.Adam(new_internal.parameters(), opt.c_learning_rate,
                                                        weight_decay=1e-5)
                    else:
                        internal_optimizer = optim.Adam(internal.parameters(), opt.c_learning_rate, weight_decay=1e-5)
                    internal.train()
                    if opt.region_bleu_flg == 1:
                        sim_model.train()
                    dp_model.eval()
                    if weight_index is not None and loader.weight_deterministic_flg == 5:
                        loader.weight_deterministic_flg = 4
                    internal.reset()
                    internal.max_r = 0
        # --------------------------------------------------------------------------------------------------

        # ------------------- Write the training loss summary ------------------------------
        if (iteration % opt.losses_log_every == 0) and internal_rl_flg == False and train_loss is not None:
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration, opt.tag)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration, opt.tag)
            add_summary_value(tb_summary_writer, 'critic_learning_rate', opt.c_current_lr, iteration, opt.tag)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration, opt.tag)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration, opt.tag)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
        # ----------------------------------------------------------------------------------

        # ------------------------ make evaluation on validation set, and save model ------------------------------
        wdf7_eval_flg = (opt.weight_deterministic_flg != 7 or sc_flag)
        if ((iteration % opt.save_checkpoint_every == 0) or iteration == 39110 or iteration == 113280 or iteration == 151045 or iteration == 78225 or iteration == 31288 or iteration == 32850 or iteration == 46934) and train_loss is not None:
            if sc_flag and (opt.caption_model == 'hcatt_hard' or opt.caption_model == 'basicxt_hard' or opt.caption_model == 'hcatt_hard_nregion' or opt.caption_model == 'basicxt_hard_nregion'):
                if baseline is None:
                    baseline = torch.zeros((sample_logprobs.size()[0], sample_logprobs.size()[1] + 1)) / sample_logprobs.size()[1]
                    baseline = baseline.cuda()
                    # baseline = torch.log(baseline)
            # eval model
            varbose_loss = not sc_flag


            eval_kwargs = {'split': 'val',
                           'internal': internal,
                           'sim_model': sim_model,
                           'caption_model': opt.caption_model,
                           'baseline': baseline,
                           'gts': data['gts'],
                           'dataset': opt.dataset,
                           'verbose_loss': varbose_loss,
                           'weight_deterministic_flg': opt.weight_deterministic_flg
                           }
            eval_kwargs.update(vars(opt))

            # pdb.set_trace()
            if wdf7_eval_flg:
                # eval_utils.eval_writer(dp_model, iteration, loader, tb_summary_writer, eval_kwargs)
                val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)

                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration, opt.tag)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration, opt.tag)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss
            else:
                val_result_history[iteration] = {'loss': None, 'lang_stats': None, 'predictions': None}
                current_score = 0

            best_flag = False
            if True: # if true
                if internal_rl_flg == False:
                    if best_val_score is None or current_score > best_val_score:
                        best_val_score = current_score
                        best_flag = True
                    checkpoint_path = os.path.join(result_path, 'model_{}_{}.pth'.format(epoch, iteration))
                    torch.save(model.state_dict(), checkpoint_path)

                    optimizer_path = os.path.join(result_path, 'optimizer_{}_{}.pth'.format(epoch, iteration))
                    torch.save(optimizer.state_dict(), optimizer_path)
                    print("model saved to {}".format(checkpoint_path))
                    if internal is not None:
                        internal.eval()
                        checkpoint_path = os.path.join(result_path, 'internal_{}_{}.pth'.format(epoch, iteration))
                        torch.save(internal.state_dict(), checkpoint_path)
                        optimizer_path = os.path.join(result_path,
                                                      'internal_optimizer_{}_{}.pth'.format(epoch, iteration))
                        torch.save(internal_optimizer.state_dict(), optimizer_path)
                        print("internal model saved to {}".format(checkpoint_path))
                        checkpoint_path = os.path.join(result_path, 'sim_model_{}_{}.pth'.format(epoch, iteration))
                        torch.save(sim_model.state_dict(), checkpoint_path)
                        print("sim_model saved to {}".format(checkpoint_path))

                else:
                    checkpoint_path = os.path.join(result_path, 'model_{}_{}.pth'.format(epoch, iteration))
                    torch.save(model.state_dict(), checkpoint_path)
                    optimizer_path = os.path.join(result_path, 'optimizer_{}_{}.pth'.format(epoch, iteration))
                    torch.save(optimizer.state_dict(), optimizer_path)
                    print("model saved to {}".format(checkpoint_path))
                    if best_val_score is None or current_score > best_val_score:
                        best_val_score = current_score
                        best_flag = True
                    checkpoint_path = os.path.join(result_path, 'internal_{}_{}.pth'.format(epoch, iteration))
                    torch.save(internal.state_dict(), checkpoint_path)

                    optimizer_path = os.path.join(result_path, 'internal_optimizer_{}_{}.pth'.format(epoch, iteration))
                    torch.save(internal_optimizer.state_dict(), optimizer_path)
                    print("internal model saved to {}".format(checkpoint_path))
                    checkpoint_path = os.path.join(result_path, 'sim_model_{}_{}.pth'.format(epoch, iteration))
                    torch.save(sim_model.state_dict(), checkpoint_path)
                    print("sim_model saved to {}".format(checkpoint_path))
                    dp_model.eval()

                if Discriminator is not None:
                    discriminator_path = os.path.join(result_path, 'discriminator_{}_{}.pth'.format(epoch, iteration))
                    torch.save(Discriminator.state_dict(), discriminator_path)
                    dis_optimizer_path = os.path.join(result_path, 'dis_optimizer_{}_{}.pth'.format(epoch, iteration))
                    torch.save(dis_optimizer.state_dict(), dis_optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()
                infos['internal_rl_flg'] = internal_rl_flg

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                with open(os.path.join(result_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(result_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)
                if best_flag:
                    checkpoint_path = os.path.join(result_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))

                    with open(os.path.join(result_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)
                # pdb.set_trace()
        # ---------------------------------------------------------------------------------------------------------

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 14
0
def train(opt):
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model




    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    #TODO: change this to a flag
    crit = utils.LanguageModelCriterion_binary()
    rl_crit = utils.RewardCriterion_binary()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    first_order = 0
    second_order = 0
    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        if data['bounds']['it_pos_now'] > 10000:
            loader.reset_iterator('train')
            continue
        dp_model.train()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:], dp_model.depth,
                        dp_model.vocab2code, dp_model.phi_list, dp_model.cluster_size)
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth)
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth)
            elif opt.rl_type == 'arm':
                loss = dp_model.get_arm_loss_binary_fast(fc_feats, att_feats, att_masks, opt, data, loader)
                #print(loss)
                reward = np.zeros([2,2])
            elif opt.rl_type == 'rf4':
                loss,_,_,_ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = np.zeros([2,2])
            elif opt.rl_type =='mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data,
                                                                         opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'embeddings.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.999 * first_order + 0.001 * gradient
        second_order = 0.999 * second_order + 0.001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic

        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance, iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward)
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            time_history[iteration] = end - start


        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_binary.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            if lang_stats is not None:
                for k,v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth')
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['time'] = time_history
                # histories['variance'] = 0
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 15
0
def train(opt):
    assert opt.annfile is not None and len(opt.annfile) > 0

    print('Checkpoint path is ' + opt.checkpoint_path)
    print('This program is using GPU ' +
          str(os.environ['CUDA_VISIBLE_DEVICES']))
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        if opt.load_best:
            info_path = os.path.join(opt.start_from,
                                     'infos_' + opt.id + '-best.pkl')
        else:
            info_path = os.path.join(opt.start_from,
                                     'infos_' + opt.id + '.pkl')
        with open(info_path) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    if opt.learning_rate_decay_start is None:
        opt.learning_rate_decay_start = infos.get(
            'opt', None).learning_rate_decay_start
    # if opt.load_best:
    #     opt.self_critical_after = epoch
    elif opt.learning_rate_decay_start == -1 and opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
        opt.learning_rate_decay_start = epoch

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    # loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_val_score_ave_model = infos.get('best_val_score_ave_model', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion(opt.XE_eps)
    rl_crit = utils.RewardCriterion()

    # build_optimizer
    optimizer = build_optimizer(model, opt)

    # Load the optimizer
    if opt.load_opti and vars(opt).get(
            'start_from',
            None) is not None and opt.load_best == 0 and os.path.isfile(
                os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # initialize the running average of parameters
    avg_param = deepcopy(list(p.data for p in model.parameters()))

    # make evaluation using original model
    best_val_score, histories, infos = eva_original_model(
        best_val_score, crit, epoch, histories, infos, iteration, loader,
        loss_history, lr_history, model, opt, optimizer, ss_prob_history,
        tb_summary_writer, val_result_history)

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                if opt.lr_decay == 'exp':
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                elif opt.lr_decay == 'cosine':
                    lr_epoch = min((epoch - opt.learning_rate_decay_start),
                                   opt.lr_max_epoch)
                    cosine_decay = 0.5 * (
                        1 + math.cos(math.pi * lr_epoch / opt.lr_max_epoch))
                    decay_factor = (1 - opt.lr_cosine_decay_base
                                    ) * cosine_decay + opt.lr_cosine_decay_base
                    opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate

            lr = [opt.current_lr]
            if opt.att_normalize_method is not None and '6' in opt.att_normalize_method:
                lr = [opt.current_lr, opt.lr_ratio * opt.current_lr]

            utils.set_lr(optimizer, lr)
            print('learning rate is: ' + str(lr))

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Update the iteration
        iteration += 1

        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            output = dp_model(fc_feats, att_feats, labels, att_masks)
            # calculate loss
            loss = crit(output[0], labels[:, 1:], masks[:, 1:])

            # add some middle variable histogram
            if iteration % (4 * opt.losses_log_every) == 0:
                outputs = [
                    _.data.cpu().numpy() if _ is not None else None
                    for _ in output
                ]
                variables_histogram(data, iteration, outputs,
                                    tb_summary_writer, opt)

        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        # grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_max_norm)
        # add_summary_value(tb_summary_writer, 'grad_L2_norm', grad_norm, iteration)

        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()

        # compute the running average of parameters
        for p, avg_p in zip(model.parameters(), avg_param):
            avg_p.mul_(opt.beta).add_((1.0 - opt.beta), p.data)

        if iteration % 10 == 0:
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        if opt.tensorboard_weights_grads and (iteration %
                                              (8 * opt.losses_log_every) == 0):
            # add weights histogram to tensorboard summary
            for name, param in model.named_parameters():
                if (opt.tensorboard_parameters_name is None or sum([
                        p_name in name
                        for p_name in opt.tensorboard_parameters_name
                ]) > 0) and param.grad is not None:
                    tb_summary_writer.add_histogram(
                        'Weights_' + name.replace('.', '/'), param, iteration)
                    tb_summary_writer.add_histogram(
                        'Grads_' + name.replace('.', '/'), param.grad,
                        iteration)

        if opt.tensorboard_buffers and (iteration %
                                        (opt.losses_log_every) == 0):
            for name, buffer in model.named_buffers():
                if (opt.tensorboard_buffers_name is None or sum([
                        p_name in name
                        for p_name in opt.tensorboard_buffers_name
                ]) > 0) and buffer is not None:
                    add_summary_value(tb_summary_writer,
                                      name.replace('.',
                                                   '/'), buffer, iteration)

        if opt.distance_sensitive_coefficient and iteration % (
                4 * opt.losses_log_every) == 0:
            print('The coefficient in intra_att_att_lstm is as follows:')
            print(
                model.core.intra_att_att_lstm.coefficient.data.cpu().tolist())
            print('The coefficient in intra_att_lang_lstm is as follows:')
            print(
                model.core.intra_att_lang_lstm.coefficient.data.cpu().tolist())
        if opt.distance_sensitive_bias and iteration % (
                4 * opt.losses_log_every) == 0:
            print('The bias in intra_att_att_lstm is as follows:')
            print(model.core.intra_att_att_lstm.bias.data.cpu().tolist())
            print('The bias in intra_att_lang_lstm is as follows:')
            print(model.core.intra_att_lang_lstm.bias.data.cpu().tolist())

        # make evaluation using original model
        if (iteration % opt.save_checkpoint_every == 0):
            best_val_score, histories, infos = eva_original_model(
                best_val_score, crit, epoch, histories, infos, iteration,
                loader, loss_history, lr_history, model, opt, optimizer,
                ss_prob_history, tb_summary_writer, val_result_history)

        # make evaluation with the averaged parameters model
        if iteration > opt.ave_threshold and (iteration %
                                              opt.save_checkpoint_every == 0):
            best_val_score_ave_model, infos = eva_ave_model(
                avg_param, best_val_score_ave_model, crit, infos, iteration,
                loader, model, opt, tb_summary_writer)

        # # Stop if reaching max epochs
        # if epoch >= opt.max_epochs and opt.max_epochs != -1:
        #     break

        if iteration >= opt.max_iter:
            break
Ejemplo n.º 16
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    epoch_done = True
    # Assure in training mode
    dp_model.train()

    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    total_loss = 0
    times = 0
    while True:
        if epoch_done:
            if not opt.noamopt and not opt.reduce_on_plateau:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            epoch_done = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        times += 1

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        total_loss = total_loss + train_loss
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            # epoch += 1
            epoch_done = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            if opt.noamopt:
                opt.current_lr = optimizer.rate()
            elif opt.reduce_on_plateau:
                opt.current_lr = optimizer.current_lr
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration % opt.save_checkpoint_every == 0):
        if data['bounds']['wrapped']:
            epoch += 1
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': False
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            if opt.reduce_on_plateau:
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats
                f = open('train_log_%s.txt' % opt.id, 'a')
                f.write(
                    'Epoch {}: | Date: {} | TrainLoss: {} | ValLoss: {} | Score: {}'
                    .format(epoch, str(datetime.now()),
                            str(total_loss / times), str(val_loss),
                            str(current_score)))
                f.write('\n')
                f.close()
                print('-------------------wrote to log file')
                total_loss = 0
                times = 0
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                # print(str(infos['best_val_score']))
                print("model saved to {}".format(checkpoint_path))
                if opt.save_history_ckpt:
                    checkpoint_path = os.path.join(
                        opt.checkpoint_path, 'model-%d.pth' % (iteration))
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    utils.pickle_dump(infos, f)
                if opt.save_history_ckpt:
                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'infos_' + opt.id + '-%d.pkl' % (iteration)),
                            'wb') as f:
                        cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    utils.pickle_dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        utils.pickle_dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 17
0
def train(opt):
    acc_steps = getattr(opt, 'acc_steps', 1)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.ix_to_word = loader.ix_to_word

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    try:
        while True:
            sys.stdout.flush()
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr)
                    print('Learning Rate: ', opt.current_lr)
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                epoch_done = False

            data = loader.get_batch('train')
            if (iteration % acc_steps == 0):
                optimizer.zero_grad()

            torch.cuda.synchronize()
            start = time.time()
            tmp = [data['fc_feats'], data['att_feats'], data['c3d_feats'], data['labels'], data['masks'], data['att_masks'], data['c3d_masks']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, c3d_feats, labels, masks, att_masks, c3d_masks = tmp

            model_out = dp_lw_model(fc_feats, att_feats, c3d_feats, labels, masks, att_masks, c3d_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()
            loss_sp = loss / acc_steps

            loss_sp.backward()
            if ((iteration + 1) % acc_steps == 0):
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
            torch.cuda.synchronize()
            train_loss = loss.item()
            end = time.time()
            if iteration % 1 == 0:
                if not sc_flag:
                    print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}".format(iteration, epoch, train_loss, end - start))
                else:
                    print("iter {} (epoch {}), reward1 = {:.3f}, reward2 = {:.3f}, reward3 = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}".format(iteration, epoch, model_out['reward_layer1'].mean(), model_out['reward_layer2'].mean(), model_out['reward_layer3'].mean(), train_loss, end - start))

            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'reward1', model_out['reward_layer1'].mean(), iteration)
                    add_summary_value(tb_summary_writer, 'reward2', model_out['reward_layer2'].mean(), iteration)
                    add_summary_value(tb_summary_writer, 'reward3', model_out['reward_layer3'].mean(), iteration)

                loss_history[iteration] = train_loss
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': opt.val_split, 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, lw_model.crit, loader, eval_kwargs)
                print('Summary Epoch {} Iteration {}: CIDEr: {} BLEU-4: {}'.format(epoch, iteration, lang_stats['CIDEr'], lang_stats['Bleu_4']))

                if opt.reduce_on_plateau:
                    if opt.reward_metric == 'cider':
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    elif opt.reward_metric == 'bleu':
                        optimizer.scheduler_step(-lang_stats['Bleu_4'])
                    elif opt.reward_metric == 'meteor':
                        optimizer.scheduler_step(-lang_stats['METEOR'])
                    elif opt.reward_metric == 'rouge':
                        optimizer.scheduler_step(-lang_stats['ROUGE_L'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    if opt.reward_metric == 'cider':
                        current_score = lang_stats['CIDEr']
                    elif opt.reward_metric == 'bleu':
                        current_score = lang_stats['Bleu_4']
                    elif opt.reward_metric == 'meteor':
                        current_score = lang_stats['METEOR']
                    elif opt.reward_metric == 'rouge':
                        current_score = lang_stats['ROUGE_L']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 18
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')):
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
        with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth')))
    
    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {split: {'index_list': infos['split_ix'][split], 'iter_counter': infos['iterators'][split]} for split in ['train', 'val', 'test']}
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate  ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                
                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False
                    
            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp
            
            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if struc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar('lm_loss', model_out['lm_loss'].mean().item(), iteration)
                    tb_summary_writer.add_scalar('struc_loss', model_out['struc_loss'].mean().item(), iteration)
                    tb_summary_writer.add_scalar('reward', model_out['reward'].mean().item(), iteration)

                histories['loss_history'][iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()
            
            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                                'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(opt, model, infos, optimizer, append=str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt, model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 19
0
def train(opt):
    print(opt)

    # To reproduce training results
    init_seed()
    # Image Preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
                             ])
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt, transform=transform)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '-best.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[
                    checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    if torch.cuda.is_available():
        model = models.setup(opt).cuda()
    else:
        model = models.setup(opt)
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    #fgm = FGM(model)

    cnn_model = ResnetBackbone()
    if torch.cuda.is_available():
        cnn_model = cnn_model.cuda()
    if opt.start_from is not None:
        model_dict = cnn_model.state_dict()
        predict_dict = torch.load(os.path.join(opt.start_from, 'cnn_model-best.pth'))
        model_dict = {k: predict_dict["module."+k] for k, _ in model_dict.items() if "module."+ k in predict_dict}
        cnn_model.load_state_dict(model_dict)
    cnn_model = torch.nn.DataParallel(cnn_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-best.pth')))

    def save_checkpoint(model, cnn_model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        #Transformer model
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        #CNN model
        checkpoint_path = os.path.join(opt.checkpoint_path, 'cnn_model%s.pth' % (append))
        if not os.path.exists(checkpoint_path):
            torch.save(cnn_model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    cnn_after = 3
    try:
        while True:
            if epoch_done:
                if  opt.fix_cnn or epoch < cnn_after:
                    for p in cnn_model.parameters():
                        p.requires_grad = False
                    cnn_model.eval()
                    cnn_optimizer = None
                else:
                    for p in cnn_model.parameters():
                        p.requires_grad = True
                    # Fix the first few layers:
                    for module in cnn_model._modules['module']._modules['resnet_conv'][:5]._modules.values():
                        for p in module.parameters():
                            p.requires_grad = False
                    cnn_model.train()
                    # Constructing CNN parameters for optimization, only fine-tuning higher layers
                    cnn_optimizer = torch.optim.Adam(
                        (filter(lambda p: p.requires_grad, cnn_model.parameters())),
                        lr=2e-6 if (opt.self_critical_after != -1 and epoch >= opt.self_critical_after) else 5e-5, betas=(0.8, 0.999))

                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            if iteration % opt.losses_log_every == 0:
                print('Read data:', time.time() - start)

            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.time()

            if torch.cuda.is_available():
                data['att_feats'] = cnn_model( data['att_feats'].cuda())
            else:
                data['att_feats'] = cnn_model( data['att_feats'] )
            data['att_feats'] = repeat_feat(data['att_feats'])
            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
            if torch.cuda.is_available():
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            optimizer.zero_grad()
            if cnn_optimizer is not None:
                cnn_optimizer.zero_grad()

            # if epoch >= cnn_after:
            #     att_feats.register_hook(save_grad("att_feats"))
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()

            loss.backward()

            #loss.backward(retain_graph=True)

            # adversarial training
            #fgm.attack(emb_name='model.tgt_embed.0.lut.weight')
            #adv_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'],
            #                      torch.arange(0, len(data['gts'])), sc_flag)

            #adv_loss = adv_out['loss'].mean()
            #adv_loss.backward()
            #fgm.restore(emb_name="model.tgt_embed.0.lut.weight")


            # utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            if cnn_optimizer is not None:
                cnn_optimizer.step()
            train_loss = loss.item()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            end = time.time()
            if not sc_flag and iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
            elif iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                               'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                eval_kwargs["cnn_model"] = cnn_model
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, cnn_model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, cnn_model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, cnn_model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, cnn_model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 20
0
def train(opt):
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model

    target_actor = models.setup(opt).cuda()

    ####################### Critic pretrain #####################################################################
    ##### Critic with state as input
    # if opt.critic_model == 'state_critic':
    #     critic_model = CriticModel(opt)
    # else:
    critic_model = AttCriticModel(opt)
    target_critic = AttCriticModel(opt)
    if vars(opt).get('start_from_critic', None) is not None and True:
        # check if all necessary files exist
        assert os.path.isdir(opt.start_from_critic
                             ), " %s must be a a path" % opt.start_from_critic
        print(
            os.path.join(opt.start_from_critic,
                         opt.critic_model + '_model.pth'))
        critic_model.load_state_dict(
            torch.load(
                os.path.join(opt.start_from_critic,
                             opt.critic_model + '_model.pth')))
        target_critic.load_state_dict(
            torch.load(
                os.path.join(opt.start_from_critic,
                             opt.critic_model + '_model.pth')))
    critic_model = critic_model.cuda()
    target_critic = target_critic.cuda()
    critic_optimizer = utils.build_optimizer(critic_model.parameters(), opt)
    dp_model.eval()
    critic_iter = 0
    init_scorer(opt.cached_tokens)
    critic_model.train()
    error_sum = 0
    loss_vector_sum = 0
    while opt.pretrain_critic == 1:
        if critic_iter > opt.pretrain_critic_steps:
            print('****************Finished critic training!')
            break
        data = loader.get_batch('train')
        torch.cuda.synchronize()
        start = time.time()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        critic_model.train()
        critic_optimizer.zero_grad()
        assert opt.critic_model == 'att_critic_vocab'
        # crit_loss, reward, std = critic_loss_fun(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data)
        crit_loss, reward, std = target_critic_loss_fun_mask(
            fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data,
            target_critic, target_actor)
        crit_loss.backward()
        critic_optimizer.step()
        #TODO update target.
        for cp, tp in zip(critic_model.parameters(),
                          target_critic.parameters()):
            tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
        crit_train_loss = crit_loss.item()
        torch.cuda.synchronize
        end = time.time()
        error_sum += crit_train_loss**0.5 - std
        if (critic_iter % opt.losses_log_every == 0):
            print("iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f}, time/batch = {:.3f}" \
                .format(critic_iter, crit_train_loss**0.5, crit_train_loss**0.5-std, error_sum, end - start))
            print(opt.checkpoint_path)
            opt.importance_sampling = 1
            critic_model.eval()
            _, _, _, _ = get_rf_loss(dp_model,
                                     fc_feats,
                                     att_feats,
                                     att_masks,
                                     data,
                                     opt,
                                     loader,
                                     critic_model,
                                     test_critic=True)

        critic_iter += 1

        # make evaluation on validation set, and save model
        if (critic_iter % opt.save_checkpoint_every == 0):
            if not os.path.isdir(opt.checkpoint_path):
                os.mkdir(opt.checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           opt.critic_model + '_model.pth')
            torch.save(critic_model.state_dict(), checkpoint_path)

    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    first_order = 0
    second_order = 0
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        if data['bounds']['it_pos_now'] > 5000:
            loader.reset_iterator('train')
            continue
        dp_model.train()
        critic_model.eval()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())
            elif opt.rl_type == 'arsm':
                loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks,
                                    data, opt, loader)
                #print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'rf4':
                loss, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats,
                                            att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'importance_sampling':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1],
                                   1)
                std = np.std(reward)
            elif opt.rl_type == 'importance_sampling_critic':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(
                    target_actor, fc_feats, att_feats, att_masks, data, opt,
                    loader, target_critic)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1],
                                   1)
                std = np.std(reward)
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks,
                                   data, opt, loader)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(
                    sample_logprobs, gen_result.data,
                    torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward_cuda * arm_baseline)
            elif opt.rl_type == 'arsm_baseline_critic':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader, critic_model)
                reward, std = get_reward(data, gen_result, opt, critic=True)
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(
                    sample_logprobs, gen_result.data,
                    torch.from_numpy(reward).float().cuda() - arm_baseline)
            elif opt.rl_type == 'arsm_critic':
                #print(opt.critic_model)
                tic = time.time()
                loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks,
                                    data, opt, loader, critic_model)
                #print('arm_loss time', str(time.time()-tic))
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'critic_vocab_sum':
                assert opt.critic_model == 'att_critic_vocab'
                tic = time.time()
                gen_result, sample_logprobs_total = dp_model(
                    fc_feats,
                    att_feats,
                    att_masks,
                    opt={'sample_max': 0},
                    total_probs=True,
                    mode='sample')  #batch, seq, vocab
                #print('generation time', time.time()-tic)
                gen_result_pad = torch.cat([
                    gen_result.new_zeros(
                        gen_result.size(0), 1, dtype=torch.long), gen_result
                ], 1)
                tic = time.time()
                critic_value = critic_model(gen_result_pad, fc_feats,
                                            att_feats, True, opt,
                                            att_masks)  #batch, seq, vocab
                #print('critic time', time.time() - tic)
                probs = torch.sum(
                    F.softmax(sample_logprobs_total, 2) *
                    critic_value.detach(), 2)
                mask = (gen_result > 0).float()
                mask = torch.cat(
                    [mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)
                loss = -torch.sum(probs * mask) / torch.sum(mask)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'reinforce_critic':
                #TODO change the critic to attention
                if opt.critic_model == 'state_critic':
                    critic_value, gen_result, sample_logprobs = critic_model(
                        dp_model, fc_feats, att_feats, opt, att_masks)
                    reward, std = get_reward(data,
                                             gen_result,
                                             opt,
                                             critic=True)
                    loss = rl_crit(
                        sample_logprobs, gen_result.data,
                        torch.from_numpy(reward).float().cuda() -
                        critic_value[:, :-1].data)
                elif opt.critic_model == 'att_critic':
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    gen_result_pad = torch.cat([
                        gen_result.new_zeros(gen_result.size(0),
                                             1,
                                             dtype=torch.long), gen_result
                    ], 1)
                    critic_value = critic_model(gen_result_pad, fc_feats,
                                                att_feats, True, opt,
                                                att_masks).squeeze(2)

                    reward, std = get_reward(data,
                                             gen_result,
                                             opt,
                                             critic=True)
                    loss = rl_crit(
                        sample_logprobs, gen_result.data,
                        torch.from_numpy(reward).float().cuda() -
                        critic_value.data)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(
                dp_model(fc_feats, att_feats, labels, att_masks),
                labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'best_embed.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        # with open(os.path.join(opt.checkpoint_path, 'best_logit.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.logit.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.9999 * first_order + 0.0001 * gradient
        second_order = 0.9999 * second_order + 0.0001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order -
                                        first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic
        if 'critic' in opt.rl_type:
            dp_model.eval()
            critic_model.train()
            utils.set_lr(critic_optimizer, opt.critic_learning_rate)
            critic_optimizer.zero_grad()
            assert opt.critic_model == 'att_critic_vocab'
            crit_loss, reward, std = target_critic_loss_fun_mask(
                fc_feats,
                att_feats,
                att_masks,
                dp_model,
                critic_model,
                opt,
                data,
                target_critic,
                target_actor,
                gen_result=gen_result,
                sample_logprobs_total=sample_logprobs_total,
                reward=reward)
            crit_loss.backward()
            critic_optimizer.step()
            for cp, tp in zip(critic_model.parameters(),
                              target_critic.parameters()):
                tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
            for cp, tp in zip(dp_model.parameters(),
                              target_actor.parameters()):
                tp.data = tp.data + opt.gamma_actor * (cp.data - tp.data)
            crit_train_loss = crit_loss.item()
            error_sum += crit_train_loss**0.5 - std
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            elif 'critic' in opt.rl_type:
                print(
                    "iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f},variance = {:g}, time/batch = {:.3f}" \
                    .format(iteration, crit_train_loss ** 0.5, crit_train_loss ** 0.5 - std, error_sum, variance, end - start))
                print(opt.checkpoint_path)
                critic_model.eval()
                _, _, _, _ = get_rf_loss(dp_model,
                                         fc_feats,
                                         att_feats,
                                         att_masks,
                                         data,
                                         opt,
                                         loader,
                                         critic_model,
                                         test_critic=True)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance,
                                  iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward)
            critic_loss_history[
                iteration] = crit_train_loss if 'critic' in opt.rl_type else 0
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            time_history[iteration] = end - start

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               opt.critic_model + '_model.pth')
                torch.save(critic_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['time'] = time_history
                # histories['variance'] = 0
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 21
0
    opt.input_json = infos['opt'].input_json
if opt.batch_size == 0:
    opt.batch_size = infos['opt'].batch_size
if len(opt.id) == 0:
    opt.id = infos['opt'].id
ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval", "block_trigrams",'att_supervise','ground_reward_weight','att_sup_crit']

for k in vars(infos['opt']).keys():
    if k not in ignore:
        if k in vars(opt):
            assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
        else:
            vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model

vocab = infos['vocab'] # ix -> word mapping
init_scorer(opt.cached_tokens)
# Setup the model
model = models.setup(opt)
model.load_state_dict(torch.load(opt.model))
model.cuda()
model.eval()
crit = utils.LanguageModelCriterion()

# Create the Data Loader instance
if len(opt.image_folder) == 0:
  loader = DataLoader(opt)
else:
  loader = DataLoaderRaw({'folder_path': opt.image_folder, 
                            'coco_json': opt.coco_json,
                            'batch_size': opt.batch_size,
                            'cnn_model': opt.cnn_model})
Ejemplo n.º 22
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        data_time = time.time() - start

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if iteration % opt.print_freq == 0:
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start, data_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), end - start, data_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path) # MODIFIED (ADDED)

            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best-i{}-score{}.pth'.format(iteration, best_val_score))
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 23
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    multi_models_list = []
    for order in range(opt.number_of_models):
        multi_models_list.append(models.setup(opt).cuda())
    for order in range(opt.number_of_models):
        multi_models_list.append(models.setup(opt).cuda())
    for order in range(opt.number_of_models, 2 * opt.number_of_models):
        for param in multi_models_list[order].parameters():
            param.detach_()
    for order in range(opt.number_of_models):
        for param, param_ema in zip(
                multi_models_list[order].parameters(),
                multi_models_list[order + opt.number_of_models].parameters()):
            param_ema.data = param.data.clone()
    # multi_models = MultiModels(multi_models_list)
    # multi_models_list.append(SenEncodeModel(opt).cuda())
    multi_models = nn.ModuleList(multi_models_list)
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        multi_models.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_models = nn.ModuleList([
        LossWrapper(multi_models[index], opt)
        for index in range(opt.number_of_models)
    ])
    kdlw_models = nn.ModuleList([
        KDLossWrapper(multi_models[index], opt)
        for index in range(opt.number_of_models)
    ])
    lw_models_ema = nn.ModuleList([
        LossWrapper(multi_models[opt.number_of_models + index], opt)
        for index in range(opt.number_of_models)
    ])
    kdlw_models_ema = nn.ModuleList([
        KDLossWrapper(multi_models[opt.number_of_models + index], opt)
        for index in range(opt.number_of_models)
    ])
    # Wrap with dataparallel
    dp_models = nn.ModuleList([
        torch.nn.DataParallel(multi_models[index])
        for index in range(opt.number_of_models)
    ])
    dp_lw_models = nn.ModuleList([
        torch.nn.DataParallel(lw_models[index])
        for index in range(opt.number_of_models)
    ])
    dp_kdlw_models = nn.ModuleList([
        torch.nn.DataParallel(kdlw_models[index])
        for index in range(opt.number_of_models)
    ])
    dp_models_ema = nn.ModuleList([
        torch.nn.DataParallel(multi_models[opt.number_of_models + index])
        for index in range(opt.number_of_models)
    ])
    dp_lw_models_ema = nn.ModuleList([
        torch.nn.DataParallel(lw_models_ema[index])
        for index in range(opt.number_of_models)
    ])
    dp_kdlw_models_ema = nn.ModuleList([
        torch.nn.DataParallel(kdlw_models_ema[index])
        for index in range(opt.number_of_models)
    ])

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'bert', 'm2transformer'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(multi_models,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(multi_models.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(multi_models.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    ##########################
    #  Build loss
    ##########################
    # triplet_loss = nn.TripletMarginLoss()

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in [
                'paired_train', 'unpaired_images_train',
                'unpaired_captions_train', 'train', 'val', 'test'
            ]
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_models.train()
    dp_kdlw_models.train()
    dp_lw_models_ema.train()
    dp_kdlw_models_ema.train()

    # Build the ensemble model
    # # Setup the model
    model_ensemble = AttEnsemble(multi_models_list[opt.number_of_models:2 *
                                                   opt.number_of_models],
                                 weights=None)
    # model_ensemble.seq_length = 20
    model_ensemble.cuda()
    # model_ensemble.eval()
    kd_model_outs_list = []

    # Start training
    try:
        while True:
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    for index in range(opt.number_of_models):
                        multi_models[index].ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                if epoch >= opt.paired_train_epoch:
                    opt.current_lambda_x = opt.hyper_parameter_lambda_x * \
                                         (epoch - (opt.paired_train_epoch - 1)) /\
                                         (opt.max_epochs - opt.paired_train_epoch)
                    opt.current_lambda_y = opt.hyper_parameter_lambda_y * \
                                           (epoch - (opt.paired_train_epoch - 1)) / \
                                           (opt.max_epochs - opt.paired_train_epoch)

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            if epoch < opt.language_pretrain_epoch:
                data = loader.get_batch('unpaired_captions_train')
            elif epoch < opt.paired_train_epoch:
                data = loader.get_batch('paired_train')
            else:
                data = loader.get_batch('paired_train')
                unpaired_data = loader.get_batch('unpaired_images_train')
                unpaired_caption = loader.get_batch('unpaired_captions_train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()
            if epoch < opt.language_pretrain_epoch:
                tmp = [
                    data['fc_feats'] * 0, data['att_feats'] * 0,
                    data['labels'], data['masks'], data['att_masks']
                ]
            elif epoch < opt.paired_train_epoch:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
            else:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
                unpaired_tmp = [
                    unpaired_data['fc_feats'], unpaired_data['att_feats'],
                    unpaired_data['labels'], unpaired_data['masks'],
                    unpaired_data['att_masks']
                ]
                unpaired_caption_tmp = [
                    unpaired_caption['fc_feats'] * 0,
                    unpaired_caption['att_feats'] * 0,
                    unpaired_caption['labels'], unpaired_caption['masks'],
                    unpaired_caption['att_masks']
                ]

            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            if epoch >= opt.paired_train_epoch:
                unpaired_tmp = [
                    _ if _ is None else _.cuda() for _ in unpaired_tmp
                ]
                unpaired_fc_feats, unpaired_att_feats, unpaired_labels, unpaired_masks, unpaired_att_masks = unpaired_tmp
                unpaired_caption_tmp = [
                    _ if _ is None else _.cuda() for _ in unpaired_caption_tmp
                ]
                unpaired_caption_fc_feats, unpaired_caption_att_feats, unpaired_caption_labels, unpaired_caption_masks, unpaired_caption_att_masks = unpaired_caption_tmp
                unpaired_caption_fc_feats = unpaired_caption_fc_feats.repeat(
                    5, 1)
                unpaired_caption_fc_feats = opt.std_pseudo_visual_feature * torch.randn_like(
                    unpaired_caption_fc_feats)
                unpaired_caption_att_feats = unpaired_caption_att_feats.repeat(
                    5, 1, 1)
                unpaired_caption_fc_feats.requires_grad = True
                unpaired_caption_att_feats.requires_grad = True
                unpaired_caption_labels = unpaired_caption_labels.reshape(
                    unpaired_caption_fc_feats.shape[0], -1)
                unpaired_caption_masks = unpaired_caption_masks.reshape(
                    unpaired_caption_fc_feats.shape[0], -1)

            optimizer.zero_grad()
            if epoch < opt.language_pretrain_epoch:
                language_loss = 0
                model_outs_list = []
                for index in range(opt.number_of_models):
                    model_out = dp_lw_models[index](
                        fc_feats, att_feats, labels, masks,
                        att_masks, data['gts'],
                        torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
                    model_outs_list.append(model_out)
                    language_loss += model_out['loss'].mean()

                loss = language_loss
            elif epoch < opt.paired_train_epoch:
                language_loss = 0
                model_outs_list = []
                for index in range(opt.number_of_models):
                    model_out = dp_lw_models[index](
                        fc_feats, att_feats, labels, masks,
                        att_masks, data['gts'],
                        torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
                    model_outs_list.append(model_out)
                    language_loss += model_out['loss'].mean()

                loss = language_loss
            else:
                language_loss = 0
                model_outs_list = []
                for index in range(opt.number_of_models):
                    model_out = dp_lw_models[index](
                        fc_feats, att_feats, labels, masks,
                        att_masks, data['gts'],
                        torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
                    model_outs_list.append(model_out)
                    language_loss += model_out['loss'].mean()
                loss = language_loss

                # else:
                # for unpaired image sentences
                # # Setup the model
                # model_ensemble = AttEnsemble(multi_models_list[:opt.number_of_models], weights=None)
                # model_ensemble.seq_length = 16
                # model_ensemble.cuda()
                # model_ensemble.eval()

                model_ensemble.eval()
                eval_kwargs = dict()
                eval_kwargs.update(vars(opt))

                with torch.no_grad():
                    seq, seq_logprobs = model_ensemble(unpaired_fc_feats,
                                                       unpaired_att_feats,
                                                       unpaired_att_masks,
                                                       opt=eval_kwargs,
                                                       mode='sample')
                    # val_loss, predictions, lang_stats = eval_utils.eval_split(model_ensemble, lw_models[0].crit, loader,
                    #                                                           eval_kwargs)
                # print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in
                #                  model_ensemble.done_beams[0]]))
                # print('++' * 10)
                # for ii in range(10):
                #     sents = utils.decode_sequence(loader.get_vocab(), seq[ii].unsqueeze(0))
                #     gt_sent = utils.decode_sequence(loader.get_vocab(), labels[ii,0].unsqueeze(0))
                #     a=1

                model_ensemble.train()

                model_ensemble_sudo_labels = labels.new_zeros(
                    (opt.batch_size, opt.beam_size,
                     eval_kwargs['max_length'] + 2))
                model_ensemble_sudo_log_prob = masks.new_zeros(
                    (opt.batch_size,
                     opt.beam_size, eval_kwargs['max_length'] + 2,
                     len(loader.get_vocab()) + 1))
                model_ensemble_sum_log_prob = masks.new_zeros(
                    (opt.batch_size, opt.beam_size))

                for batch_index in range(opt.batch_size):
                    for beam_index in range(opt.beam_size):
                        # for beam_index in range(3):
                        pred = model_ensemble.done_beams[batch_index][
                            beam_index]['seq']
                        log_prob = model_ensemble.done_beams[batch_index][
                            beam_index]['logps']
                        model_ensemble_sudo_labels[batch_index, beam_index,
                                                   1:pred.shape[0] + 1] = pred
                        model_ensemble_sudo_log_prob[batch_index, beam_index,
                                                     1:pred.shape[0] +
                                                     1] = log_prob
                        model_ensemble_sum_log_prob[batch_index][
                            beam_index] = model_ensemble.done_beams[
                                batch_index][beam_index]['p']

                # model_ensemble_prob = F.softmax(model_ensemble_sum_log_prob)

                data_ensemble_sudo_gts = list()
                for data_ensemble_sudo_gts_index in range(
                        model_ensemble_sudo_labels.shape[0]):
                    data_ensemble_sudo_gts.append(model_ensemble_sudo_labels[
                        data_ensemble_sudo_gts_index, :,
                        1:-1].data.cpu().numpy())

                # generated_sentences = list()
                # for i in range(unpaired_fc_feats.shape[0]):
                #     generated_sentences.append(
                #         [utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in
                #          model_ensemble.done_beams[i]])
                #
                # pos_tag_results = list()
                # for i in range(unpaired_fc_feats.shape[0]):
                #     generated_sentences_i = generated_sentences[i]
                #     pos_tag_results_i = []
                #     for text in generated_sentences_i:
                #         text_tokenize = nltk.word_tokenize(text)
                #         pos_tag_results_i_jbeam = []
                #         for vob, vob_type in nltk.pos_tag(text_tokenize):
                #             if vob_type == 'NN' or vob_type == 'NNS':
                #                 pos_tag_results_i_jbeam.append(vob)
                #         pos_tag_results_i.append(pos_tag_results_i_jbeam)
                #     pos_tag_results.append(pos_tag_results_i)

                # for i in range(fc_feats.shape[0]):
                #     print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in
                #                      model_ensemble.done_beams[i]]))
                #     print('--' * 10)
                # dets = data['dets']
                #
                # promising_flag = labels.new_zeros(opt.batch_size, opt.beam_size)
                # for batch_index in range(opt.batch_size):
                #     dets_batch = dets[batch_index]
                #     for beam_index in range(opt.beam_size):
                #         indicator = [0] * len(dets_batch)
                #         pos_tag_batch_beam = pos_tag_results[batch_index][beam_index]
                #         for pos_tag_val in pos_tag_batch_beam:
                #             for ii in range(len(dets_batch)):
                #                 possible_list = vob_transform_list[dets_batch[ii]]
                #                 if pos_tag_val in possible_list:
                #                     indicator[ii] = 1
                #         if sum(indicator) == len(dets_batch) or sum(indicator) >= 2:
                #             promising_flag[batch_index, beam_index] = 1
                #
                # # model_ensemble_sudo_log_prob = model_ensemble_sudo_log_prob * promising_flag.unsqueeze(-1).unsqueeze(-1)
                # model_ensemble_sudo_labels = model_ensemble_sudo_labels * promising_flag.unsqueeze(-1)

                #sudo_masks_for_model = sudo_masks_for_model.detach()
                distilling_loss = 0
                # We use the random study machinism
                who_to_study = random.randint(0, opt.number_of_models - 1)

                # for index in range(opt.number_of_models):
                #     model_out = dp_kdlw_models[index](unpaired_fc_feats, unpaired_att_feats, model_ensemble_sudo_labels,
                #                                     model_ensemble_sudo_log_prob, att_masks, data_ensemble_sudo_gts,
                #                                     torch.arange(0, len(data_ensemble_sudo_gts)), sc_flag,
                #                                     struc_flag, model_ensemble_sum_log_prob)
                #     kd_model_outs_list.append(model_out)

                model_out = dp_kdlw_models[who_to_study](
                    unpaired_fc_feats, unpaired_att_feats,
                    model_ensemble_sudo_labels, model_ensemble_sudo_log_prob,
                    att_masks, data_ensemble_sudo_gts,
                    torch.arange(0, len(data_ensemble_sudo_gts)), sc_flag,
                    struc_flag, model_ensemble_sum_log_prob)
                # kd_model_outs_list.append(model_out)
                distilling_loss += model_out['loss'].mean()
                loss += opt.number_of_models * opt.current_lambda_x * distilling_loss

                ###################################################################
                # use unlabelled captions
                # simple_sgd = utils.gradient_descent(unpaired_caption_fc_feats, stepsize=1e3)
                simple_sgd = utils.gradient_descent_adagrad(
                    unpaired_caption_fc_feats, stepsize=1)
                gts_tmp = unpaired_caption['gts']
                new_gts = []
                for ii in range(len(data['gts'])):
                    for jj in range(gts_tmp[ii].shape[0]):
                        new_gts.append(gts_tmp[ii][jj])
                unpaired_caption['gts'] = new_gts
                for itr in range(opt.inner_iteration):
                    unlabelled_caption_model_out = dp_lw_models_ema[
                        itr % opt.number_of_models](
                            unpaired_caption_fc_feats,
                            unpaired_caption_att_feats,
                            unpaired_caption_labels, unpaired_caption_masks,
                            unpaired_caption_att_masks,
                            unpaired_caption['gts'],
                            torch.arange(0, len(unpaired_caption['gts'])),
                            sc_flag, struc_flag)
                    unlabelled_caption_loss = unlabelled_caption_model_out[
                        'loss'].mean()
                    unlabelled_caption_loss.backward()
                    # print(unlabelled_caption_loss)
                    simple_sgd.update(unpaired_caption_fc_feats)
                    # a=1

                unpaired_caption_fc_feats.requires_grad = False
                unpaired_caption_att_feats.requires_grad = False
                unlabelled_caption_model_out = dp_lw_models[who_to_study](
                    unpaired_caption_fc_feats, unpaired_caption_att_feats,
                    unpaired_caption_labels, unpaired_caption_masks,
                    unpaired_caption_att_masks, unpaired_caption['gts'],
                    torch.arange(0, len(unpaired_caption['gts'])), sc_flag,
                    struc_flag)
                unlabelled_caption_loss = unlabelled_caption_model_out[
                    'loss'].mean()
                loss += opt.number_of_models * opt.current_lambda_y * unlabelled_caption_loss

            loss.backward()
            if opt.grad_clip_value != 0:
                getattr(torch.nn.utils, 'clip_grad_%s_' %
                        (opt.grad_clip_mode))(multi_models.parameters(),
                                              opt.grad_clip_value)
            optimizer.step()

            for order in range(opt.number_of_models):
                for param, param_ema in zip(
                        multi_models_list[order].parameters(),
                        multi_models_list[order +
                                          opt.number_of_models].parameters()):
                    param_ema.data = opt.alpha * param_ema.data + (
                        1 - opt.alpha) * param.data

            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            # if struc_flag:
            #     print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
            #         .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
            # elif not sc_flag:
            #     print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            #         .format(iteration, epoch, train_loss, end - start))
            # else:
            #     print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
            #         .format(iteration, epoch, model_out['reward'].mean(), end - start))
            if struc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss/opt.number_of_models, sum([model_outs_list[index]['lm_loss'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models,
                            sum([model_outs_list[index]['struc_loss'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models,
                            end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, language_loss.item()/opt.number_of_models, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, sum([model_outs_list[index]['reward'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models, end - start))

            # Update the iteration and epoch
            iteration += 1
            if epoch < opt.paired_train_epoch:
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True
            else:
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                # tb_summary_writer.add_scalar('train_loss', train_loss, iteration)
                for index in range(opt.number_of_models):
                    model_id = 'model_{}'.format(index)
                    tb_summary_writer.add_scalars('language_loss', {
                        model_id:
                        model_outs_list[index]['loss'].mean().item()
                    }, iteration)
                if epoch >= opt.paired_train_epoch:
                    # for index in range(opt.number_of_models):
                    #     model_id = 'model_{}'.format(index)
                    #     kd_model_outs_val = 0 if len(kd_model_outs_list) == 0 else kd_model_outs_list[index]['loss'].mean().item()
                    #     tb_summary_writer.add_scalars('distilling_loss',
                    #                                   {model_id: kd_model_outs_val},
                    #                                   iteration)
                    tb_summary_writer.add_scalar('distilling_loss',
                                                 distilling_loss.item(),
                                                 iteration)
                    tb_summary_writer.add_scalar(
                        'unlabelled_caption_loss',
                        unlabelled_caption_loss.item(), iteration)
                    tb_summary_writer.add_scalar('hyper_parameter_lambda_x',
                                                 opt.current_lambda_x,
                                                 iteration)
                    tb_summary_writer.add_scalar('hyper_parameter_lambda_y',
                                                 opt.current_lambda_y,
                                                 iteration)
                # tb_summary_writer.add_scalar('triplet_loss', triplet_loss_val.item(), iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             multi_models[0].ss_prob,
                                             iteration)
                if sc_flag:
                    for index in range(opt.number_of_models):
                        # tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration)
                        model_id = 'model_{}'.format(index)
                        tb_summary_writer.add_scalars(
                            'avg_reward', {
                                model_id:
                                model_outs_list[index]['reward'].mean().item()
                            }, iteration)
                elif struc_flag:
                    # tb_summary_writer.add_scalar('lm_loss', model_out['lm_loss'].mean().item(), iteration)
                    # tb_summary_writer.add_scalar('struc_loss', model_out['struc_loss'].mean().item(), iteration)
                    # tb_summary_writer.add_scalar('reward', model_out['reward'].mean().item(), iteration)
                    # tb_summary_writer.add_scalar('reward_var', model_out['reward'].var(1).mean(), iteration)
                    model_id = 'model_{}'.format(index)
                    for index in range(opt.number_of_models):
                        tb_summary_writer.add_scalars(
                            'lm_loss', {
                                model_id:
                                model_outs_list[index]
                                ['lm_loss'].mean().item()
                            }, iteration)
                        tb_summary_writer.add_scalars(
                            'struc_loss', {
                                model_id:
                                model_outs_list[index]
                                ['struc_loss'].mean().item()
                            }, iteration)
                        tb_summary_writer.add_scalars(
                            'reward', {
                                model_id:
                                model_outs_list[index]['reward'].mean().item()
                            }, iteration)
                        tb_summary_writer.add_scalars(
                            'reward_var', {
                                model_id:
                                model_outs_list[index]['reward'].var(1).mean()
                            }, iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else sum([
                        model_outs_list[index]['reward'].mean().item()
                        for index in range(opt.number_of_models)
                    ]) / opt.number_of_models
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = multi_models[
                    0].ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch and epoch >= opt.paired_train_epoch) or \
                (epoch_done and opt.save_every_epoch and epoch >= opt.paired_train_epoch):
                # load ensemble
                # Setup the model
                model = AttEnsemble(multi_models_list[opt.number_of_models:2 *
                                                      opt.number_of_models],
                                    weights=None)
                model.seq_length = opt.max_length
                model.cuda()
                model.eval()
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                # eval_kwargs['beam_size'] = 5
                # eval_kwargs['verbose_beam'] = 1
                # eval_kwargs['verbose_loss'] = 1
                # val_loss, predictions, lang_stats = eval_utils.eval_split(
                #     dp_model, lw_model.crit, loader, eval_kwargs)
                with torch.no_grad():
                    val_loss, predictions, lang_stats = eval_utils.eval_split(
                        model, lw_models[0].crit, loader, eval_kwargs)
                model.train()

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, multi_models, infos, optimizer,
                                      histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(
                        opt,
                        multi_models,
                        infos,
                        optimizer,
                        append=str(epoch)
                        if opt.save_every_epoch else str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          multi_models,
                                          infos,
                                          optimizer,
                                          append='best')

            # if epoch_done and epoch == opt.paired_train_epoch:
            #     utils.save_checkpoint(opt, multi_models, infos, optimizer, histories)
            #     if opt.save_history_ckpt:
            #         utils.save_checkpoint(opt, multi_models, infos, optimizer,
            #                               append=str(epoch) if opt.save_every_epoch else str(iteration))
            #     cmd = 'cp -r ' + 'log_' + opt.id + ' ' + 'log_' + opt.id + '_backup'
            #     os.system(cmd)

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, multi_models, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 24
0
def train(opt):
    logger = initialize_logger(os.path.join(opt.checkpoint_path, 'train.log'))
    print = logger.info

    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Print out the option variables
    print("*" * 20)
    for k, v in opt.__dict__.items():
        print("%r: %r" % (k, v))
    print("*" * 20)

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos.json'), 'r') as f:
            infos = json.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    else:
        best_val_score = None

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    start_time = time.time()
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if 0 <= opt.learning_rate_decay_start < epoch:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if 0 <= opt.scheduled_sampling_start < epoch:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer()
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        batch_data = loader.get_batch('train')
        torch.cuda.synchronize()

        tmp = [
            batch_data['fc_feats'], batch_data['att_feats'],
            batch_data['labels'], batch_data['masks'], batch_data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            outputs = dp_model(fc_feats, att_feats, labels, att_masks)
            loss = crit(outputs, labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, batch_data,
                                              gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data
        torch.cuda.synchronize()

        # Update the iteration and epoch
        iteration += 1
        if batch_data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Print train loss or avg reward
        if iteration % opt.losses_print_every == 0:
            if not sc_flag:
                print(
                    "iter {} (epoch {}), loss = {:.3f}, time = {:.3f}".format(
                        iteration, epoch, loss.item(),
                        time.time() - start_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time = {:.3f}".
                      format(iteration, epoch, np.mean(reward[:, 0]),
                             time.time() - start_time))
            start_time = time.time()

        # make evaluation on validation set, and save model
        if (opt.save_checkpoint_every > 0 and iteration % opt.save_checkpoint_every == 0)\
                or (opt.save_checkpoint_every <= 0 and update_lr_flag):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.simple_eval_split(
                dp_model, loader, eval_kwargs)

            # Save model if is improving on validation result
            if not os.path.exists(opt.checkpoint_path):
                os.makedirs(opt.checkpoint_path)

            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False

            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscellaneous information
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = vars(opt)
            infos['vocab'] = loader.get_vocab()

            with open(os.path.join(opt.checkpoint_path, 'infos.json'),
                      'w') as f:
                json.dump(infos, f, sort_keys=True, indent=4)

            if best_flag:
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model-best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, 'infos-best.json'),
                          'w') as f:
                    json.dump(infos, f, sort_keys=True, indent=4)

            # Stop if reaching max epochs
            if opt.max_epochs != -1 and epoch >= opt.max_epochs:
                break
Ejemplo n.º 25
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.fs_index = loader.fs_index

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(
                os.path.join(
                    opt.checkpoint_path, 'infos_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            #need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(
                    opt.checkpoint_path, 'histories_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')):
            with open(
                    os.path.join(
                        opt.checkpoint_path, 'histories_' + opt.id +
                        format(int(opt.start_from), '04') + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    # dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    # Loss function
    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    parameters = model.named_children()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag
    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(
                opt.checkpoint_path, 'optimizer' + opt.id +
                format(int(opt.start_from), '04') + '.pth')):
        optimizer.load_state_dict(
            torch.load(
                os.path.join(
                    opt.checkpoint_path, 'optimizer' + opt.id +
                    format(int(opt.start_from), '04') + '.pth')))

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])
    reset_optimzer_index = 1

    # Training loop
    while True:
        if opt.self_critical_after != -1 and epoch >= opt.self_critical_after and reset_optimzer_index:
            opt.learning_rate_decay_start = opt.self_critical_after
            opt.learning_rate_decay_rate = opt.learning_rate_decay_rate_rl
            opt.learning_rate_decay_every = opt.learning_rate_decay_every_rl
            opt.learning_rate = opt.learning_rate_rl
            reset_optimzer_index = 0

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        # Unpack data
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        # Forward pass and loss
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # Backward pass
        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()

            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), end - start))

        torch.cuda.synchronize()

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if (iteration % opt.save_checkpoint_every == 0):

            # # Evaluate model
            # eval_kwargs = {'split': 'val',
            #                 'dataset': opt.input_json}
            # eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)
            #
            # # Write validation result into summary
            # add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            # if lang_stats is not None:
            #     for k,v in lang_stats.items():
            #         add_summary_value(tb_summary_writer, k, v, iteration)
            # val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}
            #
            # # Our metric is CIDEr if available, otherwise validation loss
            # if opt.language_eval == 1:
            #     current_score = lang_stats['CIDEr']
            # else:
            #     current_score = - val_loss
            current_score = 0

            # Save model in checkpoint path
            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'model' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(
                    opt.checkpoint_path,
                    'optimizer' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'infos_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'histories_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

            # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break