def parse_opt(): parser = argparse.ArgumentParser() # Data input settings parser.add_argument( '--input_json', type=str, default='data/coco.json', help='path to the json file containing additional info and vocab') parser.add_argument( '--input_fc_dir', type=str, default='data/cocotalk_fc', help='path to the directory containing the preprocessed fc feats') parser.add_argument( '--input_att_dir', type=str, default='data/cocotalk_att', help='path to the directory containing the preprocessed att feats') parser.add_argument( '--input_box_dir', type=str, default='data/cocotalk_box', help='path to the directory containing the boxes of att feats') parser.add_argument( '--input_label_h5', type=str, default='data/coco_label.h5', help='path to the h5file containing the preprocessed dataset') parser.add_argument( '--start_from', type=str, default=None, help= """continue training from saved model at this path. Path must contain files saved by previous training process: 'infos.pkl' : configuration; 'checkpoint' : paths to model file(s) (created by tf). Note: this file contains absolute paths, be careful when moving files around; 'model.ckpt-*' : file(s) with model definition (created by tf) """) parser.add_argument( '--cached_tokens', type=str, default='coco-train-idxs', help= 'Cached token file for calculating cider score during self critical training.' ) # Model settings parser.add_argument( '--caption_model', type=str, default="show_tell", help= 'show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer' ) parser.add_argument( '--rnn_size', type=int, default=512, help='size of the rnn in number of hidden nodes in each layer') parser.add_argument('--num_layers', type=int, default=1, help='number of layers in the RNN') parser.add_argument('--rnn_type', type=str, default='lstm', help='rnn, gru, or lstm') parser.add_argument( '--input_encoding_size', type=int, default=512, help='the encoding size of each token in the vocabulary, and the image.' ) parser.add_argument( '--att_hid_size', type=int, default=512, help= 'the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer' ) parser.add_argument('--fc_feat_size', type=int, default=2048, help='2048 for resnet, 4096 for vgg') parser.add_argument('--att_feat_size', type=int, default=2048, help='2048 for resnet, 512 for vgg') parser.add_argument('--logit_layers', type=int, default=1, help='number of layers in the RNN') parser.add_argument( '--use_bn', type=int, default=0, help= 'If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed' ) # feature manipulation parser.add_argument('--norm_att_feat', type=int, default=0, help='If normalize attention features') parser.add_argument('--use_box', type=int, default=0, help='If use box features') parser.add_argument('--norm_box_feat', type=int, default=0, help='If use box, do we normalize box feature') # Optimization: General parser.add_argument('--max_epochs', type=int, default=-1, help='number of epochs') parser.add_argument('--batch_size', type=int, default=16, help='minibatch size') parser.add_argument( '--grad_clip_mode', type=str, default='value', #5., help='value or norm') parser.add_argument( '--grad_clip_value', type=float, default=0.1, #5., help='clip gradients at this value/max_norm') parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='strength of dropout in the Language Model RNN') parser.add_argument( '--self_critical_after', type=int, default=-1, help= 'After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)' ) parser.add_argument( '--seq_per_img', type=int, default=5, help= 'number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image' ) # Sample related parser.add_argument( '--beam_size', type=int, default=1, help= 'used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.' ) parser.add_argument('--max_length', type=int, default=20, help='Maximum length during sampling') parser.add_argument('--length_penalty', type=str, default='', help='wu_X or avg_X, X is the alpha') parser.add_argument('--block_trigrams', type=int, default=0, help='block repeated trigram.') parser.add_argument('--remove_bad_endings', type=int, default=0, help='Remove bad endings') #Optimization: for the Language Model parser.add_argument( '--optim', type=str, default='adam', help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam') parser.add_argument('--learning_rate', type=float, default=4e-4, help='learning rate') parser.add_argument( '--learning_rate_decay_start', type=int, default=-1, help= 'at what iteration to start decaying learning rate? (-1 = dont) (in epoch)' ) parser.add_argument( '--learning_rate_decay_every', type=int, default=3, help='every how many iterations thereafter to drop LR?(in epoch)') parser.add_argument( '--learning_rate_decay_rate', type=float, default=0.8, help='every how many iterations thereafter to drop LR?(in epoch)') parser.add_argument('--optim_alpha', type=float, default=0.9, help='alpha for adam') parser.add_argument('--optim_beta', type=float, default=0.999, help='beta used for adam') parser.add_argument( '--optim_epsilon', type=float, default=1e-8, help='epsilon that goes into denominator for smoothing') parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') # Transformer parser.add_argument('--label_smoothing', type=float, default=0, help='') parser.add_argument('--noamopt', action='store_true', help='') parser.add_argument('--noamopt_warmup', type=int, default=2000, help='') parser.add_argument('--noamopt_factor', type=float, default=1, help='') parser.add_argument('--reduce_on_plateau', action='store_true', help='') parser.add_argument('--scheduled_sampling_start', type=int, default=-1, help='at what iteration to start decay gt probability') parser.add_argument( '--scheduled_sampling_increase_every', type=int, default=5, help='every how many iterations thereafter to gt probability') parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, help='How much to update the prob') parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, help='Maximum scheduled sampling prob.') # Evaluation/Checkpointing parser.add_argument( '--val_images_use', type=int, default=3200, help= 'how many images to use when periodically evaluating the validation loss? (-1 = all)' ) parser.add_argument( '--save_checkpoint_every', type=int, default=2500, help='how often to save a model checkpoint (in iterations)?') parser.add_argument( '--save_every_epoch', action='store_true', help='Save checkpoint every epoch, will overwrite save_checkpoint_every' ) parser.add_argument('--save_history_ckpt', type=int, default=0, help='If save checkpoints at every save point') parser.add_argument('--checkpoint_path', type=str, default=None, help='directory to store checkpointed models') parser.add_argument( '--language_eval', type=int, default=0, help= 'Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.' ) parser.add_argument( '--losses_log_every', type=int, default=25, help= 'How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)' ) parser.add_argument( '--load_best_score', type=int, default=1, help='Do we load previous best score when resuming training.') # misc parser.add_argument( '--id', type=str, default='', help= 'an id identifying this run/job. used in cross-val and appended when writing progress files' ) parser.add_argument('--train_only', type=int, default=0, help='if true then use 80k, else use 110k') # Reward parser.add_argument('--cider_reward_weight', type=float, default=1, help='The reward weight from cider') parser.add_argument('--bleu_reward_weight', type=float, default=0, help='The reward weight from bleu4') # Structure_loss parser.add_argument('--structure_loss_weight', type=float, default=1, help='') parser.add_argument('--structure_after', type=int, default=-1, help='T') parser.add_argument('--structure_loss_type', type=str, default='seqnll', help='') parser.add_argument('--struc_use_logsoftmax', action='store_true', help='') parser.add_argument('--entropy_reward_weight', type=float, default=0, help='Entropy reward, seems very interesting') parser.add_argument('--self_cider_reward_weight', type=float, default=0, help='self cider reward') # Used for self critical or structure. Used when sampling is need during training parser.add_argument('--train_sample_n', type=int, default=16, help='The reward weight from cider') parser.add_argument('--train_sample_method', type=str, default='sample', help='') parser.add_argument('--train_beam_size', type=int, default=1, help='') # For diversity evaluation during training add_diversity_opts(parser) # config parser.add_argument( '--cfg', type=str, default=None, help='configuration; similar to what is used in detectron') # How will config be used # 1) read cfg argument, and load the cfg file if it's not None # 2) parse config argument # 3) in the end, parse command line argument # step 1: read cfg_fn args = parser.parse_args() if args.cfg is not None: from misc.config import CfgNode cn = CfgNode.load_yaml_with_base(args.cfg) for k, v in cn.items(): if hasattr(args, k): setattr(args, k, v) else: print('Warning: key %s not in args' % k) args = parser.parse_args(namespace=args) # Check if args are valid assert args.rnn_size > 0, "rnn_size should be greater than 0" assert args.num_layers > 0, "num_layers should be greater than 0" assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" assert args.batch_size > 0, "batch_size should be greater than 0" assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" assert args.seq_per_img > 0, "seq_per_img should be greater than 0" assert args.beam_size > 0, "beam_size should be greater than 0" assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" assert args.losses_log_every > 0, "losses_log_every should be greater than 0" assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" # default value for start_from and checkpoint_path args.checkpoint_path = args.checkpoint_path or './log_%s' % args.id args.start_from = args.start_from or args.checkpoint_path # Deal with feature things before anything args.use_fc, args.use_att = utils.if_use_feat(args.caption_model) if args.use_box: args.att_feat_size = args.att_feat_size + 5 return args
def train(opt): # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length opt.vocab = loader.get_vocab() tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible # with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f: with open(os.path.join(opt.start_from, 'infos_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme # if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): # with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f: if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl')): with open(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f: histories = utils.pickle_load(f) else: infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) lw_model = LossWrapper(model, opt) dp_lw_model = torch.nn.DataParallel(lw_model) epoch_done = True # Assure in training mode dp_lw_model.train() if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) def save_checkpoint(model, infos, optimizer, histories=None, append=''): if len(append) > 0: append = '-' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append)) torch.save(optimizer.state_dict(), optimizer_path) with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f: utils.pickle_dump(infos, f) if histories: with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f: utils.pickle_dump(histories, f) # pdb.set_trace() try: while True: if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False init_scorer(opt.cached_tokens) epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') # pdb.set_trace() print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['sents_mask']] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks, sents_mask = tmp box_inds = None optimizer.zero_grad() model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, box_inds, epoch, sents_mask) loss = model_out['loss'].mean() loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean() lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix # make evaluation on validation set, and save model # eval model # eval_kwargs = {'split': 'val', # 'dataset': opt.input_json} # eval_kwargs.update(vars(opt)) # val_loss, predictions, lang_stats = eval_utils.eval_split( # dp_model, lw_model.crit, loader, eval_kwargs) if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k,v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = - val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history save_checkpoint(model, infos, optimizer, histories) if opt.save_history_ckpt: save_checkpoint(model, infos, optimizer, append=str(iteration)) if best_flag: save_checkpoint(model, infos, optimizer, append='best') # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): print("=================Training Information==============") print("start from {}".format(opt.start_from)) print("box from {}".format(opt.input_box_dir)) print("attributes from {}".format(opt.input_att_dir)) print("features from {}".format(opt.input_fc_dir)) print("batch size ={}".format(opt.batch_size)) print("#GPU={}".format(torch.cuda.device_count())) print("Caption model {}".format(opt.caption_model)) print("refine aoa {}".format(opt.refine_aoa)) print("Number of aoa module {}".format(opt.aoa_num)) print("Self Critic After {}".format(opt.self_critical_after)) print("learning_rate_decay_every {}".format(opt.learning_rate_decay_every)) # use more data to fine tune the model for better challeng results. We dont use it if opt.use_val or opt.use_test: print("+++++++++++It is a refining training+++++++++++++++") print("===========Val is {} used for training ===========".format( '' if opt.use_val else 'not')) print("===========Test is {} used for training ===========".format( '' if opt.use_test else 'not')) print("=====================================================") # set more detail name of checkpoint paths checkpoint_path_suffix = "_bs{}".format(opt.batch_size) if opt.use_warmup: checkpoint_path_suffix += "_warmup" if torch.cuda.device_count() > 1: checkpoint_path_suffix += "_gpu{}".format(torch.cuda.device_count()) if opt.checkpoint_path.endswith('_rl'): opt.checkpoint_path = opt.checkpoint_path[: -3] + checkpoint_path_suffix + '_rl' else: opt.checkpoint_path += checkpoint_path_suffix print("Save model to {}".format(opt.checkpoint_path)) # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 acc_steps = getattr(opt, 'acc_steps', 1) name_append = opt.name_append if len(name_append) > 0 and name_append[0] != '-': name_append = '_' + name_append loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length opt.losses_log_every = len(loader.split_ix['train']) // opt.batch_size print("Evaluate on each {} iterations".format(opt.losses_log_every)) if opt.write_summary: print("write summary to {}".format(opt.checkpoint_path)) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} # load checkpoint if opt.start_from is not None: # open old infos and check if models are compatible infos_path = os.path.join(opt.start_from, 'infos' + name_append + '.pkl') print("Load model information {}".format(infos_path)) with open(infos_path, 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] # this sanity check may not work well, and comment it if necessary for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], \ "Command line argument and saved model disagree on '%s' " % checkme histories_path = os.path.join(opt.start_from, 'histories' + name_append + '.pkl') if os.path.isfile(histories_path): with open(histories_path, 'rb') as f: histories = utils.pickle_load(f) else: # start from scratch print("==============================================") print("Initialize training process from all begining") print("==============================================") infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) print("==================start from {} iterations -- {} epoch".format( iteration, epoch)) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) start_Img_idx = loader.iterators['train'] loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) best_epoch = infos.get('best_epoch', None) best_cider = infos.get('best_val_score', 0) print("========best history val cider score: {} in epoch {}=======". format(best_val_score, best_epoch)) # sanity check for the saved model name has a correct index if opt.name_append.isdigit() and int(opt.name_append) < 100: assert int( opt.name_append ) - epoch == 1, "dismatch in the model index and the real epoch number" epoch += 1 opt.vocab = loader.get_vocab() model = models.setup(opt).cuda() del opt.vocab if torch.cuda.device_count() > 1: dp_model = torch.nn.DataParallel(model) else: dp_model = model lw_model = LossWrapper1(model, opt) # wrap loss into model dp_lw_model = torch.nn.DataParallel(lw_model) epoch_done = True # Assure in training mode dp_lw_model.train() if opt.noamopt: assert opt.caption_model in [ 'transformer', 'aoa' ], 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None: optimizer_path = os.path.join(opt.start_from, 'optimizer' + name_append + '.pth') if os.path.isfile(optimizer_path): print("Loading optimizer............") optimizer.load_state_dict(torch.load(optimizer_path)) def save_checkpoint(model, infos, optimizer, histories=None, append=''): if len(append) > 0: append = '_' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' % (append)) torch.save(model.state_dict(), checkpoint_path) print("Save model state to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' % (append)) torch.save(optimizer.state_dict(), optimizer_path) print("Save model optimizer to {}".format(optimizer_path)) with open( os.path.join(opt.checkpoint_path, 'infos' + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(infos, f) print("Save training information to {}".format( os.path.join(opt.checkpoint_path, 'infos' + '%s.pkl' % (append)))) if histories: with open( os.path.join(opt.checkpoint_path, 'histories' + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(histories, f) print("Save training historyes to {}".format( os.path.join(opt.checkpoint_path, 'histories' + '%s.pkl' % (append)))) try: while True: if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor * opt.refine_lr_decay else: opt.current_lr = opt.learning_rate infos['current_lr'] = opt.current_lr print("Current Learning Rate is: {}".format( opt.current_lr)) utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min( opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False print("{}th Epoch Training starts now!".format(epoch)) with tqdm(total=len(loader.split_ix['train']), initial=start_Img_idx) as pbar: for i in range(start_Img_idx, len(loader.split_ix['train']), opt.batch_size): start = time.time() if (opt.use_warmup == 1) and (iteration < opt.noamopt_warmup): opt.current_lr = opt.learning_rate * ( iteration + 1) / opt.noamopt_warmup utils.set_lr(optimizer, opt.current_lr) # Load data from train split (0) data = loader.get_batch('train') # print('Read data:', time.time() - start) if (iteration % acc_steps == 0): optimizer.zero_grad() torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['flag_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, flag_feats, labels, masks, att_masks = tmp model_out = dp_lw_model(fc_feats, att_feats, flag_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) loss = model_out['loss'].mean() loss_sp = loss / acc_steps loss_sp.backward() if (iteration + 1) % acc_steps == 0: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() torch.cuda.synchronize() train_loss = loss.item() end = time.time() if not sc_flag: pbar.set_description( "iter {:8} (epoch {:2}), train_loss = {:.3f}, time/batch = {:.3f}" .format(iteration, epoch, train_loss, end - start)) else: pbar.set_description( "iter {:8} (epoch {:2}), avg_reward = {:.3f}, time/batch = {:.3f}" .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 pbar.update(opt.batch_size) if data['bounds']['wrapped']: epoch += 1 epoch_done = True # save after each epoch save_checkpoint(model, infos, optimizer) if epoch > 15: # To save memory, you can comment this part save_checkpoint(model, infos, optimizer, append=str(epoch)) print( "=====================================================" ) print( "======Best Cider = {} in epoch {}: iter {}!======" .format(best_val_score, best_epoch, infos.get('best_itr', None))) print( "=====================================================" ) # Write training history into summary if (iteration % opt.losses_log_every == 0) and opt.write_summary: # if (iteration % 10== 0) and opt.write_summary: add_summary_value(tb_summary_writer, 'loss/train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'hyperparam/learning_rate', opt.current_lr, iteration) add_summary_value( tb_summary_writer, 'hyperparam/scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) loss_history[ iteration] = train_loss if not sc_flag else model_out[ 'reward'].mean() lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix # make evaluation on validation set, and save model # unnecessary to eval from the beginning if (iteration % opt.save_checkpoint_every == 0) and eval_ and epoch > 3: # eval model model_path = os.path.join( opt.checkpoint_path, 'model_itr%s.pth' % (iteration)) if opt.use_val and not opt.use_test: val_split = 'test' if not opt.use_val: val_split = 'val' # val_split = 'val' eval_kwargs = { 'split': val_split, 'dataset': opt.input_json, 'model': model_path } eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary if opt.write_summary: add_summary_value(tb_summary_writer, 'loss/validation loss', val_loss, iteration) if lang_stats is not None: bleu_dict = {} for k, v in lang_stats.items(): if 'Bleu' in k: bleu_dict[k] = v if len(bleu_dict) > 0: tb_summary_writer.add_scalars( 'val/Bleu', bleu_dict, epoch) for k, v in lang_stats.items(): if 'Bleu' not in k: add_summary_value( tb_summary_writer, 'val/' + k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score infos['best_epoch'] = epoch infos['best_itr'] = iteration best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history save_checkpoint(model, infos, optimizer, histories) if opt.save_history_ckpt: save_checkpoint(model, infos, optimizer, append=str(iteration)) if best_flag: best_epoch = epoch save_checkpoint(model, infos, optimizer, append='best') print( "update best model at {} iteration--{} epoch". format(iteration, epoch)) # reset start_Img_idx = 0 # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: print("epoch {} break all".format(epoch)) save_checkpoint(model, infos, optimizer) # save_checkpoint(model, infos, optimizer, append=str(epoch)) tb_summary_writer.close() print("============{} Training Done !==============".format( 'Refine' if opt.use_test or opt.use_val else '')) break except (RuntimeError, KeyboardInterrupt): # KeyboardInterrupt print('Save ckpt on exception ...') save_checkpoint(model, infos, optimizer, append='interrupt') print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories = utils.pickle_load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) epoch_done = True # Assure in training mode dp_model.train() if opt.label_smoothing > 0: crit = utils.LabelSmoothing(smoothing=opt.label_smoothing) else: crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) total_loss = 0 times = 0 while True: if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp times += 1 optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() total_loss = total_loss + train_loss torch.cuda.synchronize() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: # epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model # if (iteration % opt.save_checkpoint_every == 0): if data['bounds']['wrapped']: epoch += 1 # eval model eval_kwargs = { 'split': 'val', 'dataset': opt.input_json, 'verbose': False } eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats f = open('train_log_%s.txt' % opt.id, 'a') f.write( 'Epoch {}: | Date: {} | TrainLoss: {} | ValLoss: {} | Score: {}' .format(epoch, str(datetime.now()), str(total_loss / times), str(val_loss), str(current_score))) f.write('\n') f.close() print('-------------------wrote to log file') total_loss = 0 times = 0 current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if not os.path.isdir(opt.checkpoint_path): os.mkdir(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) # print(str(infos['best_val_score'])) print("model saved to {}".format(checkpoint_path)) if opt.save_history_ckpt: checkpoint_path = os.path.join( opt.checkpoint_path, 'model-%d.pth' % (iteration)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: utils.pickle_dump(infos, f) if opt.save_history_ckpt: with open( os.path.join( opt.checkpoint_path, 'infos_' + opt.id + '-%d.pkl' % (iteration)), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: utils.pickle_dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: utils.pickle_dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): print(opt) # To reproduce training results init_seed() # Image Preprocessing # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(degrees=10), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)) ]) # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt, transform=transform) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '-best.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[ checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories = utils.pickle_load(f) else: infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) opt.vocab = loader.get_vocab() if torch.cuda.is_available(): model = models.setup(opt).cuda() else: model = models.setup(opt) del opt.vocab dp_model = torch.nn.DataParallel(model) lw_model = LossWrapper(model, opt) dp_lw_model = torch.nn.DataParallel(lw_model) #fgm = FGM(model) cnn_model = ResnetBackbone() if torch.cuda.is_available(): cnn_model = cnn_model.cuda() if opt.start_from is not None: model_dict = cnn_model.state_dict() predict_dict = torch.load(os.path.join(opt.start_from, 'cnn_model-best.pth')) model_dict = {k: predict_dict["module."+k] for k, _ in model_dict.items() if "module."+ k in predict_dict} cnn_model.load_state_dict(model_dict) cnn_model = torch.nn.DataParallel(cnn_model) epoch_done = True # Assure in training mode dp_lw_model.train() if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-best.pth'))) def save_checkpoint(model, cnn_model, infos, optimizer, histories=None, append=''): if len(append) > 0: append = '-' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) #Transformer model checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' % (append)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) #CNN model checkpoint_path = os.path.join(opt.checkpoint_path, 'cnn_model%s.pth' % (append)) if not os.path.exists(checkpoint_path): torch.save(cnn_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' % (append)) torch.save(optimizer.state_dict(), optimizer_path) with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(infos, f) if histories: with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(histories, f) cnn_after = 3 try: while True: if epoch_done: if opt.fix_cnn or epoch < cnn_after: for p in cnn_model.parameters(): p.requires_grad = False cnn_model.eval() cnn_optimizer = None else: for p in cnn_model.parameters(): p.requires_grad = True # Fix the first few layers: for module in cnn_model._modules['module']._modules['resnet_conv'][:5]._modules.values(): for p in module.parameters(): p.requires_grad = False cnn_model.train() # Constructing CNN parameters for optimization, only fine-tuning higher layers cnn_optimizer = torch.optim.Adam( (filter(lambda p: p.requires_grad, cnn_model.parameters())), lr=2e-6 if (opt.self_critical_after != -1 and epoch >= opt.self_critical_after) else 5e-5, betas=(0.8, 0.999)) if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') if iteration % opt.losses_log_every == 0: print('Read data:', time.time() - start) if torch.cuda.is_available(): torch.cuda.synchronize() start = time.time() if torch.cuda.is_available(): data['att_feats'] = cnn_model( data['att_feats'].cuda()) else: data['att_feats'] = cnn_model( data['att_feats'] ) data['att_feats'] = repeat_feat(data['att_feats']) tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] if torch.cuda.is_available(): tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if cnn_optimizer is not None: cnn_optimizer.zero_grad() # if epoch >= cnn_after: # att_feats.register_hook(save_grad("att_feats")) model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) loss = model_out['loss'].mean() loss.backward() #loss.backward(retain_graph=True) # adversarial training #fgm.attack(emb_name='model.tgt_embed.0.lut.weight') #adv_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], # torch.arange(0, len(data['gts'])), sc_flag) #adv_loss = adv_out['loss'].mean() #adv_loss.backward() #fgm.restore(emb_name="model.tgt_embed.0.lut.weight") # utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() if cnn_optimizer is not None: cnn_optimizer.step() train_loss = loss.item() if torch.cuda.is_available(): torch.cuda.synchronize() end = time.time() if not sc_flag and iteration % opt.losses_log_every == 0: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) elif iteration % opt.losses_log_every == 0: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean() lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) eval_kwargs["cnn_model"] = cnn_model val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = - val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history save_checkpoint(model, cnn_model, infos, optimizer, histories) if opt.save_history_ckpt: save_checkpoint(model, cnn_model, infos, optimizer, append=str(iteration)) if best_flag: save_checkpoint(model, cnn_model, infos, optimizer, append='best') # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(model, cnn_model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def parse_opt(): parser = argparse.ArgumentParser() # Data input settings parser.add_argument( '--input_json', type=str, default='./data/output.json', help='path to the json file containing additional info and vocab') parser.add_argument( '--input_att_dir', type=str, default='./data/bottom_up_att', help= 'path to the directory containing the preprocessed default att feats') parser.add_argument( '--input_box_dir', type=str, default='./data/bottom_up_box', help='path to the directory containing the boxes of default att feats') parser.add_argument( '--input_label_h5', type=str, default='./data/output_label.h5', help='path to the h5file containing the preprocessed dataset') parser.add_argument( '--input_att1_dir', type=str, default='./data/bottom_up_att1', help= 'path to the directory containing the preprocessed semantic att feats') parser.add_argument( '--input_box1_dir', type=str, default='./data/bottom_up_box1', help='path to the directory containing the boxes of semantic att feats' ) parser.add_argument( '--start_from', type=str, default="./checkpoint", help= """continue training from saved model at this path. Path must contain files saved by previous training process: 'infos.pkl' : configuration; 'checkpoint' : paths to model file(s) (created by tf). Note: this file contains absolute paths, be careful when moving files around; 'model.ckpt-*' : file(s) with model definition (created by tf) """) # Model settings parser.add_argument('--caption_model', type=str, default="updown", help='updown, topdown') parser.add_argument( '--rnn_size', type=int, default=512, help='size of the rnn in number of hidden nodes in each layer') parser.add_argument('--num_layers', type=int, default=2, help='number of layers in the RNN') parser.add_argument('--rnn_type', type=str, default='lstm', help='rnn, gru, or lstm') parser.add_argument( '--input_encoding_size', type=int, default=300, help='the encoding size of each token in the vocabulary, and the image.' ) parser.add_argument( '--att_hid_size', type=int, default=512, help= 'the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer' ) parser.add_argument('--fc_feat_size', type=int, default=2048, help='2048 for resnet, 4096 for vgg') parser.add_argument('--att_feat_size', type=int, default=2048, help='2048 for resnet, 512 for vgg') parser.add_argument('--logit_layers', type=int, default=1, help='number of layers in the RNN') parser.add_argument( '--use_bn', type=int, default=0, help= 'If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed' ) # feature manipulation parser.add_argument('--norm_att_feat', type=int, default=0, help='If normalize attention features') parser.add_argument('--use_box', type=int, default=1, help='If use box features') parser.add_argument('--norm_box_feat', type=int, default=0, help='If use box, do we normalize box feature') # Optimization: General parser.add_argument('--max_epochs', type=int, default=20, help='number of epochs') parser.add_argument('--batch_size', type=int, default=64, help='minibatch size') parser.add_argument('--grad_clip_mode', type=str, default='value', help='value or norm') parser.add_argument( '--grad_clip_value', type=float, default=0.1, help='clip gradients at this value/max_norm, 0 means no clipping') parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='strength of dropout in the Language Model RNN') parser.add_argument( '--self_critical_after', type=int, default=-1, help= 'After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)' ) parser.add_argument( '--seq_per_img', type=int, default=2, help='number of captions to sample for each image during training.') # Sample related add_eval_sample_opts(parser) #Optimization: for the Language Model parser.add_argument( '--optim', type=str, default='adam', help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam') parser.add_argument('--learning_rate', type=float, default=0.0005, help='learning rate') parser.add_argument( '--learning_rate_decay_start', type=int, default=-1, help= 'at what iteration to start decaying learning rate? (-1 = dont) (in epoch)' ) parser.add_argument( '--learning_rate_decay_every', type=int, default=10, help='every how many iterations thereafter to drop LR?(in epoch)') parser.add_argument( '--learning_rate_decay_rate', type=float, default=0.8, help='every how many iterations thereafter to drop LR?(in epoch)') parser.add_argument('--optim_alpha', type=float, default=0.9, help='alpha for adam') parser.add_argument('--optim_beta', type=float, default=0.999, help='beta used for adam') parser.add_argument( '--optim_epsilon', type=float, default=1e-8, help='epsilon that goes into denominator for smoothing') parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') parser.add_argument('--scheduled_sampling_start', type=int, default=100, help='at what iteration to start decay gt probability') parser.add_argument( '--scheduled_sampling_increase_every', type=int, default=5, help='every how many iterations thereafter to gt probability') parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, help='How much to update the prob') parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, help='Maximum scheduled sampling prob.') # Evaluation/Checkpointing parser.add_argument( '--val_images_use', type=int, default=-1, help= 'how many images to use when periodically evaluating the validation loss? (-1 = all)' ) parser.add_argument( '--save_checkpoint_every', type=int, default=1500, help='how often to save a model checkpoint (in iterations)?') parser.add_argument( '--save_every_epoch', action='store_false', help='Save checkpoint every epoch, will overwrite save_checkpoint_every' ) parser.add_argument('--save_history_ckpt', type=int, default=0, help='If save checkpoints at every save point') parser.add_argument('--checkpoint_path', type=str, default="./checkpoint", help='directory to store checkpointed models') parser.add_argument( '--language_eval', type=int, default=1, help= 'Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.' ) parser.add_argument( '--losses_log_every', type=int, default=1000, help= 'How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)' ) parser.add_argument( '--load_best_score', type=int, default=1, help='Do we load previous best score when resuming training.') # misc parser.add_argument( '--id', type=str, default='up', help= 'an id identifying this run/job. used in cross-val and appended when writing progress files' ) # Reward parser.add_argument('--cider_reward_weight', type=float, default=1, help='The reward weight from cider') parser.add_argument('--bleu_reward_weight', type=float, default=0, help='The reward weight from bleu4') # Structure_loss parser.add_argument('--structure_loss_weight', type=float, default=1, help='') parser.add_argument('--structure_after', type=int, default=-1, help='T') parser.add_argument('--structure_loss_type', type=str, default='new_self_critical', help='') parser.add_argument('--struc_use_logsoftmax', action='store_true', help='') parser.add_argument('--entropy_reward_weight', type=float, default=1, help='Entropy reward, seems very interesting') parser.add_argument('--self_cider_reward_weight', type=float, default=0, help='self cider reward') # Used for self critical or structure. Used when sampling is need during training parser.add_argument('--train_sample_n', type=int, default=3, help='The reward weight from cider') parser.add_argument('--train_sample_method', type=str, default='sample', help='') parser.add_argument('--train_beam_size', type=int, default=1, help='') # Used for self critical parser.add_argument('--sc_sample_method', type=str, default='greedy', help='') parser.add_argument('--sc_beam_size', type=int, default=1, help='') # For diversity evaluation during training add_diversity_opts(parser) # config parser.add_argument( '--cfg', type=str, default=None, help='configuration; similar to what is used in detectron') parser.add_argument( '--set_cfgs', dest='set_cfgs', help='Set config keys. Key value sequence seperate by whitespace.' 'e.g. [key] [value] [key] [value]\n This has higher priority' 'than cfg file but lower than other args. (You can only overwrite' 'arguments that have alerady been defined in config file.)', default=None, nargs='+') # How will config be used # 1) read cfg argument, and load the cfg file if it's not None # 2) Overwrite cfg argument with set_cfgs # 3) parse config argument to args. # 4) in the end, parse command line argument and overwrite args # step 1: read cfg_fn args = parser.parse_args() if args.cfg is not None or args.set_cfgs is not None: from misc.config import CfgNode if args.cfg is not None: cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg)) else: cn = CfgNode() if args.set_cfgs is not None: cn.merge_from_list(args.set_cfgs) for k, v in cn.items(): if not hasattr(args, k): print('Warning: key %s not in args' % k) setattr(args, k, v) args = parser.parse_args(namespace=args) # Check if args are valid assert args.rnn_size > 0, "rnn_size should be greater than 0" assert args.num_layers > 0, "num_layers should be greater than 0" assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" assert args.batch_size > 0, "batch_size should be greater than 0" assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" assert args.seq_per_img > 0, "seq_per_img should be greater than 0" assert args.beam_size > 0, "beam_size should be greater than 0" assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" assert args.losses_log_every > 0, "losses_log_every should be greater than 0" assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" # default value for start_from and checkpoint_path args.checkpoint_path = args.checkpoint_path or './log_%s' % args.id args.start_from = args.start_from or args.checkpoint_path # Deal with feature things before anything args.use_att = utils.if_use_feat(args.caption_model) return args