Exemplo n.º 1
0
    def __init__(self,
                 model,
                 N_train,
                 lr=1e-2,
                 momentum=0.5,
                 cuda=True,
                 schedule=None,
                 seed=None,
                 weight_decay=0,
                 MC_samples=10,
                 Adam=False):
        super(regression_baseline_net, self).__init__()

        cprint('y', 'DUN learnt with marginal likelihood categorical output')
        self.lr = lr
        self.momentum = momentum
        self.Adam = Adam
        self.weight_decay = weight_decay
        self.MC_samples = MC_samples
        self.model = model
        self.cuda = cuda
        self.seed = seed

        self.f_neg_loglike = homo_Gauss_mloglike(self.model.output_dim, None)
        self.f_neg_loglike_test = self.f_neg_loglike

        self.N_train = N_train
        self.create_net()
        self.create_opt()
        self.schedule = schedule  # [] #[50,200,400,600]
        if self.schedule is not None and len(self.schedule) > 0:
            self.make_scheduler(gamma=0.1, milestones=self.schedule)
        self.epoch = 0
Exemplo n.º 2
0
    def optimize(self):
        """
        gird: (l1_regularizer, l2_regularizer, log_lr)
        """
        # enumerate initial grids
        for log_lr in self.model_queue.log_lr_grid:
            model = self.generate_model(num_classes=self.params["n_classes"],
                                        backbone=self.backbone)

            cprint("[Grid Search]", "Training at log_lr: " + str(log_lr))

            # model = add_regularization(model=model, conv_reg=1e-6, dense_reg=1e-5)
            self.fit(model=model, log_lr=log_lr)
Exemplo n.º 3
0
    def __init__(self, model, N_train, lr=1e-2, cuda=True, schedule=None):
        super(MF_BNN_cat, self).__init__()

        cprint('y', 'MF BNN categorical output')
        self.lr = lr
        self.model = model
        self.cuda = cuda
        self.f_neg_loglike = F.cross_entropy  # TODO restructure declaration of this function

        self.N_train = N_train
        self.create_net()
        self.create_opt()
        self.schedule = schedule  # [] #[50,200,400,600]
        self.epoch = 0
