def train_one_epoch(self, epoch, accum_iter): self.model.train() self.lr_scheduler.step() average_meter_set = AverageMeterSet() tqdm_dataloader = tqdm(self.train_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch_size = batch[0].size(0) batch = [x.to(self.device) for x in batch] self.optimizer.zero_grad() loss = self.calculate_loss(batch) loss.backward() self.optimizer.step() average_meter_set.update('loss', loss.item()) tqdm_dataloader.set_description( 'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg)) accum_iter += batch_size if self._needs_to_log(accum_iter): tqdm_dataloader.set_description('Logging to Tensorboard') log_data = { 'state_dict': (self._create_state_dict()), 'epoch': epoch, 'accum_iter': accum_iter, } log_data.update(average_meter_set.averages()) self.log_extra_train_info(log_data) self.logger_service.log_train(log_data) return accum_iter
def eval_one_epoch(self, eval_loader, epoch=None): average_meter_set = AverageMeterSet() with torch.no_grad(): tqdm_dataloader = tqdm(eval_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = self.batch_to_device(batch) metrics = self.calculate_metrics(batch) for k, v in metrics.items(): average_meter_set.update(k, v) if self.args.local and batch_idx > 20: break if batch_idx % 10 == 0 and batch_idx > 0: descr = get_metric_descr(average_meter_set, self.metric_ks) tqdm_dataloader.set_description(descr) descr = get_metric_descr(average_meter_set, self.metric_ks) if epoch is not None: print("\n Epoch {} avg.: {}".format(epoch+1, descr)) else: print("\n") #tqdm_dataloader.set_description(descr) return average_meter_set
def test(self): print('Test best model with test set!') best_model = torch.load( os.path.join(self.export_root, 'models', 'best_acc_model.pth')).get('model_state_dict') self.model.load_state_dict(best_model) self.model.eval() average_meter_set = AverageMeterSet() with torch.no_grad(): tqdm_dataloader = tqdm(self.test_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = [x.to(self.device) for x in batch] metrics, preds = self.calculate_metrics(batch) for k, v in metrics.items(): average_meter_set.update(k, v) description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\ ['Recall@%d' % k for k in self.metric_ks[:3]] description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics) description = description.replace('NDCG', 'N').replace('Recall', 'R') description = description.format( *(average_meter_set[k].avg for k in description_metrics)) tqdm_dataloader.set_description(description) average_metrics = average_meter_set.averages() with open( os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f: json.dump(average_metrics, f, indent=4) print(average_metrics)
def validate(self, epoch, accum_iter): self.model.eval() average_meter_set = AverageMeterSet() with torch.no_grad(): tqdm_dataloader = tqdm(self.val_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = [x.to(self.device) for x in batch] metrics = self.calculate_metrics(batch) for k, v in metrics.items(): average_meter_set.update(k, v) description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\ ['Recall@%d' % k for k in self.metric_ks[:3]] description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics) description = description.replace('NDCG', 'N').replace('Recall', 'R') description = description.format(*(average_meter_set[k].avg for k in description_metrics)) tqdm_dataloader.set_description(description) log_data = { 'state_dict': (self._create_state_dict()), 'epoch': epoch, 'accum_iter': accum_iter, } log_data.update(average_meter_set.averages()) self.logger_service.log_val(log_data)
def test(self): self.model.eval() average_meter_set = AverageMeterSet() with torch.no_grad(): tqdm_dataloader = tqdm(self.test_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = [x.to(self.device) for x in batch] metrics = self.calculate_metrics(batch) for k, v in metrics.items(): average_meter_set.update(k, v) description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\ ['Recall@%d' % k for k in self.metric_ks[:3]] description = 'Test: ' + ', '.join( s + ' {:.3f}' for s in description_metrics) description = description.replace('NDCG', 'N').replace('Recall', 'R') description = description.format( *(average_meter_set[k].avg for k in description_metrics)) tqdm_dataloader.set_description(description) return { 'dataset': self.dataset_code, 'pruning_code': self.prune_code, 'pruning_perc': self.pruning_perc, 'pruning_perc_embed': self.pruning_perc_embed, 'pruning_perc_feed': self.pruning_perc_feed, 'pruning_epochs': self.num_prune_epochs, 'num_epochs': self.num_epochs, 'result': description, }
def validate_voc_file(eval_loader, model, thre, epoch, print_freq, results_dir): start_time = time.time() meters = AverageMeterSet() model.eval() Sig = torch.nn.Sigmoid() end = time.time() preds = [] targets = [] names = [] for i, data in enumerate(eval_loader): assert len(data) >= 4 input, target, name = data[0], data[1], data[3] meters.update('data_time', time.time() - end) # compute output with torch.no_grad(): output = Sig(model(input.cuda())) # for mAP calculation preds.append(output.cpu()) targets.append(target.cpu()) names.extend(name) # measure elapsed time meters.update('batch_time', time.time() - end) end = time.time() if i % print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {meters[batch_time]:.3f}\t' 'Data {meters[data_time]:.3f}\t' .format(i, len(eval_loader), meters=meters)) preds = torch.cat(preds).numpy() targs = torch.cat(targets).numpy() if results_dir is not None: # save to results dir os.makedirs(results_dir, exist_ok=True) for i in range(20): cls_name = eval_loader.dataset.class_list[i] filename = '{}_{}.txt'.format(cls_name, eval_loader.dataset.image_set) with open(os.path.join(results_dir, filename), 'w') as f: for j in range(len(names)): f.write('{} {}\n'.format(names[j], preds[j, i])) AP = eval_loader.dataset.eval_file(results_dir) eval_loader.dataset.show_AP(AP, print_func=LOG.info) mAP = 100 * AP.mean() print(" * TEST [{}] VOC2012 mAP: {}".format(epoch, mAP)) print("--- testing epoch in {} seconds ---".format(time.time() - start_time)) return mAP
def validate(self, epoch, accum_iter, mode, doLog=True, **kwargs): if mode == 'val': loader = self.val_loader elif mode == 'test': loader = self.test_loader else: raise ValueError self.model.eval() average_meter_set = AverageMeterSet() num_instance = 0 with torch.no_grad(): tqdm_dataloader = tqdm(loader) if not self.pilot else loader for batch_idx, batch in enumerate(tqdm_dataloader): if self.pilot and batch_idx >= self.pilot_batch_cnt: # print('Break validation due to pilot mode') break batch = {k: v.to(self.device) for k, v in batch.items()} batch_size = next(iter(batch.values())).size(0) num_instance += batch_size metrics = self.calculate_metrics(batch) for k, v in metrics.items(): average_meter_set.update(k, v) if not self.pilot: description_metrics = ['NDCG@%d' % k for k in self.metric_ks] +\ ['Recall@%d' % k for k in self.metric_ks] description = '{}: '.format(mode.capitalize()) + ', '.join( s + ' {:.4f}' for s in description_metrics) description = description.replace('NDCG', 'N').replace( 'Recall', 'R') description = description.format( *(average_meter_set[k].avg for k in description_metrics)) tqdm_dataloader.set_description(description) log_data = { 'state_dict': (self._create_state_dict(epoch, accum_iter)), 'epoch': epoch, 'accum_iter': accum_iter, 'num_eval_instance': num_instance, } log_data.update(average_meter_set.averages()) log_data.update(kwargs) if doLog: if mode == 'val': self.logger_service.log_val(log_data) elif mode == 'test': self.logger_service.log_test(log_data) else: raise ValueError return log_data
def validate(self, epoch, accum_iter): self.model.eval() self.all_preds = [] average_meter_set = AverageMeterSet() with torch.no_grad(): tqdm_dataloader = tqdm(self.val_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = [x.to(self.device) for x in batch] metrics, preds = self.calculate_metrics(batch) for p in preds: self.all_preds.append(p.tolist()) for k, v in metrics.items(): average_meter_set.update(k, v) description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] + \ ['Recall@%d' % k for k in self.metric_ks[:3]] description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics) description = description.replace('NDCG', 'N').replace('Recall', 'R') description = description.format( *(average_meter_set[k].avg for k in description_metrics)) tqdm_dataloader.set_description(description) log_data = { 'state_dict': (self._create_state_dict()), 'epoch': epoch + 1, 'accum_iter': accum_iter, } log_data.update(average_meter_set.averages()) self.log_extra_val_info(log_data) self.logger_service.log_val(log_data) df = pd.DataFrame(self.all_preds, columns=[ 'prediction_' + str(i) for i in range(len(self.all_preds[0])) ]) if not os.path.isdir(self.args.output_predictions_folder): os.makedirs(self.args.output_predictions_folder) with open( os.path.join(self.args.output_predictions_folder, 'config.json'), 'w') as f: self.args.recommender = "BERT4rec" self.args.seed = str(self.args.model_init_seed) args_dict = {} args_dict['args'] = vars(self.args) f.write(json.dumps(args_dict, indent=4, sort_keys=True)) df.to_csv(self.args.output_predictions_folder + "/predictions.csv", index=False)
def train_one_epoch(self, epoch, accum_iter): self.model.train() average_meter_set = AverageMeterSet() tqdm_dataloader = tqdm(self.train_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = self.batch_to_device(batch) batch_size = self.args.train_batch_size # forward pass self.optimizer.zero_grad() loss = self.calculate_loss(batch) # backward pass loss.backward() self.optimizer.step() # update metrics average_meter_set.update('loss', loss.item()) average_meter_set.update('lr', self.optimizer.defaults['lr']) tqdm_dataloader.set_description('Epoch {}, loss {:.3f} '.format(epoch + 1, average_meter_set['loss'].avg)) accum_iter += batch_size if self._needs_to_log(accum_iter): tqdm_dataloader.set_description('Logging to Tensorboard') log_data = { 'state_dict': (self._create_state_dict()), 'epoch': epoch+1, 'accum_iter': accum_iter, } log_data.update(average_meter_set.averages()) self.log_extra_train_info(log_data) self.logger_service.log_train(log_data) if self.args.local and batch_idx == 20: break # adapt learning rate if self.args.enable_lr_schedule: self.lr_scheduler.step() if epoch % self.lr_scheduler.step_size == 0: print(self.optimizer.defaults['lr']) return accum_iter
def validate(eval_loader, model, epoch, print_freq, type_string=''): start_time = time.time() class_criterion = nn.CrossEntropyLoss().cuda() meters = AverageMeterSet() # switch to evaluate mode model.eval() end = time.time() for i, data in enumerate(eval_loader): input, target = data[0], data[1] meters.update('data_time', time.time() - end) with torch.no_grad(): input = input.cuda() target = target.cuda() # compute output model_out = model(input) if isinstance(model_out, tuple): feat, class_logit = model_out else: class_logit = model_out class_loss = class_criterion(class_logit, target) # measure accuracy and record loss prec1, prec5 = accuracy(class_logit.data, target.data, topk=(1, 5)) minibatch_size = len(target) meters.update('class_loss', class_loss.item(), minibatch_size) meters.update('top1', prec1, minibatch_size) meters.update('top5', prec5, minibatch_size) # measure elapsed time meters.update('batch_time', time.time() - end) end = time.time() if i % print_freq == 0: print( 'Test: [{0}/{1}]\t' 'Time {meters[batch_time]:.3f}\t' 'Data {meters[data_time]:.3f}\t' 'Class {meters[class_loss]:.4f}\t' 'Prec@1 {meters[top1]:.3f}\t' 'Prec@5 {meters[top5]:.3f}' .format(i, len(eval_loader), meters=meters)) print(' * Prec@1 {top1.avg:.3f}\tPrec@5 {top5.avg:.3f}' .format(top1=meters['top1'], top5=meters['top5'])) print("--- testing epoch in {} seconds ---".format(time.time() - start_time)) return meters['top1'].avg
def validate_voc(eval_loader, model, thre, epoch, print_freq): start_time = time.time() meters = AverageMeterSet() model.eval() Sig = torch.nn.Sigmoid() end = time.time() preds = [] targets = [] for i, data in enumerate(eval_loader): input, target = data[0], data[1] meters.update('data_time', time.time() - end) # compute output with torch.no_grad(): output = Sig(model(input.cuda())).cpu() # for mAP calculation preds.append(output.cpu()) targets.append(target.cpu()) # measure accuracy and record loss this_prec, this_rec = prec_recall_for_batch(output.data, target, thre) meters.update('prec', float(this_prec), input.size(0)) meters.update('rec', float(this_rec), input.size(0)) # measure elapsed time meters.update('batch_time', time.time() - end) end = time.time() if i % print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {meters[batch_time]:.3f}\t' 'Data {meters[data_time]:.3f}\t' 'Prec {meters[prec]:.2f}\t' 'Recall {meters[rec]:.2f}' .format(i, len(eval_loader), meters=meters)) targs = torch.cat(targets).numpy() preds = torch.cat(preds).numpy() AP = eval_loader.dataset.eval(preds, targs) eval_loader.dataset.show_AP(AP) mAP = 100 * AP.mean() print(" * TEST [{}] mAP: {}".format(epoch, mAP)) print("--- testing epoch in {} seconds ---".format(time.time() - start_time)) return mAP
def validate(self, epoch, accum_iter): self.model.eval() average_meter_set = AverageMeterSet() with torch.no_grad(): tqdm_dataloader = tqdm(self.val_loader) for batch_idx, batch in enumerate(tqdm_dataloader): batch = [x.to(self.device) for x in batch] metrics = self.calculate_metrics(batch) for k, v in metrics.items(): average_meter_set.update(k, v) description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] + \ ['Recall@%d' % k for k in self.metric_ks[:3]] if 'accuracy' in self.args.metrics_to_log: description_metrics = ['accuracy'] description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics) description = description.replace('NDCG', 'N').replace('Recall', 'R') description = description.format(*(average_meter_set[k].avg for k in description_metrics)) tqdm_dataloader.set_description(description) log_data = { 'state_dict': (self._create_state_dict()), 'epoch': epoch + 1, 'accum_iter': accum_iter, 'user_embedding': self.model.embedding.user.weight.cpu().detach().numpy() if self.args.dump_useritem_embeddings == 'True' and self.model.embedding.user is not None else None, 'item_embedding': self.model.embedding.token.weight.cpu().detach().numpy() if self.args.dump_useritem_embeddings == 'True' else None, } log_data.update(average_meter_set.averages()) self.log_extra_val_info(log_data) self.logger_service.log_val(log_data)
def train_one_epoch(self, epoch, accum_iter, train_loader, **kwargs): self.model.train() average_meter_set = AverageMeterSet() num_instance = 0 tqdm_dataloader = tqdm( train_loader) if not self.pilot else train_loader for batch_idx, batch in enumerate(tqdm_dataloader): if self.pilot and batch_idx >= self.pilot_batch_cnt: # print('Break training due to pilot mode') break batch_size = next(iter(batch.values())).size(0) batch = {k: v.to(self.device) for k, v in batch.items()} num_instance += batch_size if self.total_anneal_steps > 0: anneal = min(self.anneal_cap, 1. * self.update_count / self.total_anneal_steps) else: anneal = self.anneal_cap self.optimizer.zero_grad() loss = self.calculate_loss(batch, anneal) if isinstance(loss, tuple): loss, extra_info = loss for k, v in extra_info.items(): average_meter_set.update(k, v) loss.backward() if self.clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm) self.optimizer.step() self.update_count += 1 average_meter_set.update('loss', loss.item()) if not self.pilot: tqdm_dataloader.set_description( 'Epoch {}, loss {:.3f} '.format( epoch, average_meter_set['loss'].avg)) accum_iter += batch_size if self._needs_to_log(accum_iter): if not self.pilot: tqdm_dataloader.set_description('Logging') log_data = { # 'state_dict': (self._create_state_dict()), 'epoch': epoch, 'accum_iter': accum_iter, } log_data.update(average_meter_set.averages()) log_data.update(kwargs) self.log_extra_train_info(log_data) self.logger_service.log_train(log_data) log_data = { # 'state_dict': (self._create_state_dict()), 'epoch': epoch, 'accum_iter': accum_iter, 'num_train_instance': num_instance, } log_data.update(average_meter_set.averages()) log_data.update(kwargs) self.log_extra_train_info(log_data) self.logger_service.log_train(log_data) return accum_iter
def train(train_loader, model, criterion, optimizer, epoch, args): start_time = time.time() meters = AverageMeterSet() Sig = torch.nn.Sigmoid() # switch to train mode """ Switch to eval mode: Under the protocol of linear classification on frozen features/models, it is not legitimate to change any part of the pre-trained model. BatchNorm in train mode may revise running mean/std (even if it receives no gradient), which are part of the model parameters too. """ if args.linear_eval: model.eval() # switch to train mode else: #model.train() model.eval() #model.train() end = time.time() for i, data in enumerate(train_loader): images, target = data[0], data[1] meters.update('data_time', time.time() - end) adjust_learning_rate(optimizer, epoch, i, len(train_loader), args) # measure data loading time #data_time.update(time.time() - end) if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) if torch.cuda.is_available(): target = target.float().cuda(args.gpu, non_blocking=True) # compute output model_output = model(images) if isinstance(model_output, tuple): feat, class_logit = model_output else: class_logit = model_output output = Sig(class_logit) loss = criterion(class_logit, target) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() meters.update('lr', optimizer.param_groups[0]['lr']) meters.update('class_loss', loss.item()) # measure accuracy and record loss this_prec, this_rec = prec_recall_for_batch(output.data, target, args.thre) meters.update('prec', float(this_prec), images.size(0)) meters.update('rec', float(this_rec), images.size(0)) # measure elapsed time meters.update('batch_time', time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {meters[batch_time]:.3f}\t' 'Data {meters[data_time]:.3f}\t' 'Class {meters[class_loss]:.4f}\t' 'Prec {meters[prec]:.3f}\t' 'Rec {meters[rec]:.3f}\t'.format(epoch, i, len(train_loader), meters=meters)) print(' * TRAIN Prec {:.3f} ({:.1f}/{:.1f}) Recall {:.3f} ({:.1f}/{:.1f})'. format(meters['prec'].avg, meters['prec'].sum / 100, meters['prec'].count, meters['rec'].avg, meters['rec'].sum / 100, meters['rec'].count)) print("--- training epoch in {} seconds ---".format(time.time() - start_time))
def train(train_loader, model, class_criterion, optimizer, epoch): global global_step start_time = time.time() Sig = torch.nn.Sigmoid() meters = AverageMeterSet() # switch to train mode model.train() end = time.time() for i, data in enumerate(train_loader): input, target = data[0], data[1] # measure data loading time meters.update('data_time', time.time() - end) adjust_learning_rate(optimizer, epoch, i, len(train_loader)) input, target = input.cuda(), target.float().cuda() model_out = model(input) if isinstance(model_out, tuple): feat, class_logit = model_out else: class_logit = model_out output = Sig(class_logit) # output = class_logit class_loss = class_criterion(class_logit, target) # compute gradient and do SGD step optimizer.zero_grad() class_loss.backward() # nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) optimizer.step() global_step += 1 meters.update('lr', optimizer.param_groups[0]['lr']) minibatch_size = len(target) meters.update('class_loss', class_loss.item()) # measure accuracy and record loss this_prec, this_rec = prec_recall_for_batch(output.data, target, thre) meters.update('prec', float(this_prec), input.size(0)) meters.update('rec', float(this_rec), input.size(0)) # measure elapsed time meters.update('batch_time', time.time() - end) end = time.time() if i % print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {meters[batch_time]:.3f}\t' 'Data {meters[data_time]:.3f}\t' 'Class {meters[class_loss]:.4f}\t' 'Prec {meters[prec]:.3f}\t' 'Rec {meters[rec]:.3f}\t'.format(epoch, i, len(train_loader), meters=meters)) print(' * TRAIN Prec {:.3f} ({:.1f}/{:.1f}) Recall {:.3f} ({:.1f}/{:.1f})'. format(meters['prec'].avg, meters['prec'].sum / 100, meters['prec'].count, meters['rec'].avg, meters['rec'].sum / 100, meters['rec'].count)) print("--- training epoch in {} seconds ---".format(time.time() - start_time))