def run_train(opt, model, crit, optimizer, loader, device, logger=None, epoch=-1, return_all_info=False, **kwargs): model.train() crit.reset_loss_recorder() vocab = loader.dataset.get_vocab() pb = ProgressBar(len(loader)) pb.start() for data in loader: optimizer.zero_grad() results = get_forword_results(opt, model, data, device=device, only_data=False, vocab=vocab, **kwargs) loss = crit.get_loss(results, epoch=epoch) loss.backward() clip_grad_value_(model.parameters(), opt['grad_clip']) optimizer.step() pb.update() name, loss_info = crit.get_loss_info() if logger is not None: logger.write_text('\t'.join(['%10s: %05.3f' % (item[0], item[1]) for item in zip(name, loss_info)])) if return_all_info: return loss_info return loss_info[0]
class DummyIterBasedRunner(IterBasedRunner): """Fake Iter-based Runner. This runner won't train model, and it will only call hooks and return all learning rate in each iteration. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.progress_bar = ProgressBar(self._max_iters, start=False) def train(self, data_loader, **kwargs): lr_list = [] self.model.train() self.mode = 'train' self.data_loader = data_loader self._epoch = data_loader.epoch next(data_loader) self.call_hook('before_train_iter') lr_list.append(self.current_lr()) self.call_hook('after_train_iter') self._inner_iter += 1 self._iter += 1 self.progress_bar.update(1) return lr_list def run(self, data_loaders, workflow, **kwargs): assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) assert self._max_iters is not None, ( 'max_iters must be specified during instantiation') self.logger.info('workflow: %s, max: %d iters', workflow, self._max_iters) self.call_hook('before_run') iter_loaders = [IterLoader(x) for x in data_loaders] self.call_hook('before_epoch') self.progress_bar.start() lr_list = [] while self.iter < self._max_iters: for i, flow in enumerate(workflow): self._inner_iter = 0 mode, iters = flow if not isinstance(mode, str) or not hasattr(self, mode): raise ValueError( 'runner has no method named "{}" to run a workflow'. format(mode)) iter_runner = getattr(self, mode) for _ in range(iters): if mode == 'train' and self.iter >= self._max_iters: break lr_list.extend(iter_runner(iter_loaders[i], **kwargs)) self.progress_bar.file.write('\n') time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_epoch') self.call_hook('after_run') return lr_list
class DummyEpochBasedRunner(EpochBasedRunner): """Fake Epoch-based Runner. This runner won't train model, and it will only call hooks and return all learning rate in each iteration. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.progress_bar = ProgressBar(self._max_epochs, start=False) def train(self, data_loader, **kwargs): lr_list = [] self.model.train() self.mode = 'train' self.data_loader = data_loader self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook('before_train_epoch') for i in range(len(self.data_loader)): self._inner_iter = i self.call_hook('before_train_iter') lr_list.append(self.current_lr()) self.call_hook('after_train_iter') self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1 self.progress_bar.update(1) return lr_list def run(self, data_loaders, workflow, **kwargs): assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) assert self._max_epochs is not None, ( 'max_epochs must be specified during instantiation') for i, flow in enumerate(workflow): mode, epochs = flow if mode == 'train': self._max_iters = self._max_epochs * len(data_loaders[i]) break self.logger.info('workflow: %s, max: %d epochs', workflow, self._max_epochs) self.call_hook('before_run') self.progress_bar.start() lr_list = [] while self.epoch < self._max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow if isinstance(mode, str): # self.train() if not hasattr(self, mode): raise ValueError( f'runner has no method named "{mode}" to run an ' 'epoch') epoch_runner = getattr(self, mode) else: raise TypeError( 'mode in workflow must be a str, but got {}'.format( type(mode))) for _ in range(epochs): if mode == 'train' and self.epoch >= self._max_epochs: break lr_list.extend(epoch_runner(data_loaders[i], **kwargs)) self.progress_bar.file.write('\n') time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_run') return lr_list
def run_eval( opt, model, crit, loader, vocab, device, json_path='', json_name='', scorer=COCOScorer(), teacher_model=None, dict_mapping={}, no_score=False, print_sent=False, analyze=False, collect_best_candidate_iterative_results=False, collect_path=None, extra_opt={}, summarywriter=None, global_step=0): opt.update(extra_opt) model.eval() if teacher_model is not None: teacher_model.eval() gt_captions = loader.dataset.get_references() pred_captions = defaultdict(list) opt['collect_best_candidate_iterative_results'] = collect_best_candidate_iterative_results translator = Translator(model=model, opt=opt, teacher_model=teacher_model, dict_mapping=dict_mapping) best_candidate_sents = defaultdict(list) best_candidate_score = defaultdict(list) best_ar_sent = [] all_time = 0 if crit is not None: crit.reset_loss_recorder() collect_ar_flag = (opt['decoding_type'] == 'ARFormer' and collect_best_candidate_iterative_results) pb = ProgressBar(len(loader)) pb.start() for data in loader: with torch.no_grad(): encoder_outputs, category, labels = get_forword_results(opt, model, data, device=device, only_data=True, vocab=vocab) if crit is not None: _ = crit.get_loss(encoder_outputs) if teacher_model is not None: teacher_encoder_outputs, *_ = get_forword_results(opt, teacher_model, data, device=device, only_data=True, vocab=vocab) else: teacher_encoder_outputs = None if opt['batch_size'] == 1: start_time = time.time() all_hyp, all_scores = translator.translate_batch(encoder_outputs, category, labels, vocab, teacher_encoder_outputs=teacher_encoder_outputs) if opt['batch_size'] == 1: all_time += (time.time() - start_time) if isinstance(all_hyp, torch.Tensor): if len(all_hyp.shape) == 2: all_hyp = all_hyp.unsqueeze(1) all_hyp = all_hyp.tolist() if isinstance(all_scores, torch.Tensor): if len(all_scores.shape) == 2: all_scores = all_scores.unsqueeze(1) all_scores = all_scores.tolist() video_ids = np.array(data['video_ids']).reshape(-1) for k, hyps in enumerate(all_hyp): video_id = video_ids[k] if not no_score: assert len(hyps) == 1 for j, hyp in enumerate(hyps): sent = to_sentence(hyp, vocab) if opt.get('duplicate', False) and opt['decoding_type'] == 'NARFormer': sent, _ = duplicate(sent) if not collect_ar_flag: # for evaluation pred_captions[video_id].append({'image_id': video_id, 'caption': sent}) else: # for collection pred_captions[video_id].append({'caption': sent, 'score': all_scores[k][j]}) if collect_best_candidate_iterative_results and not collect_ar_flag: assert isinstance(all_scores, tuple) all_sents = all_scores[0].tolist() all_score = all_scores[1].tolist() if len(video_ids) != len(all_sents): video_ids = np.array(data['video_ids'])[:, np.newaxis].repeat(opt['length_beam_size'], axis=1).reshape( -1) assert len(video_ids) == len(all_sents) for k, (hyps, scores) in enumerate(zip(all_sents, all_score)): video_id = video_ids[k] pre_sent_len = 0 assert len(hyps) == len(scores) for j, (hyp, score) in enumerate(zip(hyps, scores)): sent = to_sentence(hyp, vocab) if not pre_sent_len: pre_sent_len = len(sent.split(' ')) else: assert len(sent.split(' ')) == pre_sent_len best_candidate_sents[video_id].append(sent) best_candidate_score[video_id].append(score) pb.update() if collect_best_candidate_iterative_results: assert collect_path is not None if not collect_ar_flag: pickle.dump( [best_candidate_sents, best_candidate_score], open(collect_path, 'wb') ) else: pickle.dump(pred_captions, open(collect_path, 'wb')) if opt['batch_size'] == 1: latency = all_time/len(loader) print(latency, len(loader)) res = {} if analyze: ave_length, novel, unique, usage, hy_res, gram4 = analyze_length_novel_unique(loader.dataset.captions, pred_captions, vocab, splits=loader.dataset.splits, n=1) res.update({'ave_length': ave_length, 'novel': novel, 'unique': unique, 'usage': usage, 'gram4': gram4}) if not no_score: # with suppress_stdout_stderr(): valid_score, detail_scores = scorer.score(gt_captions, pred_captions, pred_captions.keys()) res.update(valid_score) metric_sum = opt.get('metric_sum', [1, 1, 1, 1]) candidate = [res["Bleu_4"], res["METEOR"], res["ROUGE_L"], res["CIDEr"]] res['Sum'] = sum([item for index, item in enumerate(candidate) if metric_sum[index]]) if crit is not None: names, metrics = crit.get_loss_info() for n, m in zip(names, metrics): res[n] = m if summarywriter is not None: for k, v in res.items(): summarywriter.add_scalar(k, v, global_step=global_step) if json_path: if not os.path.exists(json_path): os.makedirs(json_path) with open(os.path.join(json_path, json_name), 'w') as prediction_results: json.dump({"predictions": pred_captions, "scores": valid_score}, prediction_results) prediction_results.close() return res
def train_network_all(opt, model, device, **kwargs): if opt.get('load_teacher_weights', False): assert opt.get('teacher_path', None) is not None model = load_satisfied_weights( model=model, checkpoint_path=opt['teacher_path'], str_mapping={'decoder.bert.': 'decoder.'} ) model.to(device) summarywriter = SummaryWriter(os.path.join(opt['checkpoint_path'], 'trainval')) optimizer = get_optimizer(opt, model, summarywriter=summarywriter) crit = get_criterion(opt, summarywriter=summarywriter) crit_eval = get_criterion_during_evaluation(opt) if opt.get('with_teacher', False) and opt['method'] in ['NAB', 'NACF']: assert opt.get('teacher_path', None) is not None teacher_model, _ = load_model_and_opt(opt['teacher_path'], device) else: teacher_model = None folder_path = os.path.join(opt["checkpoint_path"], 'tmp_models') best_model = k_PriorityQueue( k_best_model=opt.get('k_best_model', 1), folder_path=folder_path, standard=opt.get('standard', ['METEOR', 'CIDEr']) ) train_loader = get_loader(opt, 'train', print_info=False, **kwargs) vali_loader = get_loader(opt, 'validate', print_info=False) # test_loader = get_loader(opt, 'test', print_info=False) vocab = vali_loader.dataset.get_vocab() logger = CsvLogger( filepath=opt["checkpoint_path"], filename='trainning_record.csv', fieldsnames=[ 'epoch', 'train_loss', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'Sum'] + crit.get_fieldsnames() ) start_epoch = opt['start_epoch'] pb = ProgressBar(opt['epochs']) pb.start() for epoch in range(opt['epochs']): if epoch < start_epoch: continue train_loader.dataset.shuffle() logger.write_text("epoch %d lr=%g (ss_prob=%g)" % (epoch, optimizer.get_lr(), opt.get('teacher_prob', 1))) # training train_loss = run_train(opt, model, crit, optimizer, train_loader, device, logger=logger, epoch=epoch) optimizer.epoch_update_learning_rate() if (epoch+1) > opt['start_eval_epoch'] and (epoch+1) % opt["save_checkpoint_every"] == 0: res = run_eval(opt, model, crit_eval, vali_loader, vocab, device, teacher_model=teacher_model, analyze=True, summarywriter=summarywriter, global_step=epoch) res['train_loss'] = train_loss res['epoch'] = epoch logger.write(res) save_checkpoint( {'epoch': epoch + 1, 'state_dict': model.state_dict(), 'validate_result': res, 'settings': opt}, False, filepath=opt["checkpoint_path"], filename='checkpoint.pth.tar' ) model_name = 'model_%04d.pth.tar' % res['epoch'] model_path = os.path.join(folder_path, model_name) not_break, info = best_model.check(res, opt, model_path, model_name) if not not_break: # reach the tolerence break logger.write_text(info) pb.update() if not opt.get('no_test', False): model = model.to('cpu') del model del optimizer torch.cuda.empty_cache() os.system( 'python translate.py --default --method {} --dataset {} --record --scope {} --field {} -em test --use_ct --base_checkpoint_path {}'.format( opt['method'], opt['dataset'], opt['scope'] if opt['scope'] else '\"\"', ' '.join(opt['field']), opt['base_checkpoint_path']) ) if opt['k_best_model'] > 1: shutil.rmtree(folder_path)