Exemplo n.º 4
0
def get_dataset(dataset_path,
                model_path,
                batch_size=32,
                imgsize=260,
                val_split=0.2,
                debug=False):
    """
    :param debug: debug mode returns the first100 images
    """
    X_train, X_valid, y_train, y_valid, num_classes, class_weights = create_dataset(
        dataset_path, model_path, imgsize, val_split)
    # Debug mode

    if debug:
        X_train, X_valid, y_train, y_valid = X_train[:
                                                     100], X_valid[:
                                                                   100], y_train[:
                                                                                 100], y_valid[:
                                                                                               100]
    train_set = (X_train, y_train)
    valid_set = (X_valid, y_valid)

    cprint(
        "[INFO]",
        "Data generator built! train data size {}, valid data size {}, classes {} \n"
        .format(len(y_train), len(y_valid), num_classes))

    data_transform = transforms.Compose([
        # transforms.CenterCrop(imgsize),
        transforms.Resize((imgsize, imgsize)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    dataset_params = {"n_classes": num_classes, "class_weights": class_weights}

    gen_train = DataGenerator(train_set, data_transform, dataset_path,
                              num_classes, batch_size)
    gen_valid = DataGenerator(valid_set, data_transform, dataset_path,
                              num_classes, batch_size)
    return gen_train, gen_valid, dataset_params
Exemplo n.º 5
0
    def __init__(self,
                 model,
                 prob_model,
                 N_train,
                 lr=1e-2,
                 momentum=0.5,
                 weight_decay=0,
                 cuda=True,
                 schedule=None,
                 regression=False,
                 pred_sig=None):
        super(DUN, self).__init__()

        cprint('y', 'DUN learnt with marginal likelihood categorical output')
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.model = model
        self.prob_model = prob_model
        self.cuda = cuda
        self.regression = regression
        self.pred_sig = pred_sig
        if self.regression:
            self.f_neg_loglike = homo_Gauss_mloglike(self.model.output_dim,
                                                     self.pred_sig)
            self.f_neg_loglike_test = self.f_neg_loglike
        else:
            self.f_neg_loglike = nn.CrossEntropyLoss(
                reduction='none')  # This one takes logits
            self.f_neg_loglike_test = nn.NLLLoss(
                reduction='none')  # This one takes log probs

        self.N_train = N_train
        self.create_net()
        self.create_opt()
        self.schedule = schedule
        if self.schedule is not None and len(self.schedule) > 0:
            self.make_scheduler(gamma=0.1, milestones=self.schedule)
        self.epoch = 0
Exemplo n.º 6
0
def train_loop(net,
               dname,
               data_dir,
               epochs=90,
               workers=4,
               resume='',
               savedir='./',
               save_all_epochs=False,
               q_nograd_its=0,
               batch_size=256):
    mkdir(savedir)
    global best_err1

    # Load data here:
    _, train_loader, val_loader, _, _, Ntrain = \
        get_image_loader(dname, batch_size, cuda=True, workers=workers, distributed=False, data_dir=data_dir)

    net.N_train = Ntrain

    start_epoch = 0

    marginal_loglike = np.zeros(epochs)
    train_loss = np.zeros(epochs)
    dev_loss = np.zeros(epochs)

    err_train = np.zeros(epochs)
    err_dev = np.zeros(epochs)

    # optionally resume from a checkpoint
    if resume:
        if os.path.isfile(resume):
            print("=> loading checkpoint '{}'".format(resume))
            start_epoch, best_err1 = net.load(resume)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume, start_epoch))
        else:
            print("=> no checkpoint found at '{}'".format(resume))

        candidate_progress_file = resume.split('/')
        candidate_progress_file = '/'.join(
            candidate_progress_file[:-1]) + '/stats_array.pkl'

        if os.path.isfile(candidate_progress_file):
            print("=> found progress file at '{}'".format(
                candidate_progress_file))
            try:
                marginal_loglike, err_train, train_loss, err_dev, dev_loss = \
                    load_object(candidate_progress_file)
                print("=> Loaded progress file at '{}'".format(
                    candidate_progress_file))
            except Exception:
                print("=> Unable to load progress file at '{}'".format(
                    candidate_progress_file))
        else:
            print("=> NOT found progress file at '{}'".format(
                candidate_progress_file))

    if q_nograd_its > 0:
        net.prob_model.q_logits.requires_grad = False

    for epoch in range(start_epoch, epochs):
        if q_nograd_its > 0 and epoch == q_nograd_its:
            net.prob_model.q_logits.requires_grad = True

        tic = time.time()
        nb_samples = 0
        for x, y in train_loader:
            marg_loglike_estimate, minus_loglike, err = net.fit(x, y)

            marginal_loglike[epoch] += marg_loglike_estimate * x.shape[0]
            err_train[epoch] += err * x.shape[0]
            train_loss[epoch] += minus_loglike * x.shape[0]
            nb_samples += len(x)

        marginal_loglike[epoch] /= nb_samples
        train_loss[epoch] /= nb_samples
        err_train[epoch] /= nb_samples

        toc = time.time()

        # ---- print
        print('\n depth approx posterior',
              net.prob_model.current_posterior.data.cpu().numpy())
        print(
            "it %d/%d, ELBO/evidence %.4f, pred minus loglike = %f, err = %f" %
            (epoch, epochs, marginal_loglike[epoch], train_loss[epoch],
             err_train[epoch]),
            end="")
        cprint('r', '   time: %f seconds\n' % (toc - tic))

        net.update_lr()

        # ---- dev
        tic = time.time()
        nb_samples = 0
        for x, y in val_loader:
            minus_loglike, err = net.eval(x, y)

            dev_loss[epoch] += minus_loglike * x.shape[0]
            err_dev[epoch] += err * x.shape[0]
            nb_samples += len(x)

        dev_loss[epoch] /= nb_samples
        err_dev[epoch] /= nb_samples

        toc = time.time()

        cprint('g',
               '     pred minus loglike = %f, err = %f\n' %
               (dev_loss[epoch], err_dev[epoch]),
               end="")
        cprint('g', '    time: %f seconds\n' % (toc - tic))

        filename = 'checkpoint.pth.tar'
        if save_all_epochs:
            filename = str(epoch) + '_' + filename
        net.save(os.path.join(savedir, filename), best_err1)
        if err_dev[epoch] < best_err1:
            best_err1 = err_dev[epoch]
            cprint('b', 'best top1 dev err: %f' % err_dev[epoch])
            shutil.copyfile(os.path.join(savedir, filename),
                            os.path.join(savedir, 'model_best.pth.tar'))

        all_results = [
            marginal_loglike, err_train, train_loss, err_dev, dev_loss
        ]
        save_object(all_results, os.path.join(savedir, 'stats_array.pkl'))
Exemplo n.º 7
0
models_dir = args.models_dir
# Where to save plots and error, accuracy vectors
results_dir = args.results_dir

mkdir(models_dir)
mkdir(results_dir)
# ------------------------------------------------------------------------------------------------------
# train config
NTrainPointsMNIST = 60000
batch_size = 128
nb_epochs = args.epochs
log_interval = 1

# ------------------------------------------------------------------------------------------------------
# dataset
cprint('c', '\nData:')

# load data

# data augmentation
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
])

