示例#1
0
def train_gan(d_model,
              g_model,
              train_dataloader,
              dev_dataloader,
              d_optimizer,
              g_optimizer,
              loss_fn,
              params,
              model_dir,
              restore_file=None):

    fig = display_results.create_figure()

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        test_samples = train(d_model, g_model, d_optimizer, g_optimizer,
                             loss_fn, train_dataloader, params, epoch, fig)

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': d_model.state_dict(),
                'optim_dict': d_optimizer.state_dict()
            },
            is_best=False,
            checkpoint=model_dir,
            ntype='d')

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': g_model.state_dict(),
                'optim_dict': g_optimizer.state_dict()
            },
            is_best=False,
            checkpoint=model_dir,
            ntype='g')

        if test_samples is not None:
            np_test_samples = np.array(test_samples)
            np_test_samples = np.around(np_test_samples * 127.5 +
                                        127.5).astype(int)
            np_test_out = (test_noise.cpu().numpy())  # .tolist()
            # np_test_out = (test_noise.numpy())  # .tolist()
            np_test_labels = (test_labels.view(test_labels.shape[0],
                                               -1).cpu().numpy())

            test_all_data = (np.concatenate(
                (np_test_samples, np_test_out, np_test_labels),
                axis=1)).tolist()
            last_csv_path = os.path.join(
                model_dir, "samples_epoch_{}.csv".format(epoch + 1))
            utils.save_incorrect_to_csv(test_all_data, last_csv_path)

    return
