Example #1
0
    def train(self):
        if self.verbose:
            progress_disp = mkutils.ProgressDisplay()

        self.start()

        for i in range(self.numepochs):

            self.start_epoch()

            # display the epoch information
            if self.verbose:
                s = "Epoch %6d/%6d" % (i+1, numepochs)
                print(s)
                print("-" * len(s))

            total_batches = len(dataloader)
            for j,(phase, data) in enumerate(dataloader):

                self.update_phase(phase, data)

                # show the progress bar
                if self.verbose:
                    progress_disp.show(j+1, total_batches)

            self.end_epoch()

            # add more blank spaces
            if self.verbose:
                print("")

        self.end()
Example #2
0
def validate(model, dataloader, val_criterion, device=None, verbose=1,
             load_wts_from=None):
    # get the device
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load the model to the device
    model = model.to(device)

    # load the weights
    if load_wts_from is not None:
        model.load_state_dict(torch.load(load_wts_from))

    # set to evaluation mode
    model.eval()

    # reset the validation criterion
    val_criterion.reset()

    num_batches = 0 # num batches
    total_batches = len(dataloader)
    if verbose >= 2:
        progress_disp = mkutils.ProgressDisplay()
    for inputs, labels in dataloader:
        num_batches += 1

        # load the data to the device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # calculate the validation criterion
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            val_criterion.feed(outputs, labels)

        # write the progress bar
        if verbose >= 2:
            progress_disp.show(num_batches, total_batches)

    print("Validation with %s criterion: %e" %
          (val_criterion.name, val_criterion.getval()))
    return float(val_criterion.getval())
Example #3
0
def train(model, dataloaders, criteria, optimizer, scheduler=None,
          num_epochs=25, device=None, verbose=1, plot=0, save_wts_to=None,
          save_model_to=None, return_history=False, return_best_last=9e99):
    """
    Performs a training of the model.

    Args:
        model :
            A torch trainable class method that accepts "inputs" and returns
            prediction of "outputs".
        dataloaders (dict or torch.utils.data.DataLoader):
            Dictionary with two keys: ["train", "val"] with every value is an
            iterable with two outputs: (1) the "inputs" to the model and (2) the
            ground truth of the "outputs". If it is a DataLoader, then it's only
            for the training, nothing for validation.
        criteria (dict or callable or deepmk.criteria):
            Dictionary with two keys: ["train", "val"] with every value is a
            callable or deepmk.criteria to calculate the criterion for the
            corresponding phase. If it is not a dictionary, then the criterion
            is set for both training and validation phases.
            If it is a callable, it is wrapped by deepmk.criteria.MeanCriterion
            object to calculate the mean criterion.
            The criterion for the training needs to be differentiable and it
            will be minimized during the training.
        optimizer (torch.optim optimizer or dict):
            Optimizer class in training the model. If it is a dictionary, it
            must have "train" and "val" keys and it makes it a meta-learning
            problem.
        scheduler (torch.optim.lr_scheduler object or dict):
            Scheduler of how the learning rate is evolving through the epochs. If it
            is None, it does not update the learning rate. It can be a dictionary
            like the optimizer argument. (default: None)
        num_epochs (int):
            The number of epochs in training. (default: 25)
        device :
            Device where to do the training. None to choose cuda:0 if available,
            otherwise, cpu. (default: None)
        verbose (int):
            The level of verbosity from 0 to 1. (default: 1)
        plot (int):
            Whether to plot the loss of training and validation data. (default: 0)
        save_wts_to (str):
            Name of a file to save the best model's weights. If None, then do not
            save. (default: None)
        save_model_to (str):
            Name of a file to save the best whole model. If None, then do not save.
            (default: None)
        return_history (bool):
            A flag to indicate whether the training and validation losses history
            will be returned. (default: False)
        return_best_last (int):
            Return the best model over the last `return_best_last` epochs.
            (default: 9e99)

    Returns:
        best_model :
            The trained model with the lowest loss criterion during "val" phase
    """
    # get the device
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device:")
    print(device)

    # check optimizer and scheduler types and decide if this is a meta
    # learning problem
    metalearning = _check_opt_sched(optimizer, scheduler)
    if metalearning and verbose >= 1:
        print("We are doing meta-learning")

    # set interactive plot
    if plot:
        plt.ion()

    # check if the dataloader is for validation as well
    if type(dataloaders) != dict:
        dataloaders = {"train": dataloaders, "val": []}

    # set the criteria object right
    if type(criteria) != dict:
        criteria = {"train": criteria, "val": criteria}
    for phase in ["train", "val"]:
        if not issubclass(criteria[phase].__class__, deepmk.criteria.Criterion):
            criteria[phase] = deepmk.criteria.MeanCriterion(criteria[phase])
        criteria[phase].reset()

    # prepare the memory of the last best weights
    if return_best_last < num_epochs:
        weights_history = [None for _ in range(return_best_last)]

    # load the model to the device first
    model = model.to(device)

    if verbose >= 1:
        since = time.time()
    best_model_weights = copy.deepcopy(model.state_dict())
    best_loss = np.inf

    train_losses = []
    val_losses = []

    total_batches = len(dataloaders["train"]) + len(dataloaders["val"])
    try:
        best_epoch = 0
        for epoch in range(num_epochs):
            if verbose >= 1:
                print("Epoch %d/%d" % (epoch+1, num_epochs))
                print("-"*10)

            # to time the progress
            epoch_start_time = time.time()

            # progress counter
            num_batches = 0 # num batches in training and validation
            if verbose >= 2:
                progress_disp = mkutils.ProgressDisplay()

            # every epoch has a training and a validation phase
            for phase in ["train", "val"]:

                # skip phase if the dataloaders for the current phase is empty
                if dataloaders[phase] == []: continue

                # set the model's mode
                if not metalearning:
                    if phase == "train":
                        if scheduler is not None:
                            scheduler.step() # adjust the training learning rate
                        model.train() # set the model to the training mode
                    else:
                        model.eval() # set the model to the evaluation mode
                else:
                    if scheduler is not None:
                        scheduler[phase].step()
                    model.train()

                # the total loss during this epoch
                running_loss = 0.0

                # iterate over the data
                dataset_size = 0

                # reset the criteria before the training epoch starts
                criteria[phase].reset()
                for inputs, labels in dataloaders[phase]:
                    # get the size of the dataset
                    dataset_size += inputs.size(0)
                    num_batches += 1

                    # write the progress bar
                    if verbose >= 2:
                        progress_disp.show(num_batches, total_batches)

                    # load the inputs and the labels to the working device
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # reset the model gradient to 0
                    if not metalearning:
                        optimizer.zero_grad()
                    else:
                        optimizer["train"].zero_grad()
                        optimizer["val"].zero_grad()

                    # forward
                    # track history if only in train
                    grad_enabled = (phase == "train" or metalearning)
                    with torch.set_grad_enabled(grad_enabled):
                        outputs = model(inputs)
                        loss = criteria[phase].feed(outputs, labels)

                        # backward gradient computation and optimize in training
                        if not metalearning:
                            if phase == "train":
                                loss.backward()
                                optimizer.step()
                        else:
                            loss.backward()
                            optimizer[phase].step()

                # get the mean loss in this epoch
                mult = -1 if (criteria[phase].best == "max") else 1
                crit_val = criteria[phase].getval()
                epoch_loss = mult * crit_val

                # save the losses
                if phase == "train":
                    train_losses.append(crit_val.data)
                elif phase == "val":
                    val_losses.append(crit_val.data)

                # save the model history
                if return_best_last < num_epochs:
                    weights_history[epoch % return_best_last] = copy.deepcopy(model.state_dict())

                # copy the best model
                if phase == "val" and \
                        ((epoch_loss < best_loss) or \
                         (epoch - best_epoch > return_best_last)):
                    if epoch - best_epoch > return_best_last:
                        # get the index of the next best last
                        val_losses_n = val_losses[-return_best_last:]
                        min_idx_rel = np.argmin(val_losses_n)
                        min_idx = min_idx_rel + len(val_losses) - return_best_last

                        # get the best conditions
                        best_epoch = min_idx
                        best_model_weights = weights_history[best_epoch % return_best_last]
                    else:
                        best_epoch = epoch
                        best_model_weights = copy.deepcopy(model.state_dict())

                    # save the best conditions
                    best_loss = val_losses[best_epoch]

                    # save the model
                    _save_wts(best_model_weights, save_wts_to)

            # show the loss in the current epoch
            if verbose >= 1:
                print("train %s: %.4e, val %s: %.4e, done in %fs (best val: %.3e)" % \
                      (criteria["train"].name, train_losses[-1],
                       criteria["val"].name, val_losses[-1],
                       time.time()-since, best_loss))
            # plot the losses
            if plot:
                xs_plot = range(1,epoch+2)
                plt.clf()
                plt.plot(xs_plot, train_losses, 'o-')
                plt.plot(xs_plot, val_losses, 'o-')
                plt.legend(["Train", "Validation"])
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.pause(0.001)

            print("")
    except KeyboardInterrupt:
        print("Interrupted. Returning the results.")

    if verbose >= 1:
        time_elapsed = time.time()- since
        print("Training complete in %fs" % time_elapsed)
        print("Best val loss: %.4f" % best_loss)

    # return the model
    model.load_state_dict(best_model_weights)
    if return_history:
        return model, best_loss, train_losses, val_losses
    return model, best_loss
