def train(ini_file): ''' Performs training according to .ini file :param ini_file: (String) the path of .ini file :return best_c_index: the best c-index ''' # reads configuration from .ini file config = read_config(ini_file) # builds network|criterion|optimizer based on configuration model = Net(config['network']).to(device) criterion = Criterion(config['network'], device).to(device) optimizer = eval('optim.{}'.format(config['train']['optimizer']))( model.parameters(), lr=config['train']['learning_rate']) # constructs data loaders based on configuration train_dataset = MakeDataset(config['train']['h5_file'], is_train=True, device=device) test_dataset = MakeDataset(config['train']['h5_file'], is_train=False, device=device) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_dataset.__len__()) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=test_dataset.__len__()) # training _best_acc = 0.70 best_acc = 0.65 best_ep = 0 flag = 0 _best_auc = 0 best_auc = 0 best_roc = None for epoch in range(1, config['train']['epochs'] + 1): # adjusts learning rate lr = adjust_learning_rate(optimizer, epoch, config['train']['learning_rate'], config['train']['lr_decay_rate']) # train step model.train() for X, y in train_loader: # makes predictions pred = model(X) train_loss = criterion(pred, y, model) train_FPR, train_TPR, train_ACC, train_roc, train_roc_auc, _, _, _, _ = Auc(pred, y) # updates parameters optimizer.zero_grad() train_loss.backward() optimizer.step() # valid step model.eval() for X, y in test_loader: # makes predictions with torch.no_grad(): pred = model(X) # print(pred, y) valid_loss = criterion(pred, y, model) valid_FPR, valid_TPR, valid_ACC, valid_roc, valid_roc_auc, _, _, _, _ = Auc(pred, y) if valid_ACC > best_acc and train_ACC > _best_acc: flag = 0 best_acc = valid_ACC _best_acc = train_ACC best_ep = epoch best_auc = valid_roc_auc _best_auc = train_roc_auc best_roc = valid_roc # saves the best model torch.save({ 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}, os.path.join(models_dir, ini_file.split('\\')[-1] + '.pth')) else: flag += 1 if flag >= patience: print('epoch: {}\t{:.8f}({:.8f})'.format(best_ep, _best_acc, best_acc)) if best_roc is not None: plt.plot(best_roc[:, 0], best_roc[:, 1]) plt.title('ep:{} AUC: {:.4f}({:.4f}) ACC: {:.4f}({:.4f})'.format(best_ep, _best_auc, best_auc, _best_acc, best_acc)) plt.show() return best_acc, _best_acc # notes that, train loader and valid loader both have one batch!!! print('\rEpoch: {}\tLoss: {:.8f}({:.8f})\tACC: {:.8f}({:.8f})\tAUC: {}({})\tFPR: {:.8f}({:.8f})\tTPR: {:.8f}({:.8f})\tlr: {:g}\n'.format( epoch, train_loss.item(), valid_loss.item(), train_ACC, valid_ACC, train_roc_auc, valid_roc_auc, train_FPR, valid_FPR, train_TPR, valid_TPR, lr), end='', flush=False) return best_acc, _best_acc
def test(self, args): with open(args.test_list, 'r') as test_list_file: self.test_list = [line.strip() for line in test_list_file.readlines()] self.model_name = args.model_name self.model_file = args.model_file self.test_mixture_path = args.test_mixture_path self.prediction_path = args.prediction_path # create a network print('model', self.model_name) net = Net(device=self.device, L=self.frame_size, width=self.width) # net = torch.nn.DataParallel(net) net.to(self.device) print('Number of learnable parameters: %d' % numParams(net)) print(net) # loss and optimizer criterion = mse_loss() net.eval() print('Load model from "%s"' % self.model_file) checkpoint = Checkpoint() checkpoint.load(self.model_file) net.load_state_dict(checkpoint.state_dict) with torch.no_grad(): for i in range(len(self.test_list)): # read the mixture for resynthesis filename_input = self.test_list[i].split('/')[-1] start1 = timeit.default_timer() print('{}/{}, Started working on {}.'.format(i + 1, len(self.test_list), self.test_list[i])) print('') filename_mix = filename_input.replace('.samp', '_mix.dat') filename_s_ideal = filename_input.replace('.samp', '_s_ideal.dat') filename_s_est = filename_input.replace('.samp', '_s_est.dat') # print(filename_mix) # sys.exit() f_mix = h5py.File(os.path.join(self.test_mixture_path, filename_mix), 'r') f_s_ideal = h5py.File(os.path.join(self.prediction_path, filename_s_ideal), 'w') f_s_est = h5py.File(os.path.join(self.prediction_path, filename_s_est), 'w') # create a test dataset testSet = EvalDataset(os.path.join(self.test_mixture_path, self.test_list[i]), self.num_test_sentences) # create a data loader for test test_loader = DataLoader(testSet, batch_size=1, shuffle=False, num_workers=2, collate_fn=EvalCollate()) # print '\n[%d/%d] Predict on %s' % (i+1, len(self.test_list), self.test_list[i]) accu_test_loss = 0.0 accu_test_nframes = 0 ttime = 0. mtime = 0. cnt = 0. for k, (mix_raw, cln_raw) in enumerate(test_loader): start = timeit.default_timer() est_s = self.eval_forward(mix_raw, net) est_s = est_s[:mix_raw.size] mix = f_mix[str(k)][:] ideal_s = cln_raw f_s_ideal.create_dataset(str(k), data=ideal_s.astype(np.float32), chunks=True) f_s_est.create_dataset(str(k), data=est_s.astype(np.float32), chunks=True) # compute eval_loss test_loss = np.mean((est_s - ideal_s) ** 2) accu_test_loss += test_loss cnt += 1 end = timeit.default_timer() curr_time = end - start ttime += curr_time mtime = ttime / cnt mtime = (mtime * (k) + (end - start)) / (k + 1) print('{}/{}, test_loss = {:.4f}, time/utterance = {:.4f}, ' 'mtime/utternace = {:.4f}'.format(k + 1, self.num_test_sentences, test_loss, curr_time, mtime)) avg_test_loss = accu_test_loss / cnt # bar.update(k,test_loss=avg_test_loss) # bar.finish() end1 = timeit.default_timer() print('********** Finisehe working on {}. time taken = {:.4f} **********'.format(filename_input, end1 - start1)) print('') f_mix.close() f_s_est.close() f_s_ideal.close()