def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure =  settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist()
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn()
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]
    # pick the selected classes
    if classes is not None:
        train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
                num_var=train_x.shape[1],
                num_dims=3 if svhn else 1,
                num_classes=1,
                num_sums=K,
                num_input_distributions=K,
                exponential_family=exponential_family,
                exponential_family_args=exponential_family_args,
                use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    ##################################
    # Evalueate after initialization #
    ##################################

    train_lls = []
    valid_lls = []
    test_lls = []
    train_accs = []
    valid_accs = []
    test_accs = []

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]
    mixture.eval()
    train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size, skip_reparam=True)
    valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size, skip_reparam=True)
    test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size, skip_reparam=True)
    print()
    print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
            train_ll_before / train_N,
            valid_ll_before / valid_N,
            test_ll_before / test_N))
    train_lls.append(train_ll_before / train_N)
    valid_lls.append(valid_ll_before / valid_N)
    test_lls.append(test_ll_before / test_N)
    
    ################
    # Experiment 4 #
    ################
    train_labelsz = torch.tensor(train_labels).to(torch.device(device))
    valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
    test_labelsz = torch.tensor(test_labels).to(torch.device(device))

    acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size, skip_reparam=True)
    acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size, skip_reparam=True)
    acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size, skip_reparam=True)
    print()
    print("Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
            acc_train_before,
            acc_valid_before,
            acc_test_before))
    train_accs.append(acc_train_before)
    valid_accs.append(acc_valid_before)
    test_accs.append(acc_test_before)
    mixture.train()

    ##################
    # Training phase #
    ##################

    """ Learning each sub Network Generatively """

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)

    start_time = time.time()

    end_time = time.time()

    for epoch_count in range(num_epochs):
        for (einet, c) in zip(einets, classes):
            train_x_c = train_x[[l == c for l in train_labels]]
            valid_x_c = valid_x[[l == c for l in valid_labels]]
            test_x_c = test_x[[l == c for l in test_labels]]

            train_N = train_x_c.shape[0]
            valid_N = valid_x_c.shape[0]
            test_N = test_x_c.shape[0]

            idx_batches = torch.randperm(train_N, device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()
            print(f'[{epoch_count}]   total loss: {total_loss}')

        mixture.eval()
        train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        train_ll = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size)
        valid_ll = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size)
        test_ll = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size)
        train_lls.append(train_ll / train_N)
        valid_lls.append(valid_ll / valid_N)
        test_lls.append(test_ll / test_N)

        train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        acc_train = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size)
        acc_test = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size)
        train_accs.append(acc_train)
        valid_accs.append(acc_valid)
        test_accs.append(acc_test)
        mixture.train()

    print()
    print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
            train_ll / train_N,
            valid_ll / valid_N,
            test_ll / test_N))

    print()
    print("Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
            acc_train,
            acc_valid,
            acc_test))

    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_lls': train_lls,
        'valid_lls': valid_lls,
        'test_lls': test_lls,
        'train_accs': train_accs,
        'valid_accs': valid_accs,
        'test_accs': test_accs,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
