예제 #1
0
    def new_start(start_train_set, online_offset):
        ############################################################################
        fashion_mnist = settings.fashion_mnist
        svhn = settings.svhn

        exponential_family = settings.exponential_family

        classes = settings.classes

        K = settings.K

        structure =  settings.structure

        # 'poon-domingos'
        pd_num_pieces = settings.pd_num_pieces

        # 'binary-trees'
        depth = settings.depth
        num_repetitions = settings.num_repetitions_mixture

        width = settings.width
        height = settings.height

        num_epochs = settings.num_epochs
        batch_size = settings.batch_size
        online_em_frequency = settings.online_em_frequency
        online_em_stepsize = settings.online_em_stepsize

        ############################################################################

        exponential_family_args = None
        if exponential_family == EinsumNetwork.BinomialArray:
            exponential_family_args = {'N': 255}
        if exponential_family == EinsumNetwork.CategoricalArray:
            exponential_family_args = {'K': 256}
        if exponential_family == EinsumNetwork.NormalArray:
            exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

        # get data
        if fashion_mnist:
            train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist()
        elif svhn:
            train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn()
        else:
            train_x, train_labels, test_x, test_labels = datasets.load_mnist()

        if not exponential_family != EinsumNetwork.NormalArray:
            train_x /= 255.
            test_x /= 255.
            train_x -= .5
            test_x -= .5

        # validation split
        valid_x = train_x[-10000:, :]
        # online_x = train_x[-40000:, :]
        train_x = train_x[:-(10000+online_offset-start_train_set), :]
        valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-40000:]
        train_labels = train_labels[:-(10000+online_offset-start_train_set)]
        
        # # debug setup
        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-45000:, :]
        # train_x = train_x[:-55000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-45000:]
        # train_labels = train_labels[:-55000]

        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-10000:, :]
        # train_x = train_x[:-20000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-10000:]
        # train_labels = train_labels[:-20000]

        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-20000:, :]
        # train_x = train_x[:-30000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-20000:]
        # train_labels = train_labels[:-30000]

        # pick the selected classes
        if classes is not None:
            train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
            # online_x = online_x[np.any(np.stack([online_labels == c for c in classes], 1), 1), :]
            valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
            test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

            train_labels = [l for l in train_labels if l in classes]
            # online_labels = [l for l in online_labels if l in classes]
            valid_labels = [l for l in valid_labels if l in classes]
            test_labels = [l for l in test_labels if l in classes]
        else:
            classes = np.unique(train_labels).tolist()

            train_labels = [l for l in train_labels if l in classes]
            # online_labels = [l for l in online_labels if l in classes]
            valid_labels = [l for l in valid_labels if l in classes]
            test_labels = [l for l in test_labels if l in classes]

        train_x = torch.from_numpy(train_x).to(torch.device(device))
        # online_x = torch.from_numpy(online_x).to(torch.device(device))
        valid_x = torch.from_numpy(valid_x).to(torch.device(device))
        test_x = torch.from_numpy(test_x).to(torch.device(device))

        ######################################
        # Make EinsumNetworks for each class #
        ######################################
        einets = []
        ps = []
        for c in classes:
            if structure == 'poon-domingos':
                pd_delta = [[height / d, width / d] for d in pd_num_pieces]
                graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
            elif structure == 'binary-trees':
                graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
            else:
                raise AssertionError("Unknown Structure")

            args = EinsumNetwork.Args(
                    num_var=train_x.shape[1],
                    num_dims=3 if svhn else 1,
                    num_classes=1,
                    num_sums=K,
                    num_input_distributions=K,
                    exponential_family=exponential_family,
                    exponential_family_args=exponential_family_args,
                    online_em_frequency=online_em_frequency,
                    online_em_stepsize=online_em_stepsize)

            einet = EinsumNetwork.EinsumNetwork(graph, args)

            init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c)
            einet.initialize(init_dict)
            einet.to(device)
            einets.append(einet)

            # Calculate amount of training samples per class
            ps.append(train_labels.count(c))

            print(f'Einsum network for class {c}:')
            print(einet)

        # normalize ps, construct mixture component
        ps = [p / sum(ps) for p in ps]
        ps = torch.tensor(ps).to(torch.device(device))
        mixture = EinetMixture(ps, einets, classes=classes)

        num_params = mixture.eval_size()

        ##################################
        # Evalueate after initialization #
        ##################################

        train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        mixture.eval()
        train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size)
        valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size)
        test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size)
        print()
        print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
                train_ll_before / train_N,
                valid_ll_before / valid_N,
                test_ll_before / test_N))
        train_lls.append(train_ll_before / train_N)
        valid_lls.append(valid_ll_before / valid_N)
        test_lls.append(test_ll_before / test_N)

        ################
        # Experiment 4 #
        ################
        train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size)
        acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size)
        print()
        print("Experiment 8: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
                acc_train_before,
                acc_valid_before,
                acc_test_before))
        train_accs.append(acc_train_before)
        valid_accs.append(acc_valid_before)
        test_accs.append(acc_test_before)
예제 #2
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure =  settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist()
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn()
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]
    # pick the selected classes
    if classes is not None:
        train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
                num_var=train_x.shape[1],
                num_dims=3 if svhn else 1,
                num_classes=1,
                num_sums=K,
                num_input_distributions=K,
                exponential_family=exponential_family,
                exponential_family_args=exponential_family_args,
                use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    ##################################
    # Evalueate after initialization #
    ##################################

    train_lls = []
    valid_lls = []
    test_lls = []
    train_accs = []
    valid_accs = []
    test_accs = []

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]
    mixture.eval()
    train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size, skip_reparam=True)
    valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size, skip_reparam=True)
    test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size, skip_reparam=True)
    print()
    print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
            train_ll_before / train_N,
            valid_ll_before / valid_N,
            test_ll_before / test_N))
    train_lls.append(train_ll_before / train_N)
    valid_lls.append(valid_ll_before / valid_N)
    test_lls.append(test_ll_before / test_N)
    
    ################
    # Experiment 4 #
    ################
    train_labelsz = torch.tensor(train_labels).to(torch.device(device))
    valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
    test_labelsz = torch.tensor(test_labels).to(torch.device(device))

    acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size, skip_reparam=True)
    acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size, skip_reparam=True)
    acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size, skip_reparam=True)
    print()
    print("Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
            acc_train_before,
            acc_valid_before,
            acc_test_before))
    train_accs.append(acc_train_before)
    valid_accs.append(acc_valid_before)
    test_accs.append(acc_test_before)
    mixture.train()

    ##################
    # Training phase #
    ##################

    """ Learning each sub Network Generatively """

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)

    start_time = time.time()

    end_time = time.time()

    for epoch_count in range(num_epochs):
        for (einet, c) in zip(einets, classes):
            train_x_c = train_x[[l == c for l in train_labels]]
            valid_x_c = valid_x[[l == c for l in valid_labels]]
            test_x_c = test_x[[l == c for l in test_labels]]

            train_N = train_x_c.shape[0]
            valid_N = valid_x_c.shape[0]
            test_N = test_x_c.shape[0]

            idx_batches = torch.randperm(train_N, device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()
            print(f'[{epoch_count}]   total loss: {total_loss}')

        mixture.eval()
        train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        train_ll = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size)
        valid_ll = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size)
        test_ll = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size)
        train_lls.append(train_ll / train_N)
        valid_lls.append(valid_ll / valid_N)
        test_lls.append(test_ll / test_N)

        train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        acc_train = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size)
        acc_test = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size)
        train_accs.append(acc_train)
        valid_accs.append(acc_valid)
        test_accs.append(acc_test)
        mixture.train()

    print()
    print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
            train_ll / train_N,
            valid_ll / valid_N,
            test_ll / test_N))

    print()
    print("Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
            acc_train,
            acc_valid,
            acc_test))

    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_lls': train_lls,
        'valid_lls': valid_lls,
        'test_lls': test_lls,
        'train_accs': train_accs,
        'valid_accs': valid_accs,
        'test_accs': test_accs,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
