def start(self, train_loader, train_set, valid_set=None, valid_loader=None): self.train_num_batches = math.ceil(train_set.num_images / float(self.cf.train_batch_size)) self.val_num_batches = 0 if valid_set is None else math.ceil( valid_set.num_images / float(self.cf.valid_batch_size)) # Define early stopping control if self.cf.early_stopping: early_stopping = EarlyStopping(self.cf) else: early_stopping = None # Train process for epoch in tqdm(range(self.curr_epoch, self.cf.epochs + 1), desc='Training', file=sys.stdout): # Shuffle train data train_set.update_indexes() # Initialize logger self.logger_stats.write('\n\t ------ Epoch: ' + str(epoch) + ' ------ \n') # Initialize stats self.stats.epoch = epoch self.train_loss = AverageMeter() self.confm_list = np.zeros((self.cf.num_classes, self.cf.num_classes)) # Train epoch self.training_loop(epoch, train_loader) # Save stats self.stats.train.conf_m = self.confm_list self.compute_stats(self.confm_list, self.train_loss) self.save_stats_epoch(epoch) self.logger_stats.write_stat(self.stats.train, epoch, os.path.join(self.cf.train_json_path, 'train_epoch_' + str(epoch) + '.json')) # Validate epoch self.validate_epoch(valid_set, valid_loader, early_stopping, epoch) # Update scheduler if self.model.scheduler is not None: self.model.scheduler.step(self.stats.val.loss) # Saving model if score improvement new_best = self.model.save(self.stats) if new_best: self.logger_stats.write_best_stats(self.stats, epoch, self.cf.best_json_file) if self.stop: return # Save model without training if self.cf.epochs == 0: self.model.save_model()
def train(self, setting): """Training Function. Args: setting: Name used to save the model Returns: model: Trained model """ # Load different datasets train_loader = self._get_data(flag='train') vali_loader = self._get_data(flag='val') test_loader = self._get_data(flag='test') path = os.path.join(self.args.checkpoints, setting) if not os.path.exists(path): os.makedirs(path) time_now = time.time() train_steps = len(train_loader) early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) # Setting optimizer and loss functions model_optim = self._select_optimizer() criterion = nn.MSELoss() all_training_loss = [] all_validation_loss = [] # Training Loop for epoch in range(self.args.train_epochs): iter_count = 0 train_loss = [] self.model.train() epoch_time = time.time() for i, (batch_x, batch_y) in enumerate(train_loader): iter_count += 1 model_optim.zero_grad() if self.model_type == 'SDT': (pred, panelty), true = self._process_one_batch( batch_x, batch_y) loss = criterion(pred, true) + panelty else: pred, true = self._process_one_batch(batch_x, batch_y) loss = criterion(pred, true) train_loss.append(loss.item()) if (i + 1) % 100 == 0: print( '\titers: {0}/{1}, epoch: {2} | loss: {3:.7f}'.format( i + 1, train_steps, epoch + 1, loss.item())) speed = (time.time() - time_now) / iter_count left_time = speed * ( (self.args.train_epochs - epoch) * train_steps - i) print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format( speed, left_time)) iter_count = 0 time_now = time.time() loss.backward() model_optim.step() print('Epoch: {} cost time: {}'.format(epoch + 1, time.time() - epoch_time)) train_loss = np.average(train_loss) all_training_loss.append(train_loss) vali_loss = self.vali(vali_loader, criterion) all_validation_loss.append(vali_loss) test_loss = self.vali(test_loader, criterion) print( 'Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}' .format(epoch + 1, train_steps, train_loss, vali_loss, test_loss)) early_stopping(vali_loss, self.model, path) # Plotting train and validation loss if ((epoch + 1) % 5 == 0 and self.args.plot): check_folder = os.path.isdir(self.args.plot_dir) # If folder doesn't exist, then create it. if not check_folder: os.makedirs(self.args.plot_dir) plt.figure() plt.plot(all_training_loss, label='train loss') plt.plot(all_validation_loss, label='Val loss') plt.legend() plt.savefig(self.args.plot_dir + setting + '.png') plt.show() plt.close() # If ran out of patience stop training if early_stopping.early_stop: if self.args.plot: plt.figure() plt.plot(all_training_loss, label='train loss') plt.plot(all_validation_loss, label='Val loss') plt.legend() plt.savefig(self.args.plot_dir + setting + '.png') plt.show() print('Early stopping') break best_model_path = path + '/' + 'checkpoint.pth' self.model.load_state_dict(torch.load(best_model_path)) return self.model
def train(self, setting): train_data, train_loader = self._get_data(flag='train') vali_data, vali_loader = self._get_data(flag='val') test_data, test_loader = self._get_data(flag='test') path = os.path.join(self.args.checkpoints, setting) if not os.path.exists(path): os.makedirs(path) time_now = time.time() train_steps = len(train_loader) early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) model_optim = self._select_optimizer() criterion = self._select_criterion() if self.args.use_amp: scaler = torch.cuda.amp.GradScaler() for epoch in range(self.args.train_epochs): iter_count = 0 train_loss = [] self.model.train() epoch_time = time.time() for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): iter_count += 1 model_optim.zero_grad() batch_x = batch_x.float().to(self.device) batch_y = batch_y.float() batch_x_mark = batch_x_mark.float().to(self.device) batch_y_mark = batch_y_mark.float().to(self.device) # decoder input dec_inp = torch.zeros_like( batch_y[:, -self.args.pred_len:, :]).float() dec_inp = torch.cat( [batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) # encoder - decoder if self.args.use_amp: with torch.cuda.amp.autocast(): if self.args.output_attention: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] else: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) f_dim = -1 if self.args.features == 'MS' else 0 batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) loss = criterion(outputs, batch_y) train_loss.append(loss.item()) else: if self.args.output_attention: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] else: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) if self.args.inverse: outputs = train_data.inverse_transform(outputs) f_dim = -1 if self.args.features == 'MS' else 0 batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) loss = criterion(outputs, batch_y) train_loss.append(loss.item()) if (i + 1) % 100 == 0: print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format( i + 1, epoch + 1, loss.item())) speed = (time.time() - time_now) / iter_count left_time = speed * ( (self.args.train_epochs - epoch) * train_steps - i) print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format( speed, left_time)) iter_count = 0 time_now = time.time() if self.args.use_amp: scaler.scale(loss).backward() scaler.step(model_optim) scaler.update() else: loss.backward() model_optim.step() print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) train_loss = np.average(train_loss) vali_loss = self.vali(vali_data, vali_loader, criterion) test_loss = self.vali(test_data, test_loader, criterion) print( "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}" .format(epoch + 1, train_steps, train_loss, vali_loss, test_loss)) early_stopping(vali_loss, self.model, path) if early_stopping.early_stop: print("Early stopping") break adjust_learning_rate(model_optim, epoch + 1, self.args) best_model_path = path + '/' + 'checkpoint.pth' self.model.load_state_dict(torch.load(best_model_path)) return self.model
def train(self, ii, logger): train_data, train_loader = self._get_data(flag='train') vali_data, vali_loader = self._get_data(flag='val') next_data, next_loader = self._get_data(flag='train') test_data, test_loader = self._get_data(flag='test') if self.args.rank == 1: train_data, train_loader = self._get_data(flag='train') path = os.path.join(self.args.path, str(ii)) try: os.mkdir(path) except FileExistsError: pass time_now = time.time() train_steps = len(train_loader) early_stopping = EarlyStopping(patience=self.args.patience, verbose=True, rank=self.args.rank) W_optim, A_optim = self._select_optimizer() criterion = self._select_criterion() if self.args.use_amp: scaler = torch.cuda.amp.GradScaler() for epoch in range(self.args.train_epochs): iter_count = 0 train_loss = [] rate_counter = AverageMeter() Ag_counter, A_counter, Wg_counter, W_counter = AverageMeter( ), AverageMeter(), AverageMeter(), AverageMeter() self.model.train() epoch_time = time.time() for i, (trn_data, val_data, next_data) in enumerate( zip(train_loader, vali_loader, next_loader)): for i in range(len(trn_data)): trn_data[i], val_data[i], next_data[i] = trn_data[i].float( ).to(self.device), val_data[i].float().to( self.device), next_data[i].float().to(self.device) iter_count += 1 A_optim.zero_grad() rate = self.arch.unrolled_backward( self.args, trn_data, val_data, next_data, W_optim.param_groups[0]['lr'], W_optim) rate_counter.update(rate) # for r in range(1, self.args.world_size): # for n, h in self.model.named_H(): # if "proj.{}".format(r) in n: # if self.args.rank <= r: # with torch.no_grad(): # dist.all_reduce(h.grad) # h.grad *= self.args.world_size/r+1 # else: # z = torch.zeros(h.shape).to(self.device) # dist.all_reduce(z) for a in self.model.A(): with torch.no_grad(): dist.all_reduce(a.grad) a_g_norm = 0 a_norm = 0 n = 0 for a in self.model.A(): a_g_norm += a.grad.mean() a_norm += a.mean() n += 1 Ag_counter.update(a_g_norm / n) A_counter.update(a_norm / n) A_optim.step() W_optim.zero_grad() pred, true = self._process_one_batch(train_data, trn_data) loss = criterion(pred, true) train_loss.append(loss.item()) if (i + 1) % 100 == 0: logger.info( "\tR{0} iters: {1}, epoch: {2} | loss: {3:.7f}".format( self.args.rank, i + 1, epoch + 1, loss.item())) speed = (time.time() - time_now) / iter_count left_time = speed * ( (self.args.train_epochs - epoch) * train_steps - i) logger.info( '\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format( speed, left_time)) iter_count = 0 time_now = time.time() if self.args.use_amp: scaler.scale(loss).backward() scaler.step(W_optim) scaler.update() else: loss.backward() w_g_norm = 0 w_norm = 0 n = 0 for w in self.model.W(): w_g_norm += w.grad.mean() w_norm += w.mean() n += 1 Wg_counter.update(w_g_norm / n) W_counter.update(w_norm / n) W_optim.step() logger.info("R{} Epoch: {} W:{} Wg:{} A:{} Ag:{} rate{}".format( self.args.rank, epoch + 1, W_counter.avg, Wg_counter.avg, A_counter.avg, Ag_counter.avg, rate_counter.avg)) logger.info("R{} Epoch: {} cost time: {}".format( self.args.rank, epoch + 1, time.time() - epoch_time)) train_loss = np.average(train_loss) vali_loss = self.vali(vali_data, vali_loader, criterion) test_loss = self.vali(test_data, test_loader, criterion) logger.info( "R{0} Epoch: {1}, Steps: {2} | Train Loss: {3:.7f} Vali Loss: {4:.7f} Test Loss: {5:.7f}" .format(self.args.rank, epoch + 1, train_steps, train_loss, vali_loss, test_loss)) early_stopping(vali_loss, self.model, path) flag = torch.tensor( [1]) if early_stopping.early_stop else torch.tensor([0]) flag = flag.to(self.device) flags = [ torch.tensor([1]).to(self.device), torch.tensor([1]).to(self.device) ] dist.all_gather(flags, flag) if flags[0].item() == 1 and flags[1].item() == 1: logger.info("Early stopping") break adjust_learning_rate(W_optim, epoch + 1, self.args) best_model_path = path + '/' + '{}_checkpoint.pth'.format( self.args.rank) self.model.load_state_dict(torch.load(best_model_path)) return self.model
def start(self, train_loader, train_set, valid_set=None, valid_loader=None): self.train_num_batches = math.ceil(train_set.num_images / float(self.cf.train_batch_size)) self.val_num_batches = 0 if valid_set is None else math.ceil(valid_set.num_images / \ float(self.cf.valid_batch_size)) # Define early stopping control if self.cf.early_stopping: early_stopping = EarlyStopping(self.cf) else: early_stopping = None prev_msg = '\nTotal estimated training time...\n' self.global_bar = ProgressBar( (self.cf.epochs + 1 - self.curr_epoch) * (self.train_num_batches + self.val_num_batches), lenBar=20) self.global_bar.set_prev_msg(prev_msg) # Train process for epoch in range(self.curr_epoch, self.cf.epochs + 1): # Shuffle train data train_set.update_indexes() # Initialize logger epoch_time = time.time() self.logger_stats.write('\t ------ Epoch: ' + str(epoch) + ' ------ \n') # Initialize epoch progress bar self.msg.accum_str = '\n\nEpoch %d/%d estimated time...\n' % \ (epoch, self.cf.epochs) epoch_bar = ProgressBar(self.train_num_batches, lenBar=20) epoch_bar.update(show=False) # Initialize stats self.stats.epoch = epoch self.train_loss = AverageMeter() self.confm_list = np.zeros( (self.cf.num_classes, self.cf.num_classes)) # Train epoch self.training_loop(epoch, train_loader, epoch_bar) # Save stats self.stats.train.conf_m = self.confm_list self.compute_stats(np.asarray(self.confm_list), self.train_loss) self.save_stats_epoch(epoch) self.logger_stats.write_stat( self.stats.train, epoch, os.path.join(self.cf.train_json_path, 'train_epoch_' + str(epoch) + '.json')) # Validate epoch self.validate_epoch(valid_set, valid_loader, early_stopping, epoch, self.global_bar) # Update scheduler if self.model.scheduler is not None: self.model.scheduler.step(self.stats.val.loss) # Saving model if score improvement new_best = self.model.save(self.stats) if new_best: self.logger_stats.write_best_stats(self.stats, epoch, self.cf.best_json_file) # Update display values self.update_messages(epoch, epoch_time, new_best) if self.stop: return # Save model without training if self.cf.epochs == 0: self.model.save_model()
def train(self, setting): # データを取得, pytorchのライブラリを活用 # data_set, data_loader train_data, train_loader = self._get_data(flag='train') vali_data, vali_loader = self._get_data(flag='val') # val?? test_data, test_loader = self._get_data(flag='test') path = os.path.join(self.args.checkpoints, setting) if not os.path.exists(path): os.makedirs(path) time_now = time.time() train_steps = len(train_loader) early_stopping = EarlyStopping( patience=self.args.patience, verbose=True) model_optim = self._select_optimizer() criterion = self._select_criterion() # lossの計算方法 if self.args.use_amp: scaler = torch.cuda.amp.GradScaler() for epoch in range(self.args.train_epochs): # epoch 初期値は 6 iter_count = 0 train_loss = [] train_loss_avg_list = [] self.model.train() # 1. modelのtrainを呼び出す epoch_time = time.time() # データローダをfor inで回すことによって扱いやすくなる for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(tqdm(train_loader)): # print("Shape of batch_x") # print(batch_x.shape) iter_count += 1 model_optim.zero_grad() # 勾配の初期化 # 学習時は model.eval()を呼ばない # ここからが本質 xとyが何者なのか pred, true = self._process_one_batch( train_data, batch_x, batch_y, batch_x_mark, batch_y_mark) # 現在の出力と正しい値 if self.args.interpret is True: # 高いattentionをmaskしたときのpred mask_attention_pred, true = self._process_one_batch( train_data, batch_x, batch_y, batch_x_mark, batch_y_mark) # 現在の出力と正しい値 # mask_attention_output loss = criterion(pred, true, mask_attention_pred) else: loss = criterion(pred, true) # 誤差計算 train_loss.append(loss.item()) if (i+1) % 100 == 0: print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format( i + 1, epoch + 1, loss.item())) speed = (time.time()-time_now)/iter_count left_time = speed * \ ((self.args.train_epochs - epoch)*train_steps - i) print( '\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) iter_count = 0 time_now = time.time() if self.args.use_amp: scaler.scale(loss).backward() scaler.step(model_optim) scaler.update() else: loss.backward() # 誤差逆伝搬 model_optim.step() # 更新 # loss のデータをsaveしたい print("Epoch: {} cost time: {}".format( epoch+1, time.time()-epoch_time)) train_loss_avg = np.average(train_loss) train_loss_avg_list.append(train_loss_avg) vali_loss = self.vali(vali_data, vali_loader, criterion) test_loss = self.vali(test_data, test_loader, criterion) print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format( epoch + 1, train_steps, train_loss_avg, vali_loss, test_loss)) early_stopping(vali_loss, self.model, path) if early_stopping.early_stop: print("Early stopping") break adjust_learning_rate(model_optim, epoch+1, self.args) # line notify if self.args.notify: send_line_notify(message="Epoch: {} cost time: {}".format( epoch+1, time.time()-epoch_time)+"Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format( epoch + 1, train_steps, train_loss_avg, vali_loss, test_loss)) best_model_path = path+'/'+'checkpoint.pth' self.model.load_state_dict(torch.load(best_model_path)) # いつセーブした? folder_path = './results/' + setting + '/' if not os.path.exists(folder_path): os.makedirs(folder_path) # loss のsave np.save(folder_path+'/'+'train_loss_avg_list.npy', train_loss_avg_list) return self.model
def train(self, setting): train_data, train_loader = self._get_data(flag = 'train') vali_data, vali_loader = self._get_data(flag = 'val') test_data, test_loader = self._get_data(flag = 'test') path = os.path.join(self.args.checkpoints, setting) if not os.path.exists(path): os.makedirs(path) time_now = time.time() train_steps = len(train_loader) early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) model_optim = self._select_optimizer() criterion = self._select_criterion() if self.args.use_amp: scaler = torch.cuda.amp.GradScaler() for epoch in range(self.args.train_epochs): iter_count = 0 train_loss = [] self.model.train() epoch_time = time.time() for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(train_loader): iter_count += 1 model_optim.zero_grad() pred, true = self._process_one_batch( train_data, batch_x, batch_y, batch_x_mark, batch_y_mark) loss = criterion(pred, true) train_loss.append(loss.item()) if (i+1) % 100==0: print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) speed = (time.time()-time_now)/iter_count left_time = speed*((self.args.train_epochs - epoch)*train_steps - i) print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) iter_count = 0 time_now = time.time() if self.args.use_amp: scaler.scale(loss).backward() scaler.step(model_optim) scaler.update() else: loss.backward() model_optim.step() print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time)) train_loss = np.average(train_loss) vali_loss = self.vali(vali_data, vali_loader, criterion) test_loss = self.vali(test_data, test_loader, criterion) print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format( epoch + 1, train_steps, train_loss, vali_loss, test_loss)) early_stopping(vali_loss, self.model, path) if early_stopping.early_stop: print("Early stopping") break adjust_learning_rate(model_optim, epoch+1, self.args) best_model_path = path+'/'+'checkpoint.pth' self.model.load_state_dict(torch.load(best_model_path)) return self.model
def train(self, setting): train_data, train_loader = self._get_data(flag='train') valid_data, valid_loader = self._get_data(flag='val') print(f'number of batches in train data={len(train_loader)}') print(f'number of batches in valid data={len(valid_loader)}') path = './checkpoints/'+setting if not os.path.exists(path): os.makedirs(path) time_now = time.time() train_steps = len(train_loader) early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) model_optim = self._select_optimizer() criterion = self._select_criterion(self.args.data) # print(self.model) best_utility = 0 for epoch in range(self.args.train_epochs): iter_count = 0 train_loss = [] train_auc = [] self.model.train() # batch_x: (batch_size, seq_len, n_features) # batch_y: (batch_size, label_len + pred_len, n_features) # batch_x_mark: (batch_size, seq_len) # batch_y_mark: (batch_size, label_len + pred_len) for i, (s_begin, s_end, r_begin, r_end, batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): # for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): # print(f'{i} s_begin: ', s_begin) # print(f'{i} s_end: ', s_end) # print(f'{i} r_begin: ', r_begin) # print(f'{i} r_end: ',r_end) iter_count += 1 # print(f'x : {batch_x}') # print(f'y : {batch_y}') # print(f'x_mark : {batch_x_mark}') # print(f'y_mark : {batch_y_mark}') model_optim.zero_grad() batch_x = batch_x.float().to(self.device) batch_y = batch_y.float() batch_x_mark = batch_x_mark.float().to(self.device) batch_y_mark = batch_y_mark.float().to(self.device) # decoder input dec_inp = torch.zeros_like(batch_y[:,-self.args.pred_len:,:]).float() dec_inp = torch.cat([batch_y[:,:self.args.label_len,:], dec_inp], dim=1).float().to(self.device) # encoder - decoder if self.args.output_attention: y_pred = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] else: y_pred = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) f_dim = -1 if self.args.features=='MS' else 0 y_true = batch_y[:,-self.args.pred_len:,-self.args.c_out:].to(self.device) # y_true = batch_y[:,-self.args.pred_len:,f_dim:].to(self.device) loss = criterion(y_pred, y_true) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) model_optim.step() # print(y_pred) y_pred = np.where(y_pred.sigmoid().detach().cpu().numpy() >= 0.5, 1, 0).astype(int) y_true = np.where(y_true.sigmoid().detach().cpu().numpy() >= 0.5, 1, 0).astype(int) # y_true = y_true.detach().cpu().numpy().astype(int) # print('y_true: ', np.median(y_true, axis=1)) # print('y_pred: ', np.median(y_pred, axis=1)) train_auc.append(roc_auc_score(np.median(y_true, axis=1), np.median(y_pred, axis=1))) loss = loss.item() train_loss.append(loss) if (i+1) % 100 == 0: print(f'\ttrain_iters={i+1} | epoch={epoch+1} | ' \ f'batch_loss={loss:.4f} | running_loss={np.mean(train_loss):.4f} | running_auc={np.mean(train_auc):.4f}') speed = (time.time()-time_now)/iter_count left_time = speed*((self.args.train_epochs - epoch)*train_steps - i) print(f'\tspeed={speed:.4f}s/batch | left_time={left_time:.4f}s') # print(torch.cuda.memory_summary(abbreviated=True)) iter_count = 0 time_now = time.time() del batch_x, batch_x_mark, batch_y, batch_y_mark, dec_inp, y_true, y_pred gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated(self.device) returns = self.evaluate(valid_data, valid_loader, criterion) valid_loss, valid_preds, valid_trues, v_start, v_end = returns # print(valid_data.data_x[b_start:b_end, -self.args.c_out:].shape) # print('before where: ', valid_preds) # print(valid_data.data_x[b_start:b_end, -self.args.c_out:]) # valid_trues = valid_data.data_x[b_start:b_end, -self.args.c_out:] # print('valid_preds.shape: ', valid_preds.shape) valid_preds = np.median(valid_preds, axis=1) print(pd.DataFrame(valid_preds).describe()) valid_preds = np.where(valid_preds >= 0.5, 1, 0).astype(int) # print('after where: ', valid_preds) # valid_trues = 1/(1+np.exp(-valid_trues)) valid_trues = np.median(valid_trues, axis=1) valid_trues = np.where(valid_trues >= 0.5, 1, 0).astype(int) # print('valid_trues shape: ', valid_trues.shape) # print('valid_trues: ', valid_trues) valid_auc = roc_auc_score(valid_trues, valid_preds) valid_u_score = utility_score_bincount(date=valid_data.data_stamp[v_start:v_end], weight=valid_data.weight[v_start:v_end], resp=valid_data.resp[v_start:v_end], action=valid_preds) max_u_score = utility_score_bincount(date=valid_data.data_stamp[v_start:v_end], weight=valid_data.weight[v_start:v_end], resp=valid_data.resp[v_start:v_end], action=valid_trues) best_utility = max(best_utility, valid_u_score) print(f'epoch={epoch+1} | ' \ f'average_train_loss={np.mean(train_loss):.4f} | average_valid_loss={valid_loss:.4f} | ' f'valid_utility={valid_u_score:.4f}/{max_u_score:.4f} | valid_auc={valid_auc:.4f}') early_stopping(valid_auc, self.model, path) if early_stopping.early_stop: print("Early stopping") print(f"Best utility score is {best_utility:.4f}") break adjust_learning_rate(model_optim, epoch+1, self.args) best_model_path = path+'/'+'checkpoint.pth' self.model.load_state_dict(torch.load(best_model_path)) return self.model
def train(self, setting): print(self.model) print(sum(p.numel() for p in self.model.parameters())) train_data, train_loader = self._get_data( flag='train', data_dir= "/mnt/ufs18/home-052/surunze/biostat_project/archive_1/transcheckkernels1200/dataset/" ) print("train data loaded") vali_data, vali_loader = self._get_data( flag='val', data_dir= "/mnt/ufs18/home-052/surunze/biostat_project/archive_1/transcheckkernels1200/dataset/" ) print("valid data loaded") test_data, test_loader = self._get_data( flag='test', data_dir= "/mnt/ufs18/home-052/surunze/biostat_project/archive_1/transcheckkernels1200/dataset/" ) print("test data loaded") s = train_data print("train data train data", len(s[0])) print("train data train data", s[1][0].shape) print("train data train data", (s[0][0].shape), (s[1][0].shape), (s[2][0].shape), (s[3][0].shape)) print("train data train data", (s[0][1].shape), (s[1][1].shape), (s[2][1].shape), (s[3][1].shape)) print("train data train data", (s[0][2].shape), (s[1][2].shape), (s[2][2].shape), (s[3][2].shape)) print("train data train data", (s[0][3].shape), (s[1][3].shape), (s[2][3].shape), (s[3][3].shape)) path = os.path.join(self.args.checkpoints, setting) if not os.path.exists(path): os.makedirs(path) time_now = time.time() train_steps = len(train_loader) early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) model_optim = self._select_optimizer() criterion = self._select_criterion() if self.args.use_amp: scaler = torch.cuda.amp.GradScaler() for epoch in range(self.args.train_epochs): iter_count = 0 train_loss = [] self.model.train() epoch_time = time.time() for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): iter_count += 1 model_optim.zero_grad() batch_x = batch_x.float().to(self.device) batch_y = batch_y.float().to(self.device) batch_x_mark = batch_x_mark.float().to(self.device) batch_y_mark = batch_y_mark.float().to(self.device) # decoder input dec_inp = torch.zeros_like( batch_y[:, -self.args.pred_len:, :]).float() dec_inp = torch.cat( [batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) # encoder - decoder if self.args.use_amp: with torch.cuda.amp.autocast(): if self.args.output_attention: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] else: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) f_dim = -1 if self.args.features == 'MS' else 0 batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) #print(outputs.shape, batch_y.shape) loss = criterion(outputs, batch_y) train_loss.append(loss.item()) else: if self.args.output_attention: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] else: outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) if self.args.inverse: outputs = train_data.inverse_transform(outputs) f_dim = -1 if self.args.features == 'MS' else 0 batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) #print(outputs.shape, batch_y.shape) loss = criterion(outputs[:, :, 0], batch_y[:, :, 0]) train_loss.append(loss.item()) if (i + 1) % 100 == 0: print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format( i + 1, epoch + 1, loss.item())) speed = (time.time() - time_now) / iter_count left_time = speed * ( (self.args.train_epochs - epoch) * train_steps - i) print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format( speed, left_time)) iter_count = 0 time_now = time.time() if self.args.use_amp: scaler.scale(loss).backward() scaler.step(model_optim) scaler.update() else: loss.backward() model_optim.step() print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) train_loss = np.average(train_loss) #vali_loss = train_loss #test_loss = train_loss vali_loss = self.vali(vali_data, vali_loader, criterion) test_loss = self.test(setting) print("Training Summary", epoch + 1, train_steps, train_loss, vali_loss, test_loss) early_stopping(vali_loss, self.model, path) if early_stopping.early_stop: print("Early stopping") break adjust_learning_rate(model_optim, epoch + 1, self.args) best_model_path = path + '/' + 'checkpoint.pth' self.model.load_state_dict(torch.load(best_model_path)) return self.model
def train(self, setting): train_data, train_loader = self._get_data(flag='train') vali_data, vali_loader = self._get_data(flag='val') test_data, test_loader = self._get_data(flag='test') total_para, trainable_para = self._get_number_parameters() print('Total number of parameters: {:d}'.format(total_para)) print('Number of trainable parameters: {:d}'.format(trainable_para)) path = './checkpoints/' + setting if not os.path.exists(path): os.makedirs(path) time_now = time.time() train_steps = len(train_loader) early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) model_optim = self._select_optimizer() criterion = self._select_criterion() for epoch in range(self.args.train_epochs): iter_count = 0 train_loss = [] self.model.train() for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): iter_count += 1 model_optim.zero_grad() batch_x = batch_x.double().to(self.device) batch_y = batch_y.double() batch_x_mark = batch_x_mark.double().to(self.device) batch_y_mark = batch_y_mark.double().to(self.device) # decoder input dec_inp = torch.zeros_like( batch_y[:, -self.args.pred_len:, :]).double() dec_inp = torch.cat( [batch_y[:, :self.args.label_len, :], dec_inp], dim=1).double().to(self.device) # encoder - decoder outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device) loss = criterion(outputs, batch_y) train_loss.append(loss.item()) if (i + 1) % 100 == 0: print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format( i + 1, epoch + 1, loss.item())) speed = (time.time() - time_now) / iter_count left_time = speed * ( (self.args.train_epochs - epoch) * train_steps - i) print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format( speed, left_time)) iter_count = 0 time_now = time.time() loss.backward() model_optim.step() train_loss = np.average(train_loss) vali_loss = self.vali(vali_data, vali_loader, criterion) test_loss = self.vali(test_data, test_loader, criterion) print( "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}" .format(epoch + 1, train_steps, train_loss, vali_loss, test_loss)) early_stopping(vali_loss, self.model, path) if early_stopping.early_stop: print("Early stopping") break adjust_learning_rate(model_optim, epoch + 1, self.args) best_model_path = path + '/' + 'checkpoint.pth' self.model.load_state_dict(torch.load(best_model_path)) return self.model
def train(self, setting): """Training Function. Args: setting: Name used to save the model Returns: model: Trained model """ # Load different datasets train_loader, train_loader_shuffled = self._get_data(flag="train") vali_loader = self._get_data(flag="val") test_loader = self._get_data(flag="test") path = os.path.join(self.args.checkpoints, setting) if not os.path.exists(path): os.makedirs(path) train_steps = len(train_loader) early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) # Setting optimizer and loss functions model_optim, gate_optim = self._select_optimizer() criterion, criterion_kl = self._select_criterion() loss_train = [] acc_loss_train = [] utilization_loss_train = [] smoothness_loss_train = [] diversity_loss_train = [] loss_val = [] acc_loss_val = [] utilization_loss_val = [] smoothness_loss_val = [] diversity_loss_val = [] mse_train_ = [] mse_test_ = [] mse_val_ = [] upper_bound_train_ = [] upper_bound_test_ = [] upper_bound_val_ = [] oracle_acc_test_ = [] oracle_acc_train_ = [] oracle_acc_val_ = [] # Getting the intial mse, oracle accuracy and upper bound (mse_train, upper_bound_train, oracle_acc_train) = self.get_upper_bound_accuracy( train_loader, flag="train") (mse_test, upper_bound_test, oracle_acc_test) = self.get_upper_bound_accuracy( test_loader, flag="test") (mse_val, upper_bound_val, oracle_acc_val) = self.get_upper_bound_accuracy( vali_loader, flag="val") mse_train_.append(mse_train) mse_test_.append(mse_test) mse_val_.append(mse_val) upper_bound_train_.append(upper_bound_train) upper_bound_test_.append(upper_bound_test) upper_bound_val_.append(upper_bound_val) oracle_acc_train_.append(oracle_acc_train) oracle_acc_test_.append(oracle_acc_test) oracle_acc_val_.append(oracle_acc_val) # Training loop for epoch in range(self.args.train_epochs): self.model.train() loss_all = 0 # Add noise to the weights of the expert this promotes diversity if self.args.noise: with torch.no_grad(): for param in self.model.experts.parameters(): param.add_(torch.randn(param.size()).to(self.device) * 0.01) for i, (batch_x, index, batch_y) in enumerate(train_loader_shuffled): # get past error made by experts past_errors = self.get_past_errors(index, "train") model_optim.zero_grad() if self.args.expert_type == "SDT": pred, true, weights, (reg_out, panelty) = self._process_one_batch( batch_x, batch_y, past_errors=past_errors) accuracy_loss = self.accuracy_loss(pred, true, weights) + panelty else: pred, true, weights, reg_out = self._process_one_batch( batch_x, batch_y, past_errors=past_errors) accuracy_loss = self.accuracy_loss(pred, true, weights) batch_size = pred.shape[0] # Calcuate gate loss gate_loss = criterion(reg_out, true) # Calcuate utilization loss if self.args.utilization_hp != 0: batch_expert_utilization = torch.sum( weights.squeeze(-1), dim=0) / batch_size expert_utilization_loss = self.expert_utilization_loss( batch_expert_utilization) else: expert_utilization_loss = 0 # Calcuate smoothness loss if self.args.smoothness_hp != 0: previous_weight = self.get_gate_assignment_weights(index, "train") smoothness_loss = criterion_kl( torch.log(weights.squeeze(-1) + eps), previous_weight) self.set_gate_assignment_weights( weights.squeeze(-1).detach(), index, "train") else: smoothness_loss = 0 # Calcuate diversity loss if self.args.diversity_hp != 0: batch_x_noisy = add_gaussian_noise(batch_x) pred_noisy, _, _, _ = self._process_one_batch( batch_x_noisy, batch_y, past_errors=past_errors) diversity_loss = self.diversity_loss(pred, pred_noisy) # avg_diversity_loss += self.args.diversity_hp * diversity_loss.item() else: diversity_loss = 0 # set expert error error = self.model_assignment_error(pred, true) self.set_past_errors(error.detach(), index, "train") loss = ( self.args.accuracy_hp * accuracy_loss + self.args.gate_hp * gate_loss + self.args.utilization_hp * expert_utilization_loss + self.args.smoothness_hp * smoothness_loss + self.args.diversity_hp * diversity_loss) loss.backward() model_optim.step() loss_all += loss.item() if (i + 1) % 50 == 0: print("\tOne iters: {0}/{1}, epoch: {2} | loss: {3:.7f}".format( i + 1, len(train_loader), epoch + 1, loss_all / i)) # update past error matrix self.error_scaler.fit( self.past_train_error.detach().cpu().numpy().flatten().reshape(-1, 1)) self.past_train_error = torch.Tensor( self.error_scaler.transform( self.past_train_error.detach().cpu().numpy().flatten().reshape( -1, 1))).reshape(-1, self.num_experts).to(self.device) # Getting different losses (train_loss, acc_train, utilization_train, smoothness_train, diversity_train) = self.vali(train_loader, "train") loss_train.append(train_loss) acc_loss_train.append(acc_train) utilization_loss_train.append(utilization_train) smoothness_loss_train.append(smoothness_train) diversity_loss_train.append(diversity_train) (val_loss, acc_val, utilization_val, smoothness_val, diversity_val) = self.vali( vali_loader, flag="val") loss_val.append(val_loss) acc_loss_val.append(acc_val) utilization_loss_val.append(utilization_val) smoothness_loss_val.append(smoothness_val) diversity_loss_val.append(diversity_val) # getting mse, oracle accuracy and upper bound (mse_train, upper_bound_train, oracle_acc_train) = self.get_upper_bound_accuracy( train_loader, flag="train") (mse_test, upper_bound_test, oracle_acc_test) = self.get_upper_bound_accuracy( test_loader, flag="test") (mse_val, upper_bound_val, oracle_acc_val) = self.get_upper_bound_accuracy( vali_loader, flag="val") mse_train_.append(mse_train) mse_test_.append(mse_test) mse_val_.append(mse_val) upper_bound_train_.append(upper_bound_train) upper_bound_test_.append(upper_bound_test) upper_bound_val_.append(upper_bound_val) oracle_acc_train_.append(oracle_acc_train) oracle_acc_test_.append(oracle_acc_test) oracle_acc_val_.append(oracle_acc_test) # early stopping depends on the validation accuarcy loss early_stopping(acc_val, self.model, path) print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}" .format(epoch + 1, train_steps, train_loss, val_loss)) if ((epoch + 1) % 10 == 0 and self.args.plot): self.plot_all(setting, mse_test_, mse_train_, mse_val_, upper_bound_test_, upper_bound_train_, upper_bound_val_, oracle_acc_test_, oracle_acc_train_, oracle_acc_val_, loss_train, loss_val, acc_loss_train, acc_loss_val, utilization_loss_train, utilization_loss_val, smoothness_loss_train, smoothness_loss_val, diversity_loss_train, diversity_loss_val) # when training runs out of patience if early_stopping.early_stop: break # if freezing experts and tunning gate network if (self.args.freeze and oracle_acc_train != 1): # load the best model best_model_path = path + "/" + "checkpoint.pth" self.model.load_state_dict(torch.load(best_model_path)) # set past errors to zero self.past_train_error[:, :] = 0 self.past_test_error[:, :] = 0 self.past_val_error[:, :] = 0 # get validation accuracy on the best model (val_loss, acc_val, utilization_val, smoothness_val, diversity_val) = self.vali( vali_loader, flag="val") # reseting and adjusting early_stopping early_stopping.val_loss_min = acc_val early_stopping.counter = 0 early_stopping.early_stop = False for e in range(self.args.train_epochs): self.model.train() loss_all = 0 for i, (batch_x, index, batch_y) in enumerate(train_loader_shuffled): past_errors = self.get_past_errors(index, "train") gate_optim.zero_grad() if self.args.expert_type == "SDT": pred, true, weights, (reg_out, panelty) = self._process_one_batch( batch_x, batch_y, past_errors=past_errors) accuracy_loss = self.accuracy_loss(pred, true, weights) else: pred, true, weights, reg_out = self._process_one_batch( batch_x, batch_y, past_errors=past_errors) accuracy_loss = self.accuracy_loss(pred, true, weights) # set expert error error = self.model_assignment_error(pred, true) self.set_past_errors(error.detach(), index, "train") loss = accuracy_loss loss.backward() # clear the expert gradients since we want them frozen self.model.experts.zero_grad() gate_optim.step() loss_all += loss.item() if (i + 1) % 50 == 0: print( "\tFreeze iters: {0}/{1}, epoch: {2} sub epoch {4} | loss: {3:.7f}" .format(i + 1, len(train_loader), epoch + 1, loss_all / i, e)) # update past error matrix self.error_scaler.fit( self.past_train_error.detach().cpu().numpy().flatten().reshape( -1, 1)) self.past_train_error = torch.Tensor( self.error_scaler.transform( self.past_train_error.detach().cpu().numpy().flatten().reshape( -1, 1))).reshape(-1, self.num_experts).to(self.device) # Getting different losses (train_loss, acc_train, utilization_train, smoothness_train, diversity_train) = self.vali( train_loader, flag="train") loss_train.append(train_loss) acc_loss_train.append(acc_train) utilization_loss_train.append(utilization_train) smoothness_loss_train.append(smoothness_train) diversity_loss_train.append(diversity_train) (val_loss, acc_val, utilization_val, smoothness_val, diversity_val) = self.vali( vali_loader, flag="val") loss_val.append(val_loss) acc_loss_val.append(acc_val) utilization_loss_val.append(utilization_val) smoothness_loss_val.append(smoothness_val) diversity_loss_val.append(diversity_val) # getting mse, oracle accuracy and upper bound (mse_train, upper_bound_train, oracle_acc_train) = self.get_upper_bound_accuracy( train_loader, flag="train") (mse_test, upper_bound_test, oracle_acc_test) = self.get_upper_bound_accuracy( test_loader, flag="test") (mse_val, upper_bound_val, oracle_acc_val) = self.get_upper_bound_accuracy( vali_loader, flag="val") mse_train_.append(mse_train) mse_test_.append(mse_test) mse_val_.append(mse_val) upper_bound_train_.append(upper_bound_train) upper_bound_test_.append(upper_bound_test) upper_bound_val_.append(upper_bound_val) oracle_acc_train_.append(oracle_acc_train) oracle_acc_test_.append(oracle_acc_test) oracle_acc_val_.append(oracle_acc_test) # early stopping depends on the validation accuarcy loss early_stopping(acc_val, self.model, path) print( "Frozen Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}" .format(e + 1, train_steps, train_loss, val_loss)) if early_stopping.early_stop: print("Early stopping") break if self.args.plot: self.plot_all(setting, mse_test_, mse_train_, mse_val_, upper_bound_test_, upper_bound_train_, upper_bound_val_, oracle_acc_test_, oracle_acc_train_, oracle_acc_val_, loss_train, loss_val, acc_loss_train, acc_loss_val, utilization_loss_train, utilization_loss_val, smoothness_loss_train, smoothness_loss_val, diversity_loss_train, diversity_loss_val) # Load the best model on the validation dataset best_model_path = path + "/" + "checkpoint.pth" self.model.load_state_dict(torch.load(best_model_path)) pickle.dump(self.error_scaler, open(path + "/" + "std_scaler.bin", "wb")) return self.model
def train(self, setting): print('prepare data...') train_data_loaders, vali_data_loaders, test_data_loaders = self._get_data( ) print('Number of data loaders:', len(train_data_loaders)) path = './checkpoints/' + setting if not os.path.exists(path): os.makedirs(path) time_now = time.time() early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) model_optim = self._select_optimizer() criterion = self._select_criterion() for epoch in range(self.args.train_epochs): iter_count = 0 train_loss = [] self.model.train() for index in range(len(train_data_loaders)): train_loader = train_data_loaders[index] train_loss = [] begin_ = time.time() for i, (batch_x, batch_y) in enumerate(train_loader): iter_count += 1 model_optim.zero_grad() batch_x = batch_x.double() # .to(self.device) batch_y = batch_y.double() outputs = self.model(batch_x).view(-1, 24) batch_y = batch_y[:, -self.args.pred_len:, -1].view(-1, 24) # .to(self.device) loss = criterion(outputs, batch_y) # + 0.1*corr train_loss.append(loss.item()) loss.backward() model_optim.step() print('INDEX Finished', index, 'train loss', np.average(train_loss), 'COST', time.time() - begin_) train_loss = np.average(train_loss) vali_loss, mae, score = self.test('1') early_stopping(-score, self.model, path) print( "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} score: {4:.7f}" .format(epoch + 1, 0, np.average(train_loss), vali_loss, score)) if early_stopping.early_stop: print("Early stopping") break adjust_learning_rate(model_optim, epoch + 1, self.args) best_model_path = path + '/' + 'checkpoint.pth' self.model.load_state_dict(torch.load(best_model_path)) print('Model is saved at', best_model_path) self.model.eval() return self.model