def train_one_epoch(self, epoch, accum_iter, train_loader, **kwargs): self.model.train() average_meter_set = AverageMeterSet() num_instance = 0 tqdm_dataloader = tqdm(train_loader) if not self.pilot else train_loader for batch_idx, batch in enumerate(tqdm_dataloader): if self.pilot and batch_idx >= self.pilot_batch_cnt: # print('Break training due to pilot mode') break batch_size = next(iter(batch.values())).size(0) batch = {k:v.to(self.device) for k, v in batch.items()} num_instance += batch_size self.optimizer.zero_grad() loss = self.calculate_loss(batch) if isinstance(loss, tuple): loss, extra_info = loss for k, v in extra_info.items(): average_meter_set.update(k, v) loss.backward() if self.clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm) self.optimizer.step() average_meter_set.update('loss', loss.item()) if not self.pilot: tqdm_dataloader.set_description( 'Epoch {}, loss {:.3f} '.format(epoch, average_meter_set['loss'].avg)) accum_iter += batch_size if self._needs_to_log(accum_iter): if not self.pilot: tqdm_dataloader.set_description('Logging') log_data = { # 'state_dict': (self._create_state_dict()), 'epoch': epoch, 'accum_iter': accum_iter, } log_data.update(average_meter_set.averages()) log_data.update(kwargs) self.log_extra_train_info(log_data) self.logger_service.log_train(log_data) log_data = { # 'state_dict': (self._create_state_dict()), 'epoch': epoch, 'accum_iter': accum_iter, 'num_train_instance': num_instance, } log_data.update(average_meter_set.averages()) log_data.update(kwargs) self.log_extra_train_info(log_data) self.logger_service.log_train(log_data) return accum_iter
def validate(self, epoch, accum_iter, mode, doLog=True, **kwargs): if mode == 'val': loader = self.val_loader elif mode == 'test': loader = self.test_loader else: raise ValueError self.model.eval() average_meter_set = AverageMeterSet() num_instance = 0 with torch.no_grad(): tqdm_dataloader = tqdm(loader) if not self.pilot else loader for batch_idx, batch in enumerate(tqdm_dataloader): if self.pilot and batch_idx >= self.pilot_batch_cnt: # print('Break validation due to pilot mode') break batch = {k: v.to(self.device) for k, v in batch.items()} batch_size = next(iter(batch.values())).size(0) num_instance += batch_size metrics = self.calculate_metrics(batch) for k, v in metrics.items(): average_meter_set.update(k, v) if not self.pilot: description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\ ['Recall@%d' % k for k in self.metric_ks[:3]] description = '{}: '.format(mode.capitalize()) + ', '.join( s + ' {:.3f}' for s in description_metrics) description = description.replace('NDCG', 'N').replace( 'Recall', 'R') description = description.format( *(average_meter_set[k].avg for k in description_metrics)) tqdm_dataloader.set_description(description) log_data = { 'state_dict': (self._create_state_dict(epoch, accum_iter)), 'epoch': epoch, 'accum_iter': accum_iter, 'num_eval_instance': num_instance, } log_data.update(average_meter_set.averages()) log_data.update(kwargs) if doLog: if mode == 'val': self.logger_service.log_val(log_data) elif mode == 'test': self.logger_service.log_test(log_data) else: raise ValueError return log_data
def validate(self, epoch, accum_iter, mode, doLog=True, **kwargs): print(' ') print('meantime / trainers / base.py / AbstractTrainer.validate is') ### My Code Start### my_final_result = -1 * torch.ones(1, 205) my_dtype = torch.cuda.FloatTensor if torch.cuda.is_available( ) else torch.FloatTensor my_final_result = my_final_result.to(self.device) ### My Code End### if mode == 'val': loader = self.val_loader elif mode == 'test': loader = self.test_loader else: raise ValueError self.model.eval() average_meter_set = AverageMeterSet() num_instance = 0 with torch.no_grad(): tqdm_dataloader = tqdm(loader) if not self.pilot else loader for batch_idx, batch in enumerate(tqdm_dataloader): if self.pilot and batch_idx >= self.pilot_batch_cnt: # print('Break validation due to pilot mode') break batch = {k: v.to(self.device) for k, v in batch.items()} batch_size = next(iter(batch.values())).size(0) num_instance += batch_size metrics = self.calculate_metrics(batch) ''' print(' ') print(' ') print('batch idx is') print(batch_idx) print('batch : token, [Batch_size x seq_len]') print(batch['tokens']) print('batch : candidate, [Batch_size x 100_negative_samples is]') print(batch['candidates']) print('batch : labels, [Batch_size x (1 + 100)_labels is]') print(batch['labels']) ###### MY CODE ###### #print('epoch is') # 20201214 #print(epoch) #print('batch is') ##### My code 20201119 #print(batch) #print('true answer is') #print(batch['candidates'][:,0]) MY_SCORES, MY_LABELS, MY_CUT, MY_HITS = self.NEW_CODE_PRINT_PREDICTION(batch) ##### My code 20201119 my_len = len(MY_CUT) print("MY_SCORES is, [Batch_size x (1 + 100)]") print(MY_SCORES) ##### My code 20201119 print(' ') #print("MY_LABELS") #print(MY_LABELS) ##### My code 20201119 print("MY_CUT(prediction) is, [Batch_size x 1]") print(MY_CUT) ##### My code 20201119 print(' ') print("MY_HITS is, [Batch_size x 1]") print(MY_HITS) ##### My code 20201119 print(' ') #print('MY_SCORES shape') #print(MY_SCORES.shape) #print(' ') #print('MY_LABELS shape') #print(MY_LABELS.shape) #print(' ') #print('MY_CUT shape') #print(MY_CUT.shape) #print('MY_HITS.shape') #print(MY_HITS.shape) ''' #my_epoch = epoch #my_batch_idx = batch_idx #my_batch_token = batch['tokens'] #my_batch_candidate = batch['candidates'] #my_batch_score = MY_SCORES #my_batch_cut = MY_CUT #my_hit = MY_HITS #print('true answer is') #print(batch['candidates'][:,0]) #my_epoch1 = torch.Tensor([my_epoch]*batch_size).reshape(batch_size,1) #batch_idx1 = torch.Tensor([my_batch_idx]*batch_size).reshape(batch_size,1) #batch_idx2 = torch.Tensor(range(batch_size)).reshape(batch_size,1) #my_batch_token = my_batch_token.to(self.device) #my_candi = batch['candidates'][:,0] #my_candi = my_candi.to(self.device) #my_cut = MY_CUT #my_cut = my_cut.to(self.device) #my_epoch1 = my_epoch1.type(my_dtype) #batch_idx1 = batch_idx1.type(my_dtype) #batch_idx2 = batch_idx2.type(my_dtype) #my_batch_token = my_batch_token.type(my_dtype) #my_candi = my_candi.type(my_dtype).reshape(batch_size,1) #my_hit = my_hit.type(my_dtype) #my_cut = my_cut.type(my_dtype) #print('###') #print('my batch token shape') #print(my_batch_token.shape) #print(my_candi.shape) #print(my_hit.shape) #print('batch_idx1') #print(batch_idx1) #print(batch_idx2) #print('my_epoch') #print(my_epoch) #my_epoch_result = torch.cat([my_epoch1, batch_idx1, batch_idx2, my_batch_token, my_candi, my_cut], 1) #my_final_result = torch.cat([my_final_result, my_epoch_result], 0) ###### MY CODE ###### for k, v in metrics.items(): average_meter_set.update(k, v) if not self.pilot: description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\ ['Recall@%d' % k for k in self.metric_ks[:3]] description = '{}: '.format(mode.capitalize()) + ', '.join( s + ' {:.3f}' for s in description_metrics) description = description.replace('NDCG', 'N').replace( 'Recall', 'R') description = description.format( *(average_meter_set[k].avg for k in description_metrics)) tqdm_dataloader.set_description(description) log_data = { 'state_dict': (self._create_state_dict(epoch, accum_iter)), 'epoch': epoch, 'accum_iter': accum_iter, 'num_eval_instance': num_instance, } log_data.update(average_meter_set.averages()) log_data.update(kwargs) if doLog: if mode == 'val': self.logger_service.log_val(log_data) elif mode == 'test': self.logger_service.log_test(log_data) else: raise ValueError ###### MY CODE ###### #ts = time.time() #my_final_result = my_final_result.cpu() #my_final_result_np = my_final_result.numpy() #my_final_result_df = pd.DataFrame(my_final_result_np) #FILENAME = 'my_final_result' + mode + str(epoch) + 'time' + str(ts) + '_' + '.csv' #my_final_result_df.to_csv(FILENAME) ###### MY CODE ###### return log_data