示例#2
0
def train_and_evaluate(model,
                       train_dataloader,
                       dev_dataloader,
                       optimizer,
                       loss_fn,
                       metrics,
                       incorrect,
                       correct_fn,
                       params,
                       model_dir,
                       restore_file=None):
    """Train the model and evaluate every epoch.

    Args:
        model: (torch.nn.Module) the neural network
        train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        dev_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches validation data
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        incorrect: a function that save all samples with incorrect classification
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir,
                                    args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_dev_acc = 0.0

    if args.early_stop:
        early_stopping = EarlyStopping(patience=round(0.1 * params.num_epochs),
                                       verbose=False)
        # early_stopping = EarlyStopping(patience=round(0.01 * params.num_epochs), verbose=False)

    fig = display_results.create_figure()

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics, params,
              epoch, fig, model_dir, losses)

        # Evaluate for one epoch on validation set
        dev_metrics, incorrect_samples, correct_samples = evaluate(
            model, loss_fn, dev_dataloader, metrics, incorrect, correct_fn,
            params, epoch)

        dev_loss = dev_metrics['loss']
        if args.early_stop:
            early_stopping(dev_loss, model)

        if args.early_stop and early_stopping.early_stop:
            # need_to_stop = True
            print("Early stopping")
            logging.info("Early stopping")
            break

        # grads_graph = collect_network_statistics(model)
        # grads_per_epoch.append(grads_graph)

        dev_acc = dev_metrics['accuracy']
        is_best = dev_acc > best_dev_acc

        grads_graph, _ = get_network_grads(model)
        vals_graph = collect_network_statistics(model)

        grads_per_epoch.append(grads_graph)
        vals_per_epoch.append(vals_graph)

        # Save weights
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            is_best=is_best,
            checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            print("Epoch {}/{}".format(epoch + 1, params.num_epochs))
            print("- Found new best accuracy")
            best_dev_acc = dev_acc
            print("accuracy is {:05.3f}".format(best_dev_acc))

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir,
                                          "metrics_dev_best_weights.json")
            utils.save_dict_to_json(dev_metrics, best_json_path, epoch + 1)

            best_inc_csv_path = os.path.join(model_dir,
                                             "incorrect_best_samples.csv")
            utils.save_incorrect_to_csv(incorrect_samples, best_inc_csv_path)

            best_c_csv_path = os.path.join(model_dir,
                                           "correct_best_samples.csv")
            utils.save_incorrect_to_csv(correct_samples, best_c_csv_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir,
                                      "metrics_dev_last_weights.json")
        utils.save_dict_to_json(dev_metrics, last_json_path, epoch + 1)

        last_inc_csv_path = os.path.join(model_dir,
                                         "incorrect_last_samples.csv")
        utils.save_incorrect_to_csv(incorrect_samples, last_inc_csv_path)

        last_c_csv_path = os.path.join(model_dir, "correct_last_samples.csv")
        utils.save_incorrect_to_csv(correct_samples, last_c_csv_path)

        # compute mean of all metrics in summary (loss, bce part, kl part)
        accuracy.append(dev_acc)
        # if isinstance(loss, torch.autograd.Variable):
        #     loss_v = loss.data.cpu().numpy()
    display_results.close_figure(fig)
    return
def train_gan(d_model, g_model, train_dataloader, d_optimizer, g_optimizer,
              r_f_loss_fn, c_loss_fn, params, model_dir):

    best_dict = {'loss': np.inf, 'accuracy': 0.0, 'prediction': 1.0}

    dest_min = 0
    dest_max = 255
    curr_min = -1
    curr_max = 1

    fig = display_results.create_figure()

    stats_dir = os.path.join(model_dir, 'stats')
    if not os.path.isdir(stats_dir):
        os.mkdir(stats_dir)

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        test_samples, loss_mean_sum, accuracy_sum, preds_sum, incorrect_samples = train(
            d_model, d_optimizer, g_model, g_optimizer, r_f_loss_fn, c_loss_fn,
            train_dataloader, params, epoch, fig)

        curr_vals_dict = {
            'loss': loss_mean_sum,
            'accuracy': accuracy_sum,
            'prediction': preds_sum
        }
        is_best_dict = {
            'loss': (curr_vals_dict['loss'] <= best_dict['loss']),
            'accuracy': (curr_vals_dict['accuracy'] >= best_dict['accuracy']),
            'prediction':
            (curr_vals_dict['prediction'] <= best_dict['prediction'])
        }

        g_grads_graph, _ = get_network_grads(g_model)
        d_grads_graph, _ = get_network_grads(d_model)
        g_vals_graph = collect_network_statistics(g_model)
        d_vals_graph = collect_network_statistics(d_model)

        grads_dict['grads_per_epoch_g'].append(g_grads_graph)
        grads_dict['grads_per_epoch_d'].append(d_grads_graph)
        vals_dict['vals_per_epoch_g'].append(g_vals_graph)
        vals_dict['vals_per_epoch_d'].append(d_vals_graph)

        for it in is_best_dict.keys():
            if is_best_dict[it]:
                logging.info("- Found new best {}".format(it))
                print("Epoch {}/{}".format(epoch + 1, params.num_epochs))
                print("- Found new best {}".format(it))
                best_dict[it] = curr_vals_dict[it]

                print("mean {} is {:05.3f}".format(it, best_dict[it]))

        metric_dict = curr_vals_dict

        # Save best val metrics in a json file in the model directory
        for it in is_best_dict.keys():
            if is_best_dict[it]:
                best_json_path = os.path.join(
                    stats_dir, "metrics_dev_best_{}_weights.json".format(it))
                best_csv_real_path = os.path.join(
                    stats_dir, "incorrect_real_best_{}_samples.csv".format(it))
                best_csv_fake_path = os.path.join(
                    stats_dir, "incorrect_fake_best_{}_samples.csv".format(it))

                utils.save_dict_to_json(best_dict, best_json_path, epoch + 1)
                # utils.save_dict_to_json(metric_dict, best_json_path, epoch + 1)
                utils.save_incorrect_to_csv(incorrect_samples[0],
                                            best_csv_real_path)
                utils.save_incorrect_to_csv(incorrect_samples[1],
                                            best_csv_fake_path)

            # if test_samples is not None:
            #     np_test_samples = np.array(test_samples)
            #     # convert back to range [0, 255]
            #     np_test_samples = \
            #         dest_min + (dest_max - dest_min) * (np_test_samples - curr_min) / (curr_max - curr_min)
            #     np_test_samples = np.around(np_test_samples).astype(int)
            #     np_test_out = (test_noise.cpu().numpy())
            #     np_test_labels = (test_labels.view(test_labels.shape[0], -1).cpu().numpy())
            #
            #     data_path = os.path.join(model_dir, 'data')
            #     if not os.path.isdir(data_path):
            #         os.mkdir(data_path)
            #
            #     test_all_data = (np.concatenate((np_test_samples, np_test_out, np_test_labels), axis=1)).tolist()
            #     last_csv_path = os.path.join(data_path, "best_samples_epoch_{}.csv".format(epoch + 1))
            #     utils.save_incorrect_to_csv(test_all_data, last_csv_path)

        if test_samples is not None:
            for it in is_best_dict.keys():
                if is_best_dict[it]:
                    best_type = it
                    utils.save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'state_dict': d_model.state_dict(),
                            'optim_dict': d_optimizer.state_dict()
                        },
                        is_best=is_best_dict[it],
                        checkpoint=stats_dir,
                        ntype='d',
                        best_type=best_type)

                    utils.save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'state_dict': g_model.state_dict(),
                            'optim_dict': g_optimizer.state_dict()
                        },
                        is_best=is_best_dict[it],
                        checkpoint=stats_dir,
                        ntype='g',
                        best_type=best_type)

            np_test_samples = np.array(test_samples)
            # convert back to range [0, 255]
            np_test_samples = \
                dest_min + (dest_max - dest_min) * (np_test_samples - curr_min) / (curr_max - curr_min)
            np_test_samples = np.around(np_test_samples).astype(int)
            np_test_out = (test_noise.cpu().numpy())
            np_test_labels = (test_labels.view(test_labels.shape[0],
                                               -1).cpu().numpy())

            data_path = os.path.join(model_dir, 'data')
            if not os.path.isdir(data_path):
                os.mkdir(data_path)

            test_all_data = (np.concatenate(
                (np_test_samples, np_test_out, np_test_labels),
                axis=1)).tolist()
            last_csv_path = os.path.join(
                data_path, "samples_epoch_{}.csv".format(epoch + 1))
            utils.save_incorrect_to_csv(test_all_data, last_csv_path)

    display_results.close_figure(fig)
    return
示例#4
0
def after_transfer_train_and_evaluate(model, train_dataloader, dev_dataloader,
                                      optimizer, loss_fn, metrics, incorrect,
                                      correct_fn, params, model_dir,
                                      model_out_dir, restore_file):

    best_dev_acc = 0.0

    fig = display_results.create_figure()

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model,
              optimizer,
              loss_fn,
              train_dataloader,
              metrics,
              params,
              epoch,
              fig,
              model_out_dir,
              losses,
              grayscale=True)

        # Evaluate for one epoch on validation set
        dev_metrics, incorrect_samples, correct_samples = evaluate(
            model, loss_fn, dev_dataloader, metrics, incorrect, correct_fn,
            params, epoch)

        dev_acc = dev_metrics['accuracy_two_labels']
        is_best = dev_acc >= best_dev_acc

        grads_graph, _ = get_network_grads(model)
        vals_graph = collect_network_statistics(model)

        grads_per_epoch.append(grads_graph)
        vals_per_epoch.append(vals_graph)

        # Save weights
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            is_best=is_best,
            checkpoint=model_out_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            print("Epoch {}/{}".format(epoch + 1, params.num_epochs))
            print("- Found new best accuracy")
            best_dev_acc = dev_acc
            print("accuracy is {:05.3f}".format(best_dev_acc))
            print("loss is {:05.3f}".format(dev_metrics['loss']))

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_out_dir,
                                          "metrics_dev_best_weights.json")
            utils.save_dict_to_json(dev_metrics, best_json_path, epoch + 1)

            best_csv_path = os.path.join(model_out_dir,
                                         "incorrect_best_samples.csv")
            utils.save_incorrect_to_csv(incorrect_samples, best_csv_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_out_dir,
                                      "metrics_dev_last_weights.json")
        utils.save_dict_to_json(dev_metrics, last_json_path, epoch + 1)

        last_csv_path = os.path.join(model_out_dir,
                                     "incorrect_last_samples.csv")
        utils.save_incorrect_to_csv(incorrect_samples, last_csv_path)

        accuracy.append(dev_acc)

    display_results.close_figure(fig)

    return
示例#5
0
def train_g(g_model, train_dataloader, g_optimizer, mse_loss_fn, params, model_dir):

    best_loss = np.inf
    dest_min = 0
    dest_max = 255
    curr_min = -1
    curr_max = 1

    fig = display_results.create_figure()

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        test_samples, real_samples, loss_mean = train(g_model, g_optimizer, mse_loss_fn, train_dataloader, params, epoch, fig)

        is_best = loss_mean <= best_loss

        if is_best:
            logging.info("- Found new best loss")
            print("Epoch {}/{}".format(epoch + 1, params.num_epochs))
            print("- Found new best loss")
            best_loss = loss_mean
            print("mean loss is {:05.3f}".format(loss_mean))
            loss_metric_dict = {'loss': loss_mean}

            utils.save_checkpoint({'epoch': epoch + 1,
                                   'state_dict': g_model.state_dict(),
                                   'optim_dict': g_optimizer.state_dict()}, is_best=is_best, checkpoint=model_dir)

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "metrics_min_avg_loss_best_weights.json")
            utils.save_dict_to_json(loss_metric_dict, best_json_path)

            best_g_grads_graph = collect_network_statistics(g_model)
            display_results.plot_graph(best_g_grads_graph, [], "Grads_Best", args.model_dir, epoch=epoch+1)

            if test_samples is not None:

                np_test_samples = np.array(test_samples)
                np_test_samples = \
                    dest_min + (dest_max - dest_min) * (np_test_samples - curr_min) / (curr_max - curr_min)
                np_test_samples = np.around(np_test_samples).astype(int)
                np_test_out = (test_noise.cpu().numpy())
                np_test_labels = (test_labels.view(test_labels.shape[0], -1).cpu().numpy())

                data_path = os.path.join(model_dir, 'data')
                if not os.path.isdir(data_path):
                    os.mkdir(data_path)

                test_all_data = (np.concatenate((np_test_samples, np_test_out, np_test_labels), axis=1)).tolist()
                last_csv_path = os.path.join(data_path, "best_samples_epoch_{}.csv".format(epoch + 1))
                utils.save_incorrect_to_csv(test_all_data, last_csv_path)

        if test_samples is not None:

            utils.save_checkpoint({'epoch': epoch + 1,
                                   'state_dict': g_model.state_dict(),
                                   'optim_dict': g_optimizer.state_dict()}, is_best=False, checkpoint=model_dir,
                                  ntype='g')

            np_test_samples = np.array(test_samples)
            np_test_samples = \
                dest_min + (dest_max - dest_min) * (np_test_samples - curr_min) / (curr_max - curr_min)
            np_test_samples = np.around(np_test_samples).astype(int)
            np_test_out = (test_noise.cpu().numpy())
            np_test_labels = (test_labels.view(test_labels.shape[0], -1).cpu().numpy())

            data_path = os.path.join(model_dir, 'data')
            if not os.path.isdir(data_path):
                os.mkdir(data_path)

            test_all_data = (np.concatenate((np_test_samples, np_test_out, np_test_labels), axis=1)).tolist()
            last_csv_path = os.path.join(data_path, "samples_epoch_{}.csv".format(epoch+1))
            utils.save_incorrect_to_csv(test_all_data, last_csv_path)

    display_results.close_figure(fig)
    return
def train_and_evaluate(model, train_dataloader, dev_dataloader, optimizer, loss_fn, metrics, incorrect, params, model_dir,
                       restore_file=None):
    """Train the model and evaluate every epoch.

    Args:
        model: (torch.nn.Module) the neural network
        train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        val_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches validation data
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        incorrect: a function that save all samples with incorrect classification
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_dev_acc = 0.0

    for epoch in range(params.num_epochs):
        # Run one epoch
        # if (epoch) % params.save_summary_steps == 0:
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics, params)

        # Evaluate for one epoch on validation set
        dev_metrics, incorrect_samples = evaluate(model, loss_fn, dev_dataloader, metrics, incorrect, params)

        dev_acc = dev_metrics['accuracy']
        is_best = dev_acc >= best_dev_acc

        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict' : optimizer.state_dict()},
                               is_best=is_best,
                               checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            best_dev_acc = dev_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "metrics_dev_best_weights.json")
            utils.save_dict_to_json(dev_metrics, best_json_path)

            best_csv_path = os.path.join(model_dir, "incorrect_best_samples.csv")
            utils.save_incorrect_to_csv(incorrect_samples, best_csv_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir, "metrics_dev_last_weights.json")
        utils.save_dict_to_json(dev_metrics, last_json_path)

        last_csv_path = os.path.join(model_dir, "incorrect_last_samples.csv")
        utils.save_incorrect_to_csv(incorrect_samples, last_csv_path)
示例#7
0
def train_gan(d_model, g_model, train_dataloader, dev_dataloader, d_optimizer, g_optimizer, loss_fn, params, model_dir,
                       restore_file=None):

    best_loss = np.inf
    dest_min = 0
    dest_max = 255
    curr_min = -1
    curr_max = 1

    fig = display_results.create_figure()

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        test_samples, loss_mean_sum = \
            train(d_model, g_model, d_optimizer, g_optimizer, loss_fn, train_dataloader, params, epoch, fig)

        is_best = loss_mean_sum <= best_loss

        g_grads_graph, _ = get_network_grads(g_model)
        d_grads_graph, _ = get_network_grads(d_model)
        g_vals_graph = collect_network_statistics(g_model)
        d_vals_graph = collect_network_statistics(d_model)

        grads_per_epoch_g.append(g_grads_graph)
        grads_per_epoch_d.append(d_grads_graph)
        vals_per_epoch_g.append(g_vals_graph)
        vals_per_epoch_d.append(d_vals_graph)

        if is_best:
            logging.info("- Found new best loss")
            print("Epoch {}/{}".format(epoch + 1, params.num_epochs))
            print("- Found new best loss")
            best_loss = loss_mean_sum
            print("mean loss is {:05.3f}".format(loss_mean_sum))
            loss_metric_dict = {'loss': loss_mean_sum}

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "metrics_dev_best_weights.json")
            utils.save_dict_to_json(loss_metric_dict, best_json_path, epoch + 1)

            if test_samples is not None:
                np_test_samples = np.array(test_samples)
                np_test_samples = \
                    dest_min + (dest_max - dest_min) * (np_test_samples - curr_min) / (curr_max - curr_min)
                np_test_samples = np.around(np_test_samples).astype(int)
                np_test_out = (test_noise.cpu().numpy())

                data_path = os.path.join(model_dir, 'data')
                if not os.path.isdir(data_path):
                    os.mkdir(data_path)

                test_all_data = (np.concatenate((np_test_samples, np_test_out), axis=1)).tolist()
                last_csv_path = os.path.join(data_path, "best_samples_epoch_{}.csv".format(epoch + 1))
                utils.save_incorrect_to_csv(test_all_data, last_csv_path)

        if test_samples is not None:
            utils.save_checkpoint({'epoch': epoch + 1,
                                   'state_dict': d_model.state_dict(),
                                   'optim_dict': d_optimizer.state_dict()}, is_best=is_best, checkpoint=model_dir,
                                  ntype='d')

            utils.save_checkpoint({'epoch': epoch + 1,
                                   'state_dict': g_model.state_dict(),
                                   'optim_dict': g_optimizer.state_dict()}, is_best=is_best, checkpoint=model_dir,
                                  ntype='g')

            np_test_samples = np.array(test_samples)
            np_test_samples = \
                dest_min + (dest_max - dest_min) * (np_test_samples - curr_min) / (curr_max - curr_min)
            np_test_samples = np.around(np_test_samples).astype(int)
            np_test_out = (test_noise.cpu().numpy())

            data_path = os.path.join(model_dir, 'data')
            if not os.path.isdir(data_path):
                os.mkdir(data_path)

            test_all_data = (np.concatenate((np_test_samples, np_test_out), axis=1)).tolist()
            last_csv_path = os.path.join(data_path, "samples_epoch_{}.csv".format(epoch + 1))
            utils.save_incorrect_to_csv(test_all_data, last_csv_path)

    display_results.close_figure(fig)
    return
