def train(self): if self.verbose: progress_disp = mkutils.ProgressDisplay() self.start() for i in range(self.numepochs): self.start_epoch() # display the epoch information if self.verbose: s = "Epoch %6d/%6d" % (i+1, numepochs) print(s) print("-" * len(s)) total_batches = len(dataloader) for j,(phase, data) in enumerate(dataloader): self.update_phase(phase, data) # show the progress bar if self.verbose: progress_disp.show(j+1, total_batches) self.end_epoch() # add more blank spaces if self.verbose: print("") self.end()
def validate(model, dataloader, val_criterion, device=None, verbose=1, load_wts_from=None): # get the device if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # load the model to the device model = model.to(device) # load the weights if load_wts_from is not None: model.load_state_dict(torch.load(load_wts_from)) # set to evaluation mode model.eval() # reset the validation criterion val_criterion.reset() num_batches = 0 # num batches total_batches = len(dataloader) if verbose >= 2: progress_disp = mkutils.ProgressDisplay() for inputs, labels in dataloader: num_batches += 1 # load the data to the device inputs = inputs.to(device) labels = labels.to(device) # calculate the validation criterion with torch.set_grad_enabled(False): outputs = model(inputs) val_criterion.feed(outputs, labels) # write the progress bar if verbose >= 2: progress_disp.show(num_batches, total_batches) print("Validation with %s criterion: %e" % (val_criterion.name, val_criterion.getval())) return float(val_criterion.getval())
def train(model, dataloaders, criteria, optimizer, scheduler=None, num_epochs=25, device=None, verbose=1, plot=0, save_wts_to=None, save_model_to=None, return_history=False, return_best_last=9e99): """ Performs a training of the model. Args: model : A torch trainable class method that accepts "inputs" and returns prediction of "outputs". dataloaders (dict or torch.utils.data.DataLoader): Dictionary with two keys: ["train", "val"] with every value is an iterable with two outputs: (1) the "inputs" to the model and (2) the ground truth of the "outputs". If it is a DataLoader, then it's only for the training, nothing for validation. criteria (dict or callable or deepmk.criteria): Dictionary with two keys: ["train", "val"] with every value is a callable or deepmk.criteria to calculate the criterion for the corresponding phase. If it is not a dictionary, then the criterion is set for both training and validation phases. If it is a callable, it is wrapped by deepmk.criteria.MeanCriterion object to calculate the mean criterion. The criterion for the training needs to be differentiable and it will be minimized during the training. optimizer (torch.optim optimizer or dict): Optimizer class in training the model. If it is a dictionary, it must have "train" and "val" keys and it makes it a meta-learning problem. scheduler (torch.optim.lr_scheduler object or dict): Scheduler of how the learning rate is evolving through the epochs. If it is None, it does not update the learning rate. It can be a dictionary like the optimizer argument. (default: None) num_epochs (int): The number of epochs in training. (default: 25) device : Device where to do the training. None to choose cuda:0 if available, otherwise, cpu. (default: None) verbose (int): The level of verbosity from 0 to 1. (default: 1) plot (int): Whether to plot the loss of training and validation data. (default: 0) save_wts_to (str): Name of a file to save the best model's weights. If None, then do not save. (default: None) save_model_to (str): Name of a file to save the best whole model. If None, then do not save. (default: None) return_history (bool): A flag to indicate whether the training and validation losses history will be returned. (default: False) return_best_last (int): Return the best model over the last `return_best_last` epochs. (default: 9e99) Returns: best_model : The trained model with the lowest loss criterion during "val" phase """ # get the device if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Using device:") print(device) # check optimizer and scheduler types and decide if this is a meta # learning problem metalearning = _check_opt_sched(optimizer, scheduler) if metalearning and verbose >= 1: print("We are doing meta-learning") # set interactive plot if plot: plt.ion() # check if the dataloader is for validation as well if type(dataloaders) != dict: dataloaders = {"train": dataloaders, "val": []} # set the criteria object right if type(criteria) != dict: criteria = {"train": criteria, "val": criteria} for phase in ["train", "val"]: if not issubclass(criteria[phase].__class__, deepmk.criteria.Criterion): criteria[phase] = deepmk.criteria.MeanCriterion(criteria[phase]) criteria[phase].reset() # prepare the memory of the last best weights if return_best_last < num_epochs: weights_history = [None for _ in range(return_best_last)] # load the model to the device first model = model.to(device) if verbose >= 1: since = time.time() best_model_weights = copy.deepcopy(model.state_dict()) best_loss = np.inf train_losses = [] val_losses = [] total_batches = len(dataloaders["train"]) + len(dataloaders["val"]) try: best_epoch = 0 for epoch in range(num_epochs): if verbose >= 1: print("Epoch %d/%d" % (epoch+1, num_epochs)) print("-"*10) # to time the progress epoch_start_time = time.time() # progress counter num_batches = 0 # num batches in training and validation if verbose >= 2: progress_disp = mkutils.ProgressDisplay() # every epoch has a training and a validation phase for phase in ["train", "val"]: # skip phase if the dataloaders for the current phase is empty if dataloaders[phase] == []: continue # set the model's mode if not metalearning: if phase == "train": if scheduler is not None: scheduler.step() # adjust the training learning rate model.train() # set the model to the training mode else: model.eval() # set the model to the evaluation mode else: if scheduler is not None: scheduler[phase].step() model.train() # the total loss during this epoch running_loss = 0.0 # iterate over the data dataset_size = 0 # reset the criteria before the training epoch starts criteria[phase].reset() for inputs, labels in dataloaders[phase]: # get the size of the dataset dataset_size += inputs.size(0) num_batches += 1 # write the progress bar if verbose >= 2: progress_disp.show(num_batches, total_batches) # load the inputs and the labels to the working device inputs = inputs.to(device) labels = labels.to(device) # reset the model gradient to 0 if not metalearning: optimizer.zero_grad() else: optimizer["train"].zero_grad() optimizer["val"].zero_grad() # forward # track history if only in train grad_enabled = (phase == "train" or metalearning) with torch.set_grad_enabled(grad_enabled): outputs = model(inputs) loss = criteria[phase].feed(outputs, labels) # backward gradient computation and optimize in training if not metalearning: if phase == "train": loss.backward() optimizer.step() else: loss.backward() optimizer[phase].step() # get the mean loss in this epoch mult = -1 if (criteria[phase].best == "max") else 1 crit_val = criteria[phase].getval() epoch_loss = mult * crit_val # save the losses if phase == "train": train_losses.append(crit_val.data) elif phase == "val": val_losses.append(crit_val.data) # save the model history if return_best_last < num_epochs: weights_history[epoch % return_best_last] = copy.deepcopy(model.state_dict()) # copy the best model if phase == "val" and \ ((epoch_loss < best_loss) or \ (epoch - best_epoch > return_best_last)): if epoch - best_epoch > return_best_last: # get the index of the next best last val_losses_n = val_losses[-return_best_last:] min_idx_rel = np.argmin(val_losses_n) min_idx = min_idx_rel + len(val_losses) - return_best_last # get the best conditions best_epoch = min_idx best_model_weights = weights_history[best_epoch % return_best_last] else: best_epoch = epoch best_model_weights = copy.deepcopy(model.state_dict()) # save the best conditions best_loss = val_losses[best_epoch] # save the model _save_wts(best_model_weights, save_wts_to) # show the loss in the current epoch if verbose >= 1: print("train %s: %.4e, val %s: %.4e, done in %fs (best val: %.3e)" % \ (criteria["train"].name, train_losses[-1], criteria["val"].name, val_losses[-1], time.time()-since, best_loss)) # plot the losses if plot: xs_plot = range(1,epoch+2) plt.clf() plt.plot(xs_plot, train_losses, 'o-') plt.plot(xs_plot, val_losses, 'o-') plt.legend(["Train", "Validation"]) plt.xlabel("Epoch") plt.ylabel("Loss") plt.pause(0.001) print("") except KeyboardInterrupt: print("Interrupted. Returning the results.") if verbose >= 1: time_elapsed = time.time()- since print("Training complete in %fs" % time_elapsed) print("Best val loss: %.4f" % best_loss) # return the model model.load_state_dict(best_model_weights) if return_history: return model, best_loss, train_losses, val_losses return model, best_loss
def train(g_model, d_model, dataloaders, lambda_g, g_opt, d_opt, train_g_after=0, g_sched=None, d_sched=None, gan_criteria="hinge", spv_criteria="mse", num_epochs=25, device=None, verbose=1, plot=0, save_wts_to=None, return_history=False): """ Performs a supervised + GAN training procedure. The generative and discriminative models are trained with GAN procedure while the mapper is trained with supervised procedure. In making the prediction, `m_model` is concatenated with `g_model` to generate signal from a given set of parameters. In one training batch: * `d_model` is trained by maximizing d-score for real and minimizing for fake signal. * `g_model` is trained by maximizing d-score for its generated signal and minimizing from the supervised data. Args: g_model : A torch trainable generative model from the parameters space to the signal space. d_model : A torch trainable discriminative model that receives the signal as the input and gives low score for fake and high score for real. dataloaders (dict or torch.utils.data.DataLoader): Dictionary with two keys: ["train", "val"] with every value is an iterable with two outputs: (1) the "inputs" to the model and (2) the ground truth of the "outputs". If it is a DataLoader, then it's only for the training, nothing for validation. lambda_g (float): The penalty factor of the discriminator regularization. g_opt, d_opt (torch.optim optimizer): Optimizer class in training the g_model and d_model, resp. g_sched, d_sched (torch.optim.lr_scheduler object): Optimizer scheduler in training the g_model, d_model, resp. Default: None. gan_criteria (str, optional): Criteria in training GAN. For now, the option is only "hinge". Default: "hinge". spv_criteria (str,optional): Criteria in the supervised training. For now, the option is only "mse". Default: "mse". num_epochs (int): The number of epochs in training. (default: 25) device : Device where to do the training. None to choose cuda:0 if available, otherwise, cpu. (default: None) verbose (int): The level of verbosity from 0 to 1. (default: 1) save_wts_to (str): Name of a file to save the best model's weights. If None, then do not save. (default: None) return_history (bool): A flag to indicate whether the training and validation losses history will be returned. (default: False) Returns: best_model : The trained model with the lowest loss criterion during "val" phase """ lambda_gp = 10.0 # get the device if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Using device:") print(device) # check if the dataloader is for validation as well if type(dataloaders) != dict: dataloaders = {"train": dataloaders, "val": []} # load the model to the device first g_model = g_model.to(device) d_model = d_model.to(device) if verbose >= 1: since = time.time() best_model_weights = get_weights(g_model, d_model) best_loss = np.inf train_losses = [] val_losses = [] # book keeping scores d_loss_real_mean = {"train": 0.0, "val": 0.0} d_loss_fake_mean = {"train": 0.0, "val": 0.0} g_loss_mean = {"train": 0.0, "val": 0.0} m_loss_mean = {"train": 0.0, "val": 0.0} if return_history: d_losses_real_train = [] d_losses_fake_train = [] g_losses_train = [] m_losses_train = [] d_losses_real_val = [] d_losses_fake_val = [] g_losses_val = [] m_losses_val = [] total_batches = len(dataloaders["train"]) + len(dataloaders["val"]) for epoch in range(num_epochs): if verbose >= 1: print("Epoch %d/%d" % (epoch + 1, num_epochs)) print("-" * 10) if verbose >= 2: progress_disp = mkutils.ProgressDisplay() # to time the progress epoch_start_time = time.time() # progress counter num_batches = 0 # num batches in training and validation if verbose >= 2: progress_disp = mkutils.ProgressDisplay() # every epoch has a training and a validation phase for phase in ["train", "val"]: # skip phase if the dataloaders for the current phase is empty if dataloaders[phase] == []: continue # set the model's mode if phase == "train": if g_sched is not None: g_sched.step() if d_sched is not None: d_sched.step() # set the model to the training mode g_model.train() d_model.train() else: # set the model to the evaluation mode g_model.eval() d_model.eval() # book keeping score d_loss_real_total = 0.0 d_loss_fake_total = 0.0 g_loss_total = 0.0 m_loss_total = 0.0 ndata_total = 0 for params, signal in dataloaders[phase]: # write the progress bar num_batches += 1 if verbose >= 2: progress_disp.show(num_batches, total_batches) batch_size = params.shape[0] ndata = batch_size ndata_total += ndata # load to device params = params.to(device) signal = signal.to(device) ################ train the discriminator ################ # calculate the d-scores for real and fake signals d_score_real = d_model(signal) z = torch.rand((params.shape[0], params.shape[1])).to(device) fake_signal = g_model(z) d_score_fake = d_model(fake_signal.detach()) # maximizing score for the real signal # minimizing score for the fake signal if gan_criteria == "hinge": d_loss_real = torch.clamp(1.0 - d_score_real, 0.0).mean() d_loss_fake = torch.clamp(1.0 + d_score_fake, 0.0).mean() elif gan_criteria == "wgan-gp": d_loss_real = -d_score_real.mean() d_loss_fake = d_score_fake.mean() elif gan_criteria == "bce": real_label = torch.full((batch_size, ), 1, device=device) fake_label = torch.full((batch_size, ), 0, device=device) d_loss_real = torch.nn.BCELoss()(d_score_real, real_label) d_loss_fake = torch.nn.BCELoss()(d_score_fake, fake_label) # backprop the discriminator d_loss = d_loss_fake + d_loss_real if phase == "train": d_model.zero_grad() d_opt.zero_grad() d_loss.backward() d_opt.step() if gan_criteria == "wgan-gp": alpha = torch.rand(signal.shape[0], 1, 1).to(device).expand_as(signal) interpolated = torch.zeros_like(signal) interpolated.data = alpha * signal.data + ( 1 - alpha) * fake_signal.data interpolated.requires_grad = True d_interp = d_model(interpolated) grad = torch.autograd.grad(outputs=d_interp, inputs=interpolated, grad_outputs=torch.ones( d_interp.size()).cuda(), retain_graph=True, create_graph=True, only_inputs=True)[0] grad = grad.view(grad.size(0), -1) grad_l2norm = torch.sqrt(torch.sum(grad * grad, dim=1)) d_loss_gp = torch.mean((grad_l2norm - 1)**2) # backward + optimize d_loss = lambda_gp * d_loss_gp d_opt.zero_grad() d_loss.backward() d_opt.step() # book keeping d_loss_real_total += d_loss_real.data * ndata d_loss_fake_total += d_loss_fake.data * ndata ################ train the generator ################ # generate fake signal d_score_fake = d_model(fake_signal) # maximize the d-score for the fake signal if gan_criteria in ["hinge", "wgan-gp"]: g_loss = -d_score_fake.mean() elif gan_criteria == "bce": g_loss = torch.nn.BCELoss()(d_score_fake, real_label) # book keeping g_loss_total += g_loss.data * ndata ################ train the mapper ################ # get the signal from the parameters predict_signal = g_model(params) # calculate the loss function if spv_criteria == "mse": sig_err = (predict_signal - signal) m_loss = (sig_err * sig_err).mean() # calculate the total loss function that the generator will be # trained on if epoch >= train_g_after: mg_loss = m_loss + lambda_g * g_loss else: mg_loss = m_loss # backprop the mapper model if phase == "train": g_model.zero_grad() g_opt.zero_grad() mg_loss.backward() g_opt.step() # book keeping m_loss_total += m_loss.data * ndata # finish one part of the epoch (either train or val) # get the mean values d_loss_real_mean[phase] = d_loss_real_total / ndata_total d_loss_fake_mean[phase] = d_loss_fake_total / ndata_total g_loss_mean[phase] = g_loss_total / ndata_total m_loss_mean[phase] = m_loss_total / ndata_total # copy the best model if phase == "val" and m_loss_mean[phase] < best_loss: best_loss = m_loss_mean[phase].data best_model_weights = get_weights(g_model, d_model) # save the model if save_wts_to is not None: mkutils.save(best_model_weights, save_wts_to) # finish one epoch # print the message if verbose > 0: print("Done in %fs (best val loss: %.3e)" % (time.time() - since, best_loss)) print("D-loss real: (train) %.3e, (val) %.3e" % \ (d_loss_real_mean["train"], d_loss_real_mean["val"])) print("D-loss fake: (train) %.3e, (val) %.3e" % \ (d_loss_fake_mean["train"], d_loss_fake_mean["val"])) print("G-loss fake: (train) %.3e, (val) %.3e" % \ (g_loss_mean["train"], g_loss_mean["val"])) print("M-loss : (train) %.3e, (val) %.3e" % \ (m_loss_mean["train"], m_loss_mean["val"])) if return_history: d_losses_real_train.append(d_loss_real_mean["train"].data) d_losses_fake_train.append(d_loss_fake_mean["train"].data) g_losses_train.append(g_loss_mean["train"].data) m_losses_train.append(m_loss_mean["train"].data) d_losses_real_val.append(d_loss_real_mean["val"].data) d_losses_fake_val.append(d_loss_fake_mean["val"].data) g_losses_val.append(g_loss_mean["val"].data) m_losses_val.append(m_loss_mean["val"].data) # finish all epochs if verbose >= 1: time_elapsed = time.time() - since print("Training complete in %fs" % time_elapsed) print("Best val loss: %.4f" % best_loss) # load the best models g_model.load_state_dict(best_model_weights[0]) d_model.load_state_dict(best_model_weights[1]) if return_history: return g_model, d_model, best_loss, \ d_losses_real_train, d_losses_fake_train, g_losses_train, m_losses_train, \ d_losses_real_val, d_losses_fake_val, g_losses_val, m_losses_val return g_model, d_model, best_loss
def train(model, dataloaders, criteria, optimizer, scheduler=None, dvbatch=1, num_epochs=25, device=None, verbose=1, plot=0, save_wts_to=None, save_model_to=None, return_history=False, return_best_last=9e99, revert_every=9e99, train_update_every=1): """ Performs a training of the model. Args: model : A torch trainable class method that accepts "inputs" and returns prediction of "outputs". The model needs to return the output and the logprobability. dataloaders (dict or torch.utils.data.DataLoader): Dictionary with two keys: ["train", "val"] with every value is an iterable with two outputs: (1) the "inputs" to the model and (2) the ground truth of the "outputs". If it is a DataLoader, then it's only for the training, nothing for validation. criteria (dict or callable or deepmk.criteria): Dictionary with two keys: ["train", "val"] with every value is a callable or deepmk.criteria to calculate the criterion for the corresponding phase. If it is not a dictionary, then the criterion is set for both training and validation phases. If it is a callable, it is wrapped by deepmk.criteria.MeanCriterion object to calculate the mean criterion. The criterion for the training needs to be differentiable and it will be minimized during the training. optimizer (torch.optim optimizer or dict): Optimizer class in training the model. If it is a dictionary, it must have "train" and "val" keys and it makes it a meta-learning problem. scheduler (torch.optim.lr_scheduler object or dict): Scheduler of how the learning rate is evolving through the epochs. If it is None, it does not update the learning rate. It can be a dictionary like the optimizer argument. (default: None) dvbatch (int): If differentiable validation applies, it averages the loss by this many before applying the backprop. (default: 1) num_epochs (int): The number of epochs in training. (default: 25) device : Device where to do the training. None to choose cuda:0 if available, otherwise, cpu. (default: None) verbose (int): The level of verbosity from 0 to 1. (default: 1) plot (int): Whether to plot the loss of training and validation data. (default: 0) save_wts_to (str): Name of a file to save the best model's weights. If None, then do not save. (default: None) save_model_to (str): Name of a file to save the best whole model. If None, then do not save. (default: None) return_history (bool): A flag to indicate whether the training and validation losses history will be returned. (default: False) return_best_last (int): Return the best model over the last `return_best_last` epochs. (default: 9e99) revert_every (int): Revert the model to the best model every this steps when the better is not found. (default: 9e99) Returns: best_model : The trained model with the lowest loss criterion during "val" phase """ # get the device if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Using device:") print(device) # check some variables if they are dictionary with "train" and "val" keys def _check(var, name): if var is None: return if not (type(var) == dict and "train" in var and "val" in var): raise TypeError("The variable %s must be a dictionary with " "'train' and 'val' in it") _check(optimizer, "optimizer") _check(scheduler, "scheduler") _check(dataloaders, "dataloaders") _check(criteria, "criteria") # set interactive plot if plot: plt.ion() for phase in ["train", "val"]: if not issubclass(criteria[phase].__class__, deepmk.criteria.Criterion): criteria[phase] = deepmk.criteria.MeanCriterion(criteria[phase]) criteria[phase].reset() # prepare the memory of the last best weights if return_best_last < num_epochs: weights_history = [None for _ in range(return_best_last)] # load the model to the device first model = model.to(device) if verbose >= 1: since = time.time() best_model_weights = copy.deepcopy(model.state_dict()) best_loss = np.inf train_losses = [] val_losses = [] total_batches = len(dataloaders["train"]) + len(dataloaders["val"]) try: best_epoch = 0 for epoch in range(num_epochs): if verbose >= 1: print("Epoch %d/%d" % (epoch + 1, num_epochs)) print("-" * 10) # to time the progress epoch_start_time = time.time() # progress counter num_batches = 0 # num batches in training and validation if verbose >= 2: progress_disp = mkutils.ProgressDisplay() # to store the losses in validation for REINFORCE losses = torch.zeros(len(dataloaders["val"])).to(device) logps = torch.zeros(len(dataloaders["val"])).to(device) # every epoch has a training and a validation phase sum_dval_loss = 0.0 count_dval = 0 optimizer["train"].zero_grad() optimizer["val"].zero_grad() for phase in ["train", "val"]: # skip phase if the dataloaders for the current phase is empty if dataloaders[phase] == []: continue # set the model's mode if scheduler is not None: scheduler[phase].step() if phase == "val" and "diffval" in scheduler: scheduler["diffval"].step() model.train() # the total loss during this epoch running_loss = 0.0 # iterate over the data dataset_size = 0 # reset the criteria before the training epoch starts criteria[phase].reset() count_i = 0 count_train_update = 0 for inputs, labels in dataloaders[phase]: count_train_update += 1 # get the size of the dataset dataset_size += inputs.size(0) num_batches += 1 # write the progress bar if verbose >= 2: progress_disp.show(num_batches, total_batches) # load the inputs and the labels to the working device inputs = inputs.to(device) labels = labels.to(device) # reset the model gradient to 0 if phase == "val": optimizer["val"].zero_grad() if "diffval" in optimizer: optimizer["diffval"].zero_grad() # forward outputs, logp = model(inputs) loss = criteria[phase].feed(outputs, labels) # backward gradient computation and optimize in training if phase == "train": loss.backward() if count_train_update % train_update_every == 0 or \ count_train_update == len(dataloaders[phase]): optimizer["train"].step() optimizer["train"].zero_grad() else: if "diffval" in optimizer: sum_dval_loss = sum_dval_loss + loss count_dval += 1 # applying the backprop if count_dval == dvbatch: mean_dval_loss = sum_dval_loss / count_dval mean_dval_loss.backward() optimizer["diffval"].step() count_dval = 0 sum_dval_loss = 0.0 # we need the gradient for logp, but not for loss losses[count_i] += loss.data logps[count_i] += logp count_i += 1 # apply backprop if there's still diff validation left if count_dval != 0: mean_dval_loss = sum_dval_loss / count_dval mean_dval_loss.backward() optimizer["diffval"].step() count_dval = 0 sum_dval_loss = 0.0 # do the reinforce if phase == "val": # transform the loss into some ranking function (min loss lower) normlosses = get_normloss(losses) # we choose sum instead of mean because the training step # is only done once, so we want to make it larger # (it is approximately mean, but doing it for every batch) loss = (normlosses * logps).sum() loss.backward() optimizer[phase].step() # get the mean loss in this epoch mult = -1 if (criteria[phase].best == "max") else 1 crit_val = criteria[phase].getval() epoch_loss = mult * crit_val # save the losses if phase == "train": train_losses.append(crit_val.data) elif phase == "val": val_losses.append(crit_val.data) # save the model history if return_best_last < num_epochs: weights_history[epoch % return_best_last] = copy.deepcopy( model.state_dict()) # copy the best model if phase == "val" and \ ((epoch_loss < best_loss) or \ (epoch - best_epoch > return_best_last) or \ (epoch - best_epoch > revert_every)): if epoch - best_epoch > return_best_last: # get the index of the next best last val_losses_n = val_losses[-return_best_last:] min_idx_rel = np.argmin(val_losses_n) min_idx = min_idx_rel + len( val_losses) - return_best_last # get the best conditions best_epoch = min_idx best_model_weights = weights_history[best_epoch % return_best_last] elif epoch - best_epoch > revert_every: # revert the model to the best model model.load_state_dict(best_model_weights) else: best_epoch = epoch best_model_weights = copy.deepcopy(model.state_dict()) # save the best conditions best_loss = val_losses[best_epoch] # save the model _save_wts(best_model_weights, save_wts_to) # show the loss in the current epoch if verbose >= 1: print("train %s: %.4e, val %s: %.4e, done in %fs (best val: %.3e)" % \ (criteria["train"].name, train_losses[-1], criteria["val"].name, val_losses[-1], time.time()-since, best_loss)) # plot the losses if plot: xs_plot = range(1, epoch + 2) plt.clf() plt.plot(xs_plot, train_losses, 'o-') plt.plot(xs_plot, val_losses, 'o-') plt.legend(["Train", "Validation"]) plt.xlabel("Epoch") plt.ylabel("Loss") plt.pause(0.001) print("") except KeyboardInterrupt: print("Interrupted. Returning the results.") if verbose >= 1: time_elapsed = time.time() - since print("Training complete in %fs" % time_elapsed) print("Best val loss: %.4f" % best_loss) # return the model model.load_state_dict(best_model_weights) if return_history: return model, best_loss, train_losses, val_losses return model, best_loss