Exemple #1
0
def preprocess_data(dataset):
    if dataset == 'mnist':
        (train_X, train_y), (test_X, test_y) = mnist.load_data()
        input_shape = (train_X[0].shape[0], train_X.shape[1], 1) 
        train_X = train_X.reshape([-1, input_shape[0], input_shape[1], input_shape[2]])
        test_X = test_X.reshape([-1, input_shape[0], input_shape[1], input_shape[2]])
    elif dataset == 'svhn':
        # data=np.concatenate((train_data, ext_data, test_data))
        data, label, n_train, n_test = datasets.load_svhn()
        train_X = data[:n_train]
        test_X = data[-n_test:]
        train_y = label[:n_train]
        test_y = label[-n_test:]
        input_shape = (32, 32, 3)
        train_X = train_X.reshape([-1, input_shape[0], input_shape[1], input_shape[2]])
        test_X = test_X.reshape([-1, input_shape[0], input_shape[1], input_shape[2]])
        
    print(input_shape)
    return input_shape, train_X, test_X, train_y, test_y
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure =  settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist()
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn()
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]
    # pick the selected classes
    if classes is not None:
        train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
                num_var=train_x.shape[1],
                num_dims=3 if svhn else 1,
                num_classes=1,
                num_sums=K,
                num_input_distributions=K,
                exponential_family=exponential_family,
                exponential_family_args=exponential_family_args,
                use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    ##################################
    # Evalueate after initialization #
    ##################################

    train_lls = []
    valid_lls = []
    test_lls = []
    train_accs = []
    valid_accs = []
    test_accs = []

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]
    mixture.eval()
    train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size, skip_reparam=True)
    valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size, skip_reparam=True)
    test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size, skip_reparam=True)
    print()
    print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
            train_ll_before / train_N,
            valid_ll_before / valid_N,
            test_ll_before / test_N))
    train_lls.append(train_ll_before / train_N)
    valid_lls.append(valid_ll_before / valid_N)
    test_lls.append(test_ll_before / test_N)
    
    ################
    # Experiment 4 #
    ################
    train_labelsz = torch.tensor(train_labels).to(torch.device(device))
    valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
    test_labelsz = torch.tensor(test_labels).to(torch.device(device))

    acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size, skip_reparam=True)
    acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size, skip_reparam=True)
    acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size, skip_reparam=True)
    print()
    print("Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
            acc_train_before,
            acc_valid_before,
            acc_test_before))
    train_accs.append(acc_train_before)
    valid_accs.append(acc_valid_before)
    test_accs.append(acc_test_before)
    mixture.train()

    ##################
    # Training phase #
    ##################

    """ Learning each sub Network Generatively """

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)

    start_time = time.time()

    end_time = time.time()

    for epoch_count in range(num_epochs):
        for (einet, c) in zip(einets, classes):
            train_x_c = train_x[[l == c for l in train_labels]]
            valid_x_c = valid_x[[l == c for l in valid_labels]]
            test_x_c = test_x[[l == c for l in test_labels]]

            train_N = train_x_c.shape[0]
            valid_N = valid_x_c.shape[0]
            test_N = test_x_c.shape[0]

            idx_batches = torch.randperm(train_N, device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()
            print(f'[{epoch_count}]   total loss: {total_loss}')

        mixture.eval()
        train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        train_ll = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size)
        valid_ll = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size)
        test_ll = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size)
        train_lls.append(train_ll / train_N)
        valid_lls.append(valid_ll / valid_N)
        test_lls.append(test_ll / test_N)

        train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        acc_train = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size)
        acc_test = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size)
        train_accs.append(acc_train)
        valid_accs.append(acc_valid)
        test_accs.append(acc_test)
        mixture.train()

    print()
    print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
            train_ll / train_N,
            valid_ll / valid_N,
            test_ll / test_N))

    print()
    print("Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
            acc_train,
            acc_valid,
            acc_test))

    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_lls': train_lls,
        'valid_lls': valid_lls,
        'test_lls': test_lls,
        'train_accs': train_accs,
        'valid_accs': valid_accs,
        'test_accs': test_accs,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
Exemple #3
0
    def new_start(start_train_set, online_offset):
        ############################################################################
        fashion_mnist = settings.fashion_mnist
        svhn = settings.svhn

        exponential_family = settings.exponential_family

        classes = settings.classes

        K = settings.K

        structure =  settings.structure

        # 'poon-domingos'
        pd_num_pieces = settings.pd_num_pieces

        # 'binary-trees'
        depth = settings.depth
        num_repetitions = settings.num_repetitions_mixture

        width = settings.width
        height = settings.height

        num_epochs = settings.num_epochs
        batch_size = settings.batch_size
        online_em_frequency = settings.online_em_frequency
        online_em_stepsize = settings.online_em_stepsize

        ############################################################################

        exponential_family_args = None
        if exponential_family == EinsumNetwork.BinomialArray:
            exponential_family_args = {'N': 255}
        if exponential_family == EinsumNetwork.CategoricalArray:
            exponential_family_args = {'K': 256}
        if exponential_family == EinsumNetwork.NormalArray:
            exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

        # get data
        if fashion_mnist:
            train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist()
        elif svhn:
            train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn()
        else:
            train_x, train_labels, test_x, test_labels = datasets.load_mnist()

        if not exponential_family != EinsumNetwork.NormalArray:
            train_x /= 255.
            test_x /= 255.
            train_x -= .5
            test_x -= .5

        # validation split
        valid_x = train_x[-10000:, :]
        # online_x = train_x[-40000:, :]
        train_x = train_x[:-(10000+online_offset-start_train_set), :]
        valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-40000:]
        train_labels = train_labels[:-(10000+online_offset-start_train_set)]
        
        # # debug setup
        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-45000:, :]
        # train_x = train_x[:-55000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-45000:]
        # train_labels = train_labels[:-55000]

        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-10000:, :]
        # train_x = train_x[:-20000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-10000:]
        # train_labels = train_labels[:-20000]

        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-20000:, :]
        # train_x = train_x[:-30000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-20000:]
        # train_labels = train_labels[:-30000]

        # pick the selected classes
        if classes is not None:
            train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
            # online_x = online_x[np.any(np.stack([online_labels == c for c in classes], 1), 1), :]
            valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
            test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

            train_labels = [l for l in train_labels if l in classes]
            # online_labels = [l for l in online_labels if l in classes]
            valid_labels = [l for l in valid_labels if l in classes]
            test_labels = [l for l in test_labels if l in classes]
        else:
            classes = np.unique(train_labels).tolist()

            train_labels = [l for l in train_labels if l in classes]
            # online_labels = [l for l in online_labels if l in classes]
            valid_labels = [l for l in valid_labels if l in classes]
            test_labels = [l for l in test_labels if l in classes]

        train_x = torch.from_numpy(train_x).to(torch.device(device))
        # online_x = torch.from_numpy(online_x).to(torch.device(device))
        valid_x = torch.from_numpy(valid_x).to(torch.device(device))
        test_x = torch.from_numpy(test_x).to(torch.device(device))

        ######################################
        # Make EinsumNetworks for each class #
        ######################################
        einets = []
        ps = []
        for c in classes:
            if structure == 'poon-domingos':
                pd_delta = [[height / d, width / d] for d in pd_num_pieces]
                graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
            elif structure == 'binary-trees':
                graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
            else:
                raise AssertionError("Unknown Structure")

            args = EinsumNetwork.Args(
                    num_var=train_x.shape[1],
                    num_dims=3 if svhn else 1,
                    num_classes=1,
                    num_sums=K,
                    num_input_distributions=K,
                    exponential_family=exponential_family,
                    exponential_family_args=exponential_family_args,
                    online_em_frequency=online_em_frequency,
                    online_em_stepsize=online_em_stepsize)

            einet = EinsumNetwork.EinsumNetwork(graph, args)

            init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c)
            einet.initialize(init_dict)
            einet.to(device)
            einets.append(einet)

            # Calculate amount of training samples per class
            ps.append(train_labels.count(c))

            print(f'Einsum network for class {c}:')
            print(einet)

        # normalize ps, construct mixture component
        ps = [p / sum(ps) for p in ps]
        ps = torch.tensor(ps).to(torch.device(device))
        mixture = EinetMixture(ps, einets, classes=classes)

        num_params = mixture.eval_size()

        ##################################
        # Evalueate after initialization #
        ##################################

        train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        mixture.eval()
        train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size)
        valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size)
        test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size)
        print()
        print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
                train_ll_before / train_N,
                valid_ll_before / valid_N,
                test_ll_before / test_N))
        train_lls.append(train_ll_before / train_N)
        valid_lls.append(valid_ll_before / valid_N)
        test_lls.append(test_ll_before / test_N)

        ################
        # Experiment 4 #
        ################
        train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size)
        acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size)
        print()
        print("Experiment 8: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
                acc_train_before,
                acc_valid_before,
                acc_test_before))
        train_accs.append(acc_train_before)
        valid_accs.append(acc_valid_before)
        test_accs.append(acc_test_before)
Exemple #4
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions_mixture = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    online_x = train_x[-11000:-10000, :]
    train_x = train_x[-56000:-11000, :]
    # init_x = train_x[-13000:-10000, :]

    valid_labels = train_labels[-10000:]
    online_labels = train_labels[-11000:-10000]
    train_labels = train_labels[-56000:-11000]
    # init_labels = train_labels[-13000:-10000]

    # full set of training
    # valid_x = train_x[-10000:, :]
    # online_x = train_x[-11000:-10000, :]
    # train_x = train_x[:-10000, :]

    # valid_labels = train_labels[-10000:]
    # online_labels = train_labels[-11000:-10000]
    # train_labels = train_labels[:-10000]

    # print('train_x:')
    # print(train_x.shape)
    # print(train_labels.shape)
    # print('online_x:')
    # print(online_x.shape)
    # print(online_labels.shape)
    # exit()

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        online_x = online_x[
            np.any(np.stack([online_labels == c for c in classes], 1), 1), :]
        # init_x = init_x[np.any(np.stack([init_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        train_labels_backup = train_labels
        online_labels = [l for l in online_labels if l in classes]
        # init_labels = [l for l in init_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        train_labels_backup = train_labels
        online_labels = [l for l in online_labels if l in classes]
        # init_labels = [l for l in init_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    train_x_backup = train_x
    online_x = torch.from_numpy(online_x).to(torch.device(device))
    # init_x = torch.from_numpy(init_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(
                num_var=train_x.shape[1],
                depth=depth,
                num_repetitions=num_repetitions_mixture)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=3 if svhn else 1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        # init_dict = get_init_dict(einet, init_x, train_labels=init_labels, einet_class=c)
        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # for (einet, c) in zip(einets, classes):
    #     data_file = os.path.join(data_dir, f"weights_before_{c}.json")
    #     weights = einet.einet_layers[-1].params.data.cpu()
    #     np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0])

    ##################
    # Training phase #
    ##################

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)

    start_time = time.time()
    """ Learning each sub Network Generatively """
    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]

        train_N = train_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()

            print(f'[{epoch_count}]   total log-likelihood: {total_loss}')

    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # for (einet, c) in zip(einets, classes):
    #     data_file = os.path.join(data_dir, f"weights_after_{c}.json")
    #     weights = einet.einet_layers[-1].params.data.cpu()
    #     np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0])
    # exit()

    ##################################
    # Evalueate after initialization #
    ##################################

    train_lls = []
    valid_lls = []
    test_lls = []
    train_accs = []
    valid_accs = []
    test_accs = []
    train_lls_ref = []
    valid_lls_ref = []
    test_lls_ref = []
    train_accs_ref = []
    valid_accs_ref = []
    test_accs_ref = []
    added_samples = [0]

    def eval_network(do_print=False, no_OA=False):
        if no_OA:
            train_N = train_x_backup.shape[0]
        else:
            train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        mixture.eval()
        if no_OA:
            train_ll_before = mixture.eval_loglikelihood_batched(
                train_x_backup, batch_size=batch_size)
        else:
            train_ll_before = mixture.eval_loglikelihood_batched(
                train_x, batch_size=batch_size)
        valid_ll_before = mixture.eval_loglikelihood_batched(
            valid_x, batch_size=batch_size)
        test_ll_before = mixture.eval_loglikelihood_batched(
            test_x, batch_size=batch_size)
        if do_print:
            print()
            print(
                "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
                .format(train_ll_before / train_N, valid_ll_before / valid_N,
                        test_ll_before / test_N))
        if no_OA:
            train_lls_ref.append(train_ll_before / train_N)
            valid_lls_ref.append(valid_ll_before / valid_N)
            test_lls_ref.append(test_ll_before / test_N)
        else:
            train_lls.append(train_ll_before / train_N)
            valid_lls.append(valid_ll_before / valid_N)
            test_lls.append(test_ll_before / test_N)

        ################
        # Experiment 4 #
        ################
        if no_OA:
            train_labelsz = torch.tensor(train_labels_backup).to(
                torch.device(device))
        else:
            train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        if no_OA:
            acc_train_before = mixture.eval_accuracy_batched(
                classes, train_x_backup, train_labelsz, batch_size=batch_size)
        else:
            acc_train_before = mixture.eval_accuracy_batched(
                classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid_before = mixture.eval_accuracy_batched(classes,
                                                         valid_x,
                                                         valid_labelsz,
                                                         batch_size=batch_size)
        acc_test_before = mixture.eval_accuracy_batched(classes,
                                                        test_x,
                                                        test_labelsz,
                                                        batch_size=batch_size)
        if do_print:
            print()
            print(
                "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
                .format(acc_train_before, acc_valid_before, acc_test_before))
        if no_OA:
            train_accs_ref.append(acc_train_before)
            valid_accs_ref.append(acc_valid_before)
            test_accs_ref.append(acc_test_before)
        else:
            train_accs.append(acc_train_before)
            valid_accs.append(acc_valid_before)
            test_accs.append(acc_test_before)
        mixture.train()

    eval_network(do_print=True, no_OA=False)
    eval_network(do_print=False, no_OA=True)

    #####################################################
    # Evaluate the network with different training sets #
    #####################################################

    idx_batches = torch.randperm(online_x.shape[0], device=device).split(20)

    for idx in tqdm(idx_batches):
        online_x_idx = online_x[idx]
        online_labels_idx = [online_labels[i] for i in idx]

        for (einet, c) in zip(einets, classes):
            batch_x = online_x_idx[[l == c for l in online_labels_idx]]
            train_x_backup = torch.cat((train_x_backup, batch_x))
            train_labels_backup += [c for i in batch_x]

        added_samples.append(added_samples[-1] + len(idx))
        eval_network(do_print=False, no_OA=True)

    #####################
    # Online adaptation #
    #####################

    for idx in tqdm(idx_batches):
        online_x_idx = online_x[idx]
        online_labels_idx = [online_labels[i] for i in idx]

        for (einet, c) in zip(einets, classes):
            batch_x = online_x_idx[[l == c for l in online_labels_idx]]
            online_update(einet, batch_x)
            train_x = torch.cat((train_x, batch_x))
            train_labels += [c for i in batch_x]

        eval_network(do_print=False, no_OA=False)

    print()
    print(f'Network size: {num_params} parameters')

    return {
        'train_lls': train_lls,
        'valid_lls': valid_lls,
        'test_lls': test_lls,
        'train_accs': train_accs,
        'valid_accs': valid_accs,
        'test_accs': test_accs,
        'train_lls_ref': train_lls_ref,
        'valid_lls_ref': valid_lls_ref,
        'test_lls_ref': test_lls_ref,
        'train_accs_ref': train_accs_ref,
        'valid_accs_ref': valid_accs_ref,
        'test_accs_ref': test_accs_ref,
        'network_size': num_params,
        'online_samples': added_samples,
    }
Exemple #5
0
    def __init__(self, conf):
        self.conf = conf

        # determine and create result dir
        i = 1
        log_path = conf.result_path + 'run0'
        while os.path.exists(log_path):
            log_path = '{}run{}'.format(conf.result_path, i)
            i += 1
        os.makedirs(log_path)
        self.log_path = log_path

        if not os.path.exists(conf.checkpoint_dir):
            os.makedirs(conf.checkpoint_dir)

        self.checkpoint_file = os.path.join(self.conf.checkpoint_dir,
                                            "model.ckpt")
        input_shape = [
            conf.batch_size, conf.scene_width, conf.scene_height, conf.channels
        ]
        # build model
        with tf.device(conf.device):
            self.mdl = model.Supair(conf)
            self.in_ph = tf.placeholder(tf.float32, input_shape)
            self.elbo = self.mdl.elbo(self.in_ph)

            self.mdl.num_parameters()

            self.optimizer = tf.train.AdamOptimizer()
            self.train_op = self.optimizer.minimize(-1 * self.elbo)

        self.sess = tf.Session()

        self.saver = tf.train.Saver()
        if self.conf.load_params:
            self.saver.restore(self.sess, self.checkpoint_file)
        else:
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.local_variables_initializer())

        # load data
        bboxes = None
        if conf.dataset == 'MNIST':
            (x, counts, y,
             bboxes), (x_test, c_test, _,
                       _) = datasets.load_mnist(conf.scene_width,
                                                max_digits=2,
                                                path=conf.data_path)
            visualize.store_images(x[0:10], log_path + '/img_raw')
            if conf.noise:
                x = datasets.add_noise(x)
                x_test = datasets.add_noise(x_test)
                visualize.store_images(x[0:10], log_path + '/img_noisy')
            if conf.structured_noise:
                x = datasets.add_structured_noise(x)
                x_test = datasets.add_structured_noise(x_test)
                visualize.store_images(x[0:10], log_path + '/img_struc_noisy')
            x_color = np.squeeze(x)
        elif conf.dataset == 'sprites':
            (x_color, counts,
             _), (x_test, c_test,
                  _) = datasets.make_sprites(50000, path=conf.data_path)
            if conf.noise:
                x_color = datasets.add_noise(x_color)
            x = visualize.rgb2gray(x_color)
            x = np.clip(x, 0.0, 1.0)
            x_test = visualize.rgb2gray(x_test)
            x_test = np.clip(x_test, 0.0, 1.0)
            if conf.noise:
                x = datasets.add_noise(x)
                x_test = datasets.add_noise(x_test)
                x_color = datasets.add_noise(x_color)
        elif conf.dataset == 'omniglot':
            x = 1 - datasets.load_omniglot(path=conf.data_path)
            counts = np.ones(x.shape[0], dtype=np.int32)
            x_color = np.squeeze(x)
        elif conf.dataset == 'svhn':
            x, counts, objects, bgs = datasets.load_svhn(path=conf.data_path)
            self.pretrain(x, objects, bgs)
            x_color = np.squeeze(x)
        else:
            raise ValueError('unknown dataset', conf.dataset)

        self.x, self.x_color, self.counts = x, x_color, counts
        self.x_test, self.c_test = x_test, c_test
        self.bboxes = bboxes

        print('Built model')
        self.obj_reconstructor = SpnReconstructor(self.mdl.obj_spn)
        self.bg_reconstructor = SpnReconstructor(self.mdl.bg_spn)

        tfgraph = tf.get_default_graph()
        self.tensors_of_interest = {
            'z_where': tfgraph.get_tensor_by_name('z_where:0'),
            'z_pres': tfgraph.get_tensor_by_name('z_pres:0'),
            'bg_score': tfgraph.get_tensor_by_name('bg_score:0'),
            'y': tfgraph.get_tensor_by_name('y:0'),
            'obj_vis': tfgraph.get_tensor_by_name('obj_vis:0'),
            'bg_maps': tfgraph.get_tensor_by_name('bg_maps:0')
        }
Exemple #6
0
def run_experiment(settings):
    ############################################################################

    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    online_em_frequency = settings.online_em_frequency
    online_em_stepsize = settings.online_em_stepsize
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    # Make EinsumNetwork
    ######################################
    if structure == 'poon-domingos':
        pd_delta = [[height / d, width / d] for d in pd_num_pieces]
        graph = Graph.poon_domingos_structure(shape=(height, width),
                                              delta=pd_delta)
    elif structure == 'binary-trees':
        graph = Graph.random_binary_trees(num_var=train_x.shape[1],
                                          depth=depth,
                                          num_repetitions=num_repetitions)
    else:
        raise AssertionError("Unknown Structure")

    args = EinsumNetwork.Args(num_var=train_x.shape[1],
                              num_dims=3 if svhn else 1,
                              num_classes=1,
                              num_sums=K,
                              num_input_distributions=K,
                              exponential_family=exponential_family,
                              exponential_family_args=exponential_family_args,
                              online_em_frequency=online_em_frequency,
                              online_em_stepsize=online_em_stepsize,
                              use_em=True)

    einet = EinsumNetwork.EinsumNetwork(graph, args)
    print(einet)

    init_dict = get_init_dict(einet, train_x)
    einet.initialize(init_dict)
    einet.to(device)

    num_params = EinsumNetwork.eval_size(einet)

    data_dir = '../src/experiments/round5/data/weights_analysis/'
    data_file = os.path.join(data_dir, f"weights_before.json")
    weights = einet.einet_layers[-1].params.data.cpu()
    np.savetxt(data_file, weights[0])

    # Train
    ######################################

    optimizer = torch.optim.SGD(einet.parameters(), lr=SGD_learning_rate)

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]

    start_time = time.time()

    for epoch_count in range(num_epochs):
        idx_batches = torch.randperm(train_N, device=device).split(batch_size)

        total_loss = 0.0
        for idx in idx_batches:
            batch_x = train_x[idx, :]
            # optimizer.zero_grad()
            outputs = einet.forward(batch_x)
            ll_sample = EinsumNetwork.log_likelihoods(outputs)
            log_likelihood = ll_sample.sum()
            log_likelihood.backward()
            # nll = log_likelihood * -1
            # nll.backward()
            # optimizer.step()

            einet.em_process_batch()
        einet.em_update()

        print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()

    data_dir = '../src/experiments/round5/data/weights_analysis/'
    data_file = os.path.join(data_dir, f"weights_after.json")
    weights = einet.einet_layers[-1].params.data.cpu()
    np.savetxt(data_file, weights[0])
    # exit()

    ################
    # Experiment 1 #
    ################
    einet.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        train_x,
                                                        batch_size=batch_size)
    valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        valid_x,
                                                        batch_size=batch_size)
    test_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                       test_x,
                                                       batch_size=batch_size)
    print()
    print(
        "Experiment 1: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 2 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    train_x,
                                                    train_labels,
                                                    batch_size=batch_size)
    acc_valid = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    valid_x,
                                                    valid_labels,
                                                    batch_size=batch_size)
    acc_test = EinsumNetwork.eval_accuracy_batched(einet,
                                                   classes,
                                                   test_x,
                                                   test_labels,
                                                   batch_size=batch_size)
    print()
    print(
        "Experiment 2: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]
    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    if structure == 'poon-domingos':
        pd_delta = [[height / d, width / d] for d in pd_num_pieces]
        graph = Graph.poon_domingos_structure(shape=(height, width),
                                              delta=pd_delta)
    elif structure == 'binary-trees':
        graph = Graph.random_binary_trees(num_var=train_x.shape[1],
                                          depth=depth,
                                          num_repetitions=num_repetitions)
    else:
        raise AssertionError("Unknown Structure")

    args = EinsumNetwork.Args(num_var=train_x.shape[1],
                              num_dims=3 if svhn else 1,
                              num_classes=len(classes),
                              num_sums=K,
                              num_input_distributions=K,
                              exponential_family=exponential_family,
                              exponential_family_args=exponential_family_args,
                              use_em=False)

    einet = EinsumNetwork.EinsumNetwork(graph, args)

    init_dict = get_init_dict(einet, train_x)
    einet.initialize(init_dict)
    einet.to(device)
    print(einet)

    num_params = EinsumNetwork.eval_size(einet)

    #################################
    # Discriminative training phase #
    #################################

    optimizer = torch.optim.SGD(einet.parameters(), lr=SGD_learning_rate)
    loss_function = torch.nn.CrossEntropyLoss()

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]

    start_time = time.time()

    for epoch_count in range(num_epochs):
        idx_batches = torch.randperm(train_N, device=device).split(batch_size)

        total_loss = 0
        for idx in idx_batches:
            batch_x = train_x[idx, :]
            optimizer.zero_grad()
            outputs = einet.forward(batch_x)
            target = torch.tensor([
                classes.index(train_labels[i]) for i in idx
            ]).to(torch.device(device))
            loss = loss_function(outputs, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.detach().item()

        print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()

    ################
    # Experiment 5 #
    ################
    einet.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        train_x,
                                                        batch_size=batch_size)
    valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        valid_x,
                                                        batch_size=batch_size)
    test_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                       test_x,
                                                       batch_size=batch_size)
    print()
    print(
        "Experiment 5: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 6 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    train_x,
                                                    train_labels,
                                                    batch_size=batch_size)
    acc_valid = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    valid_x,
                                                    valid_labels,
                                                    batch_size=batch_size)
    acc_test = EinsumNetwork.eval_accuracy_batched(einet,
                                                   classes,
                                                   test_x,
                                                   test_labels,
                                                   batch_size=batch_size)
    print()
    print(
        "Experiment 6: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
num_sums = 40

exponential_family = EinsumNetwork.NormalArray
exponential_family_args = {'min_var': 1e-6, 'max_var': 0.01}

num_epochs = 3
batch_size = 10
online_em_frequency = 50
online_em_stepsize = 0.5

height = 32
width = 32
##################################################################

print("loading data")
train_x_all, train_labels, test_x_all, test_labels, extra_x, extra_labels = datasets.load_svhn(
)

valid_x_all = train_x_all[50000:, ...]
train_x_all = np.concatenate((train_x_all[0:50000, ...], extra_x), 0)

train_x_all = train_x_all.reshape(train_x_all.shape[0], height, width, 3)
valid_x_all = valid_x_all.reshape(valid_x_all.shape[0], height, width, 3)
test_x_all = test_x_all.reshape(test_x_all.shape[0], height, width, 3)
print("done")


def get_clusters(train_x, num_clusters=100):
    cluster_path = "../auxiliary/svhn"
    filename = os.path.join(cluster_path, "kmeans_{}.pkl".format(num_clusters))

    if not os.path.isfile(filename):
    # prepare data for semi-supervised learning
    x_labelled, y_labelled, x_unlabelled, _ = create_ssl_data(
        train_x, train_t, num_classes, num_labelled, ssl_data_seed)
    y_labelled = np.int32(y_labelled)
elif dataset == 'svhn':
    colorImg = True
    dim_input = (32, 32)
    in_channels = 3
    num_classes = 10
    generation_scale = False
    num_generation = num_classes * num_classes
    vis_epoch = 10
    distribution = 'bernoulli'
    num_features = in_channels * dim_input[0] * dim_input[1]
    print "Using svhn dataset"
    train_x, train_t, valid_x, valid_t, test_x, test_t, avg = load_svhn(
        normalized=True, centered=False)
    if flag == 'validation':
        test_x = valid_x
        test_t = valid_t
    else:
        train_x = np.concatenate([train_x, valid_x])
        train_t = np.hstack((train_t, valid_t))
    train_x_size = train_t.shape[0]
    train_t = np.int32(train_t)
    test_t = np.int32(test_t)
    train_x = train_x.astype(theano.config.floatX)
    test_x = test_x.astype(theano.config.floatX)
    train_x = train_x.reshape((-1, in_channels) + dim_input)
    test_x = test_x.reshape((-1, in_channels) + dim_input)
    # prepare data for semi-supervised learning
    x_labelled, y_labelled, x_unlabelled, _ = create_ssl_data(
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions_mixture = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    online_em_frequency = settings.online_em_frequency
    online_em_stepsize = settings.online_em_stepsize
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]

    # valid_x = train_x[-10000:, :]
    # train_x = train_x[-12000:-11000, :]

    # valid_labels = train_labels[-10000:]
    # train_labels = train_labels[-12000:-11000]

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(
                num_var=train_x.shape[1],
                depth=depth,
                num_repetitions=num_repetitions_mixture)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=3 if svhn else 1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            online_em_frequency=online_em_frequency,
            online_em_stepsize=online_em_stepsize,
            use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()
    """Code for weight analysis, section 7.3"""
    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # data_file = os.path.join(data_dir, f"weights_before_{c}.json")
    # weights = {}
    # for (einet, c) in zip(einets, classes):
    #     einet_weights = {}
    #     for i, l in enumerate(reversed(einet.einet_layers)):
    #         if type(l) != FactorizedLeafLayer.FactorizedLeafLayer:
    #             einet_weights[i] = l.reparam(l.params.data)
    #         else:
    #             einet_weights[i] = l.ef_array.reparam(l.ef_array.params.data)
    #     weights[c] = einet_weights

    ##################
    # Training phase #
    ##################

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)
    """ Learning each sub Network Generatively """

    start_time = time.time()

    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]
        valid_x_c = valid_x[[l == c for l in valid_labels]]
        test_x_c = test_x[[l == c for l in test_labels]]

        train_N = train_x_c.shape[0]
        valid_N = valid_x_c.shape[0]
        test_N = test_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                # log_likelihood.backward()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()
            #     einet.em_process_batch()
            # einet.em_update()

            print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()
    """Code for weight analysis"""
    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # percentual_change = []
    # percentual_change_formatted = []
    # data_file = os.path.join(data_dir, f"percentual_change.json")
    # for (einet, c) in zip(einets, classes):
    #     einet_change = {}
    #     for i, l in enumerate(reversed(einet.einet_layers)):
    #         if type(l) != FactorizedLeafLayer.FactorizedLeafLayer:
    #             weights_new = l.reparam(l.params.data)
    #         else:
    #             weights_new = l.ef_array.reparam(l.ef_array.params.data)
    #         change = torch.mean(torch.abs(weights[c][i] - weights_new)/weights[c][i]).item()
    #         einet_change[i] = change
    #         if i == 0:
    #             percentual_change_formatted.append(change)

    #     percentual_change.append(einet_change)

    # print(percentual_change)
    # with open(data_file, 'w') as f:
    #     json.dump(percentual_change, f)
    # data_file = os.path.join(data_dir, f"percentual_change_formatted.json")
    # with open(data_file, 'w') as f:
    #     json.dump(percentual_change_formatted, f)

    ################
    # Experiment 3 #
    ################
    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]
    mixture.eval()
    train_ll = mixture.eval_loglikelihood_batched(train_x,
                                                  batch_size=batch_size,
                                                  skip_reparam=True)
    valid_ll = mixture.eval_loglikelihood_batched(valid_x,
                                                  batch_size=batch_size,
                                                  skip_reparam=True)
    test_ll = mixture.eval_loglikelihood_batched(test_x,
                                                 batch_size=batch_size,
                                                 skip_reparam=True)
    print()
    print(
        "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 4 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = mixture.eval_accuracy_batched(classes,
                                              train_x,
                                              train_labels,
                                              batch_size=batch_size,
                                              skip_reparam=True)
    acc_valid = mixture.eval_accuracy_batched(classes,
                                              valid_x,
                                              valid_labels,
                                              batch_size=batch_size,
                                              skip_reparam=True)
    acc_test = mixture.eval_accuracy_batched(classes,
                                             test_x,
                                             test_labels,
                                             batch_size=batch_size,
                                             skip_reparam=True)
    print()
    print(
        "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }