def print_and_save_auc(self, auc): print('') print('Area Under Curve (AUC): {:.5f}'.format(auc)) mondict = { "auc": auc.item(), } ydump(mondict, self.address, "net_result.yaml")
def __init__(self, res_dir, tb_dir, net_class, net_params, address, dt): self.res_dir = res_dir self.tb_dir = tb_dir self.net_class = net_class self.net_params = net_params self._ready = False self.train_params = {} self.figsize = (20, 12) self.dt = dt # (s) self.address, self.tb_address = self.find_address(address) if address is None: # create new address pdump(self.net_params, self.address, 'net_params.p') ydump(self.net_params, self.address, 'net_params.yaml') else: # pick the network parameters self.net_params = pload(self.address, 'net_params.p') self.train_params = pload(self.address, 'train_params.p') self._ready = True self.path_weights = os.path.join(self.address, 'weights.pt') self.net = self.net_class(**self.net_params) if self._ready: # fill network parameters self.load_weights()
def train(self, dataset_class, dataset_params, train_params): """train the neural network. GPU is assumed""" self.train_params = train_params pdump(self.train_params, self.address, 'train_params.p') ydump(self.train_params, self.address, 'train_params.yaml') hparams = self.get_hparams(dataset_class, dataset_params, train_params) ydump(hparams, self.address, 'hparams.yaml') # define datasets dataset_train = dataset_class(**dataset_params, mode='train') dataset_train.init_train() dataset_val = dataset_class(**dataset_params, mode='val') dataset_val.init_val() # get class Optimizer = train_params['optimizer_class'] Scheduler = train_params['scheduler_class'] Loss = train_params['loss_class'] # get parameters dataloader_params = train_params['dataloader'] optimizer_params = train_params['optimizer'] scheduler_params = train_params['scheduler'] loss_params = train_params['loss'] # define optimizer, scheduler and loss dataloader = DataLoader(dataset_train, **dataloader_params) optimizer = Optimizer(self.net.parameters(), **optimizer_params) scheduler = Scheduler(optimizer, **scheduler_params) criterion = Loss(**loss_params) # remaining training parameters freq_val = train_params['freq_val'] n_epochs = train_params['n_epochs'] # init net w.r.t dataset self.net = self.net.cuda() mean_u, std_u = dataset_train.mean_u, dataset_train.std_u self.net.set_normalized_factors(mean_u, std_u) # start tensorboard writer writer = SummaryWriter(self.tb_address) start_time = time.time() best_loss = torch.Tensor([float('Inf')]) # define some function for seeing evolution of training def write(epoch, loss_epoch): writer.add_scalar('loss/train', loss_epoch.item(), epoch) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) print('Train Epoch: {:2d} \tLoss: {:.4f}'.format( epoch, loss_epoch.item())) scheduler.step(epoch) def write_time(epoch, start_time): delta_t = time.time() - start_time print("Amount of time spent for epochs " + "{}-{}: {:.1f}s\n".format(epoch - freq_val, epoch, delta_t)) writer.add_scalar('time_spend', delta_t, epoch) def write_val(loss, best_loss): if 0.5*loss <= best_loss: msg = 'validation loss decreases! :) ' msg += '(curr/prev loss {:.4f}/{:.4f})'.format(loss.item(), best_loss.item()) cprint(msg, 'green') best_loss = loss self.save_net() else: msg = 'validation loss increases! :( ' msg += '(curr/prev loss {:.4f}/{:.4f})'.format(loss.item(), best_loss.item()) cprint(msg, 'yellow') writer.add_scalar('loss/val', loss.item(), epoch) return best_loss # training loop ! for epoch in range(1, n_epochs + 1): loss_epoch = self.loop_train(dataloader, optimizer, criterion) write(epoch, loss_epoch) scheduler.step(epoch) if epoch % freq_val == 0: loss = self.loop_val(dataset_val, criterion) write_time(epoch, start_time) best_loss = write_val(loss, best_loss) start_time = time.time() # training is over ! # test on new data dataset_test = dataset_class(**dataset_params, mode='test') self.load_weights() test_loss = self.loop_val(dataset_test, criterion) dict_loss = { 'final_loss/val': best_loss.item(), 'final_loss/test': test_loss.item() } writer.add_hparams(hparams, dict_loss) ydump(dict_loss, self.address, 'final_loss.yaml') writer.close()