예제 #3
0
    import os
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # load dataset
    from datasets import load_mnist, load_usps, load_fashion_mnist
    if args.dataset == 'mnist':
        x, y = load_mnist()
    elif args.dataset == 'usps':
        x, y = load_usps('data/usps')
    elif args.dataset == 'mnist-test':
        x, y = load_mnist()
        x, y = x[69000:], y[69000:]
    elif args.dataset == 'fashion':
        x, y = load_fashion_mnist()
    elif args.dataset == 'fashion-test':
        x, y = load_fashion_mnist()
        x, y = x[60000:], y[60000:]

    # prepare the DCEC model
    dcec = DCEC(input_shape=x.shape[1:],
                filters=[32, 64, 128, 10],
                n_clusters=args.n_clusters)
    plot_model(dcec.model,
               to_file=args.save_dir + '/dcec_model.png',
               show_shapes=True)
    dcec.model.summary()

    # begin clustering.
    optimizer = 'adam'
예제 #4
0
batch_size = 100
online_em_frequency = 1
online_em_stepsize = 0.05
############################################################################

exponential_family_args = None
if exponential_family == EinsumNetwork.BinomialArray:
    exponential_family_args = {'N': 255}
if exponential_family == EinsumNetwork.CategoricalArray:
    exponential_family_args = {'K': 256}
if exponential_family == EinsumNetwork.NormalArray:
    exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

# get data
if fashion_mnist:
    train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist()
else:
    train_x, train_labels, test_x, test_labels = datasets.load_mnist()

if not exponential_family != EinsumNetwork.NormalArray:
    train_x /= 255.
    test_x /= 255.
    train_x -= .5
    test_x -= .5

# validation split
valid_x = train_x[-10000:, :]
train_x = train_x[:-10000, :]
valid_labels = train_labels[-10000:]
train_labels = train_labels[:-10000]
예제 #5
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions_mixture = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    online_x = train_x[-11000:-10000, :]
    train_x = train_x[-56000:-11000, :]
    # init_x = train_x[-13000:-10000, :]

    valid_labels = train_labels[-10000:]
    online_labels = train_labels[-11000:-10000]
    train_labels = train_labels[-56000:-11000]
    # init_labels = train_labels[-13000:-10000]

    # full set of training
    # valid_x = train_x[-10000:, :]
    # online_x = train_x[-11000:-10000, :]
    # train_x = train_x[:-10000, :]

    # valid_labels = train_labels[-10000:]
    # online_labels = train_labels[-11000:-10000]
    # train_labels = train_labels[:-10000]

    # print('train_x:')
    # print(train_x.shape)
    # print(train_labels.shape)
    # print('online_x:')
    # print(online_x.shape)
    # print(online_labels.shape)
    # exit()

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        online_x = online_x[
            np.any(np.stack([online_labels == c for c in classes], 1), 1), :]
        # init_x = init_x[np.any(np.stack([init_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        train_labels_backup = train_labels
        online_labels = [l for l in online_labels if l in classes]
        # init_labels = [l for l in init_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        train_labels_backup = train_labels
        online_labels = [l for l in online_labels if l in classes]
        # init_labels = [l for l in init_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    train_x_backup = train_x
    online_x = torch.from_numpy(online_x).to(torch.device(device))
    # init_x = torch.from_numpy(init_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(
                num_var=train_x.shape[1],
                depth=depth,
                num_repetitions=num_repetitions_mixture)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=3 if svhn else 1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        # init_dict = get_init_dict(einet, init_x, train_labels=init_labels, einet_class=c)
        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # for (einet, c) in zip(einets, classes):
    #     data_file = os.path.join(data_dir, f"weights_before_{c}.json")
    #     weights = einet.einet_layers[-1].params.data.cpu()
    #     np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0])

    ##################
    # Training phase #
    ##################

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)

    start_time = time.time()
    """ Learning each sub Network Generatively """
    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]

        train_N = train_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()

            print(f'[{epoch_count}]   total log-likelihood: {total_loss}')

    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # for (einet, c) in zip(einets, classes):
    #     data_file = os.path.join(data_dir, f"weights_after_{c}.json")
    #     weights = einet.einet_layers[-1].params.data.cpu()
    #     np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0])
    # exit()

    ##################################
    # Evalueate after initialization #
    ##################################

    train_lls = []
    valid_lls = []
    test_lls = []
    train_accs = []
    valid_accs = []
    test_accs = []
    train_lls_ref = []
    valid_lls_ref = []
    test_lls_ref = []
    train_accs_ref = []
    valid_accs_ref = []
    test_accs_ref = []
    added_samples = [0]

    def eval_network(do_print=False, no_OA=False):
        if no_OA:
            train_N = train_x_backup.shape[0]
        else:
            train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        mixture.eval()
        if no_OA:
            train_ll_before = mixture.eval_loglikelihood_batched(
                train_x_backup, batch_size=batch_size)
        else:
            train_ll_before = mixture.eval_loglikelihood_batched(
                train_x, batch_size=batch_size)
        valid_ll_before = mixture.eval_loglikelihood_batched(
            valid_x, batch_size=batch_size)
        test_ll_before = mixture.eval_loglikelihood_batched(
            test_x, batch_size=batch_size)
        if do_print:
            print()
            print(
                "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
                .format(train_ll_before / train_N, valid_ll_before / valid_N,
                        test_ll_before / test_N))
        if no_OA:
            train_lls_ref.append(train_ll_before / train_N)
            valid_lls_ref.append(valid_ll_before / valid_N)
            test_lls_ref.append(test_ll_before / test_N)
        else:
            train_lls.append(train_ll_before / train_N)
            valid_lls.append(valid_ll_before / valid_N)
            test_lls.append(test_ll_before / test_N)

        ################
        # Experiment 4 #
        ################
        if no_OA:
            train_labelsz = torch.tensor(train_labels_backup).to(
                torch.device(device))
        else:
            train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        if no_OA:
            acc_train_before = mixture.eval_accuracy_batched(
                classes, train_x_backup, train_labelsz, batch_size=batch_size)
        else:
            acc_train_before = mixture.eval_accuracy_batched(
                classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid_before = mixture.eval_accuracy_batched(classes,
                                                         valid_x,
                                                         valid_labelsz,
                                                         batch_size=batch_size)
        acc_test_before = mixture.eval_accuracy_batched(classes,
                                                        test_x,
                                                        test_labelsz,
                                                        batch_size=batch_size)
        if do_print:
            print()
            print(
                "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
                .format(acc_train_before, acc_valid_before, acc_test_before))
        if no_OA:
            train_accs_ref.append(acc_train_before)
            valid_accs_ref.append(acc_valid_before)
            test_accs_ref.append(acc_test_before)
        else:
            train_accs.append(acc_train_before)
            valid_accs.append(acc_valid_before)
            test_accs.append(acc_test_before)
        mixture.train()

    eval_network(do_print=True, no_OA=False)
    eval_network(do_print=False, no_OA=True)

    #####################################################
    # Evaluate the network with different training sets #
    #####################################################

    idx_batches = torch.randperm(online_x.shape[0], device=device).split(20)

    for idx in tqdm(idx_batches):
        online_x_idx = online_x[idx]
        online_labels_idx = [online_labels[i] for i in idx]

        for (einet, c) in zip(einets, classes):
            batch_x = online_x_idx[[l == c for l in online_labels_idx]]
            train_x_backup = torch.cat((train_x_backup, batch_x))
            train_labels_backup += [c for i in batch_x]

        added_samples.append(added_samples[-1] + len(idx))
        eval_network(do_print=False, no_OA=True)

    #####################
    # Online adaptation #
    #####################

    for idx in tqdm(idx_batches):
        online_x_idx = online_x[idx]
        online_labels_idx = [online_labels[i] for i in idx]

        for (einet, c) in zip(einets, classes):
            batch_x = online_x_idx[[l == c for l in online_labels_idx]]
            online_update(einet, batch_x)
            train_x = torch.cat((train_x, batch_x))
            train_labels += [c for i in batch_x]

        eval_network(do_print=False, no_OA=False)

    print()
    print(f'Network size: {num_params} parameters')

    return {
        'train_lls': train_lls,
        'valid_lls': valid_lls,
        'test_lls': test_lls,
        'train_accs': train_accs,
        'valid_accs': valid_accs,
        'test_accs': test_accs,
        'train_lls_ref': train_lls_ref,
        'valid_lls_ref': valid_lls_ref,
        'test_lls_ref': test_lls_ref,
        'train_accs_ref': train_accs_ref,
        'valid_accs_ref': valid_accs_ref,
        'test_accs_ref': test_accs_ref,
        'network_size': num_params,
        'online_samples': added_samples,
    }
예제 #6
0
def run_experiment(settings):
    ############################################################################

    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    online_em_frequency = settings.online_em_frequency
    online_em_stepsize = settings.online_em_stepsize
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    # Make EinsumNetwork
    ######################################
    if structure == 'poon-domingos':
        pd_delta = [[height / d, width / d] for d in pd_num_pieces]
        graph = Graph.poon_domingos_structure(shape=(height, width),
                                              delta=pd_delta)
    elif structure == 'binary-trees':
        graph = Graph.random_binary_trees(num_var=train_x.shape[1],
                                          depth=depth,
                                          num_repetitions=num_repetitions)
    else:
        raise AssertionError("Unknown Structure")

    args = EinsumNetwork.Args(num_var=train_x.shape[1],
                              num_dims=3 if svhn else 1,
                              num_classes=1,
                              num_sums=K,
                              num_input_distributions=K,
                              exponential_family=exponential_family,
                              exponential_family_args=exponential_family_args,
                              online_em_frequency=online_em_frequency,
                              online_em_stepsize=online_em_stepsize,
                              use_em=True)

    einet = EinsumNetwork.EinsumNetwork(graph, args)
    print(einet)

    init_dict = get_init_dict(einet, train_x)
    einet.initialize(init_dict)
    einet.to(device)

    num_params = EinsumNetwork.eval_size(einet)

    data_dir = '../src/experiments/round5/data/weights_analysis/'
    data_file = os.path.join(data_dir, f"weights_before.json")
    weights = einet.einet_layers[-1].params.data.cpu()
    np.savetxt(data_file, weights[0])

    # Train
    ######################################

    optimizer = torch.optim.SGD(einet.parameters(), lr=SGD_learning_rate)

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]

    start_time = time.time()

    for epoch_count in range(num_epochs):
        idx_batches = torch.randperm(train_N, device=device).split(batch_size)

        total_loss = 0.0
        for idx in idx_batches:
            batch_x = train_x[idx, :]
            # optimizer.zero_grad()
            outputs = einet.forward(batch_x)
            ll_sample = EinsumNetwork.log_likelihoods(outputs)
            log_likelihood = ll_sample.sum()
            log_likelihood.backward()
            # nll = log_likelihood * -1
            # nll.backward()
            # optimizer.step()

            einet.em_process_batch()
        einet.em_update()

        print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()

    data_dir = '../src/experiments/round5/data/weights_analysis/'
    data_file = os.path.join(data_dir, f"weights_after.json")
    weights = einet.einet_layers[-1].params.data.cpu()
    np.savetxt(data_file, weights[0])
    # exit()

    ################
    # Experiment 1 #
    ################
    einet.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        train_x,
                                                        batch_size=batch_size)
    valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        valid_x,
                                                        batch_size=batch_size)
    test_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                       test_x,
                                                       batch_size=batch_size)
    print()
    print(
        "Experiment 1: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 2 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    train_x,
                                                    train_labels,
                                                    batch_size=batch_size)
    acc_valid = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    valid_x,
                                                    valid_labels,
                                                    batch_size=batch_size)
    acc_test = EinsumNetwork.eval_accuracy_batched(einet,
                                                   classes,
                                                   test_x,
                                                   test_labels,
                                                   batch_size=batch_size)
    print()
    print(
        "Experiment 2: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
예제 #7
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]
    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    if structure == 'poon-domingos':
        pd_delta = [[height / d, width / d] for d in pd_num_pieces]
        graph = Graph.poon_domingos_structure(shape=(height, width),
                                              delta=pd_delta)
    elif structure == 'binary-trees':
        graph = Graph.random_binary_trees(num_var=train_x.shape[1],
                                          depth=depth,
                                          num_repetitions=num_repetitions)
    else:
        raise AssertionError("Unknown Structure")

    args = EinsumNetwork.Args(num_var=train_x.shape[1],
                              num_dims=3 if svhn else 1,
                              num_classes=len(classes),
                              num_sums=K,
                              num_input_distributions=K,
                              exponential_family=exponential_family,
                              exponential_family_args=exponential_family_args,
                              use_em=False)

    einet = EinsumNetwork.EinsumNetwork(graph, args)

    init_dict = get_init_dict(einet, train_x)
    einet.initialize(init_dict)
    einet.to(device)
    print(einet)

    num_params = EinsumNetwork.eval_size(einet)

    #################################
    # Discriminative training phase #
    #################################

    optimizer = torch.optim.SGD(einet.parameters(), lr=SGD_learning_rate)
    loss_function = torch.nn.CrossEntropyLoss()

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]

    start_time = time.time()

    for epoch_count in range(num_epochs):
        idx_batches = torch.randperm(train_N, device=device).split(batch_size)

        total_loss = 0
        for idx in idx_batches:
            batch_x = train_x[idx, :]
            optimizer.zero_grad()
            outputs = einet.forward(batch_x)
            target = torch.tensor([
                classes.index(train_labels[i]) for i in idx
            ]).to(torch.device(device))
            loss = loss_function(outputs, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.detach().item()

        print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()

    ################
    # Experiment 5 #
    ################
    einet.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        train_x,
                                                        batch_size=batch_size)
    valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        valid_x,
                                                        batch_size=batch_size)
    test_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                       test_x,
                                                       batch_size=batch_size)
    print()
    print(
        "Experiment 5: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 6 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    train_x,
                                                    train_labels,
                                                    batch_size=batch_size)
    acc_valid = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    valid_x,
                                                    valid_labels,
                                                    batch_size=batch_size)
    acc_test = EinsumNetwork.eval_accuracy_batched(einet,
                                                   classes,
                                                   test_x,
                                                   test_labels,
                                                   batch_size=batch_size)
    print()
    print(
        "Experiment 6: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
예제 #8
0
online_em_stepsize = 0.05
############################################################################

exponential_family_args = None
if exponential_family == EinsumNetwork.BinomialArray:
    exponential_family_args = {'N': 255}
if exponential_family == EinsumNetwork.CategoricalArray:
    exponential_family_args = {'K': 256}
if exponential_family == EinsumNetwork.NormalArray:
    exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}
if exponential_family == EinsumNetwork.MultivariateNormalArray:
    exponential_family_args = {'min_var': 1e-6, 'max_var': 0.01}

# get data
if fashion_mnist:
    train_x_raw, train_labels, test_x_raw, test_labels = datasets.load_fashion_mnist(width, height)
else:
    train_x_raw, train_labels, test_x_raw, test_labels = datasets.load_mnist(width, height)

# TODO: Rework this section
train_x = torch.fft.rfft(torch.tensor(train_x_raw.reshape((-1, width, height))), norm='forward')
test_x = torch.fft.rfft(torch.tensor(test_x_raw.reshape((-1, width, height))), norm='forward')

train_x = train_x.reshape((-1, train_x.shape[1] * train_x.shape[2]))
test_x = test_x.reshape((-1, test_x.shape[1] * test_x.shape[2]))

train_x = torch.stack([train_x.real, train_x.imag], dim=-1)
test_x = torch.stack([test_x.real, test_x.imag], dim=-1)

# validation split
valid_x = train_x[-10000:, :]
예제 #9
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions_mixture = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    online_em_frequency = settings.online_em_frequency
    online_em_stepsize = settings.online_em_stepsize
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]

    # valid_x = train_x[-10000:, :]
    # train_x = train_x[-12000:-11000, :]

    # valid_labels = train_labels[-10000:]
    # train_labels = train_labels[-12000:-11000]

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(
                num_var=train_x.shape[1],
                depth=depth,
                num_repetitions=num_repetitions_mixture)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=3 if svhn else 1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            online_em_frequency=online_em_frequency,
            online_em_stepsize=online_em_stepsize,
            use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()
    """Code for weight analysis, section 7.3"""
    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # data_file = os.path.join(data_dir, f"weights_before_{c}.json")
    # weights = {}
    # for (einet, c) in zip(einets, classes):
    #     einet_weights = {}
    #     for i, l in enumerate(reversed(einet.einet_layers)):
    #         if type(l) != FactorizedLeafLayer.FactorizedLeafLayer:
    #             einet_weights[i] = l.reparam(l.params.data)
    #         else:
    #             einet_weights[i] = l.ef_array.reparam(l.ef_array.params.data)
    #     weights[c] = einet_weights

    ##################
    # Training phase #
    ##################

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)
    """ Learning each sub Network Generatively """

    start_time = time.time()

    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]
        valid_x_c = valid_x[[l == c for l in valid_labels]]
        test_x_c = test_x[[l == c for l in test_labels]]

        train_N = train_x_c.shape[0]
        valid_N = valid_x_c.shape[0]
        test_N = test_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                # log_likelihood.backward()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()
            #     einet.em_process_batch()
            # einet.em_update()

            print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()
    """Code for weight analysis"""
    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # percentual_change = []
    # percentual_change_formatted = []
    # data_file = os.path.join(data_dir, f"percentual_change.json")
    # for (einet, c) in zip(einets, classes):
    #     einet_change = {}
    #     for i, l in enumerate(reversed(einet.einet_layers)):
    #         if type(l) != FactorizedLeafLayer.FactorizedLeafLayer:
    #             weights_new = l.reparam(l.params.data)
    #         else:
    #             weights_new = l.ef_array.reparam(l.ef_array.params.data)
    #         change = torch.mean(torch.abs(weights[c][i] - weights_new)/weights[c][i]).item()
    #         einet_change[i] = change
    #         if i == 0:
    #             percentual_change_formatted.append(change)

    #     percentual_change.append(einet_change)

    # print(percentual_change)
    # with open(data_file, 'w') as f:
    #     json.dump(percentual_change, f)
    # data_file = os.path.join(data_dir, f"percentual_change_formatted.json")
    # with open(data_file, 'w') as f:
    #     json.dump(percentual_change_formatted, f)

    ################
    # Experiment 3 #
    ################
    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]
    mixture.eval()
    train_ll = mixture.eval_loglikelihood_batched(train_x,
                                                  batch_size=batch_size,
                                                  skip_reparam=True)
    valid_ll = mixture.eval_loglikelihood_batched(valid_x,
                                                  batch_size=batch_size,
                                                  skip_reparam=True)
    test_ll = mixture.eval_loglikelihood_batched(test_x,
                                                 batch_size=batch_size,
                                                 skip_reparam=True)
    print()
    print(
        "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 4 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = mixture.eval_accuracy_batched(classes,
                                              train_x,
                                              train_labels,
                                              batch_size=batch_size,
                                              skip_reparam=True)
    acc_valid = mixture.eval_accuracy_batched(classes,
                                              valid_x,
                                              valid_labels,
                                              batch_size=batch_size,
                                              skip_reparam=True)
    acc_test = mixture.eval_accuracy_batched(classes,
                                             test_x,
                                             test_labels,
                                             batch_size=batch_size,
                                             skip_reparam=True)
    print()
    print(
        "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }