def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, *args,
                       **kwargs):
        # warm up lr
        opt = self.opt
        iteration = self.trainer.global_step
        if opt.use_warmup and (iteration < opt.noamopt_warmup):
            opt.current_lr = opt.learning_rate * \
                (iteration+1) / opt.noamopt_warmup
            utils.set_lr(optimizer, opt.current_lr)

        super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
                               *args, **kwargs)
    def on_epoch_start(self, trainer, pl_module):
        # Update lr/training stage/scheduled sampling prob etc.
        opt = pl_module.opt
        model = pl_module.model
        epoch = trainer.current_epoch
        optimizer = trainer.optimizers[0]

        if not opt.noamopt and not opt.reduce_on_plateau:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
        # Assign the scheduled sampling prob
        if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
            frac = (epoch - opt.scheduled_sampling_start
                    ) // opt.scheduled_sampling_increase_every
            opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                              opt.scheduled_sampling_max_prob)
            model.ss_prob = opt.ss_prob

        # If start self critical training
        if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
            sc_flag = True
            init_scorer(opt.cached_tokens)
        else:
            sc_flag = False

        # If start structure loss training
        if opt.structure_after != -1 and epoch >= opt.structure_after:
            struc_flag = True
            init_scorer(opt.cached_tokens)
        else:
            struc_flag = False

        # drop_worst flag
        if opt.drop_worst_after != -1 and epoch >= opt.drop_worst_after:
            drop_worst_flag = True
        else:
            drop_worst_flag = False

        pl_module.struc_flag = struc_flag
        pl_module.sc_flag = sc_flag
        pl_module.drop_worst_flag = drop_worst_flag
示例#3
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible

    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:

    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_model.vocab = getattr(model, 'vocab', None)  # nasty
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'bert', 'm2transformer'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      optim_func=opt.optim,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(
            optimizer,
            factor=opt.reduce_on_plateau_factor,
            patience=opt.reduce_on_plateau_patience)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in ['train', 'val', 'test']
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False

            start = time.time()
            if opt.use_warmup and (iteration < opt.noamopt_warmup):
                opt.current_lr = opt.learning_rate * (iteration +
                                                      1) / opt.noamopt_warmup
                utils.set_lr(optimizer, opt.current_lr)
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [
                data['fc_feats'], data['att_feats'], data['trace_feats'],
                data['box_feats'], data['labels'], data['masks'],
                data['att_masks'], data['trace_masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, trace_feats, box_feats, labels, masks, att_masks, trace_masks = tmp

            optimizer.zero_grad()

            model_out = dp_lw_model(fc_feats, att_feats, trace_feats,
                                    box_feats, labels, masks, att_masks,
                                    trace_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag,
                                    struc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            if opt.grad_clip_value != 0:
                getattr(torch.nn.utils, 'clip_grad_%s_' %
                        (opt.grad_clip_mode))(model.parameters(),
                                              opt.grad_clip_value)
            if not torch.isnan(loss):
                if opt.language_eval == 1:
                    print('Doing final model evaluation, not updating model.')
                else:
                    optimizer.step()
            else:
                print('Meet nan loss', data['gts'], model_out)

            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if struc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward',
                                                 model_out['reward'].mean(),
                                                 iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar(
                        'lm_loss', model_out['lm_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'struc_loss', model_out['struc_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'reward', model_out['reward'].mean().item(), iteration)
                    tb_summary_writer.add_scalar(
                        'reward_var', model_out['reward'].var(1).mean(),
                        iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if opt.language_eval == 1 or (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch) or \
                (epoch_done and opt.save_every_epoch):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))

                assert (opt.task in ['caption', 'c_joint_t'] and opt.eval_task == 'caption') or \
                       (opt.task in ['trace', 'c_joint_t'] and opt.eval_task == 'trace') or \
                       (opt.task == 'pred_both' and opt.eval_task == 'pred_both')

                if opt.eval_task == 'caption':
                    val_loss, predictions, lang_stats = eval_utils.eval_split(
                        dp_model, lw_model.crit_caption, loader, 'caption',
                        eval_kwargs)
                elif opt.eval_task == 'trace':
                    val_loss = None

                    # This is a little time consuming due to the linear programming solve.
                    val_loss = eval_utils.eval_trace_generation(
                        dp_model,
                        lw_model.crit_trace,
                        loader,
                        window_size=0,
                        eval_kwargs=eval_kwargs
                    )  # Adjust the window_size as needed
                    lang_stats = None
                    predictions = None
                elif opt.eval_task == 'pred_both':
                    val_loss, predictions, lang_stats = eval_utils.eval_split(
                        dp_model, lw_model.crit_caption, loader, 'both',
                        eval_kwargs)  # caption generation
                    val_loss_trace = eval_utils.eval_trace_generation(
                        dp_model,
                        lw_model.crit_trace,
                        loader,
                        window_size=0,
                        eval_kwargs=eval_kwargs
                    )  # Adjust the window_size as needed

                if opt.language_eval == 1:
                    break  # The language eval is done during testing, after the training finishes.

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                # '''
                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(
                        opt,
                        model,
                        infos,
                        optimizer,
                        append=str(epoch)
                        if opt.save_every_epoch else str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append='best')
                # '''

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        # '''
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
示例#4
0
def train(opt):
    ################################
    # 创建dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # 初始化训练信息
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
        'stage': 1,
        'stage_saved': 1  # 用于中断处理,记录了中断时的状态,用于判定是否重新加载最佳模型
    }

    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme

    infos['opt'] = opt

    #########################
    # 创建logger
    #########################
    # 文件logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # 创建模型
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab

    if opt.finetune_only == 1:
        if os.path.isfile(os.path.join(opt.start_from, 'model_best.pth')):
            model.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'model_best.pth')))
    else:
        if opt.start_from is not None and os.path.isfile(
                os.path.join(opt.start_from, 'model.pth')):
            model.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'model.pth')))

    # 作者注:面向模型的loss封装,便于将loss计算独立,便于多卡时减小No.0 GPU的负载
    lw_model = LossWrapper(model, opt)
    # 多GPU封装
    dp_model = torch.nn.DataParallel(model)
    dp_model.vocab = getattr(model, 'vocab', None)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    model.set_stage(infos['stage'])

    ##########################
    #  创建优化器
    ##########################
    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'bert', 'm2transformer'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      optim_func=opt.optim,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(
            optimizer,
            factor=opt.reduce_on_plateau_factor,
            patience=opt.reduce_on_plateau_patience)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)

    if opt.finetune_only == 1:
        if os.path.isfile(os.path.join(opt.start_from, "optimizer_best.pth")):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer_best.pth')))
    else:
        if opt.start_from is not None and os.path.isfile(
                os.path.join(opt.start_from, "optimizer.pth")):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # 训练
    #########################

    # 准备阶段
    iteration = infos['iter']
    epoch = infos['epoch']
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration

    # 作者注:轮次完成标志量,用于新轮次可能的训练参数调整
    epoch_done = True
    eval_done = False

    dp_lw_model.train()

    # 开始训练啦!经典训练
    if infos['stage'] == 1 and opt.finetune_only != 1:
        try:
            while True:
                # 达到最大epoch限制,跳出经典训练
                if epoch >= opt.max_epochs_base != -1:
                    if eval_done:
                        break
                    else:
                        # 末尾再评估一次
                        eval_kwargs = {
                            'split': 'base_val',
                            'dataset': opt.input_json
                        }
                        eval_kwargs.update(vars(opt))
                        val_loss, predictions, lang_stats, _ = eval_utils.eval_split(
                            dp_model, lw_model.crit, loader, eval_kwargs)

                        if opt.reduce_on_plateau:
                            if 'CIDEr' in lang_stats:
                                optimizer.scheduler_step(-lang_stats['CIDEr'])
                            else:
                                optimizer.scheduler_step(val_loss)

                        # 将评估结果写入日志
                        tb_summary_writer.add_scalar('validation loss',
                                                     val_loss, iteration)
                        if lang_stats is not None:
                            for k, v in lang_stats.items():
                                tb_summary_writer.add_scalar(k, v, iteration)

                        histories['val_result_history'][iteration] = {
                            'loss': val_loss,
                            'lang_stats': lang_stats,
                            'predictions': predictions
                        }

                        # 根据CIDEr指标选择最佳模型
                        if opt.language_eval == 1:
                            current_score = lang_stats['CIDEr']
                        else:
                            current_score = -val_loss

                        best_flag = False

                        if best_val_score is None or current_score > best_val_score:
                            best_val_score = current_score
                            best_flag = True

                        infos['best_val_score'] = best_val_score

                        utils.save_checkpoint(opt, model, infos, optimizer,
                                              histories)

                        if opt.save_history_ckpt:
                            utils.save_checkpoint(
                                opt,
                                model,
                                infos,
                                optimizer,
                                append=str(epoch)
                                if opt.save_every_epoch else str(iteration))

                        if best_flag:
                            utils.save_checkpoint(opt,
                                                  model,
                                                  infos,
                                                  optimizer,
                                                  append='best')

                        break

                eval_done = False

                # 设置学习参数
                if epoch_done:
                    # Transformer相关
                    if not opt.noamopt and not opt.reduce_on_plateau:
                        if epoch > opt.learning_rate_decay_start >= 0:
                            frac = (epoch - opt.learning_rate_decay_start
                                    ) // opt.learning_rate_decay_every
                            decay_factor = opt.learning_rate_decay_rate**frac
                            opt.current_lr = opt.learning_rate_base * decay_factor
                        else:
                            opt.current_lr = opt.learning_rate_base
                        utils.set_lr(optimizer, opt.current_lr)

                    # scheduled sampling
                    if epoch > opt.scheduled_sampling_start >= 0:
                        frac = (epoch - opt.scheduled_sampling_start
                                ) // opt.scheduled_sampling_increase_every
                        opt.ss_prob = min(
                            opt.scheduled_sampling_increase_prob * frac,
                            opt.scheduled_sampling_max_prob)
                        model.ss_prob = opt.ss_prob

                    # SCST
                    if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                        sc_flag = True
                        init_scorer(opt.cached_tokens)
                    else:
                        sc_flag = False

                    # 结构损失
                    if opt.structure_after != -1 and epoch >= opt.structure_after:
                        struc_flag = True
                        init_scorer(opt.cached_tokens)
                    else:
                        struc_flag = False

                    epoch_done = False

                # start = time.time()
                # Transformer Warmup
                if opt.use_warmup and (iteration < opt.noamopt_warmup):
                    opt.current_lr = opt.learning_rate_base * (
                        iteration + 1) / opt.noamopt_warmup
                    utils.set_lr(optimizer, opt.current_lr)

                data = loader.get_batch('base_train')
                # print('\r Read data:', time.time() - start, end="")

                torch.cuda.synchronize()
                start = time.time()

                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
                fc_feats, att_feats, labels, masks, att_masks = tmp

                optimizer.zero_grad()
                model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                        att_masks, data['gts'],
                                        torch.arange(0, len(data['gts'])),
                                        sc_flag, struc_flag)

                loss = model_out['loss'].mean()

                loss.backward()

                # 梯度截断
                if opt.grad_clip_value != 0:
                    getattr(torch.nn.utils, 'clip_grad_{}_'.format(
                        opt.grad_clip_mode))(model.parameters(),
                                             opt.grad_clip_value)

                optimizer.step()

                train_loss = loss.item()
                torch.cuda.synchronize()
                end = time.time()

                # 输出
                if struc_flag:
                    print('Base Training:', "iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
                elif not sc_flag:
                    print('Base Training:', "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, train_loss, end - start))
                else:
                    print('Base Training:', "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, model_out['reward'].mean(), end - start))

                # 更新迭代计数器,如果到达epoch边界,需要调整一些参数
                iteration += 1
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True

                # 将训练结构写入到日志中
                if iteration % opt.losses_log_every == 0:
                    tb_summary_writer.add_scalar('train_loss', train_loss,
                                                 iteration)
                    if opt.noamopt:
                        opt.current_lr = optimizer.rate()
                    elif opt.reduce_on_plateau:
                        opt.current_lr = optimizer.current_lr
                    tb_summary_writer.add_scalar('learning_rate',
                                                 opt.current_lr, iteration)
                    tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                                 model.ss_prob, iteration)
                    if sc_flag:
                        tb_summary_writer.add_scalar(
                            'avg_reward', model_out['reward'].mean(),
                            iteration)
                    elif struc_flag:
                        tb_summary_writer.add_scalar(
                            'lm_loss', model_out['lm_loss'].mean().item(),
                            iteration)
                        tb_summary_writer.add_scalar(
                            'struc_loss',
                            model_out['struc_loss'].mean().item(), iteration)
                        tb_summary_writer.add_scalar(
                            'reward', model_out['reward'].mean().item(),
                            iteration)
                        tb_summary_writer.add_scalar(
                            'reward_var', model_out['reward'].var(1).mean(),
                            iteration)

                    histories['loss_history'][
                        iteration] = train_loss if not sc_flag else model_out[
                            'reward'].mean()
                    histories['lr_history'][iteration] = opt.current_lr
                    histories['ss_prob_history'][iteration] = model.ss_prob

                # 信息更新
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['loader_state_dict'] = loader.state_dict()

                # 根据需要,在两个模式下评估模型
                if (iteration % opt.save_checkpoint_every == 0
                        and not opt.save_every_epoch) or (
                            epoch_done and opt.save_every_epoch):
                    eval_kwargs = {
                        'split': 'base_val',
                        'dataset': opt.input_json
                    }
                    eval_kwargs.update(vars(opt))
                    val_loss, predictions, lang_stats, _ = eval_utils.eval_split(
                        dp_model, lw_model.crit, loader, eval_kwargs)

                    if opt.reduce_on_plateau:
                        if 'CIDEr' in lang_stats:
                            optimizer.scheduler_step(-lang_stats['CIDEr'])
                        else:
                            optimizer.scheduler_step(val_loss)

                    # 将评估结果写入日志
                    tb_summary_writer.add_scalar('validation loss', val_loss,
                                                 iteration)
                    if lang_stats is not None:
                        for k, v in lang_stats.items():
                            tb_summary_writer.add_scalar(k, v, iteration)

                    histories['val_result_history'][iteration] = {
                        'loss': val_loss,
                        'lang_stats': lang_stats,
                        'predictions': predictions
                    }

                    # 根据CIDEr指标选择最佳模型
                    if opt.language_eval == 1:
                        current_score = lang_stats['CIDEr']
                    else:
                        current_score = -val_loss

                    best_flag = False

                    if best_val_score is None or current_score > best_val_score:
                        best_val_score = current_score
                        best_flag = True

                    infos['best_val_score'] = best_val_score

                    utils.save_checkpoint(opt, model, infos, optimizer,
                                          histories)

                    if opt.save_history_ckpt:
                        utils.save_checkpoint(
                            opt,
                            model,
                            infos,
                            optimizer,
                            append=str(epoch)
                            if opt.save_every_epoch else str(iteration))

                    if best_flag:
                        utils.save_checkpoint(opt,
                                              model,
                                              infos,
                                              optimizer,
                                              append='best')

                    eval_done = True

        except (RuntimeError, KeyboardInterrupt):
            print('Save ckpt on exception ...')
            utils.save_checkpoint(opt, model, infos, optimizer)
            print('Save ckpt done.')
            stack_trace = traceback.format_exc()
            print(stack_trace)
            os._exit(0)

        infos['stage'] = 2

    # dummy配置下,不进行微调
    if opt.train_only == 1:
        # 微调训练
        infos['stage'] = 2
        epoch_done = True
        loader.reset_iterator('support')

        # 加载最佳模型,如果中断位置在第二阶段,则不进行模型加载
        if opt.start_from and infos['stage_saved'] == 2:
            pass
        else:
            # 否则加载stage 1的最佳模型进行微调
            print('Finetuning:', "loading best model from stage 1")
            model.load_state_dict(
                torch.load(os.path.join(opt.start_from,
                                        'model_best' + '.pth')))
            optimizer.load_state_dict(
                torch.load(
                    os.path.join(opt.start_from, 'optimizer_best' + '.pth')))

            lw_model = LossWrapper(model, opt)
            # 多GPU封装
            dp_model = torch.nn.DataParallel(model)
            dp_model.vocab = getattr(model, 'vocab', None)
            dp_lw_model = torch.nn.DataParallel(lw_model)

        model.set_stage(infos['stage'])
        infos['stage_saved'] = 2

        # 冻结除了最后一个logit层之外的所有参数
        for name, parameter in dp_lw_model.module.named_parameters():
            if 'logit' not in name:
                parameter.requires_grad = False
            else:
                parameter.requires_grad = True

        # 因为计数器没有清零,所以这里是直接加上去
        max_epochs_all = opt.max_epochs_base + opt.max_epochs_finetune

        # 提前准备:相关学习参数是否跟随
        if opt.learning_rate_decay_start_finetune < 0:
            opt.learning_rate_decay_start_finetune = opt.learning_rate_decay_start - opt.max_epochs_base

        if opt.learning_rate_finetune < 0:
            opt.learning_rate_finetune = opt.learning_rate_base

        if opt.scheduled_sampling_start_finetune < 0:
            opt.scheduled_sampling_start_finetune = opt.scheduled_sampling_start - opt.max_epochs_base

        try:
            while True:
                # 达到最大epoch限制,跳出
                if epoch >= max_epochs_all != -2:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          histories,
                                          append='finetune')
                    break

                # 设置学习参数
                if epoch_done:
                    # Transformer相关
                    if not opt.noamopt and not opt.reduce_on_plateau:
                        if epoch > opt.learning_rate_decay_start_finetune + opt.max_epochs_base >= 0:
                            frac = (epoch -
                                    opt.learning_rate_decay_start_finetune -
                                    opt.max_epochs_base
                                    ) // opt.learning_rate_decay_every_finetune
                            decay_factor = opt.learning_rate_decay_rate_finetune**frac
                            opt.current_lr = opt.learning_rate_finetune * decay_factor
                        else:
                            opt.current_lr = opt.learning_rate_finetune

                        utils.set_lr(optimizer, opt.current_lr)

                    # scheduled sampling
                    if epoch > opt.scheduled_sampling_start_finetune + opt.max_epochs_base >= 0:
                        frac = (
                            epoch - opt.scheduled_sampling_start_finetune -
                            opt.max_epochs_base
                        ) // opt.scheduled_sampling_increase_every_finetune
                        opt.ss_prob = min(
                            opt.scheduled_sampling_increase_prob_finetune *
                            frac, opt.scheduled_sampling_max_prob_finetune)
                        model.ss_prob = opt.ss_prob

                    # SCST
                    if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                        sc_flag = True
                        init_scorer(opt.cached_tokens)
                    else:
                        sc_flag = False

                    # 结构损失
                    if opt.structure_after != -1 and epoch >= opt.structure_after:
                        struc_flag = True
                        init_scorer(opt.cached_tokens)
                    else:
                        struc_flag = False

                    epoch_done = False

                # start = time.time()
                # Transformer Warmup
                # if opt.use_warmup and (iteration < opt.noamopt_warmup):
                #     opt.current_lr = opt.learning_rate * (iteration + 1) / opt.noamopt_warmup
                #     utils.set_lr(optimizer, opt.current_lr)

                data = loader.get_batch('support')

                torch.cuda.synchronize()
                start = time.time()

                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
                fc_feats, att_feats, labels, masks, att_masks = tmp

                optimizer.zero_grad()
                model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                        att_masks, data['gts'],
                                        torch.arange(0, len(data['gts'])),
                                        sc_flag, struc_flag)

                loss = model_out['loss'].mean()

                loss.backward()

                # 梯度截断
                if opt.grad_clip_value != 0:
                    getattr(torch.nn.utils, 'clip_grad_{}_'.format(
                        opt.grad_clip_mode))(model.parameters(),
                                             opt.grad_clip_value)

                optimizer.step()

                train_loss = loss.item()
                torch.cuda.synchronize()
                end = time.time()

                # 输出
                if struc_flag:
                    print('Finetuning:', "iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
                elif not sc_flag:
                    print('Finetuning:', "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, train_loss, end - start))
                else:
                    print('Finetuning:', "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, model_out['reward'].mean(), end - start))

                # 更新迭代计数器,如果到达epoch边界,需要调整一些参数
                iteration += 1
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True

                # 将训练结构写入到日志中
                if iteration % opt.losses_log_every == 0:
                    tb_summary_writer.add_scalar('train_loss', train_loss,
                                                 iteration)
                    if opt.noamopt:
                        opt.current_lr = optimizer.rate()
                    elif opt.reduce_on_plateau:
                        opt.current_lr = optimizer.current_lr
                    tb_summary_writer.add_scalar('learning_rate',
                                                 opt.current_lr, iteration)
                    tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                                 model.ss_prob, iteration)
                    if sc_flag:
                        tb_summary_writer.add_scalar(
                            'avg_reward', model_out['reward'].mean(),
                            iteration)
                    elif struc_flag:
                        tb_summary_writer.add_scalar(
                            'lm_loss', model_out['lm_loss'].mean().item(),
                            iteration)
                        tb_summary_writer.add_scalar(
                            'struc_loss',
                            model_out['struc_loss'].mean().item(), iteration)
                        tb_summary_writer.add_scalar(
                            'reward', model_out['reward'].mean().item(),
                            iteration)
                        tb_summary_writer.add_scalar(
                            'reward_var', model_out['reward'].var(1).mean(),
                            iteration)

                    histories['loss_history'][
                        iteration] = train_loss if not sc_flag else model_out[
                            'reward'].mean()
                    histories['lr_history'][iteration] = opt.current_lr
                    histories['ss_prob_history'][iteration] = model.ss_prob

                # 信息更新
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['loader_state_dict'] = loader.state_dict()

                if (iteration % opt.save_checkpoint_every == 0
                        and not opt.save_every_epoch) or (
                            epoch_done and opt.save_every_epoch):
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          histories,
                                          append='finetune')

                    if opt.save_history_ckpt:
                        utils.save_checkpoint(
                            opt,
                            model,
                            infos,
                            optimizer,
                            append=str(epoch)
                            if opt.save_every_epoch else str(iteration))

        except (RuntimeError, KeyboardInterrupt):
            print('Save ckpt on exception ...')
            utils.save_checkpoint(opt, model, infos, optimizer)
            print('Save ckpt done.')
            stack_trace = traceback.format_exc()
            print(stack_trace)
            os._exit(0)