def __init__(self, opt):
        super().__init__()
        self.opt = opt
        # Intilaize dataset
        self.dataset = CaptionDataset(opt)
        opt.vocab_size = self.dataset.vocab_size
        opt.seq_length = self.dataset.seq_length
        self.batch_size = opt.batch_size

        # Build model
        opt.vocab = self.dataset.get_vocab()
        model = models.setup(opt)
        print(model)
        del opt.vocab

        # wrapper with loss in it.
        lw_model = LossWrapper(model, opt)

        self.model = model
        self.lw_model = lw_model

        self.struc_flag = None
        self.sc_flag = None
Ejemplo n.º 2
0
if not opt.force:
    # Check out if
    try:
        # if no pred exists, then continue
        tmp = torch.load(pred_fn)
        # if language_eval == 1, and no pred exists, then continue
        if opt.language_eval == 1:
            json.load(open(result_fn, 'r'))
        print('Result is already there')
        os._exit(0)
    except:
        pass

# Setup the model
opt.vocab = vocab
model = models.setup(opt)
del opt.vocab
model.load_state_dict(torch.load(opt.model))
model.cuda()
model.eval()
crit = losses.LanguageModelCriterion()

# Create the Data Loader instance
if len(opt.image_folder) == 0:
    loader = DataLoader(opt)
else:
    loader = DataLoaderRaw({
        'folder_path': opt.image_folder,
        'coco_json': opt.coco_json,
        'batch_size': opt.batch_size,
        'cnn_model': opt.cnn_model
Ejemplo n.º 3
0
    def _get_model(self):
        opt = argparse.Namespace(batch_size=0,
                                 beam_size=1,
                                 block_trigrams=0,
                                 coco_json='',
                                 decoding_constraint=0,
                                 diversity_lambda=0.5,
                                 dump_images=1,
                                 dump_json=1,
                                 dump_path=0,
                                 group_size=1,
                                 id='',
                                 image_folder='',
                                 image_root='',
                                 input_att_dir='',
                                 input_box_dir='',
                                 input_fc_dir='',
                                 input_json='',
                                 input_label_h5='',
                                 language_eval=0,
                                 length_penalty='',
                                 max_length=20,
                                 num_images=-1,
                                 remove_bad_endings=0,
                                 sample_method='greedy',
                                 split='test',
                                 suppress_UNK=1,
                                 temperature=1.0,
                                 verbose_beam=1,
                                 verbose_loss=0)

        opt.model = self._model_path
        opt.infos_path = self._infos_path
        opt.device = self._device
        opt.dataset = opt.input_json

        # Load infos
        with open(opt.infos_path, 'rb') as f:
            infos = utils.pickle_load(f)

        # override and collect parameters
        replace = [
            'input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5',
            'input_json', 'batch_size', 'id'
        ]
        ignore = ['start_from']

        for k in vars(infos['opt']).keys():
            if k in replace:
                setattr(opt, k,
                        getattr(opt, k) or getattr(infos['opt'], k, ''))
            elif k not in ignore:
                if k not in vars(opt):
                    vars(opt).update({k: vars(infos['opt'])[k]})

        vocab = infos['vocab']

        opt.vocab = vocab
        model = models.setup(opt)
        del opt.vocab
        model.load_state_dict(torch.load(opt.model, map_location='cpu'))
        model.to(opt.device)
        model.eval()

        return opt, infos, model
Ejemplo n.º 4
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible

    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:

    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_model.vocab = getattr(model, 'vocab', None)  # nasty
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'bert', 'm2transformer'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      optim_func=opt.optim,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(
            optimizer,
            factor=opt.reduce_on_plateau_factor,
            patience=opt.reduce_on_plateau_patience)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in ['train', 'val', 'test']
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False

            start = time.time()
            if opt.use_warmup and (iteration < opt.noamopt_warmup):
                opt.current_lr = opt.learning_rate * (iteration +
                                                      1) / opt.noamopt_warmup
                utils.set_lr(optimizer, opt.current_lr)
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [
                data['fc_feats'], data['att_feats'], data['trace_feats'],
                data['box_feats'], data['labels'], data['masks'],
                data['att_masks'], data['trace_masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, trace_feats, box_feats, labels, masks, att_masks, trace_masks = tmp

            optimizer.zero_grad()

            model_out = dp_lw_model(fc_feats, att_feats, trace_feats,
                                    box_feats, labels, masks, att_masks,
                                    trace_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag,
                                    struc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            if opt.grad_clip_value != 0:
                getattr(torch.nn.utils, 'clip_grad_%s_' %
                        (opt.grad_clip_mode))(model.parameters(),
                                              opt.grad_clip_value)
            if not torch.isnan(loss):
                if opt.language_eval == 1:
                    print('Doing final model evaluation, not updating model.')
                else:
                    optimizer.step()
            else:
                print('Meet nan loss', data['gts'], model_out)

            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if struc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward',
                                                 model_out['reward'].mean(),
                                                 iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar(
                        'lm_loss', model_out['lm_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'struc_loss', model_out['struc_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'reward', model_out['reward'].mean().item(), iteration)
                    tb_summary_writer.add_scalar(
                        'reward_var', model_out['reward'].var(1).mean(),
                        iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if opt.language_eval == 1 or (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch) or \
                (epoch_done and opt.save_every_epoch):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))

                assert (opt.task in ['caption', 'c_joint_t'] and opt.eval_task == 'caption') or \
                       (opt.task in ['trace', 'c_joint_t'] and opt.eval_task == 'trace') or \
                       (opt.task == 'pred_both' and opt.eval_task == 'pred_both')

                if opt.eval_task == 'caption':
                    val_loss, predictions, lang_stats = eval_utils.eval_split(
                        dp_model, lw_model.crit_caption, loader, 'caption',
                        eval_kwargs)
                elif opt.eval_task == 'trace':
                    val_loss = None

                    # This is a little time consuming due to the linear programming solve.
                    val_loss = eval_utils.eval_trace_generation(
                        dp_model,
                        lw_model.crit_trace,
                        loader,
                        window_size=0,
                        eval_kwargs=eval_kwargs
                    )  # Adjust the window_size as needed
                    lang_stats = None
                    predictions = None
                elif opt.eval_task == 'pred_both':
                    val_loss, predictions, lang_stats = eval_utils.eval_split(
                        dp_model, lw_model.crit_caption, loader, 'both',
                        eval_kwargs)  # caption generation
                    val_loss_trace = eval_utils.eval_trace_generation(
                        dp_model,
                        lw_model.crit_trace,
                        loader,
                        window_size=0,
                        eval_kwargs=eval_kwargs
                    )  # Adjust the window_size as needed

                if opt.language_eval == 1:
                    break  # The language eval is done during testing, after the training finishes.

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                # '''
                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(
                        opt,
                        model,
                        infos,
                        optimizer,
                        append=str(epoch)
                        if opt.save_every_epoch else str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append='best')
                # '''

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        # '''
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 5
0
def main():

    use_gpu = torch.cuda.is_available()
    # Create model
    # models.resnet18(num_classes=365)
    # model = ColorNet()
    #args = get_args()
    #model = MODEL(args)
    # state_dict = torch.load("./checkpoint/checkpoint6/model_epoch133_step1.pth")
    # new_state_dict = OrderedDict()

    # for k, v in state_dict.items():
    #     k = k.replace('module.', '')
    #     new_state_dict[k] = v

    # model = torch.nn.DataParallel(model)
    # model.load_state_dict(new_state_dict)
    parser = argparse.ArgumentParser()
    # Input paths
    parser.add_argument('--model',
                        type=str,
                        default='',
                        help='path to model to evaluate')
    parser.add_argument('--cnn_model',
                        type=str,
                        default='resnet101',
                        help='resnet101, resnet152')
    parser.add_argument('--infos_path',
                        type=str,
                        default='',
                        help='path to infos to evaluate')
    parser.add_argument('--only_lang_eval',
                        type=int,
                        default=0,
                        help='lang eval on saved results')
    parser.add_argument(
        '--force',
        type=int,
        default=0,
        help='force to evaluate no matter if there are results available')
    opts.add_eval_options(parser)
    opts.add_diversity_opts(parser)
    opt = parser.parse_args()
    opt.caption_model = 'newfc'
    opt.infos_path = '/home/zzgyf/github_yifan/ImageCaptioning.pytorch/models/infos_fc_nsc-best.pkl'
    with open(opt.infos_path, 'rb') as f:
        infos = utils.pickle_load(f)

    replace = [
        'input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5',
        'input_json', 'batch_size', 'id'
    ]
    ignore = ['start_from']

    for k in vars(infos['opt']).keys():
        if k in replace:
            setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, ''))
        elif k not in ignore:
            if not k in vars(opt):
                vars(opt).update({k: vars(infos['opt'])[k]
                                  })  # copy over options from model

    vocab = infos['vocab']  # ix -> word mapping

    opt.vocab = vocab
    model = models.setup(opt)

    checkpoint = torch.load(
        "/home/zzgyf/github_yifan/ImageCaptioning.pytorch/models/model-best.pth"
    )
    model.load_state_dict(checkpoint)

    # print(model)
    #input_shape = (1, 256, 256)
    cocotest_bu_fc_size = (10, 2048)
    cocotest_bu_att_size = (10, 0, 0)
    labels_size = (10, 5, 18)
    masks_size = (10, 5, 18)
    model_onnx_path = "./image_captioning.onnx"
    model.train(False)

    # Export the model to an ONNX file
    # dummy_input = Variable(torch.randn(1, *input_shape))
    # dummy_input = Variable(torch.randn(10, 2048), torch.randn(10, 0, 0), torch.randint(5200, (10, 5, 18)), torch.randint(1, (10, 5, 18)))
    dummy_cocotest_bu_fc = Variable(torch.randn(10, 2048))
    dummy_cocotest_bu_att = Variable(torch.randn(10, 0, 0))
    dummy_labels = Variable(torch.randint(5200, (10, 5, 18)))
    dummy_masks = Variable(torch.randint(1, (10, 5, 18)))
    #output = torch_onnx.export(model, dummy_input, model_onnx_path, verbose=False)
    output = torch_onnx.export(model,
                               (dummy_cocotest_bu_fc, dummy_cocotest_bu_att,
                                dummy_labels, dummy_masks),
                               model_onnx_path,
                               verbose=False)
    print("Export of torch_model.onnx complete!")
Ejemplo n.º 6
0
def check():
    parser = argparse.ArgumentParser()
    # Input paths
    parser.add_argument('--model',
                        type=str,
                        default='',
                        help='path to model to evaluate')
    parser.add_argument('--cnn_model',
                        type=str,
                        default='resnet101',
                        help='resnet101, resnet152')
    parser.add_argument('--infos_path',
                        type=str,
                        default='',
                        help='path to infos to evaluate')
    parser.add_argument('--only_lang_eval',
                        type=int,
                        default=0,
                        help='lang eval on saved results')
    parser.add_argument(
        '--force',
        type=int,
        default=0,
        help='force to evaluate no matter if there are results available')
    opts.add_eval_options(parser)
    opts.add_diversity_opts(parser)
    opt = parser.parse_args()
    opt.caption_model = 'newfc'
    opt.infos_path = '/home/zzgyf/github_yifan/ImageCaptioning.pytorch/models/infos_fc_nsc-best.pkl'
    with open(opt.infos_path, 'rb') as f:
        infos = utils.pickle_load(f)

    replace = [
        'input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5',
        'input_json', 'batch_size', 'id'
    ]
    ignore = ['start_from']

    for k in vars(infos['opt']).keys():
        if k in replace:
            setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, ''))
        elif k not in ignore:
            if not k in vars(opt):
                vars(opt).update({k: vars(infos['opt'])[k]
                                  })  # copy over options from model

    vocab = infos['vocab']  # ix -> word mapping

    opt.vocab = vocab
    model = models.setup(opt)

    checkpoint = torch.load(
        "/home/zzgyf/github_yifan/ImageCaptioning.pytorch/models/model-best.pth"
    )
    model.load_state_dict(checkpoint)

    # torch.nn.utils.remove_weight_norm(model.head[0])
    # for i in range(2):
    #     for j in [0,2,3]:
    #         torch.nn.utils.remove_weight_norm(model.body[i].body[j])
    # torch.nn.utils.remove_weight_norm(model.tail[0])
    # torch.nn.utils.remove_weight_norm(model.skip[0])

    model.eval()
    ort_session = onnxruntime.InferenceSession("image_captioning.onnx")

    dummy_cocotest_bu_fc = Variable(torch.randn(10, 2048))
    dummy_cocotest_bu_att = Variable(torch.randn(10, 0, 0))
    dummy_labels = Variable(torch.randint(5200, (10, 5, 18)))
    dummy_masks = Variable(torch.randint(1, (10, 5, 18)))
    x = (dummy_cocotest_bu_fc, dummy_cocotest_bu_att, dummy_labels,
         dummy_masks)
    #x = torch.randn(1, 3, 392, 392, requires_grad=False)
    #torch_out = model(x)
    # # Load the ONNX model
    # model = onnx.load("wdsr_b.onnx")

    # # Check that the IR is well formed
    # onnx.checker.check_model(model)

    # # Print a human readable representation of the graph
    # onnx.helper.printable_graph(model.graph)

    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    ort_outs = ort_session.run(None, ort_inputs)

    # compare ONNX Runtime and PyTorch results
    np.testing.assert_allclose(to_numpy(torch_out),
                               ort_outs[0],
                               rtol=1e-03,
                               atol=1e-05)
Ejemplo n.º 7
0
def train(opt):
    ################################
    # 创建dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # 初始化训练信息
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
        'stage': 1,
        'stage_saved': 1  # 用于中断处理,记录了中断时的状态,用于判定是否重新加载最佳模型
    }

    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme

    infos['opt'] = opt

    #########################
    # 创建logger
    #########################
    # 文件logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # 创建模型
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab

    if opt.finetune_only == 1:
        if os.path.isfile(os.path.join(opt.start_from, 'model_best.pth')):
            model.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'model_best.pth')))
    else:
        if opt.start_from is not None and os.path.isfile(
                os.path.join(opt.start_from, 'model.pth')):
            model.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'model.pth')))

    # 作者注:面向模型的loss封装,便于将loss计算独立,便于多卡时减小No.0 GPU的负载
    lw_model = LossWrapper(model, opt)
    # 多GPU封装
    dp_model = torch.nn.DataParallel(model)
    dp_model.vocab = getattr(model, 'vocab', None)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    model.set_stage(infos['stage'])

    ##########################
    #  创建优化器
    ##########################
    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'bert', 'm2transformer'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      optim_func=opt.optim,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(
            optimizer,
            factor=opt.reduce_on_plateau_factor,
            patience=opt.reduce_on_plateau_patience)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)

    if opt.finetune_only == 1:
        if os.path.isfile(os.path.join(opt.start_from, "optimizer_best.pth")):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer_best.pth')))
    else:
        if opt.start_from is not None and os.path.isfile(
                os.path.join(opt.start_from, "optimizer.pth")):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # 训练
    #########################

    # 准备阶段
    iteration = infos['iter']
    epoch = infos['epoch']
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration

    # 作者注:轮次完成标志量,用于新轮次可能的训练参数调整
    epoch_done = True
    eval_done = False

    dp_lw_model.train()

    # 开始训练啦!经典训练
    if infos['stage'] == 1 and opt.finetune_only != 1:
        try:
            while True:
                # 达到最大epoch限制,跳出经典训练
                if epoch >= opt.max_epochs_base != -1:
                    if eval_done:
                        break
                    else:
                        # 末尾再评估一次
                        eval_kwargs = {
                            'split': 'base_val',
                            'dataset': opt.input_json
                        }
                        eval_kwargs.update(vars(opt))
                        val_loss, predictions, lang_stats, _ = eval_utils.eval_split(
                            dp_model, lw_model.crit, loader, eval_kwargs)

                        if opt.reduce_on_plateau:
                            if 'CIDEr' in lang_stats:
                                optimizer.scheduler_step(-lang_stats['CIDEr'])
                            else:
                                optimizer.scheduler_step(val_loss)

                        # 将评估结果写入日志
                        tb_summary_writer.add_scalar('validation loss',
                                                     val_loss, iteration)
                        if lang_stats is not None:
                            for k, v in lang_stats.items():
                                tb_summary_writer.add_scalar(k, v, iteration)

                        histories['val_result_history'][iteration] = {
                            'loss': val_loss,
                            'lang_stats': lang_stats,
                            'predictions': predictions
                        }

                        # 根据CIDEr指标选择最佳模型
                        if opt.language_eval == 1:
                            current_score = lang_stats['CIDEr']
                        else:
                            current_score = -val_loss

                        best_flag = False

                        if best_val_score is None or current_score > best_val_score:
                            best_val_score = current_score
                            best_flag = True

                        infos['best_val_score'] = best_val_score

                        utils.save_checkpoint(opt, model, infos, optimizer,
                                              histories)

                        if opt.save_history_ckpt:
                            utils.save_checkpoint(
                                opt,
                                model,
                                infos,
                                optimizer,
                                append=str(epoch)
                                if opt.save_every_epoch else str(iteration))

                        if best_flag:
                            utils.save_checkpoint(opt,
                                                  model,
                                                  infos,
                                                  optimizer,
                                                  append='best')

                        break

                eval_done = False

                # 设置学习参数
                if epoch_done:
                    # Transformer相关
                    if not opt.noamopt and not opt.reduce_on_plateau:
                        if epoch > opt.learning_rate_decay_start >= 0:
                            frac = (epoch - opt.learning_rate_decay_start
                                    ) // opt.learning_rate_decay_every
                            decay_factor = opt.learning_rate_decay_rate**frac
                            opt.current_lr = opt.learning_rate_base * decay_factor
                        else:
                            opt.current_lr = opt.learning_rate_base
                        utils.set_lr(optimizer, opt.current_lr)

                    # scheduled sampling
                    if epoch > opt.scheduled_sampling_start >= 0:
                        frac = (epoch - opt.scheduled_sampling_start
                                ) // opt.scheduled_sampling_increase_every
                        opt.ss_prob = min(
                            opt.scheduled_sampling_increase_prob * frac,
                            opt.scheduled_sampling_max_prob)
                        model.ss_prob = opt.ss_prob

                    # SCST
                    if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                        sc_flag = True
                        init_scorer(opt.cached_tokens)
                    else:
                        sc_flag = False

                    # 结构损失
                    if opt.structure_after != -1 and epoch >= opt.structure_after:
                        struc_flag = True
                        init_scorer(opt.cached_tokens)
                    else:
                        struc_flag = False

                    epoch_done = False

                # start = time.time()
                # Transformer Warmup
                if opt.use_warmup and (iteration < opt.noamopt_warmup):
                    opt.current_lr = opt.learning_rate_base * (
                        iteration + 1) / opt.noamopt_warmup
                    utils.set_lr(optimizer, opt.current_lr)

                data = loader.get_batch('base_train')
                # print('\r Read data:', time.time() - start, end="")

                torch.cuda.synchronize()
                start = time.time()

                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
                fc_feats, att_feats, labels, masks, att_masks = tmp

                optimizer.zero_grad()
                model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                        att_masks, data['gts'],
                                        torch.arange(0, len(data['gts'])),
                                        sc_flag, struc_flag)

                loss = model_out['loss'].mean()

                loss.backward()

                # 梯度截断
                if opt.grad_clip_value != 0:
                    getattr(torch.nn.utils, 'clip_grad_{}_'.format(
                        opt.grad_clip_mode))(model.parameters(),
                                             opt.grad_clip_value)

                optimizer.step()

                train_loss = loss.item()
                torch.cuda.synchronize()
                end = time.time()

                # 输出
                if struc_flag:
                    print('Base Training:', "iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
                elif not sc_flag:
                    print('Base Training:', "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, train_loss, end - start))
                else:
                    print('Base Training:', "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, model_out['reward'].mean(), end - start))

                # 更新迭代计数器,如果到达epoch边界,需要调整一些参数
                iteration += 1
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True

                # 将训练结构写入到日志中
                if iteration % opt.losses_log_every == 0:
                    tb_summary_writer.add_scalar('train_loss', train_loss,
                                                 iteration)
                    if opt.noamopt:
                        opt.current_lr = optimizer.rate()
                    elif opt.reduce_on_plateau:
                        opt.current_lr = optimizer.current_lr
                    tb_summary_writer.add_scalar('learning_rate',
                                                 opt.current_lr, iteration)
                    tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                                 model.ss_prob, iteration)
                    if sc_flag:
                        tb_summary_writer.add_scalar(
                            'avg_reward', model_out['reward'].mean(),
                            iteration)
                    elif struc_flag:
                        tb_summary_writer.add_scalar(
                            'lm_loss', model_out['lm_loss'].mean().item(),
                            iteration)
                        tb_summary_writer.add_scalar(
                            'struc_loss',
                            model_out['struc_loss'].mean().item(), iteration)
                        tb_summary_writer.add_scalar(
                            'reward', model_out['reward'].mean().item(),
                            iteration)
                        tb_summary_writer.add_scalar(
                            'reward_var', model_out['reward'].var(1).mean(),
                            iteration)

                    histories['loss_history'][
                        iteration] = train_loss if not sc_flag else model_out[
                            'reward'].mean()
                    histories['lr_history'][iteration] = opt.current_lr
                    histories['ss_prob_history'][iteration] = model.ss_prob

                # 信息更新
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['loader_state_dict'] = loader.state_dict()

                # 根据需要,在两个模式下评估模型
                if (iteration % opt.save_checkpoint_every == 0
                        and not opt.save_every_epoch) or (
                            epoch_done and opt.save_every_epoch):
                    eval_kwargs = {
                        'split': 'base_val',
                        'dataset': opt.input_json
                    }
                    eval_kwargs.update(vars(opt))
                    val_loss, predictions, lang_stats, _ = eval_utils.eval_split(
                        dp_model, lw_model.crit, loader, eval_kwargs)

                    if opt.reduce_on_plateau:
                        if 'CIDEr' in lang_stats:
                            optimizer.scheduler_step(-lang_stats['CIDEr'])
                        else:
                            optimizer.scheduler_step(val_loss)

                    # 将评估结果写入日志
                    tb_summary_writer.add_scalar('validation loss', val_loss,
                                                 iteration)
                    if lang_stats is not None:
                        for k, v in lang_stats.items():
                            tb_summary_writer.add_scalar(k, v, iteration)

                    histories['val_result_history'][iteration] = {
                        'loss': val_loss,
                        'lang_stats': lang_stats,
                        'predictions': predictions
                    }

                    # 根据CIDEr指标选择最佳模型
                    if opt.language_eval == 1:
                        current_score = lang_stats['CIDEr']
                    else:
                        current_score = -val_loss

                    best_flag = False

                    if best_val_score is None or current_score > best_val_score:
                        best_val_score = current_score
                        best_flag = True

                    infos['best_val_score'] = best_val_score

                    utils.save_checkpoint(opt, model, infos, optimizer,
                                          histories)

                    if opt.save_history_ckpt:
                        utils.save_checkpoint(
                            opt,
                            model,
                            infos,
                            optimizer,
                            append=str(epoch)
                            if opt.save_every_epoch else str(iteration))

                    if best_flag:
                        utils.save_checkpoint(opt,
                                              model,
                                              infos,
                                              optimizer,
                                              append='best')

                    eval_done = True

        except (RuntimeError, KeyboardInterrupt):
            print('Save ckpt on exception ...')
            utils.save_checkpoint(opt, model, infos, optimizer)
            print('Save ckpt done.')
            stack_trace = traceback.format_exc()
            print(stack_trace)
            os._exit(0)

        infos['stage'] = 2

    # dummy配置下,不进行微调
    if opt.train_only == 1:
        # 微调训练
        infos['stage'] = 2
        epoch_done = True
        loader.reset_iterator('support')

        # 加载最佳模型,如果中断位置在第二阶段,则不进行模型加载
        if opt.start_from and infos['stage_saved'] == 2:
            pass
        else:
            # 否则加载stage 1的最佳模型进行微调
            print('Finetuning:', "loading best model from stage 1")
            model.load_state_dict(
                torch.load(os.path.join(opt.start_from,
                                        'model_best' + '.pth')))
            optimizer.load_state_dict(
                torch.load(
                    os.path.join(opt.start_from, 'optimizer_best' + '.pth')))

            lw_model = LossWrapper(model, opt)
            # 多GPU封装
            dp_model = torch.nn.DataParallel(model)
            dp_model.vocab = getattr(model, 'vocab', None)
            dp_lw_model = torch.nn.DataParallel(lw_model)

        model.set_stage(infos['stage'])
        infos['stage_saved'] = 2

        # 冻结除了最后一个logit层之外的所有参数
        for name, parameter in dp_lw_model.module.named_parameters():
            if 'logit' not in name:
                parameter.requires_grad = False
            else:
                parameter.requires_grad = True

        # 因为计数器没有清零,所以这里是直接加上去
        max_epochs_all = opt.max_epochs_base + opt.max_epochs_finetune

        # 提前准备:相关学习参数是否跟随
        if opt.learning_rate_decay_start_finetune < 0:
            opt.learning_rate_decay_start_finetune = opt.learning_rate_decay_start - opt.max_epochs_base

        if opt.learning_rate_finetune < 0:
            opt.learning_rate_finetune = opt.learning_rate_base

        if opt.scheduled_sampling_start_finetune < 0:
            opt.scheduled_sampling_start_finetune = opt.scheduled_sampling_start - opt.max_epochs_base

        try:
            while True:
                # 达到最大epoch限制,跳出
                if epoch >= max_epochs_all != -2:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          histories,
                                          append='finetune')
                    break

                # 设置学习参数
                if epoch_done:
                    # Transformer相关
                    if not opt.noamopt and not opt.reduce_on_plateau:
                        if epoch > opt.learning_rate_decay_start_finetune + opt.max_epochs_base >= 0:
                            frac = (epoch -
                                    opt.learning_rate_decay_start_finetune -
                                    opt.max_epochs_base
                                    ) // opt.learning_rate_decay_every_finetune
                            decay_factor = opt.learning_rate_decay_rate_finetune**frac
                            opt.current_lr = opt.learning_rate_finetune * decay_factor
                        else:
                            opt.current_lr = opt.learning_rate_finetune

                        utils.set_lr(optimizer, opt.current_lr)

                    # scheduled sampling
                    if epoch > opt.scheduled_sampling_start_finetune + opt.max_epochs_base >= 0:
                        frac = (
                            epoch - opt.scheduled_sampling_start_finetune -
                            opt.max_epochs_base
                        ) // opt.scheduled_sampling_increase_every_finetune
                        opt.ss_prob = min(
                            opt.scheduled_sampling_increase_prob_finetune *
                            frac, opt.scheduled_sampling_max_prob_finetune)
                        model.ss_prob = opt.ss_prob

                    # SCST
                    if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                        sc_flag = True
                        init_scorer(opt.cached_tokens)
                    else:
                        sc_flag = False

                    # 结构损失
                    if opt.structure_after != -1 and epoch >= opt.structure_after:
                        struc_flag = True
                        init_scorer(opt.cached_tokens)
                    else:
                        struc_flag = False

                    epoch_done = False

                # start = time.time()
                # Transformer Warmup
                # if opt.use_warmup and (iteration < opt.noamopt_warmup):
                #     opt.current_lr = opt.learning_rate * (iteration + 1) / opt.noamopt_warmup
                #     utils.set_lr(optimizer, opt.current_lr)

                data = loader.get_batch('support')

                torch.cuda.synchronize()
                start = time.time()

                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
                fc_feats, att_feats, labels, masks, att_masks = tmp

                optimizer.zero_grad()
                model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                        att_masks, data['gts'],
                                        torch.arange(0, len(data['gts'])),
                                        sc_flag, struc_flag)

                loss = model_out['loss'].mean()

                loss.backward()

                # 梯度截断
                if opt.grad_clip_value != 0:
                    getattr(torch.nn.utils, 'clip_grad_{}_'.format(
                        opt.grad_clip_mode))(model.parameters(),
                                             opt.grad_clip_value)

                optimizer.step()

                train_loss = loss.item()
                torch.cuda.synchronize()
                end = time.time()

                # 输出
                if struc_flag:
                    print('Finetuning:', "iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
                elif not sc_flag:
                    print('Finetuning:', "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, train_loss, end - start))
                else:
                    print('Finetuning:', "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, model_out['reward'].mean(), end - start))

                # 更新迭代计数器,如果到达epoch边界,需要调整一些参数
                iteration += 1
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True

                # 将训练结构写入到日志中
                if iteration % opt.losses_log_every == 0:
                    tb_summary_writer.add_scalar('train_loss', train_loss,
                                                 iteration)
                    if opt.noamopt:
                        opt.current_lr = optimizer.rate()
                    elif opt.reduce_on_plateau:
                        opt.current_lr = optimizer.current_lr
                    tb_summary_writer.add_scalar('learning_rate',
                                                 opt.current_lr, iteration)
                    tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                                 model.ss_prob, iteration)
                    if sc_flag:
                        tb_summary_writer.add_scalar(
                            'avg_reward', model_out['reward'].mean(),
                            iteration)
                    elif struc_flag:
                        tb_summary_writer.add_scalar(
                            'lm_loss', model_out['lm_loss'].mean().item(),
                            iteration)
                        tb_summary_writer.add_scalar(
                            'struc_loss',
                            model_out['struc_loss'].mean().item(), iteration)
                        tb_summary_writer.add_scalar(
                            'reward', model_out['reward'].mean().item(),
                            iteration)
                        tb_summary_writer.add_scalar(
                            'reward_var', model_out['reward'].var(1).mean(),
                            iteration)

                    histories['loss_history'][
                        iteration] = train_loss if not sc_flag else model_out[
                            'reward'].mean()
                    histories['lr_history'][iteration] = opt.current_lr
                    histories['ss_prob_history'][iteration] = model.ss_prob

                # 信息更新
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['loader_state_dict'] = loader.state_dict()

                if (iteration % opt.save_checkpoint_every == 0
                        and not opt.save_every_epoch) or (
                            epoch_done and opt.save_every_epoch):
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          histories,
                                          append='finetune')

                    if opt.save_history_ckpt:
                        utils.save_checkpoint(
                            opt,
                            model,
                            infos,
                            optimizer,
                            append=str(epoch)
                            if opt.save_every_epoch else str(iteration))

        except (RuntimeError, KeyboardInterrupt):
            print('Save ckpt on exception ...')
            utils.save_checkpoint(opt, model, infos, optimizer)
            print('Save ckpt done.')
            stack_trace = traceback.format_exc()
            print(stack_trace)
            os._exit(0)

opt.use_box = max([getattr(infos['opt'], 'use_box', 0) for infos in model_infos])
assert max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]), 'Not support different norm_att_feat'
assert max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]), 'Not support different norm_box_feat'

vocab = infos['vocab'] # ix -> word mapping

# Setup the model
from models.AttEnsemble import AttEnsemble

_models = []
for i in range(len(model_infos)):
    model_infos[i]['opt'].start_from = None
    model_infos[i]['opt'].vocab = vocab
    tmp = models.setup(model_infos[i]['opt'])
    tmp.load_state_dict(torch.load(model_paths[i]))
    _models.append(tmp)

if opt.weights is not None:
    opt.weights = [float(_) for _ in opt.weights]
model = AttEnsemble(_models, weights=opt.weights)
model.seq_length = opt.max_length
model.cuda()
model.eval()
crit = losses.LanguageModelCriterion()

# Create the Data Loader instance
if len(opt.image_folder) == 0:
  loader = DataLoader(opt)
else: