Ejemplo n.º 1
0
def parse_opt():
    parser = argparse.ArgumentParser()
    # Data input settings
    parser.add_argument(
        '--input_json',
        type=str,
        default='data/coco.json',
        help='path to the json file containing additional info and vocab')
    parser.add_argument(
        '--input_fc_dir',
        type=str,
        default='data/cocotalk_fc',
        help='path to the directory containing the preprocessed fc feats')
    parser.add_argument(
        '--input_att_dir',
        type=str,
        default='data/cocotalk_att',
        help='path to the directory containing the preprocessed att feats')
    parser.add_argument(
        '--input_box_dir',
        type=str,
        default='data/cocotalk_box',
        help='path to the directory containing the boxes of att feats')
    parser.add_argument(
        '--input_label_h5',
        type=str,
        default='data/coco_label.h5',
        help='path to the h5file containing the preprocessed dataset')
    parser.add_argument(
        '--start_from',
        type=str,
        default=None,
        help=
        """continue training from saved model at this path. Path must contain files saved by previous training process: 
                        'infos.pkl'         : configuration;
                        'checkpoint'        : paths to model file(s) (created by tf).
                                              Note: this file contains absolute paths, be careful when moving files around;
                        'model.ckpt-*'      : file(s) with model definition (created by tf)
                    """)
    parser.add_argument(
        '--cached_tokens',
        type=str,
        default='coco-train-idxs',
        help=
        'Cached token file for calculating cider score during self critical training.'
    )

    # Model settings
    parser.add_argument(
        '--caption_model',
        type=str,
        default="show_tell",
        help=
        'show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer'
    )
    parser.add_argument(
        '--rnn_size',
        type=int,
        default=512,
        help='size of the rnn in number of hidden nodes in each layer')
    parser.add_argument('--num_layers',
                        type=int,
                        default=1,
                        help='number of layers in the RNN')
    parser.add_argument('--rnn_type',
                        type=str,
                        default='lstm',
                        help='rnn, gru, or lstm')
    parser.add_argument(
        '--input_encoding_size',
        type=int,
        default=512,
        help='the encoding size of each token in the vocabulary, and the image.'
    )
    parser.add_argument(
        '--att_hid_size',
        type=int,
        default=512,
        help=
        'the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer'
    )
    parser.add_argument('--fc_feat_size',
                        type=int,
                        default=2048,
                        help='2048 for resnet, 4096 for vgg')
    parser.add_argument('--att_feat_size',
                        type=int,
                        default=2048,
                        help='2048 for resnet, 512 for vgg')
    parser.add_argument('--logit_layers',
                        type=int,
                        default=1,
                        help='number of layers in the RNN')

    parser.add_argument(
        '--use_bn',
        type=int,
        default=0,
        help=
        'If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed'
    )

    # feature manipulation
    parser.add_argument('--norm_att_feat',
                        type=int,
                        default=0,
                        help='If normalize attention features')
    parser.add_argument('--use_box',
                        type=int,
                        default=0,
                        help='If use box features')
    parser.add_argument('--norm_box_feat',
                        type=int,
                        default=0,
                        help='If use box, do we normalize box feature')

    # Optimization: General
    parser.add_argument('--max_epochs',
                        type=int,
                        default=-1,
                        help='number of epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='minibatch size')
    parser.add_argument(
        '--grad_clip_mode',
        type=str,
        default='value',  #5.,
        help='value or norm')
    parser.add_argument(
        '--grad_clip_value',
        type=float,
        default=0.1,  #5.,
        help='clip gradients at this value/max_norm')
    parser.add_argument('--drop_prob_lm',
                        type=float,
                        default=0.5,
                        help='strength of dropout in the Language Model RNN')
    parser.add_argument(
        '--self_critical_after',
        type=int,
        default=-1,
        help=
        'After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)'
    )
    parser.add_argument(
        '--seq_per_img',
        type=int,
        default=5,
        help=
        'number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image'
    )

    # Sample related
    parser.add_argument(
        '--beam_size',
        type=int,
        default=1,
        help=
        'used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.'
    )
    parser.add_argument('--max_length',
                        type=int,
                        default=20,
                        help='Maximum length during sampling')
    parser.add_argument('--length_penalty',
                        type=str,
                        default='',
                        help='wu_X or avg_X, X is the alpha')
    parser.add_argument('--block_trigrams',
                        type=int,
                        default=0,
                        help='block repeated trigram.')
    parser.add_argument('--remove_bad_endings',
                        type=int,
                        default=0,
                        help='Remove bad endings')

    #Optimization: for the Language Model
    parser.add_argument(
        '--optim',
        type=str,
        default='adam',
        help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=4e-4,
                        help='learning rate')
    parser.add_argument(
        '--learning_rate_decay_start',
        type=int,
        default=-1,
        help=
        'at what iteration to start decaying learning rate? (-1 = dont) (in epoch)'
    )
    parser.add_argument(
        '--learning_rate_decay_every',
        type=int,
        default=3,
        help='every how many iterations thereafter to drop LR?(in epoch)')
    parser.add_argument(
        '--learning_rate_decay_rate',
        type=float,
        default=0.8,
        help='every how many iterations thereafter to drop LR?(in epoch)')
    parser.add_argument('--optim_alpha',
                        type=float,
                        default=0.9,
                        help='alpha for adam')
    parser.add_argument('--optim_beta',
                        type=float,
                        default=0.999,
                        help='beta used for adam')
    parser.add_argument(
        '--optim_epsilon',
        type=float,
        default=1e-8,
        help='epsilon that goes into denominator for smoothing')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0,
                        help='weight_decay')
    # Transformer
    parser.add_argument('--label_smoothing', type=float, default=0, help='')
    parser.add_argument('--noamopt', action='store_true', help='')
    parser.add_argument('--noamopt_warmup', type=int, default=2000, help='')
    parser.add_argument('--noamopt_factor', type=float, default=1, help='')
    parser.add_argument('--reduce_on_plateau', action='store_true', help='')

    parser.add_argument('--scheduled_sampling_start',
                        type=int,
                        default=-1,
                        help='at what iteration to start decay gt probability')
    parser.add_argument(
        '--scheduled_sampling_increase_every',
        type=int,
        default=5,
        help='every how many iterations thereafter to gt probability')
    parser.add_argument('--scheduled_sampling_increase_prob',
                        type=float,
                        default=0.05,
                        help='How much to update the prob')
    parser.add_argument('--scheduled_sampling_max_prob',
                        type=float,
                        default=0.25,
                        help='Maximum scheduled sampling prob.')

    # Evaluation/Checkpointing
    parser.add_argument(
        '--val_images_use',
        type=int,
        default=3200,
        help=
        'how many images to use when periodically evaluating the validation loss? (-1 = all)'
    )
    parser.add_argument(
        '--save_checkpoint_every',
        type=int,
        default=2500,
        help='how often to save a model checkpoint (in iterations)?')
    parser.add_argument(
        '--save_every_epoch',
        action='store_true',
        help='Save checkpoint every epoch, will overwrite save_checkpoint_every'
    )
    parser.add_argument('--save_history_ckpt',
                        type=int,
                        default=0,
                        help='If save checkpoints at every save point')
    parser.add_argument('--checkpoint_path',
                        type=str,
                        default=None,
                        help='directory to store checkpointed models')
    parser.add_argument(
        '--language_eval',
        type=int,
        default=0,
        help=
        'Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.'
    )
    parser.add_argument(
        '--losses_log_every',
        type=int,
        default=25,
        help=
        'How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)'
    )
    parser.add_argument(
        '--load_best_score',
        type=int,
        default=1,
        help='Do we load previous best score when resuming training.')

    # misc
    parser.add_argument(
        '--id',
        type=str,
        default='',
        help=
        'an id identifying this run/job. used in cross-val and appended when writing progress files'
    )
    parser.add_argument('--train_only',
                        type=int,
                        default=0,
                        help='if true then use 80k, else use 110k')

    # Reward
    parser.add_argument('--cider_reward_weight',
                        type=float,
                        default=1,
                        help='The reward weight from cider')
    parser.add_argument('--bleu_reward_weight',
                        type=float,
                        default=0,
                        help='The reward weight from bleu4')

    # Structure_loss
    parser.add_argument('--structure_loss_weight',
                        type=float,
                        default=1,
                        help='')
    parser.add_argument('--structure_after', type=int, default=-1, help='T')
    parser.add_argument('--structure_loss_type',
                        type=str,
                        default='seqnll',
                        help='')
    parser.add_argument('--struc_use_logsoftmax', action='store_true', help='')
    parser.add_argument('--entropy_reward_weight',
                        type=float,
                        default=0,
                        help='Entropy reward, seems very interesting')
    parser.add_argument('--self_cider_reward_weight',
                        type=float,
                        default=0,
                        help='self cider reward')

    # Used for self critical or structure. Used when sampling is need during training
    parser.add_argument('--train_sample_n',
                        type=int,
                        default=16,
                        help='The reward weight from cider')
    parser.add_argument('--train_sample_method',
                        type=str,
                        default='sample',
                        help='')
    parser.add_argument('--train_beam_size', type=int, default=1, help='')

    # For diversity evaluation during training
    add_diversity_opts(parser)

    # config
    parser.add_argument(
        '--cfg',
        type=str,
        default=None,
        help='configuration; similar to what is used in detectron')
    # How will config be used
    # 1) read cfg argument, and load the cfg file if it's not None
    # 2) parse config argument
    # 3) in the end, parse command line argument

    # step 1: read cfg_fn
    args = parser.parse_args()
    if args.cfg is not None:
        from misc.config import CfgNode
        cn = CfgNode.load_yaml_with_base(args.cfg)
        for k, v in cn.items():
            if hasattr(args, k):
                setattr(args, k, v)
            else:
                print('Warning: key %s not in args' % k)
        args = parser.parse_args(namespace=args)

    # Check if args are valid
    assert args.rnn_size > 0, "rnn_size should be greater than 0"
    assert args.num_layers > 0, "num_layers should be greater than 0"
    assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
    assert args.batch_size > 0, "batch_size should be greater than 0"
    assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
    assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
    assert args.beam_size > 0, "beam_size should be greater than 0"
    assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
    assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
    assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
    assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
    assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"

    # default value for start_from and checkpoint_path
    args.checkpoint_path = args.checkpoint_path or './log_%s' % args.id
    args.start_from = args.start_from or args.checkpoint_path

    # Deal with feature things before anything
    args.use_fc, args.use_att = utils.if_use_feat(args.caption_model)
    if args.use_box: args.att_feat_size = args.att_feat_size + 5

    return args
Ejemplo n.º 2
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.vocab = loader.get_vocab()

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        # with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
        with open(os.path.join(opt.start_from, 'infos_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
        #     with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))


    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
                utils.pickle_dump(histories, f)
    # pdb.set_trace()
    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate  ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                    init_scorer(opt.cached_tokens)

                epoch_done = False
                    
            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            # pdb.set_trace()
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['sents_mask']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks, sents_mask = tmp
            box_inds = None
                
            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, box_inds, epoch, sents_mask)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            
            # make evaluation on validation set, and save model
            # eval model
            # eval_kwargs = {'split': 'val',
            #                 'dataset': opt.input_json}
            # eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(
            #     dp_model, lw_model.crit, loader, eval_kwargs)

            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                                'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 3
0
def train(opt):
    print("=================Training Information==============")
    print("start from {}".format(opt.start_from))
    print("box from {}".format(opt.input_box_dir))
    print("attributes from {}".format(opt.input_att_dir))
    print("features from {}".format(opt.input_fc_dir))
    print("batch size ={}".format(opt.batch_size))
    print("#GPU={}".format(torch.cuda.device_count()))
    print("Caption model {}".format(opt.caption_model))
    print("refine aoa {}".format(opt.refine_aoa))
    print("Number of aoa module {}".format(opt.aoa_num))
    print("Self Critic After  {}".format(opt.self_critical_after))
    print("learning_rate_decay_every {}".format(opt.learning_rate_decay_every))

    # use more data to fine tune the model for better challeng results. We dont use it
    if opt.use_val or opt.use_test:
        print("+++++++++++It is a refining training+++++++++++++++")
        print("===========Val is {} used for training ===========".format(
            '' if opt.use_val else 'not'))
        print("===========Test is {} used for training ===========".format(
            '' if opt.use_test else 'not'))
    print("=====================================================")

    # set more detail name of checkpoint paths
    checkpoint_path_suffix = "_bs{}".format(opt.batch_size)
    if opt.use_warmup:
        checkpoint_path_suffix += "_warmup"
    if torch.cuda.device_count() > 1:
        checkpoint_path_suffix += "_gpu{}".format(torch.cuda.device_count())

    if opt.checkpoint_path.endswith('_rl'):
        opt.checkpoint_path = opt.checkpoint_path[:
                                                  -3] + checkpoint_path_suffix + '_rl'
    else:
        opt.checkpoint_path += checkpoint_path_suffix
    print("Save model to {}".format(opt.checkpoint_path))

    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5

    acc_steps = getattr(opt, 'acc_steps', 1)
    name_append = opt.name_append
    if len(name_append) > 0 and name_append[0] != '-':
        name_append = '_' + name_append

    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.losses_log_every = len(loader.split_ix['train']) // opt.batch_size
    print("Evaluate on each {} iterations".format(opt.losses_log_every))
    if opt.write_summary:
        print("write summary to {}".format(opt.checkpoint_path))
        tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}

    # load  checkpoint
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        infos_path = os.path.join(opt.start_from,
                                  'infos' + name_append + '.pkl')
        print("Load model information {}".format(infos_path))
        with open(infos_path, 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]

            # this sanity check may not work well, and comment it if necessary
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], \
                    "Command line argument and saved model disagree on '%s' " % checkme

        histories_path = os.path.join(opt.start_from,
                                      'histories' + name_append + '.pkl')
        if os.path.isfile(histories_path):
            with open(histories_path, 'rb') as f:
                histories = utils.pickle_load(f)
    else:  # start from scratch
        print("==============================================")
        print("Initialize training process from all begining")
        print("==============================================")
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()

    infos['opt'] = opt
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    print("==================start from {} iterations -- {} epoch".format(
        iteration, epoch))
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    start_Img_idx = loader.iterators['train']
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_epoch = infos.get('best_epoch', None)
        best_cider = infos.get('best_val_score', 0)
        print("========best history val cider score: {} in epoch {}=======".
              format(best_val_score, best_epoch))

    #  sanity check for the saved model name has a correct index
    if opt.name_append.isdigit() and int(opt.name_append) < 100:
        assert int(
            opt.name_append
        ) - epoch == 1, "dismatch in the model index and the real epoch number"
        epoch += 1
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab

    if torch.cuda.device_count() > 1:
        dp_model = torch.nn.DataParallel(model)
    else:
        dp_model = model
    lw_model = LossWrapper1(model, opt)  # wrap loss into model
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'aoa'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer_path = os.path.join(opt.start_from,
                                      'optimizer' + name_append + '.pth')
        if os.path.isfile(optimizer_path):
            print("Loading optimizer............")
            optimizer.load_state_dict(torch.load(optimizer_path))

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '_' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path,
                                       'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("Save model state to {}".format(checkpoint_path))

        optimizer_path = os.path.join(opt.checkpoint_path,
                                      'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        print("Save model optimizer to {}".format(optimizer_path))

        with open(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
            print("Save training information to {}".format(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append))))

        if histories:
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories' + '%s.pkl' % (append)),
                    'wb') as f:
                utils.pickle_dump(histories, f)
                print("Save training historyes to {}".format(
                    os.path.join(opt.checkpoint_path,
                                 'histories' + '%s.pkl' % (append))))

    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor * opt.refine_lr_decay
                    else:
                        opt.current_lr = opt.learning_rate
                    infos['current_lr'] = opt.current_lr
                    print("Current Learning Rate is: {}".format(
                        opt.current_lr))
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False
            print("{}th Epoch Training starts now!".format(epoch))
            with tqdm(total=len(loader.split_ix['train']),
                      initial=start_Img_idx) as pbar:
                for i in range(start_Img_idx, len(loader.split_ix['train']),
                               opt.batch_size):
                    start = time.time()
                    if (opt.use_warmup
                            == 1) and (iteration < opt.noamopt_warmup):
                        opt.current_lr = opt.learning_rate * (
                            iteration + 1) / opt.noamopt_warmup
                        utils.set_lr(optimizer, opt.current_lr)
                    # Load data from train split (0)
                    data = loader.get_batch('train')
                    # print('Read data:', time.time() - start)

                    if (iteration % acc_steps == 0):
                        optimizer.zero_grad()

                    torch.cuda.synchronize()
                    start = time.time()
                    tmp = [
                        data['fc_feats'], data['att_feats'],
                        data['flag_feats'], data['labels'], data['masks'],
                        data['att_masks']
                    ]
                    tmp = [_ if _ is None else _.cuda() for _ in tmp]
                    fc_feats, att_feats, flag_feats, labels, masks, att_masks = tmp

                    model_out = dp_lw_model(fc_feats, att_feats, flag_feats,
                                            labels, masks, att_masks,
                                            data['gts'],
                                            torch.arange(0, len(data['gts'])),
                                            sc_flag)

                    loss = model_out['loss'].mean()
                    loss_sp = loss / acc_steps

                    loss_sp.backward()
                    if (iteration + 1) % acc_steps == 0:
                        utils.clip_gradient(optimizer, opt.grad_clip)
                        optimizer.step()
                    torch.cuda.synchronize()
                    train_loss = loss.item()
                    end = time.time()
                    if not sc_flag:
                        pbar.set_description(
                            "iter {:8} (epoch {:2}), train_loss = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch, train_loss, end - start))
                    else:
                        pbar.set_description(
                            "iter {:8} (epoch {:2}), avg_reward = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch,
                                    model_out['reward'].mean(), end - start))

                    # Update the iteration and epoch
                    iteration += 1
                    pbar.update(opt.batch_size)
                    if data['bounds']['wrapped']:
                        epoch += 1
                        epoch_done = True
                        # save after each epoch
                        save_checkpoint(model, infos, optimizer)
                        if epoch > 15:  # To save memory, you can comment this part
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append=str(epoch))
                        print(
                            "====================================================="
                        )
                        print(
                            "======Best Cider = {} in epoch {}: iter {}!======"
                            .format(best_val_score, best_epoch,
                                    infos.get('best_itr', None)))
                        print(
                            "====================================================="
                        )

                    # Write training history into summary
                    if (iteration % opt.losses_log_every
                            == 0) and opt.write_summary:
                        # if (iteration % 10== 0) and opt.write_summary:
                        add_summary_value(tb_summary_writer, 'loss/train_loss',
                                          train_loss, iteration)
                        if opt.noamopt:
                            opt.current_lr = optimizer.rate()
                        elif opt.reduce_on_plateau:
                            opt.current_lr = optimizer.current_lr
                        add_summary_value(tb_summary_writer,
                                          'hyperparam/learning_rate',
                                          opt.current_lr, iteration)
                        add_summary_value(
                            tb_summary_writer,
                            'hyperparam/scheduled_sampling_prob',
                            model.ss_prob, iteration)
                        if sc_flag:
                            add_summary_value(tb_summary_writer, 'avg_reward',
                                              model_out['reward'].mean(),
                                              iteration)

                        loss_history[
                            iteration] = train_loss if not sc_flag else model_out[
                                'reward'].mean()
                        lr_history[iteration] = opt.current_lr
                        ss_prob_history[iteration] = model.ss_prob

                    # update infos
                    infos['iter'] = iteration
                    infos['epoch'] = epoch
                    infos['iterators'] = loader.iterators
                    infos['split_ix'] = loader.split_ix

                    # make evaluation on validation set, and save model
                    # unnecessary to eval from the beginning
                    if (iteration % opt.save_checkpoint_every
                            == 0) and eval_ and epoch > 3:
                        # eval model
                        model_path = os.path.join(
                            opt.checkpoint_path,
                            'model_itr%s.pth' % (iteration))
                        if opt.use_val and not opt.use_test:
                            val_split = 'test'
                        if not opt.use_val:
                            val_split = 'val'
                        # val_split = 'val'

                        eval_kwargs = {
                            'split': val_split,
                            'dataset': opt.input_json,
                            'model': model_path
                        }
                        eval_kwargs.update(vars(opt))
                        val_loss, predictions, lang_stats = eval_utils.eval_split(
                            dp_model, lw_model.crit, loader, eval_kwargs)

                        if opt.reduce_on_plateau:
                            if 'CIDEr' in lang_stats:
                                optimizer.scheduler_step(-lang_stats['CIDEr'])
                            else:
                                optimizer.scheduler_step(val_loss)

                        # Write validation result into summary
                        if opt.write_summary:
                            add_summary_value(tb_summary_writer,
                                              'loss/validation loss', val_loss,
                                              iteration)

                            if lang_stats is not None:
                                bleu_dict = {}
                                for k, v in lang_stats.items():
                                    if 'Bleu' in k:
                                        bleu_dict[k] = v
                                if len(bleu_dict) > 0:
                                    tb_summary_writer.add_scalars(
                                        'val/Bleu', bleu_dict, epoch)

                                for k, v in lang_stats.items():
                                    if 'Bleu' not in k:
                                        add_summary_value(
                                            tb_summary_writer, 'val/' + k, v,
                                            iteration)
                        val_result_history[iteration] = {
                            'loss': val_loss,
                            'lang_stats': lang_stats,
                            'predictions': predictions
                        }

                        # Save model if is improving on validation result
                        if opt.language_eval == 1:
                            current_score = lang_stats['CIDEr']
                        else:
                            current_score = -val_loss

                        best_flag = False

                        if best_val_score is None or current_score > best_val_score:
                            best_val_score = current_score
                            infos['best_epoch'] = epoch
                            infos['best_itr'] = iteration
                            best_flag = True

                        # Dump miscalleous informations
                        infos['best_val_score'] = best_val_score
                        histories['val_result_history'] = val_result_history
                        histories['loss_history'] = loss_history
                        histories['lr_history'] = lr_history
                        histories['ss_prob_history'] = ss_prob_history

                        save_checkpoint(model, infos, optimizer, histories)
                        if opt.save_history_ckpt:
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append=str(iteration))

                        if best_flag:
                            best_epoch = epoch
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append='best')
                            print(
                                "update best model at {} iteration--{} epoch".
                                format(iteration, epoch))
                    # reset
                    start_Img_idx = 0
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                print("epoch {} break all".format(epoch))
                save_checkpoint(model, infos, optimizer)
                # save_checkpoint(model, infos, optimizer, append=str(epoch))
                tb_summary_writer.close()
                print("============{} Training Done !==============".format(
                    'Refine' if opt.use_test or opt.use_val else ''))
                break
    except (RuntimeError, KeyboardInterrupt):  # KeyboardInterrupt
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer, append='interrupt')
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 4
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    epoch_done = True
    # Assure in training mode
    dp_model.train()

    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    total_loss = 0
    times = 0
    while True:
        if epoch_done:
            if not opt.noamopt and not opt.reduce_on_plateau:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            epoch_done = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        times += 1

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        total_loss = total_loss + train_loss
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            # epoch += 1
            epoch_done = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            if opt.noamopt:
                opt.current_lr = optimizer.rate()
            elif opt.reduce_on_plateau:
                opt.current_lr = optimizer.current_lr
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration % opt.save_checkpoint_every == 0):
        if data['bounds']['wrapped']:
            epoch += 1
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': False
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            if opt.reduce_on_plateau:
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats
                f = open('train_log_%s.txt' % opt.id, 'a')
                f.write(
                    'Epoch {}: | Date: {} | TrainLoss: {} | ValLoss: {} | Score: {}'
                    .format(epoch, str(datetime.now()),
                            str(total_loss / times), str(val_loss),
                            str(current_score)))
                f.write('\n')
                f.close()
                print('-------------------wrote to log file')
                total_loss = 0
                times = 0
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                # print(str(infos['best_val_score']))
                print("model saved to {}".format(checkpoint_path))
                if opt.save_history_ckpt:
                    checkpoint_path = os.path.join(
                        opt.checkpoint_path, 'model-%d.pth' % (iteration))
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    utils.pickle_dump(infos, f)
                if opt.save_history_ckpt:
                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'infos_' + opt.id + '-%d.pkl' % (iteration)),
                            'wb') as f:
                        cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    utils.pickle_dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        utils.pickle_dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 5