Beispiel #2
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions_mixture = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    online_x = train_x[-11000:-10000, :]
    train_x = train_x[-56000:-11000, :]
    # init_x = train_x[-13000:-10000, :]

    valid_labels = train_labels[-10000:]
    online_labels = train_labels[-11000:-10000]
    train_labels = train_labels[-56000:-11000]
    # init_labels = train_labels[-13000:-10000]

    # full set of training
    # valid_x = train_x[-10000:, :]
    # online_x = train_x[-11000:-10000, :]
    # train_x = train_x[:-10000, :]

    # valid_labels = train_labels[-10000:]
    # online_labels = train_labels[-11000:-10000]
    # train_labels = train_labels[:-10000]

    # print('train_x:')
    # print(train_x.shape)
    # print(train_labels.shape)
    # print('online_x:')
    # print(online_x.shape)
    # print(online_labels.shape)
    # exit()

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        online_x = online_x[
            np.any(np.stack([online_labels == c for c in classes], 1), 1), :]
        # init_x = init_x[np.any(np.stack([init_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        train_labels_backup = train_labels
        online_labels = [l for l in online_labels if l in classes]
        # init_labels = [l for l in init_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        train_labels_backup = train_labels
        online_labels = [l for l in online_labels if l in classes]
        # init_labels = [l for l in init_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    train_x_backup = train_x
    online_x = torch.from_numpy(online_x).to(torch.device(device))
    # init_x = torch.from_numpy(init_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(
                num_var=train_x.shape[1],
                depth=depth,
                num_repetitions=num_repetitions_mixture)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=3 if svhn else 1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        # init_dict = get_init_dict(einet, init_x, train_labels=init_labels, einet_class=c)
        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # for (einet, c) in zip(einets, classes):
    #     data_file = os.path.join(data_dir, f"weights_before_{c}.json")
    #     weights = einet.einet_layers[-1].params.data.cpu()
    #     np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0])

    ##################
    # Training phase #
    ##################

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)

    start_time = time.time()
    """ Learning each sub Network Generatively """
    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]

        train_N = train_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()

            print(f'[{epoch_count}]   total log-likelihood: {total_loss}')

    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # for (einet, c) in zip(einets, classes):
    #     data_file = os.path.join(data_dir, f"weights_after_{c}.json")
    #     weights = einet.einet_layers[-1].params.data.cpu()
    #     np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0])
    # exit()

    ##################################
    # Evalueate after initialization #
    ##################################

    train_lls = []
    valid_lls = []
    test_lls = []
    train_accs = []
    valid_accs = []
    test_accs = []
    train_lls_ref = []
    valid_lls_ref = []
    test_lls_ref = []
    train_accs_ref = []
    valid_accs_ref = []
    test_accs_ref = []
    added_samples = [0]

    def eval_network(do_print=False, no_OA=False):
        if no_OA:
            train_N = train_x_backup.shape[0]
        else:
            train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        mixture.eval()
        if no_OA:
            train_ll_before = mixture.eval_loglikelihood_batched(
                train_x_backup, batch_size=batch_size)
        else:
            train_ll_before = mixture.eval_loglikelihood_batched(
                train_x, batch_size=batch_size)
        valid_ll_before = mixture.eval_loglikelihood_batched(
            valid_x, batch_size=batch_size)
        test_ll_before = mixture.eval_loglikelihood_batched(
            test_x, batch_size=batch_size)
        if do_print:
            print()
            print(
                "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
                .format(train_ll_before / train_N, valid_ll_before / valid_N,
                        test_ll_before / test_N))
        if no_OA:
            train_lls_ref.append(train_ll_before / train_N)
            valid_lls_ref.append(valid_ll_before / valid_N)
            test_lls_ref.append(test_ll_before / test_N)
        else:
            train_lls.append(train_ll_before / train_N)
            valid_lls.append(valid_ll_before / valid_N)
            test_lls.append(test_ll_before / test_N)

        ################
        # Experiment 4 #
        ################
        if no_OA:
            train_labelsz = torch.tensor(train_labels_backup).to(
                torch.device(device))
        else:
            train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        if no_OA:
            acc_train_before = mixture.eval_accuracy_batched(
                classes, train_x_backup, train_labelsz, batch_size=batch_size)
        else:
            acc_train_before = mixture.eval_accuracy_batched(
                classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid_before = mixture.eval_accuracy_batched(classes,
                                                         valid_x,
                                                         valid_labelsz,
                                                         batch_size=batch_size)
        acc_test_before = mixture.eval_accuracy_batched(classes,
                                                        test_x,
                                                        test_labelsz,
                                                        batch_size=batch_size)
        if do_print:
            print()
            print(
                "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
                .format(acc_train_before, acc_valid_before, acc_test_before))
        if no_OA:
            train_accs_ref.append(acc_train_before)
            valid_accs_ref.append(acc_valid_before)
            test_accs_ref.append(acc_test_before)
        else:
            train_accs.append(acc_train_before)
            valid_accs.append(acc_valid_before)
            test_accs.append(acc_test_before)
        mixture.train()

    eval_network(do_print=True, no_OA=False)
    eval_network(do_print=False, no_OA=True)

    #####################################################
    # Evaluate the network with different training sets #
    #####################################################

    idx_batches = torch.randperm(online_x.shape[0], device=device).split(20)

    for idx in tqdm(idx_batches):
        online_x_idx = online_x[idx]
        online_labels_idx = [online_labels[i] for i in idx]

        for (einet, c) in zip(einets, classes):
            batch_x = online_x_idx[[l == c for l in online_labels_idx]]
            train_x_backup = torch.cat((train_x_backup, batch_x))
            train_labels_backup += [c for i in batch_x]

        added_samples.append(added_samples[-1] + len(idx))
        eval_network(do_print=False, no_OA=True)

    #####################
    # Online adaptation #
    #####################

    for idx in tqdm(idx_batches):
        online_x_idx = online_x[idx]
        online_labels_idx = [online_labels[i] for i in idx]

        for (einet, c) in zip(einets, classes):
            batch_x = online_x_idx[[l == c for l in online_labels_idx]]
            online_update(einet, batch_x)
            train_x = torch.cat((train_x, batch_x))
            train_labels += [c for i in batch_x]

        eval_network(do_print=False, no_OA=False)

    print()
    print(f'Network size: {num_params} parameters')

    return {
        'train_lls': train_lls,
        'valid_lls': valid_lls,
        'test_lls': test_lls,
        'train_accs': train_accs,
        'valid_accs': valid_accs,
        'test_accs': test_accs,
        'train_lls_ref': train_lls_ref,
        'valid_lls_ref': valid_lls_ref,
        'test_lls_ref': test_lls_ref,
        'train_accs_ref': train_accs_ref,
        'valid_accs_ref': valid_accs_ref,
        'test_accs_ref': test_accs_ref,
        'network_size': num_params,
        'online_samples': added_samples,
    }
Beispiel #3
0
    def new_start(start_train_set, online_offset):
        ############################################################################
        fashion_mnist = settings.fashion_mnist
        svhn = settings.svhn

        exponential_family = settings.exponential_family

        classes = settings.classes

        K = settings.K

        structure =  settings.structure

        # 'poon-domingos'
        pd_num_pieces = settings.pd_num_pieces

        # 'binary-trees'
        depth = settings.depth
        num_repetitions = settings.num_repetitions_mixture

        width = settings.width
        height = settings.height

        num_epochs = settings.num_epochs
        batch_size = settings.batch_size
        online_em_frequency = settings.online_em_frequency
        online_em_stepsize = settings.online_em_stepsize

        ############################################################################

        exponential_family_args = None
        if exponential_family == EinsumNetwork.BinomialArray:
            exponential_family_args = {'N': 255}
        if exponential_family == EinsumNetwork.CategoricalArray:
            exponential_family_args = {'K': 256}
        if exponential_family == EinsumNetwork.NormalArray:
            exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

        # get data
        if fashion_mnist:
            train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist()
        elif svhn:
            train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn()
        else:
            train_x, train_labels, test_x, test_labels = datasets.load_mnist()

        if not exponential_family != EinsumNetwork.NormalArray:
            train_x /= 255.
            test_x /= 255.
            train_x -= .5
            test_x -= .5

        # validation split
        valid_x = train_x[-10000:, :]
        # online_x = train_x[-40000:, :]
        train_x = train_x[:-(10000+online_offset-start_train_set), :]
        valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-40000:]
        train_labels = train_labels[:-(10000+online_offset-start_train_set)]
        
        # # debug setup
        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-45000:, :]
        # train_x = train_x[:-55000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-45000:]
        # train_labels = train_labels[:-55000]

        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-10000:, :]
        # train_x = train_x[:-20000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-10000:]
        # train_labels = train_labels[:-20000]

        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-20000:, :]
        # train_x = train_x[:-30000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-20000:]
        # train_labels = train_labels[:-30000]

        # pick the selected classes
        if classes is not None:
            train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
            # online_x = online_x[np.any(np.stack([online_labels == c for c in classes], 1), 1), :]
            valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
            test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

            train_labels = [l for l in train_labels if l in classes]
            # online_labels = [l for l in online_labels if l in classes]
            valid_labels = [l for l in valid_labels if l in classes]
            test_labels = [l for l in test_labels if l in classes]
        else:
            classes = np.unique(train_labels).tolist()

            train_labels = [l for l in train_labels if l in classes]
            # online_labels = [l for l in online_labels if l in classes]
            valid_labels = [l for l in valid_labels if l in classes]
            test_labels = [l for l in test_labels if l in classes]

        train_x = torch.from_numpy(train_x).to(torch.device(device))
        # online_x = torch.from_numpy(online_x).to(torch.device(device))
        valid_x = torch.from_numpy(valid_x).to(torch.device(device))
        test_x = torch.from_numpy(test_x).to(torch.device(device))

        ######################################
        # Make EinsumNetworks for each class #
        ######################################
        einets = []
        ps = []
        for c in classes:
            if structure == 'poon-domingos':
                pd_delta = [[height / d, width / d] for d in pd_num_pieces]
                graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
            elif structure == 'binary-trees':
                graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
            else:
                raise AssertionError("Unknown Structure")

            args = EinsumNetwork.Args(
                    num_var=train_x.shape[1],
                    num_dims=3 if svhn else 1,
                    num_classes=1,
                    num_sums=K,
                    num_input_distributions=K,
                    exponential_family=exponential_family,
                    exponential_family_args=exponential_family_args,
                    online_em_frequency=online_em_frequency,
                    online_em_stepsize=online_em_stepsize)

            einet = EinsumNetwork.EinsumNetwork(graph, args)

            init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c)
            einet.initialize(init_dict)
            einet.to(device)
            einets.append(einet)

            # Calculate amount of training samples per class
            ps.append(train_labels.count(c))

            print(f'Einsum network for class {c}:')
            print(einet)

        # normalize ps, construct mixture component
        ps = [p / sum(ps) for p in ps]
        ps = torch.tensor(ps).to(torch.device(device))
        mixture = EinetMixture(ps, einets, classes=classes)

        num_params = mixture.eval_size()

        ##################################
        # Evalueate after initialization #
        ##################################

        train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        mixture.eval()
        train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size)
        valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size)
        test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size)
        print()
        print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
                train_ll_before / train_N,
                valid_ll_before / valid_N,
                test_ll_before / test_N))
        train_lls.append(train_ll_before / train_N)
        valid_lls.append(valid_ll_before / valid_N)
        test_lls.append(test_ll_before / test_N)

        ################
        # Experiment 4 #
        ################
        train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size)
        acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size)
        print()
        print("Experiment 8: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
                acc_train_before,
                acc_valid_before,
                acc_test_before))
        train_accs.append(acc_train_before)
        valid_accs.append(acc_valid_before)
        test_accs.append(acc_test_before)
