if k in vars(opt) and getattr(opt, k) is not None: assert vars(opt)[k] == vars( infos['opt'])[k], k + ' option not consistent' else: vars(opt).update({k: vars(infos['opt'])[k] }) # copy over options from model vocab = infos['vocab'] # ix -> word mapping assert opt.seq_per_img == 5 opt.vse_loss_weight = vars(opt).get('vse_loss_weight', 1) opt.caption_loss_weight = vars(opt).get('caption_loss_weight', 1) # Setup the model model = models.JointModel(opt) utils.load_state_dict(model, torch.load(opt.model)) if opt.initialize_retrieval is not None: print("Make sure the vse opt are the same !!!!!\n" * 100) utils.load_state_dict( model, { k: v for k, v in torch.load(opt.initialize_retrieval).items() if 'vse' in k }) model.cuda() model.eval() # Create the Data Loader instance if len(opt.image_folder) == 0: loader = DataLoader(opt)
def eval_split(model, crit, loader, eval_kwargs={}): verbose = eval_kwargs.get('verbose', True) verbose_beam = eval_kwargs.get('verbose_beam', 1) verbose_loss = eval_kwargs.get('verbose_loss', 1) num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) split = eval_kwargs.get('split', 'val') lang_eval = eval_kwargs.get('language_eval', 0) dataset = eval_kwargs.get('dataset', 'coco') beam_size = eval_kwargs.get('beam_size', 1) if eval_kwargs.get('rank', 0): infos_path = 'log_fc_con/infos_vse_fc_con-best.pkl' # 'log_fc_con_discsplit/infos_vse_fc_con_discsplit-best.pkl' model_path = 'log_fc_con/model_vse-best.pth' # 'log_fc_con_discsplit/model_vse-best.pth' with open(infos_path) as f: infos = cPickle.load(f) rank_model = models.JointModel(infos['opt']) utils.load_state_dict(rank_model, torch.load(model_path)) rank_model.cuda() rank_model.eval() print('success loaded retrieval model !') # Make sure in the evaluation mode model.eval() loader.reset_iterator(split) n = 0 loss = 0 loss_sum = 0 loss_evals = 1e-8 predictions = [] seqs = [] while True: data = loader.get_batch(split) n = n + loader.batch_size sys.stdout.flush() if data.get('labels', None) is not None and verbose_loss: # forward the model to get loss tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [ torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp ] fc_feats, att_feats, labels, masks, att_masks = tmp with torch.no_grad(): loss = crit(model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]).item() loss_sum = loss_sum + loss loss_evals = loss_evals + 1 # forward the model to also get generated samples for each image # Only leave one feature for each image, in case duplicate sample tmp = [ data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img], data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img] if data['att_masks'] is not None else None ] tmp = [torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp] fc_feats, att_feats, att_masks = tmp # forward the model to also get generated samples for each image with torch.no_grad(): seq = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')[0].data # Print beam search if beam_size > 1 and verbose_beam: for i in range(loader.batch_size): pass #print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) #print('--' * 10) sents = utils.decode_sequence(loader.get_vocab(), seq) for k, sent in enumerate(sents): entry = {'image_id': data['infos'][k]['id'], 'caption': sent} if eval_kwargs.get('dump_path', 0) == 1: entry['file_name'] = data['infos'][k]['file_path'] predictions.append(entry) if eval_kwargs.get('dump_images', 0) == 1: # dump the raw image to vis/ folder cmd = 'cp "' + os.path.join( eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str( len(predictions)) + '.jpg' # bit gross print(cmd) os.system(cmd) if verbose: print('image %s: %s' % (entry['image_id'], entry['caption'])) # if we wrapped around the split or used up val imgs budget then bail ix0 = data['bounds']['it_pos_now'] ix1 = data['bounds']['it_max'] if eval_kwargs.get('rank', 0): seqs.append(padding(seq, 30)) if num_images != -1: ix1 = min(ix1, num_images) for i in range(n - ix1): predictions.pop() if n > ix1: seq = seq[:(ix1 - n) * loader.seq_per_img] if verbose: print('evaluating validation preformance... %d/%d (%f)' % (ix0 - 1, ix1, loss)) if data['bounds']['wrapped']: break if num_images >= 0 and n >= num_images: break if eval_kwargs.get('rank', 0): seqs = torch.cat(seqs, 0).contiguous() seqs = change_seq(seqs, loader.ix_to_word) if eval_kwargs.get('vsepp', 0): from eval_vsepp import evalrank_vsepp from eval_utils_pair import get_transform import torchvision.transforms as transforms from PIL import Image imgids = [_['image_id'] for _ in predictions] seqs = seqs[:num_images] transform = get_transform('COCO', 'val', None) imgs = [] for i, imgid in enumerate(imgids): img_path = '../imgcap/data/raw_images/val2014/COCO_val2014_' + str( imgid).zfill(12) + '.jpg' if i % 100 == 0: print('load %d images' % i) image = Image.open(img_path).convert('RGB') image = transform(image) imgs.append(image.unsqueeze(0)) imgs = torch.cat(imgs, 0).contiguous() lengths = torch.sum((seqs > 0), 1) + 1 lengths = lengths.cpu() with torch.no_grad(): evalrank_vsepp(imgs, loader.ix_to_word, seqs, lengths) lang_stats = None if lang_eval == 1: lang_stats = language_eval(eval_kwargs.get('data', 'coco'), predictions, eval_kwargs['id'], split) if eval_kwargs.get('rank', 0): ranks = evalrank(rank_model, loader, seqs, eval_kwargs) # Switch back to training mode model.train() return loss_sum / loss_evals, predictions, lang_stats
def train(opt): opt.use_att = utils.if_use_att(opt) loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tf_summary_writer = tf and SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) best_val_score_vse = infos.get('best_val_score_vse', None) model = models.JointModel(opt) model.cuda() update_lr_flag = True # Assure in training mode model.train() optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=opt.learning_rate, weight_decay=opt.weight_decay) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, 'optimizer.pth')): state_dict = torch.load(os.path.join(opt.start_from, 'optimizer.pth')) if len(state_dict['state']) == len(optimizer.state_dict()['state']): optimizer.load_state_dict(state_dict) else: print( 'Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.' ) init_scorer(opt.cached_tokens) while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor utils.set_lr(optimizer, opt.current_lr) # set the decayed rate else: opt.current_lr = opt.learning_rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.caption_generator.ss_prob = opt.ss_prob # Assign retrieval loss weight if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0: frac = (epoch - opt.retrieval_reward_weight_decay_start ) // opt.retrieval_reward_weight_decay_every model.retrieval_reward_weight = opt.retrieval_reward_weight * ( opt.retrieval_reward_weight_decay_rate**frac) update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'], data['masks'] ] tmp = utils.var_wrapper(tmp) fc_feats, att_feats, att_masks, labels, masks = tmp optimizer.zero_grad() loss = model(fc_feats, att_feats, att_masks, labels, masks, data) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.data[0] torch.cuda.synchronize() end = time.time() print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) prt_str = "" for k, v in model.loss().items(): prt_str += "{} = {:.3f} ".format(k, v) print(prt_str) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): if tf is not None: tf_summary_writer.add_scalar('train_loss', train_loss, iteration) for k, v in model.loss().items(): tf_summary_writer.add_scalar(k, v, iteration) tf_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration) tf_summary_writer.add_scalar('scheduled_sampling_prob', model.caption_generator.ss_prob, iteration) tf_summary_writer.add_scalar('retrieval_reward_weight', model.retrieval_reward_weight, iteration) tf_summary_writer.file_writer.flush() loss_history[iteration] = train_loss lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.caption_generator.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) # Load the retrieval model for evaluation val_loss, predictions, lang_stats = eval_utils.eval_split( model, loader, eval_kwargs) # Write validation result into summary if tf is not None: for k, v in val_loss.items(): tf_summary_writer.add_scalar('validation ' + k, v, iteration) for k, v in lang_stats.items(): tf_summary_writer.add_scalar(k, v, iteration) tf_summary_writer.add_text( 'Captions', '.\n\n'.join([_['caption'] for _ in predictions[:100]]), iteration) #tf_summary_writer.add_image('images', utils.make_summary_image(), iteration) #utils.make_html(opt.id, iteration) tf_summary_writer.file_writer.flush() val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['SPICE'] * 100 else: current_score = -val_loss['loss_cap'] current_score_vse = val_loss.get(opt.vse_eval_criterion, 0) * 100 best_flag = False best_flag_vse = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if best_val_score_vse is None or current_score_vse > best_val_score_vse: best_val_score_vse = current_score_vse best_flag_vse = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) checkpoint_path = os.path.join(opt.checkpoint_path, 'model-%d.pth' % (iteration)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['best_val_score_vse'] = best_val_score_vse infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join( opt.checkpoint_path, 'infos_' + opt.id + '-%d.pkl' % (iteration)), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) if best_flag_vse: checkpoint_path = os.path.join(opt.checkpoint_path, 'model_vse-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_vse_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) #opt.ss_prob=0.0 infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f: histories = utils.pickle_load(f) else: infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['pix_perss']=loader.get_personality() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) print("current epoch: ",epoch) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) opt.vocab = loader.get_vocab() opt.xpersonality=loader.get_personality() if opt.use_joint==0: #torch.cuda.set_device(0) model = models.setup(opt).cuda() elif opt.use_joint==1: model = models.JointModel(opt) model.cuda() #model=models.setup(opt) del opt.vocab if opt.start_from is not None: opt.model=os.path.join(opt.start_from, 'model'+'.pth') model.load_state_dict(torch.load(opt.model)) dp_model = torch.nn.DataParallel(model) lw_model = LossWrapper(model, opt) dp_lw_model = torch.nn.DataParallel(lw_model) #dp_lw_model=LossWrapper(model, opt) # this is for no cuda epoch_done = True # Assure in training mode #dp_lw_model=lw_model dp_lw_model.train() if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer([p for p in model.parameters() if p.requires_grad], opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer([p for p in model.parameters() if p.requires_grad], opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) else: print('Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.') def save_checkpoint(model, infos, optimizer, histories=None, append=''): if len(append) > 0: append = '-' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append)) torch.save(optimizer.state_dict(), optimizer_path) with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f: utils.pickle_dump(infos, f) if histories: with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f: utils.pickle_dump(histories, f) try: while True: if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False # Assign retrieval loss weight if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0: frac = (epoch - opt.retrieval_reward_weight_decay_start) // opt.retrieval_reward_weight_decay_every model.retrieval_reward_weight = opt.retrieval_reward_weight * (opt.retrieval_reward_weight_decay_rate ** frac) epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() with torch.autograd.set_detect_anomaly(True): tmp = [data['fc_feats'], data['att_feats'],data['densecap'], data['labels'], data['masks'], data['att_masks'], data['personality']] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats,densecap, labels, masks, att_masks,personality = tmp optimizer.zero_grad() model_out = dp_lw_model(fc_feats, att_feats,densecap, labels, masks, att_masks,personality, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) loss = model_out['loss'].mean() loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f},train_loss = {:.3f}" \ .format(iteration, epoch, model_out['reward'].mean(), end - start,train_loss)) if opt.use_joint==1: for k, v in model.loss().items(): prt_str += "{} = {:.3f} ".format(k, v) print(prt_str) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean() lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k,v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: if opt.use_joint==1: current_score = lang_stats['SPICE']*100 elif opt.use_joint==0: current_score = lang_stats['CIDEr'] # could use SPICE else: if opt.use_joint==0: current_score = - val_loss elif opt.use_joint==1: current_score= - val_loss['loss_cap'] if opt.use_joint==1: current_score_vse = val_loss.get(opt.vse_eval_criterion, 0)*100 best_flag = False best_flag_vse= False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if opt.use_joint==1: if best_val_score_vse is None or current_score_vse > best_val_score_vse: best_val_score_vse = current_score_vse best_flag_vse = True infos['best_val_score_vse'] = best_val_score_vse # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history save_checkpoint(model, infos, optimizer, histories) if opt.save_history_ckpt: save_checkpoint(model, infos, optimizer, append=str(iteration)) if best_flag: save_checkpoint(model, infos, optimizer, append='best') if best_flag_vse: save_checkpoint(model, infos, optimizer, append='vse-best') # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)