0
def train(opt):
    print(opt)

    # To reproduce training results
    init_seed()
    # Image Preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
                             ])
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt, transform=transform)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '-best.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[
                    checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    if torch.cuda.is_available():
        model = models.setup(opt).cuda()
    else:
        model = models.setup(opt)
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    #fgm = FGM(model)

    cnn_model = ResnetBackbone()
    if torch.cuda.is_available():
        cnn_model = cnn_model.cuda()
    if opt.start_from is not None:
        model_dict = cnn_model.state_dict()
        predict_dict = torch.load(os.path.join(opt.start_from, 'cnn_model-best.pth'))
        model_dict = {k: predict_dict["module."+k] for k, _ in model_dict.items() if "module."+ k in predict_dict}
        cnn_model.load_state_dict(model_dict)
    cnn_model = torch.nn.DataParallel(cnn_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-best.pth')))

    def save_checkpoint(model, cnn_model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        #Transformer model
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        #CNN model
        checkpoint_path = os.path.join(opt.checkpoint_path, 'cnn_model%s.pth' % (append))
        if not os.path.exists(checkpoint_path):
            torch.save(cnn_model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    cnn_after = 3
    try:
        while True:
            if epoch_done:
                if  opt.fix_cnn or epoch < cnn_after:
                    for p in cnn_model.parameters():
                        p.requires_grad = False
                    cnn_model.eval()
                    cnn_optimizer = None
                else:
                    for p in cnn_model.parameters():
                        p.requires_grad = True
                    # Fix the first few layers:
                    for module in cnn_model._modules['module']._modules['resnet_conv'][:5]._modules.values():
                        for p in module.parameters():
                            p.requires_grad = False
                    cnn_model.train()
                    # Constructing CNN parameters for optimization, only fine-tuning higher layers
                    cnn_optimizer = torch.optim.Adam(
                        (filter(lambda p: p.requires_grad, cnn_model.parameters())),
                        lr=2e-6 if (opt.self_critical_after != -1 and epoch >= opt.self_critical_after) else 5e-5, betas=(0.8, 0.999))

                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            if iteration % opt.losses_log_every == 0:
                print('Read data:', time.time() - start)

            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.time()

            if torch.cuda.is_available():
                data['att_feats'] = cnn_model( data['att_feats'].cuda())
            else:
                data['att_feats'] = cnn_model( data['att_feats'] )
            data['att_feats'] = repeat_feat(data['att_feats'])
            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
            if torch.cuda.is_available():
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            optimizer.zero_grad()
            if cnn_optimizer is not None:
                cnn_optimizer.zero_grad()

            # if epoch >= cnn_after:
            #     att_feats.register_hook(save_grad("att_feats"))
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()

            loss.backward()

            #loss.backward(retain_graph=True)

            # adversarial training
            #fgm.attack(emb_name='model.tgt_embed.0.lut.weight')
            #adv_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'],
            #                      torch.arange(0, len(data['gts'])), sc_flag)

            #adv_loss = adv_out['loss'].mean()
            #adv_loss.backward()
            #fgm.restore(emb_name="model.tgt_embed.0.lut.weight")


            # utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            if cnn_optimizer is not None:
                cnn_optimizer.step()
            train_loss = loss.item()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            end = time.time()
            if not sc_flag and iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
            elif iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                               'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                eval_kwargs["cnn_model"] = cnn_model
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, cnn_model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, cnn_model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, cnn_model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, cnn_model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 6
0
def parse_opt():
    parser = argparse.ArgumentParser()
    # Data input settings
    parser.add_argument(
        '--input_json',
        type=str,
        default='./data/output.json',
        help='path to the json file containing additional info and vocab')
    parser.add_argument(
        '--input_att_dir',
        type=str,
        default='./data/bottom_up_att',
        help=
        'path to the directory containing the preprocessed default att feats')
    parser.add_argument(
        '--input_box_dir',
        type=str,
        default='./data/bottom_up_box',
        help='path to the directory containing the boxes of default att feats')
    parser.add_argument(
        '--input_label_h5',
        type=str,
        default='./data/output_label.h5',
        help='path to the h5file containing the preprocessed dataset')
    parser.add_argument(
        '--input_att1_dir',
        type=str,
        default='./data/bottom_up_att1',
        help=
        'path to the directory containing the preprocessed semantic att feats')
    parser.add_argument(
        '--input_box1_dir',
        type=str,
        default='./data/bottom_up_box1',
        help='path to the directory containing the boxes of semantic att feats'
    )
    parser.add_argument(
        '--start_from',
        type=str,
        default="./checkpoint",
        help=
        """continue training from saved model at this path. Path must contain files saved by previous training process: 
                        'infos.pkl'         : configuration;
                        'checkpoint'        : paths to model file(s) (created by tf).
                                              Note: this file contains absolute paths, be careful when moving files around;
                        'model.ckpt-*'      : file(s) with model definition (created by tf)
                    """)

    # Model settings
    parser.add_argument('--caption_model',
                        type=str,
                        default="updown",
                        help='updown, topdown')
    parser.add_argument(
        '--rnn_size',
        type=int,
        default=512,
        help='size of the rnn in number of hidden nodes in each layer')
    parser.add_argument('--num_layers',
                        type=int,
                        default=2,
                        help='number of layers in the RNN')
    parser.add_argument('--rnn_type',
                        type=str,
                        default='lstm',
                        help='rnn, gru, or lstm')
    parser.add_argument(
        '--input_encoding_size',
        type=int,
        default=300,
        help='the encoding size of each token in the vocabulary, and the image.'
    )
    parser.add_argument(
        '--att_hid_size',
        type=int,
        default=512,
        help=
        'the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer'
    )
    parser.add_argument('--fc_feat_size',
                        type=int,
                        default=2048,
                        help='2048 for resnet, 4096 for vgg')
    parser.add_argument('--att_feat_size',
                        type=int,
                        default=2048,
                        help='2048 for resnet, 512 for vgg')
    parser.add_argument('--logit_layers',
                        type=int,
                        default=1,
                        help='number of layers in the RNN')
    parser.add_argument(
        '--use_bn',
        type=int,
        default=0,
        help=
        'If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed'
    )

    # feature manipulation
    parser.add_argument('--norm_att_feat',
                        type=int,
                        default=0,
                        help='If normalize attention features')
    parser.add_argument('--use_box',
                        type=int,
                        default=1,
                        help='If use box features')
    parser.add_argument('--norm_box_feat',
                        type=int,
                        default=0,
                        help='If use box, do we normalize box feature')

    # Optimization: General
    parser.add_argument('--max_epochs',
                        type=int,
                        default=20,
                        help='number of epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='minibatch size')
    parser.add_argument('--grad_clip_mode',
                        type=str,
                        default='value',
                        help='value or norm')
    parser.add_argument(
        '--grad_clip_value',
        type=float,
        default=0.1,
        help='clip gradients at this value/max_norm, 0 means no clipping')
    parser.add_argument('--drop_prob_lm',
                        type=float,
                        default=0.5,
                        help='strength of dropout in the Language Model RNN')
    parser.add_argument(
        '--self_critical_after',
        type=int,
        default=-1,
        help=
        'After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)'
    )
    parser.add_argument(
        '--seq_per_img',
        type=int,
        default=2,
        help='number of captions to sample for each image during training.')

    # Sample related
    add_eval_sample_opts(parser)

    #Optimization: for the Language Model
    parser.add_argument(
        '--optim',
        type=str,
        default='adam',
        help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.0005,
                        help='learning rate')
    parser.add_argument(
        '--learning_rate_decay_start',
        type=int,
        default=-1,
        help=
        'at what iteration to start decaying learning rate? (-1 = dont) (in epoch)'
    )
    parser.add_argument(
        '--learning_rate_decay_every',
        type=int,
        default=10,
        help='every how many iterations thereafter to drop LR?(in epoch)')
    parser.add_argument(
        '--learning_rate_decay_rate',
        type=float,
        default=0.8,
        help='every how many iterations thereafter to drop LR?(in epoch)')
    parser.add_argument('--optim_alpha',
                        type=float,
                        default=0.9,
                        help='alpha for adam')
    parser.add_argument('--optim_beta',
                        type=float,
                        default=0.999,
                        help='beta used for adam')
    parser.add_argument(
        '--optim_epsilon',
        type=float,
        default=1e-8,
        help='epsilon that goes into denominator for smoothing')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0,
                        help='weight_decay')

    parser.add_argument('--scheduled_sampling_start',
                        type=int,
                        default=100,
                        help='at what iteration to start decay gt probability')
    parser.add_argument(
        '--scheduled_sampling_increase_every',
        type=int,
        default=5,
        help='every how many iterations thereafter to gt probability')
    parser.add_argument('--scheduled_sampling_increase_prob',
                        type=float,
                        default=0.05,
                        help='How much to update the prob')
    parser.add_argument('--scheduled_sampling_max_prob',
                        type=float,
                        default=0.25,
                        help='Maximum scheduled sampling prob.')

    # Evaluation/Checkpointing
    parser.add_argument(
        '--val_images_use',
        type=int,
        default=-1,
        help=
        'how many images to use when periodically evaluating the validation loss? (-1 = all)'
    )
    parser.add_argument(
        '--save_checkpoint_every',
        type=int,
        default=1500,
        help='how often to save a model checkpoint (in iterations)?')
    parser.add_argument(
        '--save_every_epoch',
        action='store_false',
        help='Save checkpoint every epoch, will overwrite save_checkpoint_every'
    )
    parser.add_argument('--save_history_ckpt',
                        type=int,
                        default=0,
                        help='If save checkpoints at every save point')
    parser.add_argument('--checkpoint_path',
                        type=str,
                        default="./checkpoint",
                        help='directory to store checkpointed models')
    parser.add_argument(
        '--language_eval',
        type=int,
        default=1,
        help=
        'Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.'
    )
    parser.add_argument(
        '--losses_log_every',
        type=int,
        default=1000,
        help=
        'How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)'
    )
    parser.add_argument(
        '--load_best_score',
        type=int,
        default=1,
        help='Do we load previous best score when resuming training.')

    # misc
    parser.add_argument(
        '--id',
        type=str,
        default='up',
        help=
        'an id identifying this run/job. used in cross-val and appended when writing progress files'
    )

    # Reward
    parser.add_argument('--cider_reward_weight',
                        type=float,
                        default=1,
                        help='The reward weight from cider')
    parser.add_argument('--bleu_reward_weight',
                        type=float,
                        default=0,
                        help='The reward weight from bleu4')

    # Structure_loss
    parser.add_argument('--structure_loss_weight',
                        type=float,
                        default=1,
                        help='')
    parser.add_argument('--structure_after', type=int, default=-1, help='T')
    parser.add_argument('--structure_loss_type',
                        type=str,
                        default='new_self_critical',
                        help='')
    parser.add_argument('--struc_use_logsoftmax', action='store_true', help='')
    parser.add_argument('--entropy_reward_weight',
                        type=float,
                        default=1,
                        help='Entropy reward, seems very interesting')
    parser.add_argument('--self_cider_reward_weight',
                        type=float,
                        default=0,
                        help='self cider reward')

    # Used for self critical or structure. Used when sampling is need during training
    parser.add_argument('--train_sample_n',
                        type=int,
                        default=3,
                        help='The reward weight from cider')
    parser.add_argument('--train_sample_method',
                        type=str,
                        default='sample',
                        help='')
    parser.add_argument('--train_beam_size', type=int, default=1, help='')

    # Used for self critical
    parser.add_argument('--sc_sample_method',
                        type=str,
                        default='greedy',
                        help='')
    parser.add_argument('--sc_beam_size', type=int, default=1, help='')

    # For diversity evaluation during training
    add_diversity_opts(parser)

    # config
    parser.add_argument(
        '--cfg',
        type=str,
        default=None,
        help='configuration; similar to what is used in detectron')
    parser.add_argument(
        '--set_cfgs',
        dest='set_cfgs',
        help='Set config keys. Key value sequence seperate by whitespace.'
        'e.g. [key] [value] [key] [value]\n This has higher priority'
        'than cfg file but lower than other args. (You can only overwrite'
        'arguments that have alerady been defined in config file.)',
        default=None,
        nargs='+')
    # How will config be used
    # 1) read cfg argument, and load the cfg file if it's not None
    # 2) Overwrite cfg argument with set_cfgs
    # 3) parse config argument to args.
    # 4) in the end, parse command line argument and overwrite args

    # step 1: read cfg_fn
    args = parser.parse_args()
    if args.cfg is not None or args.set_cfgs is not None:
        from misc.config import CfgNode
        if args.cfg is not None:
            cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg))
        else:
            cn = CfgNode()
        if args.set_cfgs is not None:
            cn.merge_from_list(args.set_cfgs)
        for k, v in cn.items():
            if not hasattr(args, k):
                print('Warning: key %s not in args' % k)
            setattr(args, k, v)
        args = parser.parse_args(namespace=args)

    # Check if args are valid
    assert args.rnn_size > 0, "rnn_size should be greater than 0"
    assert args.num_layers > 0, "num_layers should be greater than 0"
    assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
    assert args.batch_size > 0, "batch_size should be greater than 0"
    assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
    assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
    assert args.beam_size > 0, "beam_size should be greater than 0"
    assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
    assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
    assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
    assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"

    # default value for start_from and checkpoint_path
    args.checkpoint_path = args.checkpoint_path or './log_%s' % args.id
    args.start_from = args.start_from or args.checkpoint_path

    # Deal with feature things before anything
    args.use_att = utils.if_use_feat(args.caption_model)

    return args