Example #4
0
def train(g_model,
          d_model,
          dataloaders,
          lambda_g,
          g_opt,
          d_opt,
          train_g_after=0,
          g_sched=None,
          d_sched=None,
          gan_criteria="hinge",
          spv_criteria="mse",
          num_epochs=25,
          device=None,
          verbose=1,
          plot=0,
          save_wts_to=None,
          return_history=False):
    """
    Performs a supervised + GAN training procedure. The generative and
    discriminative models are trained with GAN procedure while the mapper
    is trained with supervised procedure.
    In making the prediction, `m_model` is concatenated with `g_model` to
    generate signal from a given set of parameters.

    In one training batch:
    * `d_model` is trained by maximizing d-score for real and minimizing for
                fake signal.
    * `g_model` is trained by maximizing d-score for its generated signal and
                minimizing from the supervised data.

    Args:
        g_model :
            A torch trainable generative model from the parameters space to the
            signal space.
        d_model :
            A torch trainable discriminative model that receives the signal
            as the input and gives low score for fake and high score for real.
        dataloaders (dict or torch.utils.data.DataLoader):
            Dictionary with two keys: ["train", "val"] with every value is an
            iterable with two outputs: (1) the "inputs" to the model and (2) the
            ground truth of the "outputs". If it is a DataLoader, then it's only
            for the training, nothing for validation.
        lambda_g (float):
            The penalty factor of the discriminator regularization.
        g_opt, d_opt (torch.optim optimizer):
            Optimizer class in training the g_model and d_model, resp.
        g_sched, d_sched (torch.optim.lr_scheduler object):
            Optimizer scheduler in training the g_model, d_model, resp.
            Default: None.
        gan_criteria (str, optional):
            Criteria in training GAN. For now, the option is only "hinge".
            Default: "hinge".
        spv_criteria (str,optional):
            Criteria in the supervised training. For now, the option is only
            "mse". Default: "mse".
        num_epochs (int):
            The number of epochs in training. (default: 25)
        device :
            Device where to do the training. None to choose cuda:0 if available,
            otherwise, cpu. (default: None)
        verbose (int):
            The level of verbosity from 0 to 1. (default: 1)
        save_wts_to (str):
            Name of a file to save the best model's weights. If None, then do not
            save. (default: None)
        return_history (bool):
            A flag to indicate whether the training and validation losses history
            will be returned. (default: False)

    Returns:
        best_model :
            The trained model with the lowest loss criterion during "val" phase
    """
    lambda_gp = 10.0

    # get the device
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device:")
    print(device)

    # check if the dataloader is for validation as well
    if type(dataloaders) != dict:
        dataloaders = {"train": dataloaders, "val": []}

    # load the model to the device first
    g_model = g_model.to(device)
    d_model = d_model.to(device)

    if verbose >= 1:
        since = time.time()
    best_model_weights = get_weights(g_model, d_model)
    best_loss = np.inf

    train_losses = []
    val_losses = []

    # book keeping scores
    d_loss_real_mean = {"train": 0.0, "val": 0.0}
    d_loss_fake_mean = {"train": 0.0, "val": 0.0}
    g_loss_mean = {"train": 0.0, "val": 0.0}
    m_loss_mean = {"train": 0.0, "val": 0.0}

    if return_history:
        d_losses_real_train = []
        d_losses_fake_train = []
        g_losses_train = []
        m_losses_train = []
        d_losses_real_val = []
        d_losses_fake_val = []
        g_losses_val = []
        m_losses_val = []
    total_batches = len(dataloaders["train"]) + len(dataloaders["val"])
    for epoch in range(num_epochs):
        if verbose >= 1:
            print("Epoch %d/%d" % (epoch + 1, num_epochs))
            print("-" * 10)

        if verbose >= 2:
            progress_disp = mkutils.ProgressDisplay()

        # to time the progress
        epoch_start_time = time.time()

        # progress counter
        num_batches = 0  # num batches in training and validation
        if verbose >= 2:
            progress_disp = mkutils.ProgressDisplay()

        # every epoch has a training and a validation phase
        for phase in ["train", "val"]:

            # skip phase if the dataloaders for the current phase is empty
            if dataloaders[phase] == []: continue

            # set the model's mode
            if phase == "train":
                if g_sched is not None:
                    g_sched.step()
                if d_sched is not None:
                    d_sched.step()
                # set the model to the training mode
                g_model.train()
                d_model.train()
            else:
                # set the model to the evaluation mode
                g_model.eval()
                d_model.eval()

            # book keeping score
            d_loss_real_total = 0.0
            d_loss_fake_total = 0.0
            g_loss_total = 0.0
            m_loss_total = 0.0
            ndata_total = 0
            for params, signal in dataloaders[phase]:
                # write the progress bar
                num_batches += 1
                if verbose >= 2:
                    progress_disp.show(num_batches, total_batches)

                batch_size = params.shape[0]
                ndata = batch_size
                ndata_total += ndata

                # load to device
                params = params.to(device)
                signal = signal.to(device)

                ################ train the discriminator ################
                # calculate the d-scores for real and fake signals
                d_score_real = d_model(signal)
                z = torch.rand((params.shape[0], params.shape[1])).to(device)
                fake_signal = g_model(z)
                d_score_fake = d_model(fake_signal.detach())

                # maximizing score for the real signal
                # minimizing score for the fake signal
                if gan_criteria == "hinge":
                    d_loss_real = torch.clamp(1.0 - d_score_real, 0.0).mean()
                    d_loss_fake = torch.clamp(1.0 + d_score_fake, 0.0).mean()
                elif gan_criteria == "wgan-gp":
                    d_loss_real = -d_score_real.mean()
                    d_loss_fake = d_score_fake.mean()
                elif gan_criteria == "bce":
                    real_label = torch.full((batch_size, ), 1, device=device)
                    fake_label = torch.full((batch_size, ), 0, device=device)
                    d_loss_real = torch.nn.BCELoss()(d_score_real, real_label)
                    d_loss_fake = torch.nn.BCELoss()(d_score_fake, fake_label)

                # backprop the discriminator
                d_loss = d_loss_fake + d_loss_real
                if phase == "train":
                    d_model.zero_grad()
                    d_opt.zero_grad()
                    d_loss.backward()
                    d_opt.step()

                if gan_criteria == "wgan-gp":
                    alpha = torch.rand(signal.shape[0], 1,
                                       1).to(device).expand_as(signal)
                    interpolated = torch.zeros_like(signal)
                    interpolated.data = alpha * signal.data + (
                        1 - alpha) * fake_signal.data
                    interpolated.requires_grad = True

                    d_interp = d_model(interpolated)
                    grad = torch.autograd.grad(outputs=d_interp,
                                               inputs=interpolated,
                                               grad_outputs=torch.ones(
                                                   d_interp.size()).cuda(),
                                               retain_graph=True,
                                               create_graph=True,
                                               only_inputs=True)[0]

                    grad = grad.view(grad.size(0), -1)
                    grad_l2norm = torch.sqrt(torch.sum(grad * grad, dim=1))
                    d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                    # backward + optimize
                    d_loss = lambda_gp * d_loss_gp
                    d_opt.zero_grad()
                    d_loss.backward()
                    d_opt.step()

                # book keeping
                d_loss_real_total += d_loss_real.data * ndata
                d_loss_fake_total += d_loss_fake.data * ndata

                ################ train the generator ################
                # generate fake signal
                d_score_fake = d_model(fake_signal)

                # maximize the d-score for the fake signal
                if gan_criteria in ["hinge", "wgan-gp"]:
                    g_loss = -d_score_fake.mean()
                elif gan_criteria == "bce":
                    g_loss = torch.nn.BCELoss()(d_score_fake, real_label)

                # book keeping
                g_loss_total += g_loss.data * ndata

                ################ train the mapper ################
                # get the signal from the parameters
                predict_signal = g_model(params)

                # calculate the loss function
                if spv_criteria == "mse":
                    sig_err = (predict_signal - signal)
                    m_loss = (sig_err * sig_err).mean()

                # calculate the total loss function that the generator will be
                # trained on
                if epoch >= train_g_after:
                    mg_loss = m_loss + lambda_g * g_loss
                else:
                    mg_loss = m_loss

                # backprop the mapper model
                if phase == "train":
                    g_model.zero_grad()
                    g_opt.zero_grad()

                    mg_loss.backward()
                    g_opt.step()

                # book keeping
                m_loss_total += m_loss.data * ndata

            # finish one part of the epoch (either train or val)
            # get the mean values
            d_loss_real_mean[phase] = d_loss_real_total / ndata_total
            d_loss_fake_mean[phase] = d_loss_fake_total / ndata_total
            g_loss_mean[phase] = g_loss_total / ndata_total
            m_loss_mean[phase] = m_loss_total / ndata_total

            # copy the best model
            if phase == "val" and m_loss_mean[phase] < best_loss:
                best_loss = m_loss_mean[phase].data
                best_model_weights = get_weights(g_model, d_model)

                # save the model
                if save_wts_to is not None:
                    mkutils.save(best_model_weights, save_wts_to)

        # finish one epoch

        # print the message
        if verbose > 0:
            print("Done in %fs (best val loss: %.3e)" %
                  (time.time() - since, best_loss))
            print("D-loss real: (train) %.3e, (val) %.3e" % \
                (d_loss_real_mean["train"], d_loss_real_mean["val"]))
            print("D-loss fake: (train) %.3e, (val) %.3e" % \
                (d_loss_fake_mean["train"], d_loss_fake_mean["val"]))
            print("G-loss fake: (train) %.3e, (val) %.3e" % \
                (g_loss_mean["train"], g_loss_mean["val"]))
            print("M-loss     : (train) %.3e, (val) %.3e" % \
                (m_loss_mean["train"], m_loss_mean["val"]))

        if return_history:
            d_losses_real_train.append(d_loss_real_mean["train"].data)
            d_losses_fake_train.append(d_loss_fake_mean["train"].data)
            g_losses_train.append(g_loss_mean["train"].data)
            m_losses_train.append(m_loss_mean["train"].data)
            d_losses_real_val.append(d_loss_real_mean["val"].data)
            d_losses_fake_val.append(d_loss_fake_mean["val"].data)
            g_losses_val.append(g_loss_mean["val"].data)
            m_losses_val.append(m_loss_mean["val"].data)

    # finish all epochs
    if verbose >= 1:
        time_elapsed = time.time() - since
        print("Training complete in %fs" % time_elapsed)
        print("Best val loss: %.4f" % best_loss)

    # load the best models
    g_model.load_state_dict(best_model_weights[0])
    d_model.load_state_dict(best_model_weights[1])
    if return_history:
        return g_model, d_model, best_loss, \
               d_losses_real_train, d_losses_fake_train, g_losses_train, m_losses_train, \
               d_losses_real_val, d_losses_fake_val, g_losses_val, m_losses_val
    return g_model, d_model, best_loss
Example #5
0
def train(model,
          dataloaders,
          criteria,
          optimizer,
          scheduler=None,
          dvbatch=1,
          num_epochs=25,
          device=None,
          verbose=1,
          plot=0,
          save_wts_to=None,
          save_model_to=None,
          return_history=False,
          return_best_last=9e99,
          revert_every=9e99,
          train_update_every=1):
    """
    Performs a training of the model.

    Args:
        model :
            A torch trainable class method that accepts "inputs" and returns
            prediction of "outputs".
            The model needs to return the output and the logprobability.
        dataloaders (dict or torch.utils.data.DataLoader):
            Dictionary with two keys: ["train", "val"] with every value is an
            iterable with two outputs: (1) the "inputs" to the model and (2) the
            ground truth of the "outputs". If it is a DataLoader, then it's only
            for the training, nothing for validation.
        criteria (dict or callable or deepmk.criteria):
            Dictionary with two keys: ["train", "val"] with every value is a
            callable or deepmk.criteria to calculate the criterion for the
            corresponding phase. If it is not a dictionary, then the criterion
            is set for both training and validation phases.
            If it is a callable, it is wrapped by deepmk.criteria.MeanCriterion
            object to calculate the mean criterion.
            The criterion for the training needs to be differentiable and it
            will be minimized during the training.
        optimizer (torch.optim optimizer or dict):
            Optimizer class in training the model. If it is a dictionary, it
            must have "train" and "val" keys and it makes it a meta-learning
            problem.
        scheduler (torch.optim.lr_scheduler object or dict):
            Scheduler of how the learning rate is evolving through the epochs. If it
            is None, it does not update the learning rate. It can be a dictionary
            like the optimizer argument. (default: None)
        dvbatch (int):
            If differentiable validation applies, it averages the loss by this
            many before applying the backprop. (default: 1)
        num_epochs (int):
            The number of epochs in training. (default: 25)
        device :
            Device where to do the training. None to choose cuda:0 if available,
            otherwise, cpu. (default: None)
        verbose (int):
            The level of verbosity from 0 to 1. (default: 1)
        plot (int):
            Whether to plot the loss of training and validation data. (default: 0)
        save_wts_to (str):
            Name of a file to save the best model's weights. If None, then do not
            save. (default: None)
        save_model_to (str):
            Name of a file to save the best whole model. If None, then do not save.
            (default: None)
        return_history (bool):
            A flag to indicate whether the training and validation losses history
            will be returned. (default: False)
        return_best_last (int):
            Return the best model over the last `return_best_last` epochs.
            (default: 9e99)
        revert_every (int):
            Revert the model to the best model every this steps when the better
            is not found.
            (default: 9e99)

    Returns:
        best_model :
            The trained model with the lowest loss criterion during "val" phase
    """
    # get the device
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device:")
    print(device)

    # check some variables if they are dictionary with "train" and "val" keys
    def _check(var, name):
        if var is None: return
        if not (type(var) == dict and "train" in var and "val" in var):
            raise TypeError("The variable %s must be a dictionary with "
                            "'train' and 'val' in it")

    _check(optimizer, "optimizer")
    _check(scheduler, "scheduler")
    _check(dataloaders, "dataloaders")
    _check(criteria, "criteria")

    # set interactive plot
    if plot:
        plt.ion()

    for phase in ["train", "val"]:
        if not issubclass(criteria[phase].__class__,
                          deepmk.criteria.Criterion):
            criteria[phase] = deepmk.criteria.MeanCriterion(criteria[phase])
        criteria[phase].reset()

    # prepare the memory of the last best weights
    if return_best_last < num_epochs:
        weights_history = [None for _ in range(return_best_last)]

    # load the model to the device first
    model = model.to(device)

    if verbose >= 1:
        since = time.time()
    best_model_weights = copy.deepcopy(model.state_dict())
    best_loss = np.inf

    train_losses = []
    val_losses = []

    total_batches = len(dataloaders["train"]) + len(dataloaders["val"])
    try:
        best_epoch = 0
        for epoch in range(num_epochs):
            if verbose >= 1:
                print("Epoch %d/%d" % (epoch + 1, num_epochs))
                print("-" * 10)

            # to time the progress
            epoch_start_time = time.time()

            # progress counter
            num_batches = 0  # num batches in training and validation
            if verbose >= 2:
                progress_disp = mkutils.ProgressDisplay()

            # to store the losses in validation for REINFORCE
            losses = torch.zeros(len(dataloaders["val"])).to(device)
            logps = torch.zeros(len(dataloaders["val"])).to(device)

            # every epoch has a training and a validation phase
            sum_dval_loss = 0.0
            count_dval = 0
            optimizer["train"].zero_grad()
            optimizer["val"].zero_grad()
            for phase in ["train", "val"]:

                # skip phase if the dataloaders for the current phase is empty
                if dataloaders[phase] == []: continue

                # set the model's mode
                if scheduler is not None:
                    scheduler[phase].step()
                    if phase == "val" and "diffval" in scheduler:
                        scheduler["diffval"].step()
                model.train()

                # the total loss during this epoch
                running_loss = 0.0

                # iterate over the data
                dataset_size = 0

                # reset the criteria before the training epoch starts
                criteria[phase].reset()
                count_i = 0
                count_train_update = 0
                for inputs, labels in dataloaders[phase]:
                    count_train_update += 1

                    # get the size of the dataset
                    dataset_size += inputs.size(0)
                    num_batches += 1

                    # write the progress bar
                    if verbose >= 2:
                        progress_disp.show(num_batches, total_batches)

                    # load the inputs and the labels to the working device
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # reset the model gradient to 0
                    if phase == "val":
                        optimizer["val"].zero_grad()
                        if "diffval" in optimizer:
                            optimizer["diffval"].zero_grad()

                    # forward
                    outputs, logp = model(inputs)
                    loss = criteria[phase].feed(outputs, labels)

                    # backward gradient computation and optimize in training
                    if phase == "train":
                        loss.backward()
                        if count_train_update % train_update_every == 0 or \
                           count_train_update == len(dataloaders[phase]):
                            optimizer["train"].step()
                            optimizer["train"].zero_grad()
                    else:
                        if "diffval" in optimizer:
                            sum_dval_loss = sum_dval_loss + loss
                            count_dval += 1

                            # applying the backprop
                            if count_dval == dvbatch:
                                mean_dval_loss = sum_dval_loss / count_dval
                                mean_dval_loss.backward()
                                optimizer["diffval"].step()
                                count_dval = 0
                                sum_dval_loss = 0.0

                        # we need the gradient for logp, but not for loss
                        losses[count_i] += loss.data
                        logps[count_i] += logp
                        count_i += 1

                # apply backprop if there's still diff validation left
                if count_dval != 0:
                    mean_dval_loss = sum_dval_loss / count_dval
                    mean_dval_loss.backward()
                    optimizer["diffval"].step()
                    count_dval = 0
                    sum_dval_loss = 0.0

                # do the reinforce
                if phase == "val":
                    # transform the loss into some ranking function (min loss lower)
                    normlosses = get_normloss(losses)
                    # we choose sum instead of mean because the training step
                    # is only done once, so we want to make it larger
                    # (it is approximately mean, but doing it for every batch)
                    loss = (normlosses * logps).sum()
                    loss.backward()
                    optimizer[phase].step()

                # get the mean loss in this epoch
                mult = -1 if (criteria[phase].best == "max") else 1
                crit_val = criteria[phase].getval()
                epoch_loss = mult * crit_val

                # save the losses
                if phase == "train":
                    train_losses.append(crit_val.data)
                elif phase == "val":
                    val_losses.append(crit_val.data)

                # save the model history
                if return_best_last < num_epochs:
                    weights_history[epoch % return_best_last] = copy.deepcopy(
                        model.state_dict())

                # copy the best model
                if phase == "val" and \
                        ((epoch_loss < best_loss) or \
                         (epoch - best_epoch > return_best_last) or \
                         (epoch - best_epoch > revert_every)):
                    if epoch - best_epoch > return_best_last:
                        # get the index of the next best last
                        val_losses_n = val_losses[-return_best_last:]
                        min_idx_rel = np.argmin(val_losses_n)
                        min_idx = min_idx_rel + len(
                            val_losses) - return_best_last

                        # get the best conditions
                        best_epoch = min_idx
                        best_model_weights = weights_history[best_epoch %
                                                             return_best_last]

                    elif epoch - best_epoch > revert_every:
                        # revert the model to the best model
                        model.load_state_dict(best_model_weights)

                    else:
                        best_epoch = epoch
                        best_model_weights = copy.deepcopy(model.state_dict())

                    # save the best conditions
                    best_loss = val_losses[best_epoch]

                    # save the model
                    _save_wts(best_model_weights, save_wts_to)

            # show the loss in the current epoch
            if verbose >= 1:
                print("train %s: %.4e, val %s: %.4e, done in %fs (best val: %.3e)" % \
                      (criteria["train"].name, train_losses[-1],
                       criteria["val"].name, val_losses[-1],
                       time.time()-since, best_loss))
            # plot the losses
            if plot:
                xs_plot = range(1, epoch + 2)
                plt.clf()
                plt.plot(xs_plot, train_losses, 'o-')
                plt.plot(xs_plot, val_losses, 'o-')
                plt.legend(["Train", "Validation"])
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.pause(0.001)

            print("")
    except KeyboardInterrupt:
        print("Interrupted. Returning the results.")

    if verbose >= 1:
        time_elapsed = time.time() - since
        print("Training complete in %fs" % time_elapsed)
        print("Best val loss: %.4f" % best_loss)

    # return the model
    model.load_state_dict(best_model_weights)
    if return_history:
        return model, best_loss, train_losses, val_losses
    return model, best_loss