示例#8
0
def train_vae(model, train_dataloader, optimizer, loss_fn, params, model_dir):
    """Train the model and evaluate every epoch.

    Args:
        model: (torch.nn.Module) the neural network
        train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
    """

    best_loss = math.inf

    fig = display_results.create_figure()

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        reconstructed_samples, loss_mean = train(model, optimizer, loss_fn,
                                                 train_dataloader, params,
                                                 epoch, fig)

        is_best = loss_mean <= best_loss

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            is_best=is_best,
            checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best loss")
            print("Epoch {}/{}".format(epoch + 1, params.num_epochs))
            print("- Found new best loss")
            best_loss = loss_mean
            print("mean loss is {:05.3f}".format(loss_mean))
            loss_metric_dict = {'loss': loss_mean}

            # utils.save_checkpoint({'epoch': epoch + 1,
            #                        'state_dict': model.state_dict(),
            #                        'optim_dict': optimizer.state_dict()}, is_best=is_best, checkpoint=model_dir)

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(
                model_dir, "metrics_min_avg_loss_best_weights.json")
            utils.save_dict_to_json(loss_metric_dict, best_json_path)

            # best_csv_path = os.path.join(model_dir, "reconstructed_min_avg_loss_best_samples.csv")
            # utils.save_incorrect_to_csv(reconstructed_samples, best_csv_path)

        if reconstructed_samples is not None:
            np_reconstructed_samples = np.array(reconstructed_samples)
            np_reconstructed_samples = np.around(np_reconstructed_samples *
                                                 255).astype(int)

            data_path = os.path.join(model_dir, 'data')
            if not os.path.isdir(data_path):
                os.mkdir(data_path)

            last_csv_path = os.path.join(
                data_path, "samples_epoch_{}.csv".format(epoch + 1))
            utils.save_incorrect_to_csv(np_reconstructed_samples,
                                        last_csv_path)

    display_results.close_figure(fig)
    model = net.NeuralNet(params).cuda() if params.cuda else net.NeuralNet(
        params)

    # changing last fully connected layer
    num_ftrs = model.fc4.in_features
    model.fc4 = nn.Linear(num_ftrs, 20)  # 10)

    model = model.to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr=params.learning_rate)

    loss_fn = net.loss_fn_two_labels
    metrics = net.metrics
    incorrect = net.incorrect_two_labels

    logging.info("Starting evaluation")

    # Reload weights from the saved file
    load_model(args.model_dir, args.restore_file)

    # Evaluate
    test_metrics, incorrect_samples = evaluate_after_transfer(
        model, loss_fn, test_dl, metrics, incorrect, params)
    save_path = os.path.join(args.model_dir,
                             "metrics_test_{}.json".format(args.restore_file))
    utils.save_dict_to_json(test_metrics, save_path)

    best_inc_csv_path = os.path.join(
        args.model_dir, "evaluate_" + "test" + "_incorrect_samples.csv")
    utils.save_incorrect_to_csv(incorrect_samples, best_inc_csv_path)