use_cuda = torch.cuda.is_available()
Exemplo n.º 8
0
def train_fc_baseline(net,
                      name,
                      save_dir,
                      batch_size,
                      nb_epochs,
                      trainloader,
                      valloader,
                      cuda,
                      seed,
                      flat_ims=False,
                      nb_its_dev=1,
                      early_stop=None,
                      track_posterior=False,
                      track_exact_ELBO=False,
                      tags=None,
                      load_path=None,
                      save_freq=None):

    rand_name = next(tempfile._get_candidate_names())
    basedir = os.path.join(save_dir, name, rand_name)

    media_dir = basedir + '/media/'
    models_dir = basedir + '/models/'
    mkdir(models_dir)
    mkdir(media_dir)

    if seed is not None:
        torch.manual_seed(seed)

    epoch = 0

    marginal_loglike_estimate = np.zeros(
        nb_epochs
    )  # we can use this to approximately track the true value by averaging batches
    train_mean_predictive_loglike = np.zeros(nb_epochs)
    dev_mean_predictive_loglike = np.zeros(nb_epochs)
    err_train = np.zeros(nb_epochs)
    err_dev = np.zeros(nb_epochs)

    true_d_posterior = []
    approx_d_posterior = []
    true_likelihood = []
    exact_ELBO = []

    best_epoch = 0
    best_marginal_loglike = -np.inf
    # best_dev_err = -np.inf
    # best_dev_ll = -np.inf

    tic0 = time.time()
    for i in range(epoch, nb_epochs):
        net.set_mode_train(True)
        tic = time.time()
        nb_samples = 0
        for x, y in trainloader:
            if flat_ims:
                x = x.view(x.shape[0], -1)

            marg_loglike_estimate, minus_loglike, err = net.fit(x, y)

            marginal_loglike_estimate[i] += marg_loglike_estimate * x.shape[0]
            err_train[i] += err * x.shape[0]
            train_mean_predictive_loglike[i] += minus_loglike * x.shape[0]
            nb_samples += len(x)

        marginal_loglike_estimate[i] /= nb_samples
        train_mean_predictive_loglike[i] /= nb_samples
        err_train[i] /= nb_samples

        toc = time.time()

        # print('\n depth approx posterior', net.prob_model.current_posterior.data.cpu().numpy())
        print(
            "it %d/%d, ELBO/evidence %.4f, pred minus loglike = %f, err = %f" %
            (i, nb_epochs, marginal_loglike_estimate[i],
             train_mean_predictive_loglike[i], err_train[i]),
            end="")

        cprint('r', '   time: %f seconds\n' % (toc - tic))
        net.update_lr()

        if i % nb_its_dev == 0:
            tic = time.time()
            nb_samples = 0
            for x, y in valloader:
                if flat_ims:
                    x = x.view(x.shape[0], -1)

                minus_loglike, err = net.eval(x, y)

                dev_mean_predictive_loglike[i] += minus_loglike * x.shape[0]
                err_dev[i] += err * x.shape[0]
                nb_samples += len(x)

            dev_mean_predictive_loglike[i] /= nb_samples
            err_dev[i] /= nb_samples
            toc = time.time()

            cprint('g',
                   '     pred minus loglike = %f, err = %f\n' %
                   (dev_mean_predictive_loglike[i], err_dev[i]),
                   end="")
            cprint('g', '    time: %f seconds\n' % (toc - tic))

        if save_freq is not None and i % save_freq == 0:
            net.save(models_dir + '/theta_last.dat')

        if marginal_loglike_estimate[i] > best_marginal_loglike:
            best_marginal_loglike = marginal_loglike_estimate[i]

            # best_dev_ll = dev_mean_predictive_loglike[i]
            # best_dev_err = err_dev[i]
            best_epoch = i
            cprint('b', 'best marginal loglike: %f' % best_marginal_loglike)
            if i % 2 == 0:
                net.save(models_dir + '/theta_best.dat')

        if early_stop is not None and (i - best_epoch) > early_stop:
            cprint('r', '   stopped early!\n')
            break

    toc0 = time.time()
    runtime_per_it = (toc0 - tic0) / float(i + 1)
    cprint('r', '   average time: %f seconds\n' % runtime_per_it)

    # fig cost vs its
    if track_posterior:
        approx_d_posterior = np.stack(approx_d_posterior, axis=0)
        true_d_posterior = np.stack(true_d_posterior, axis=0)
        true_likelihood = np.stack(true_likelihood, axis=0)
    if track_exact_ELBO:
        exact_ELBO = np.stack(exact_ELBO, axis=0)

    return marginal_loglike_estimate, train_mean_predictive_loglike, dev_mean_predictive_loglike, err_train, err_dev,\
    approx_d_posterior, true_d_posterior, true_likelihood, exact_ELBO, basedir
Exemplo n.º 9
0
def train_VI_classification(net,
                            name,
                            save_dir,
                            batch_size,
                            nb_epochs,
                            trainset,
                            valset,
                            cuda,
                            flat_ims=False,
                            nb_its_dev=1,
                            early_stop=None,
                            load_path=None,
                            save_freq=20,
                            stop_criteria='test_ELBO',
                            tags=None,
                            show=False):
    exp = Experiment(name=name, debug=False, save_dir=save_dir, autosave=True)

    if load_path is not None:
        net.load(load_path)

    exp_version = exp.version

    media_dir = exp.get_media_path(name, exp_version)
    models_dir = exp.get_data_path(name, exp_version) + '/models'
    mkdir(models_dir)

    exp.tag({
        'n_layers': net.model.n_layers,
        'batch_size': batch_size,
        'init_lr': net.lr,
        'lr_schedule': net.schedule,
        'nb_epochs': nb_epochs,
        'early_stop': early_stop,
        'stop_criteria': stop_criteria,
        'nb_its_dev': nb_its_dev,
        'model_loaded': load_path,
        'cuda': cuda,
    })

    if net.model.__class__.__name__ == 'arq_uncert_conv2d_resnet':
        exp.tag({
            'outer_width': net.model.outer_width,
            'inner_width': net.model.inner_width
        })
    else:
        exp.tag({'width': net.model.width})

    exp.tag({
        'prob_model': net.model.prob_model.name,
        'prob_model_summary': net.model.prob_model.summary
    })
    if tags is not None:
        exp.tag(tags)

    if cuda:
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  pin_memory=True,
                                                  num_workers=3)
        valloader = torch.utils.data.DataLoader(valset,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                pin_memory=True,
                                                num_workers=3)

    else:
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  pin_memory=False,
                                                  num_workers=3)
        valloader = torch.utils.data.DataLoader(valset,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                pin_memory=False,
                                                num_workers=3)
    ## ---------------------------------------------------------------------------------------------------------------------
    # net dims
    cprint('c', '\nNetwork:')
    epoch = 0
    ## ---------------------------------------------------------------------------------------------------------------------
    # train
    cprint('c', '\nTrain:')

    print('  init cost variables:')
    mloglike_train = np.zeros(nb_epochs)
    KL_train = np.zeros(nb_epochs)
    ELBO_train = np.zeros(nb_epochs)
    ELBO_test = np.zeros(nb_epochs)
    err_train = np.zeros(nb_epochs)
    mloglike_dev = np.zeros(nb_epochs)
    err_dev = np.zeros(nb_epochs)
    best_epoch = 0
    best_train_ELBO = -np.inf
    best_test_ELBO = -np.inf
    best_dev_ll = -np.inf

    tic0 = time.time()
    for i in range(epoch, nb_epochs):
        net.set_mode_train(True)
        tic = time.time()
        nb_samples = 0
        for x, y in trainloader:

            if flat_ims:
                x = x.view(x.shape[0], -1)

            KL, minus_loglike, err = net.fit(x, y)
            err_train[i] += err
            mloglike_train[i] += minus_loglike / len(trainloader)
            KL_train[i] += KL / len(trainloader)
            nb_samples += len(x)

        # mloglike_train[i] *= nb_samples
        # KL_train[i] *= nb_samples
        ELBO_train[i] = (-KL_train[i] - mloglike_train[i]) * nb_samples
        err_train[i] /= nb_samples

        toc = time.time()

        # ---- print
        print("it %d/%d, sample minus loglike = %f, sample KL = %.10f, err = %f, ELBO = %f" % \
              (i, nb_epochs, mloglike_train[i], KL_train[i], err_train[i], ELBO_train[i]), end="")
        exp.log({
            'epoch': i,
            'MLL': mloglike_train[i],
            'KLD': KL_train[i],
            'err': err_train[i],
            'ELBO': ELBO_train[i]
        })
        cprint('r', '   time: %f seconds\n' % (toc - tic))
        net.update_lr(i, 0.1)

        # ---- dev
        if i % nb_its_dev == 0:
            tic = time.time()
            nb_samples = 0
            for j, (x, y) in enumerate(valloader):
                if flat_ims:
                    x = x.view(x.shape[0], -1)

                minus_loglike, err = net.eval(x, y)

                mloglike_dev[i] += minus_loglike / len(valloader)
                err_dev[i] += err
                nb_samples += len(x)

            ELBO_test[i] = (-KL_train[i] - mloglike_dev[i]) * nb_samples

            ELBO_test[i] = (-KL_train[i] - mloglike_dev[i]) * nb_samples
            err_dev[i] /= nb_samples
            toc = time.time()

            cprint('g',
                   '    sample minus loglike = %f, err = %f, ELBO = %f\n' %
                   (mloglike_dev[i], err_dev[i], ELBO_test[i]),
                   end="")
            cprint(
                'g',
                '    (prev best it = %i, sample minus loglike = %f, ELBO = %f)\n'
                % (best_epoch, best_dev_ll, best_test_ELBO),
                end="")
            cprint('g', '    time: %f seconds\n' % (toc - tic))
            exp.log({
                'epoch': i,
                'MLL_val': mloglike_dev[i],
                'err_val': err_dev[i],
                'ELBO_val': ELBO_test[i]
            })

            if stop_criteria == 'test_LL' and -mloglike_dev[i] > best_dev_ll:
                best_dev_ll = -mloglike_dev[i]
                best_epoch = i
                cprint('b', 'best test loglike: %d' % best_dev_ll)
                net.save(models_dir + '/theta_best.dat')
                probs = net.model.prob_model.get_q_probs().data.cpu().numpy()
                cuttoff = np.max(probs) * 0.95
                exp.tag({
                    "q_vec":
                    net.model.get_q_vector().cpu().detach().numpy(),
                    "q_probs":
                    net.model.prob_model.get_q_probs().cpu().detach().numpy(),
                    "expected_depth":
                    np.sum(probs * np.arange(net.model.n_layers + 1)),
                    "95th_depth":
                    np.argmax(probs > cuttoff),
                    "best_epoch":
                    best_epoch,
                    "best_dev_ll":
                    best_dev_ll
                })

            if stop_criteria == 'test_ELBO' and ELBO_test[i] > best_test_ELBO:
                best_test_ELBO = ELBO_test[i]
                best_epoch = i
                cprint('b', 'best test ELBO: %d' % best_test_ELBO)
                net.save(models_dir + '/theta_best.dat')
                probs = net.model.prob_model.get_q_probs().data.cpu().numpy()
                cuttoff = np.max(probs) * 0.95
                exp.tag({
                    "q_vec":
                    net.model.get_q_vector().cpu().detach().numpy(),
                    "q_probs":
                    net.model.prob_model.get_q_probs().cpu().detach().numpy(),
                    "expected_depth":
                    np.sum(probs * np.arange(net.model.n_layers + 1)),
                    "95th_depth":
                    np.argmax(probs > cuttoff),
                    "best_epoch":
                    best_epoch,
                    "best_test_ELBO":
                    best_test_ELBO
                })

        if stop_criteria == 'train_ELBO' and ELBO_train[i] > best_train_ELBO:
            best_train_ELBO = ELBO_train[i]
            best_epoch = i
            cprint('b', 'best train ELBO: %d' % best_train_ELBO)
            net.save(models_dir + '/theta_best.dat')
            probs = net.model.prob_model.get_q_probs().data.cpu().numpy()
            cuttoff = np.max(probs) * 0.95
            exp.tag({
                "q_vec":
                net.model.get_q_vector().cpu().detach().numpy(),
                "q_probs":
                net.model.prob_model.get_q_probs().cpu().detach().numpy(),
                "expected_depth":
                np.sum(probs * np.arange(net.model.n_layers + 1)),
                "95th_depth":
                np.argmax(probs > cuttoff),
                "best_epoch":
                best_epoch,
                "best_train_ELBO":
                best_train_ELBO
            })

        if save_freq is not None and i % save_freq == 0:
            exp.tag({
                "final_q_vec":
                net.model.get_q_vector().cpu().detach().numpy(),
                "final_q_probs":
                net.model.prob_model.get_q_probs().cpu().detach().numpy(),
                "final_expected_depth":
                np.sum(net.model.prob_model.get_q_probs().data.cpu().numpy() *
                       np.arange(net.model.n_layers + 1))
            })
            net.save(models_dir + '/theta_last.dat')

        if early_stop is not None and (i - best_epoch) > early_stop:
            exp.tag({"early_stop_epoch": i})
            cprint('r', '   stopped early!\n')
            break

    toc0 = time.time()
    runtime_per_it = (toc0 - tic0) / float(i + 1)
    cprint('r', '   average time: %f seconds\n' % runtime_per_it)

    ## ---------------------------------------------------------------------------------------------------------------------
    # fig cost vs its
    textsize = 15
    marker = 5

    plt.figure(dpi=100)
    fig, ax1 = plt.subplots()
    ax1.plot(range(0, i, nb_its_dev),
             np.clip(mloglike_dev[:i:nb_its_dev], a_min=-5, a_max=5), 'b-')
    ax1.plot(np.clip(mloglike_train[:i], a_min=-5, a_max=5), 'r--')
    ax1.set_ylabel('Cross Entropy')
    plt.xlabel('epoch')
    plt.grid(b=True, which='major', color='k', linestyle='-')
    plt.grid(b=True, which='minor', color='k', linestyle='--')
    lgd = plt.legend(['test', 'train'],
                     markerscale=marker,
                     prop={
                         'size': textsize,
                         'weight': 'normal'
                     })
    ax = plt.gca()
    plt.title('classification costs')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(textsize)
        item.set_weight('normal')
    plt.savefig(media_dir + '/cost.png',
                bbox_extra_artists=(lgd, ),
                bbox_inches='tight')
    if show:
        plt.show()

    plt.figure(dpi=100)
    fig, ax1 = plt.subplots()
    ax1.plot(range(0, i), KL_train[:i], 'b-')
    ax1.set_ylabel('KL')
    plt.xlabel('epoch')
    plt.grid(b=True, which='major', color='k', linestyle='-')
    plt.grid(b=True, which='minor', color='k', linestyle='--')
    lgd = plt.legend(['KL'],
                     markerscale=marker,
                     prop={
                         'size': textsize,
                         'weight': 'normal'
                     })
    ax = plt.gca()
    plt.title('KL divideed by number of samples')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(textsize)
        item.set_weight('normal')
    plt.savefig(media_dir + '/KL.png',
                bbox_extra_artists=(lgd, ),
                bbox_inches='tight')
    if show:
        plt.show()

    plt.figure(dpi=100)
    fig, ax1 = plt.subplots()
    ax1.plot(range(0, i), ELBO_train[:i], 'b-')
    ax1.set_ylabel('nats')
    plt.xlabel('epoch')
    plt.grid(b=True, which='major', color='k', linestyle='-')
    plt.grid(b=True, which='minor', color='k', linestyle='--')
    lgd = plt.legend(['ELBO'],
                     markerscale=marker,
                     prop={
                         'size': textsize,
                         'weight': 'normal'
                     })
    ax = plt.gca()
    plt.title('ELBO')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(textsize)
        item.set_weight('normal')
    plt.savefig(media_dir + '/ELBO.png',
                bbox_extra_artists=(lgd, ),
                bbox_inches='tight')
    if show:
        plt.show()

    plt.figure(dpi=100)
    fig, ax2 = plt.subplots()
    ax2.set_ylabel('% error')
    ax2.semilogy(range(0, i, nb_its_dev), err_dev[:i:nb_its_dev], 'b-')
    ax2.semilogy(err_train[:i], 'r--')
    ax2.set_ylim(top=1, bottom=1e-3)
    plt.xlabel('epoch')
    plt.grid(b=True, which='major', color='k', linestyle='-')
    plt.grid(b=True, which='minor', color='k', linestyle='--')
    ax2.get_yaxis().set_minor_formatter(matplotlib.ticker.ScalarFormatter())
    ax2.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    lgd = plt.legend(['test error', 'train error'],
                     markerscale=marker,
                     prop={
                         'size': textsize,
                         'weight': 'normal'
                     })
    ax = plt.gca()
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(textsize)
        item.set_weight('normal')
    plt.savefig(media_dir + '/err.png',
                bbox_extra_artists=(lgd, ),
                box_inches='tight')
    if show:
        plt.show()

    return exp, mloglike_train, KL_train, ELBO_train, err_train, mloglike_dev, err_dev