Beispiel #4
0
    def construct_network_and_train(
            num_starts, svhn, exponential_family, classes, K, structure,
            pd_num_pieces, depth, num_repetitions, width, height, num_epochs,
            batch_size, online_em_frequency, online_em_stepsize,
            exponential_family_args, train_x, train_labels, test_x,
            test_labels, valid_x, valid_labels):
        train_lls = []
        valid_lls = []
        test_lls = []
        train_accs = []
        valid_accs = []
        test_accs = []
        training_times = []
        num_params = None

        for s in range(num_starts):
            print(f"""
                Running start: {s}/{num_starts}
                """)
            # train_lls.append(0)
            # valid_lls.append(0)
            # test_lls.append(0)
            # train_accs.append(0)
            # valid_accs.append(0)
            # test_accs.append(0)
            # training_times.append(0)
            # continue

            ######################################
            # Make EinsumNetworks for each class #
            ######################################
            einets = []
            ps = []
            for c in classes:
                if structure == 'poon-domingos':
                    pd_delta = [[height / d, width / d] for d in pd_num_pieces]
                    graph = Graph.poon_domingos_structure(shape=(height,
                                                                 width),
                                                          delta=pd_delta)
                elif structure == 'binary-trees':
                    graph = Graph.random_binary_trees(
                        num_var=train_x.shape[1],
                        depth=depth,
                        num_repetitions=num_repetitions)
                else:
                    raise AssertionError("Unknown Structure")

                args = EinsumNetwork.Args(
                    num_var=train_x.shape[1],
                    num_dims=3 if svhn else 1,
                    num_classes=1,
                    num_sums=K,
                    num_input_distributions=K,
                    exponential_family=exponential_family,
                    exponential_family_args=exponential_family_args,
                    online_em_frequency=online_em_frequency,
                    online_em_stepsize=online_em_stepsize)

                einet = EinsumNetwork.EinsumNetwork(graph, args)

                init_dict = get_init_dict(einet,
                                          train_x,
                                          train_labels=train_labels,
                                          einet_class=c)
                einet.initialize(init_dict)
                einet.to(device)
                einets.append(einet)

                # Calculate amount of training samples per class
                ps.append(train_labels.count(c))

                # print(f'Einsum network for class {c}:')
                # print(einet)

            # normalize ps, construct mixture component
            ps = [p / sum(ps) for p in ps]
            ps = torch.tensor(ps).to(torch.device(device))
            mixture = EinetMixture(ps, einets, classes=classes)

            if num_params == None:
                num_params = mixture.eval_size()

            ##################
            # Training phase #
            ##################
            """ Learning each sub Network Generatively """

            start_time = time.time()

            for (einet, c) in zip(einets, classes):
                train_x_c = train_x[[l == c for l in train_labels]]
                valid_x_c = valid_x[[l == c for l in valid_labels]]
                test_x_c = test_x[[l == c for l in test_labels]]

                train_N = train_x_c.shape[0]
                valid_N = valid_x_c.shape[0]
                test_N = test_x_c.shape[0]

                for epoch_count in range(num_epochs):
                    idx_batches = torch.randperm(
                        train_N, device=device).split(batch_size)

                    total_ll = 0.0
                    for idx in idx_batches:
                        batch_x = train_x_c[idx, :]
                        outputs = einet.forward(batch_x)
                        ll_sample = EinsumNetwork.log_likelihoods(outputs)
                        log_likelihood = ll_sample.sum()
                        log_likelihood.backward()

                        einet.em_process_batch()
                        total_ll += log_likelihood.detach().item()
                    einet.em_update()
                    print(
                        f'[{epoch_count}]   total log-likelihood: {total_ll}')

            end_time = time.time()

            ################
            # Experiment 3 #
            ################
            train_N = train_x.shape[0]
            valid_N = valid_x.shape[0]
            test_N = test_x.shape[0]
            mixture.eval()
            train_ll = mixture.eval_loglikelihood_batched(
                train_x, batch_size=batch_size)
            valid_ll = mixture.eval_loglikelihood_batched(
                valid_x, batch_size=batch_size)
            test_ll = mixture.eval_loglikelihood_batched(test_x,
                                                         batch_size=batch_size)
            print()
            print(
                "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
                .format(train_ll / train_N, valid_ll / valid_N,
                        test_ll / test_N))

            ################
            # Experiment 4 #
            ################
            train_labels_tensor = torch.tensor(train_labels).to(
                torch.device(device))
            valid_labels_tensor = torch.tensor(valid_labels).to(
                torch.device(device))
            test_labels_tensor = torch.tensor(test_labels).to(
                torch.device(device))

            acc_train = mixture.eval_accuracy_batched(classes,
                                                      train_x,
                                                      train_labels_tensor,
                                                      batch_size=batch_size)
            acc_valid = mixture.eval_accuracy_batched(classes,
                                                      valid_x,
                                                      valid_labels_tensor,
                                                      batch_size=batch_size)
            acc_test = mixture.eval_accuracy_batched(classes,
                                                     test_x,
                                                     test_labels_tensor,
                                                     batch_size=batch_size)
            print()
            print(
                "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
                .format(acc_train, acc_valid, acc_test))

            print()
            print(f'Network size: {num_params} parameters')
            print(f'Training time: {end_time - start_time}s')

            train_lls.append(train_ll / train_N)
            valid_lls.append(valid_ll / valid_N)
            test_lls.append(test_ll / test_N)
            train_accs.append(acc_train)
            valid_accs.append(acc_valid)
            test_accs.append(acc_test)
            training_times.append(end_time - start_time)

        return {
            'train_ll': train_lls,
            'valid_ll': valid_lls,
            'test_ll': test_lls,
            'train_acc': train_accs,
            'valid_acc': valid_accs,
            'test_acc': test_accs,
            'network_size': num_params,
            'training_time': training_times,
        }
Beispiel #5
0
            init_dict[layer] = custom_initializer(layer, train_x)

    einet.initialize(init_dict)
    einet.to(device)
    einets.append(einet)

    # Calculate amount of training samples per class
    ps.append(train_labels.count(c))

    print(f'Einsum network for class {c}:')
    print(einet)

# normalize ps, construct mixture component
ps = [p / sum(ps) for p in ps]
ps = torch.tensor(ps).to(torch.device(device))
mixture = EinetMixture(ps, einets, classes=classes)

##################
# Training phase #
##################

if train_discriminatively:
    """ Train the EinetMixture discriminatively """

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())
print("done")

_, cluster_idx = pickle.load(
    open('../auxiliary/svhn/kmeans_{}.pkl'.format(num_clusters), 'rb'))

# make a mixture of EiNets
p = np.histogram(cluster_idx, num_clusters)[0].astype(np.float32)
p = p / p.sum()

einets = []
for k in range(num_clusters):
    print("Load model for cluster {}".format(k))
    model_file = os.path.join(einet_path, 'cluster_{}'.format(k), 'einet.mdl')
    einets.append(torch.load(model_file).to(device))

mixture = EinetMixture(p, einets)

L = 7
samples = mixture.sample(L**2, std_correction=0.0)
utils.save_image_stack(samples.reshape(-1, height, width, 3),
                       L,
                       L,
                       os.path.join(sample_path, 'einet_samples.png'),
                       margin=2,
                       margin_gray_val=0.,
                       frame=2,
                       frame_gray_val=0.0)
print("Saved samples to {}".format(
    os.path.join(sample_path, 'einet_samples.png')))

num_reconstructions = 10
def run_experiment(settings):
    ############################################################################
    exponential_family = EinsumNetwork.BinomialArray

    K = 1

    # structure = 'poon-domingos'
    structure = 'binary-trees'

    # 'poon-domingos'
    pd_num_pieces = [2]

    # 'binary-trees'
    depth = 1
    num_repetitions = 2

    width = 4

    num_epochs = 2
    batch_size = 3
    online_em_frequency = 1
    online_em_stepsize = 0.05
    SGD_learning_rate = 0.05

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 80}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    iris = datasets.load_iris()
    train_x = iris.data * 10
    train_labels = iris.target

    # print(train_x[0])
    # print(train_labels[0])
    # exit()

    # self generated data
    # train_x = np.array([np.array([1, 1, 1, 1]) for i in range(50)] + [np.array([3, 3, 3, 3]) for i in range(50)] + [np.array([8, 8, 8, 8]) for i in range(50)])
    # train_labels = np.array([0 for i in range(50)] + [1 for i in range(50)] + [2 for i in range(50)])

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        train_x -= .5

    classes = np.unique(train_labels).tolist()

    train_x = torch.from_numpy(train_x).to(torch.device(device))

    # Make EinsumNetwork
    ######################################

    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(num_var=train_x.shape[1],
                                              depth=depth,
                                              num_repetitions=num_repetitions)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            online_em_frequency=online_em_frequency,
            online_em_stepsize=online_em_stepsize)

        einet = EinsumNetwork.EinsumNetwork(graph, args)
        print(einet)

        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)

        einet = einet.float()
        einets.append(einet)

        ps.append(np.count_nonzero(train_labels == c))

    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    # Train
    ######################################
    """ Generative training """

    start_time = time.time()

    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]

        train_N = train_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_ll = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :].float()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                log_likelihood.backward()

                einet.em_process_batch()
                total_ll += log_likelihood.detach().item()
            einet.em_update()
            print(
                f'[{epoch_count}]   total log-likelihood: {total_ll/train_N}')

    end_time = time.time()
    """ Discriminative training """
    def discriminative_learning():
        sub_net_parameters = None
        for einet in mixture.einets:
            if sub_net_parameters is None:
                sub_net_parameters = list(einet.parameters())
            else:
                sub_net_parameters += list(einet.parameters())
        sub_net_parameters += list(mixture.parameters())

        optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)
        loss_function = torch.nn.CrossEntropyLoss()

        train_N = train_x.shape[0]

        start_time = time.time()

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0
            for idx in idx_batches:
                batch_x = train_x[idx, :].float()
                optimizer.zero_grad()
                outputs = mixture.forward(batch_x)
                target = torch.tensor([
                    classes.index(train_labels[i]) for i in idx
                ]).to(torch.device(device))
                loss = loss_function(outputs, target)
                loss.backward()
                optimizer.step()
                total_loss += loss.detach().item()

            print(f'[{epoch_count}]   total loss: {total_loss}')

        end_time = time.time()

    ################
    # Experiment 1 #
    ################
    mixture.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(mixture,
                                                        train_x.float(),
                                                        batch_size=1)
    print()
    print("Experiment 1: Log-likelihoods  --- train LL {}".format(train_ll /
                                                                  train_N))

    ################
    # Experiment 2 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))

    acc_train = EinsumNetwork.eval_accuracy_batched(mixture,
                                                    classes,
                                                    train_x.float(),
                                                    train_labels,
                                                    batch_size=1)
    print()
    print("Experiment 2: Classification accuracies  --- train acc {}".format(
        acc_train))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'train_acc': acc_train,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions_mixture = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    online_em_frequency = settings.online_em_frequency
    online_em_stepsize = settings.online_em_stepsize
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]

    # valid_x = train_x[-10000:, :]
    # train_x = train_x[-12000:-11000, :]

    # valid_labels = train_labels[-10000:]
    # train_labels = train_labels[-12000:-11000]

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(
                num_var=train_x.shape[1],
                depth=depth,
                num_repetitions=num_repetitions_mixture)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=3 if svhn else 1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            online_em_frequency=online_em_frequency,
            online_em_stepsize=online_em_stepsize,
            use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()
    """Code for weight analysis, section 7.3"""
    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # data_file = os.path.join(data_dir, f"weights_before_{c}.json")
    # weights = {}
    # for (einet, c) in zip(einets, classes):
    #     einet_weights = {}
    #     for i, l in enumerate(reversed(einet.einet_layers)):
    #         if type(l) != FactorizedLeafLayer.FactorizedLeafLayer:
    #             einet_weights[i] = l.reparam(l.params.data)
    #         else:
    #             einet_weights[i] = l.ef_array.reparam(l.ef_array.params.data)
    #     weights[c] = einet_weights

    ##################
    # Training phase #
    ##################

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)
    """ Learning each sub Network Generatively """

    start_time = time.time()

    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]
        valid_x_c = valid_x[[l == c for l in valid_labels]]
        test_x_c = test_x[[l == c for l in test_labels]]

        train_N = train_x_c.shape[0]
        valid_N = valid_x_c.shape[0]
        test_N = test_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                # log_likelihood.backward()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()
            #     einet.em_process_batch()
            # einet.em_update()

            print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()
    """Code for weight analysis"""
    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # percentual_change = []
    # percentual_change_formatted = []
    # data_file = os.path.join(data_dir, f"percentual_change.json")
    # for (einet, c) in zip(einets, classes):
    #     einet_change = {}
    #     for i, l in enumerate(reversed(einet.einet_layers)):
    #         if type(l) != FactorizedLeafLayer.FactorizedLeafLayer:
    #             weights_new = l.reparam(l.params.data)
    #         else:
    #             weights_new = l.ef_array.reparam(l.ef_array.params.data)
    #         change = torch.mean(torch.abs(weights[c][i] - weights_new)/weights[c][i]).item()
    #         einet_change[i] = change
    #         if i == 0:
    #             percentual_change_formatted.append(change)

    #     percentual_change.append(einet_change)

    # print(percentual_change)
    # with open(data_file, 'w') as f:
    #     json.dump(percentual_change, f)
    # data_file = os.path.join(data_dir, f"percentual_change_formatted.json")
    # with open(data_file, 'w') as f:
    #     json.dump(percentual_change_formatted, f)

    ################
    # Experiment 3 #
    ################
    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]
    mixture.eval()
    train_ll = mixture.eval_loglikelihood_batched(train_x,
                                                  batch_size=batch_size,
                                                  skip_reparam=True)
    valid_ll = mixture.eval_loglikelihood_batched(valid_x,
                                                  batch_size=batch_size,
                                                  skip_reparam=True)
    test_ll = mixture.eval_loglikelihood_batched(test_x,
                                                 batch_size=batch_size,
                                                 skip_reparam=True)
    print()
    print(
        "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 4 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = mixture.eval_accuracy_batched(classes,
                                              train_x,
                                              train_labels,
                                              batch_size=batch_size,
                                              skip_reparam=True)
    acc_valid = mixture.eval_accuracy_batched(classes,
                                              valid_x,
                                              valid_labels,
                                              batch_size=batch_size,
                                              skip_reparam=True)
    acc_test = mixture.eval_accuracy_batched(classes,
                                             test_x,
                                             test_labels,
                                             batch_size=batch_size,
                                             skip_reparam=True)
    print()
    print(
        "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }