def save_einet(einet, graph, out_path):
    # save model
    os.makedirs(out_path, exist_ok=True)
    graph_file = os.path.join(out_path, "einet.rg")
    Graph.write_gpickle(graph, graph_file)
    print("Saved PC graph to {}".format(graph_file))
    model_file = os.path.join(out_path, "einet.pth")
    torch.save(einet, model_file)
    print("Saved model to {}".format(model_file))
def load_einet_state(model_path,
                     einet_file='einet.pth',
                     graph_file='einet.rg',
                     n_vars=None, n_classes=None, n_sums=None, n_input_dists=None,
                     exp_fam=None, exp_fam_args=None,
                     use_em=None, em_freq=None, em_stepsize=None,
                     graph=None):

    # reload model
    einet = None

    if graph is None:
        if graph_file:
            graph_file = os.path.join(model_path, graph_file)
            graph = Graph.read_gpickle(graph_file)
        else:
            raise ValueError(f"Cannot create graph")
    
    model_file = os.path.join(model_path, einet_file)
    
    einet = make_einet(graph,
               n_vars=n_vars,
               n_classes=n_classes,
               n_sums=n_sums,
               n_input_dists=n_input_dists,
               exp_fam=exp_fam, exp_fam_args=exp_fam_args,
               use_em=use_em,
               em_freq=em_freq, em_stepsize=em_stepsize)
    einet.load_state_dict(torch.load(model_file))
    
    print("Loaded model from {}".format(model_file))
        
    return einet, graph
def make_region_graph(structure="poon-domingos",
                      height=None, width=None,
                      pd_pieces=None, depth=None, n_repetitions=None):

    
    graph = None
    if structure == 'poon-domingos':
        assert pd_pieces is not None
        pd_delta = [[height / d, width / d] for d in pd_pieces]
        graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
    elif structure == 'binary-trees':
        n_vars = height * width
        graph = Graph.random_binary_trees(num_var=n_vars,
                                          depth=depth,
                                          num_repetitions=n_repetitions)
    else:
        raise AssertionError("Unknown Structure")

    return graph
def load_einet(model_path, einet_file='einet.pth', graph_file='einet.rg'):

    # reload model
    einet, graph = None, None
    model_file = os.path.join(model_path, einet_file)
    einet = torch.load(model_file)
    print("Loaded model from {}".format(model_file))
    
    if graph_file:
        graph_file = os.path.join(model_path, graph_file)
        graph = Graph.read_gpickle(graph_file)
        
    return einet, graph
    def __init__(self, graph, args=None):
        """Make an EinsumNetwork."""
        super(EinsumNetwork, self).__init__()

        check_flag, check_msg = Graph.check_graph(graph)
        if not check_flag:
            raise AssertionError(check_msg)
        self.graph = graph

        self.args = args if args is not None else Args()

        if len(Graph.get_roots(self.graph)) != 1:
            raise AssertionError("Currently only EinNets with single root node supported.")

        root = Graph.get_roots(self.graph)[0]
        if tuple(range(self.args.num_var)) != root.scope:
            raise AssertionError("The graph should be over tuple(range(num_var)).")

        for node in Graph.get_leaves(self.graph):
            node.num_dist = self.args.num_input_distributions

        for node in Graph.get_sums(self.graph):
            if node is root:
                node.num_dist = self.args.num_classes
            else:
                node.num_dist = self.args.num_sums

        # Algorithm 1 in the paper -- organize the PC in layers
        self.graph_layers = Graph.topological_layers(self.graph)

        # input layer
        einet_layers = [FactorizedLeafLayer(self.graph_layers[0],
                                            self.args.num_var,
                                            self.args.num_dims,
                                            self.args.exponential_family,
                                            self.args.exponential_family_args,
                                            use_em=self.args.use_em)]

        # internal layers
        for c, layer in enumerate(self.graph_layers[1:]):
            if c % 2 == 0:   # product layer
                einet_layers.append(EinsumLayer(self.graph, layer, einet_layers, use_em=self.args.use_em))
            else:     # sum layer
                # the Mixing layer is only for regions which have multiple partitions as children.
                multi_sums = [n for n in layer if len(graph.succ[n]) > 1]
                if multi_sums:
                    einet_layers.append(EinsumMixingLayer(graph, multi_sums, einet_layers[-1], use_em=self.args.use_em))

        self.einet_layers = torch.nn.ModuleList(einet_layers)
        self.em_set_hyperparams(self.args.online_em_frequency, self.args.online_em_stepsize)
Пример #6
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]
    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    if structure == 'poon-domingos':
        pd_delta = [[height / d, width / d] for d in pd_num_pieces]
        graph = Graph.poon_domingos_structure(shape=(height, width),
                                              delta=pd_delta)
    elif structure == 'binary-trees':
        graph = Graph.random_binary_trees(num_var=train_x.shape[1],
                                          depth=depth,
                                          num_repetitions=num_repetitions)
    else:
        raise AssertionError("Unknown Structure")

    args = EinsumNetwork.Args(num_var=train_x.shape[1],
                              num_dims=3 if svhn else 1,
                              num_classes=len(classes),
                              num_sums=K,
                              num_input_distributions=K,
                              exponential_family=exponential_family,
                              exponential_family_args=exponential_family_args,
                              use_em=False)

    einet = EinsumNetwork.EinsumNetwork(graph, args)

    init_dict = get_init_dict(einet, train_x)
    einet.initialize(init_dict)
    einet.to(device)
    print(einet)

    num_params = EinsumNetwork.eval_size(einet)

    #################################
    # Discriminative training phase #
    #################################

    optimizer = torch.optim.SGD(einet.parameters(), lr=SGD_learning_rate)
    loss_function = torch.nn.CrossEntropyLoss()

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]

    start_time = time.time()

    for epoch_count in range(num_epochs):
        idx_batches = torch.randperm(train_N, device=device).split(batch_size)

        total_loss = 0
        for idx in idx_batches:
            batch_x = train_x[idx, :]
            optimizer.zero_grad()
            outputs = einet.forward(batch_x)
            target = torch.tensor([
                classes.index(train_labels[i]) for i in idx
            ]).to(torch.device(device))
            loss = loss_function(outputs, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.detach().item()

        print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()

    ################
    # Experiment 5 #
    ################
    einet.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        train_x,
                                                        batch_size=batch_size)
    valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        valid_x,
                                                        batch_size=batch_size)
    test_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                       test_x,
                                                       batch_size=batch_size)
    print()
    print(
        "Experiment 5: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 6 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    train_x,
                                                    train_labels,
                                                    batch_size=batch_size)
    acc_valid = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    valid_x,
                                                    valid_labels,
                                                    batch_size=batch_size)
    acc_test = EinsumNetwork.eval_accuracy_batched(einet,
                                                   classes,
                                                   test_x,
                                                   test_labels,
                                                   batch_size=batch_size)
    print()
    print(
        "Experiment 6: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
Пример #7
0
    def construct_network_and_train(
            num_starts, svhn, exponential_family, classes, K, structure,
            pd_num_pieces, depth, num_repetitions, width, height, num_epochs,
            batch_size, online_em_frequency, online_em_stepsize,
            exponential_family_args, train_x, train_labels, test_x,
            test_labels, valid_x, valid_labels):
        train_lls = []
        valid_lls = []
        test_lls = []
        train_accs = []
        valid_accs = []
        test_accs = []
        training_times = []
        num_params = None

        for s in range(num_starts):
            print(f"""
                Running start: {s}/{num_starts}
                """)
            # train_lls.append(0)
            # valid_lls.append(0)
            # test_lls.append(0)
            # train_accs.append(0)
            # valid_accs.append(0)
            # test_accs.append(0)
            # training_times.append(0)
            # continue

            ######################################
            # Make EinsumNetworks for each class #
            ######################################
            einets = []
            ps = []
            for c in classes:
                if structure == 'poon-domingos':
                    pd_delta = [[height / d, width / d] for d in pd_num_pieces]
                    graph = Graph.poon_domingos_structure(shape=(height,
                                                                 width),
                                                          delta=pd_delta)
                elif structure == 'binary-trees':
                    graph = Graph.random_binary_trees(
                        num_var=train_x.shape[1],
                        depth=depth,
                        num_repetitions=num_repetitions)
                else:
                    raise AssertionError("Unknown Structure")

                args = EinsumNetwork.Args(
                    num_var=train_x.shape[1],
                    num_dims=3 if svhn else 1,
                    num_classes=1,
                    num_sums=K,
                    num_input_distributions=K,
                    exponential_family=exponential_family,
                    exponential_family_args=exponential_family_args,
                    online_em_frequency=online_em_frequency,
                    online_em_stepsize=online_em_stepsize)

                einet = EinsumNetwork.EinsumNetwork(graph, args)

                init_dict = get_init_dict(einet,
                                          train_x,
                                          train_labels=train_labels,
                                          einet_class=c)
                einet.initialize(init_dict)
                einet.to(device)
                einets.append(einet)

                # Calculate amount of training samples per class
                ps.append(train_labels.count(c))

                # print(f'Einsum network for class {c}:')
                # print(einet)

            # normalize ps, construct mixture component
            ps = [p / sum(ps) for p in ps]
            ps = torch.tensor(ps).to(torch.device(device))
            mixture = EinetMixture(ps, einets, classes=classes)

            if num_params == None:
                num_params = mixture.eval_size()

            ##################
            # Training phase #
            ##################
            """ Learning each sub Network Generatively """

            start_time = time.time()

            for (einet, c) in zip(einets, classes):
                train_x_c = train_x[[l == c for l in train_labels]]
                valid_x_c = valid_x[[l == c for l in valid_labels]]
                test_x_c = test_x[[l == c for l in test_labels]]

                train_N = train_x_c.shape[0]
                valid_N = valid_x_c.shape[0]
                test_N = test_x_c.shape[0]

                for epoch_count in range(num_epochs):
                    idx_batches = torch.randperm(
                        train_N, device=device).split(batch_size)

                    total_ll = 0.0
                    for idx in idx_batches:
                        batch_x = train_x_c[idx, :]
                        outputs = einet.forward(batch_x)
                        ll_sample = EinsumNetwork.log_likelihoods(outputs)
                        log_likelihood = ll_sample.sum()
                        log_likelihood.backward()

                        einet.em_process_batch()
                        total_ll += log_likelihood.detach().item()
                    einet.em_update()
                    print(
                        f'[{epoch_count}]   total log-likelihood: {total_ll}')

            end_time = time.time()

            ################
            # Experiment 3 #
            ################
            train_N = train_x.shape[0]
            valid_N = valid_x.shape[0]
            test_N = test_x.shape[0]
            mixture.eval()
            train_ll = mixture.eval_loglikelihood_batched(
                train_x, batch_size=batch_size)
            valid_ll = mixture.eval_loglikelihood_batched(
                valid_x, batch_size=batch_size)
            test_ll = mixture.eval_loglikelihood_batched(test_x,
                                                         batch_size=batch_size)
            print()
            print(
                "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
                .format(train_ll / train_N, valid_ll / valid_N,
                        test_ll / test_N))

            ################
            # Experiment 4 #
            ################
            train_labels_tensor = torch.tensor(train_labels).to(
                torch.device(device))
            valid_labels_tensor = torch.tensor(valid_labels).to(
                torch.device(device))
            test_labels_tensor = torch.tensor(test_labels).to(
                torch.device(device))

            acc_train = mixture.eval_accuracy_batched(classes,
                                                      train_x,
                                                      train_labels_tensor,
                                                      batch_size=batch_size)
            acc_valid = mixture.eval_accuracy_batched(classes,
                                                      valid_x,
                                                      valid_labels_tensor,
                                                      batch_size=batch_size)
            acc_test = mixture.eval_accuracy_batched(classes,
                                                     test_x,
                                                     test_labels_tensor,
                                                     batch_size=batch_size)
            print()
            print(
                "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
                .format(acc_train, acc_valid, acc_test))

            print()
            print(f'Network size: {num_params} parameters')
            print(f'Training time: {end_time - start_time}s')

            train_lls.append(train_ll / train_N)
            valid_lls.append(valid_ll / valid_N)
            test_lls.append(test_ll / test_N)
            train_accs.append(acc_train)
            valid_accs.append(acc_valid)
            test_accs.append(acc_test)
            training_times.append(end_time - start_time)

        return {
            'train_ll': train_lls,
            'valid_ll': valid_lls,
            'test_ll': test_lls,
            'train_acc': train_accs,
            'valid_acc': valid_accs,
            'test_acc': test_accs,
            'network_size': num_params,
            'training_time': training_times,
        }
Пример #8
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure =  settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist()
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn()
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]
    # pick the selected classes
    if classes is not None:
        train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
                num_var=train_x.shape[1],
                num_dims=3 if svhn else 1,
                num_classes=1,
                num_sums=K,
                num_input_distributions=K,
                exponential_family=exponential_family,
                exponential_family_args=exponential_family_args,
                use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    ##################################
    # Evalueate after initialization #
    ##################################

    train_lls = []
    valid_lls = []
    test_lls = []
    train_accs = []
    valid_accs = []
    test_accs = []

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]
    mixture.eval()
    train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size, skip_reparam=True)
    valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size, skip_reparam=True)
    test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size, skip_reparam=True)
    print()
    print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
            train_ll_before / train_N,
            valid_ll_before / valid_N,
            test_ll_before / test_N))
    train_lls.append(train_ll_before / train_N)
    valid_lls.append(valid_ll_before / valid_N)
    test_lls.append(test_ll_before / test_N)
    
    ################
    # Experiment 4 #
    ################
    train_labelsz = torch.tensor(train_labels).to(torch.device(device))
    valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
    test_labelsz = torch.tensor(test_labels).to(torch.device(device))

    acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size, skip_reparam=True)
    acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size, skip_reparam=True)
    acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size, skip_reparam=True)
    print()
    print("Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
            acc_train_before,
            acc_valid_before,
            acc_test_before))
    train_accs.append(acc_train_before)
    valid_accs.append(acc_valid_before)
    test_accs.append(acc_test_before)
    mixture.train()

    ##################
    # Training phase #
    ##################

    """ Learning each sub Network Generatively """

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)

    start_time = time.time()

    end_time = time.time()

    for epoch_count in range(num_epochs):
        for (einet, c) in zip(einets, classes):
            train_x_c = train_x[[l == c for l in train_labels]]
            valid_x_c = valid_x[[l == c for l in valid_labels]]
            test_x_c = test_x[[l == c for l in test_labels]]

            train_N = train_x_c.shape[0]
            valid_N = valid_x_c.shape[0]
            test_N = test_x_c.shape[0]

            idx_batches = torch.randperm(train_N, device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()
            print(f'[{epoch_count}]   total loss: {total_loss}')

        mixture.eval()
        train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        train_ll = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size)
        valid_ll = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size)
        test_ll = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size)
        train_lls.append(train_ll / train_N)
        valid_lls.append(valid_ll / valid_N)
        test_lls.append(test_ll / test_N)

        train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        acc_train = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size)
        acc_test = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size)
        train_accs.append(acc_train)
        valid_accs.append(acc_valid)
        test_accs.append(acc_test)
        mixture.train()

    print()
    print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
            train_ll / train_N,
            valid_ll / valid_N,
            test_ll / test_N))

    print()
    print("Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
            acc_train,
            acc_valid,
            acc_test))

    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_lls': train_lls,
        'valid_lls': valid_lls,
        'test_lls': test_lls,
        'train_accs': train_accs,
        'valid_accs': valid_accs,
        'test_accs': test_accs,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
Пример #9
0
    def new_start(start_train_set, online_offset):
        ############################################################################
        fashion_mnist = settings.fashion_mnist
        svhn = settings.svhn

        exponential_family = settings.exponential_family

        classes = settings.classes

        K = settings.K

        structure =  settings.structure

        # 'poon-domingos'
        pd_num_pieces = settings.pd_num_pieces

        # 'binary-trees'
        depth = settings.depth
        num_repetitions = settings.num_repetitions_mixture

        width = settings.width
        height = settings.height

        num_epochs = settings.num_epochs
        batch_size = settings.batch_size
        online_em_frequency = settings.online_em_frequency
        online_em_stepsize = settings.online_em_stepsize

        ############################################################################

        exponential_family_args = None
        if exponential_family == EinsumNetwork.BinomialArray:
            exponential_family_args = {'N': 255}
        if exponential_family == EinsumNetwork.CategoricalArray:
            exponential_family_args = {'K': 256}
        if exponential_family == EinsumNetwork.NormalArray:
            exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

        # get data
        if fashion_mnist:
            train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist()
        elif svhn:
            train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn()
        else:
            train_x, train_labels, test_x, test_labels = datasets.load_mnist()

        if not exponential_family != EinsumNetwork.NormalArray:
            train_x /= 255.
            test_x /= 255.
            train_x -= .5
            test_x -= .5

        # validation split
        valid_x = train_x[-10000:, :]
        # online_x = train_x[-40000:, :]
        train_x = train_x[:-(10000+online_offset-start_train_set), :]
        valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-40000:]
        train_labels = train_labels[:-(10000+online_offset-start_train_set)]
        
        # # debug setup
        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-45000:, :]
        # train_x = train_x[:-55000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-45000:]
        # train_labels = train_labels[:-55000]

        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-10000:, :]
        # train_x = train_x[:-20000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-10000:]
        # train_labels = train_labels[:-20000]

        # valid_x = train_x[-10000:, :]
        # online_x = train_x[-20000:, :]
        # train_x = train_x[:-30000, :]
        # valid_labels = train_labels[-10000:]
        # online_labels = train_labels[-20000:]
        # train_labels = train_labels[:-30000]

        # pick the selected classes
        if classes is not None:
            train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
            # online_x = online_x[np.any(np.stack([online_labels == c for c in classes], 1), 1), :]
            valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
            test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

            train_labels = [l for l in train_labels if l in classes]
            # online_labels = [l for l in online_labels if l in classes]
            valid_labels = [l for l in valid_labels if l in classes]
            test_labels = [l for l in test_labels if l in classes]
        else:
            classes = np.unique(train_labels).tolist()

            train_labels = [l for l in train_labels if l in classes]
            # online_labels = [l for l in online_labels if l in classes]
            valid_labels = [l for l in valid_labels if l in classes]
            test_labels = [l for l in test_labels if l in classes]

        train_x = torch.from_numpy(train_x).to(torch.device(device))
        # online_x = torch.from_numpy(online_x).to(torch.device(device))
        valid_x = torch.from_numpy(valid_x).to(torch.device(device))
        test_x = torch.from_numpy(test_x).to(torch.device(device))

        ######################################
        # Make EinsumNetworks for each class #
        ######################################
        einets = []
        ps = []
        for c in classes:
            if structure == 'poon-domingos':
                pd_delta = [[height / d, width / d] for d in pd_num_pieces]
                graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
            elif structure == 'binary-trees':
                graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
            else:
                raise AssertionError("Unknown Structure")

            args = EinsumNetwork.Args(
                    num_var=train_x.shape[1],
                    num_dims=3 if svhn else 1,
                    num_classes=1,
                    num_sums=K,
                    num_input_distributions=K,
                    exponential_family=exponential_family,
                    exponential_family_args=exponential_family_args,
                    online_em_frequency=online_em_frequency,
                    online_em_stepsize=online_em_stepsize)

            einet = EinsumNetwork.EinsumNetwork(graph, args)

            init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c)
            einet.initialize(init_dict)
            einet.to(device)
            einets.append(einet)

            # Calculate amount of training samples per class
            ps.append(train_labels.count(c))

            print(f'Einsum network for class {c}:')
            print(einet)

        # normalize ps, construct mixture component
        ps = [p / sum(ps) for p in ps]
        ps = torch.tensor(ps).to(torch.device(device))
        mixture = EinetMixture(ps, einets, classes=classes)

        num_params = mixture.eval_size()

        ##################################
        # Evalueate after initialization #
        ##################################

        train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        mixture.eval()
        train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size)
        valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size)
        test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size)
        print()
        print("Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}".format(
                train_ll_before / train_N,
                valid_ll_before / valid_N,
                test_ll_before / test_N))
        train_lls.append(train_ll_before / train_N)
        valid_lls.append(valid_ll_before / valid_N)
        test_lls.append(test_ll_before / test_N)

        ################
        # Experiment 4 #
        ################
        train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size)
        acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size)
        print()
        print("Experiment 8: Classification accuracies  --- train acc {}   valid acc {}   test acc {}".format(
                acc_train_before,
                acc_valid_before,
                acc_test_before))
        train_accs.append(acc_train_before)
        valid_accs.append(acc_valid_before)
        test_accs.append(acc_test_before)
Пример #10
0
    train_x = train_x[np.any(np.stack([train_labels == c
                                       for c in classes], 1), 1), :]
    valid_x = valid_x[np.any(np.stack([valid_labels == c
                                       for c in classes], 1), 1), :]
    test_x = test_x[np.any(np.stack([test_labels == c
                                     for c in classes], 1), 1), :]

train_x = torch.from_numpy(train_x).to(torch.device(device))
valid_x = torch.from_numpy(valid_x).to(torch.device(device))
test_x = torch.from_numpy(test_x).to(torch.device(device))

# Make EinsumNetwork
######################################
if structure == 'poon-domingos':
    pd_delta = [[height / d, width / d] for d in pd_num_pieces]
    graph = Graph.poon_domingos_structure(shape=(height, width),
                                          delta=pd_delta)
elif structure == 'binary-trees':
    #graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
    graph = Graph.random_binary_trees(num_var=num_var,
                                      depth=depth,
                                      num_repetitions=num_repetitions)
else:
    raise AssertionError("Unknown Structure")

args = EinsumNetwork.Args(
    num_var=num_var,  #train_x.shape[1],
    num_dims=2 if use_pair else 1,
    num_classes=1,
    num_sums=K,
    num_input_distributions=K,
    exponential_family=exponential_family,
Пример #11
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions_mixture = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    online_x = train_x[-11000:-10000, :]
    train_x = train_x[-56000:-11000, :]
    # init_x = train_x[-13000:-10000, :]

    valid_labels = train_labels[-10000:]
    online_labels = train_labels[-11000:-10000]
    train_labels = train_labels[-56000:-11000]
    # init_labels = train_labels[-13000:-10000]

    # full set of training
    # valid_x = train_x[-10000:, :]
    # online_x = train_x[-11000:-10000, :]
    # train_x = train_x[:-10000, :]

    # valid_labels = train_labels[-10000:]
    # online_labels = train_labels[-11000:-10000]
    # train_labels = train_labels[:-10000]

    # print('train_x:')
    # print(train_x.shape)
    # print(train_labels.shape)
    # print('online_x:')
    # print(online_x.shape)
    # print(online_labels.shape)
    # exit()

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        online_x = online_x[
            np.any(np.stack([online_labels == c for c in classes], 1), 1), :]
        # init_x = init_x[np.any(np.stack([init_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        train_labels_backup = train_labels
        online_labels = [l for l in online_labels if l in classes]
        # init_labels = [l for l in init_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        train_labels_backup = train_labels
        online_labels = [l for l in online_labels if l in classes]
        # init_labels = [l for l in init_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    train_x_backup = train_x
    online_x = torch.from_numpy(online_x).to(torch.device(device))
    # init_x = torch.from_numpy(init_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(
                num_var=train_x.shape[1],
                depth=depth,
                num_repetitions=num_repetitions_mixture)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=3 if svhn else 1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        # init_dict = get_init_dict(einet, init_x, train_labels=init_labels, einet_class=c)
        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # for (einet, c) in zip(einets, classes):
    #     data_file = os.path.join(data_dir, f"weights_before_{c}.json")
    #     weights = einet.einet_layers[-1].params.data.cpu()
    #     np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0])

    ##################
    # Training phase #
    ##################

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)

    start_time = time.time()
    """ Learning each sub Network Generatively """
    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]

        train_N = train_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()

            print(f'[{epoch_count}]   total log-likelihood: {total_loss}')

    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # for (einet, c) in zip(einets, classes):
    #     data_file = os.path.join(data_dir, f"weights_after_{c}.json")
    #     weights = einet.einet_layers[-1].params.data.cpu()
    #     np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0])
    # exit()

    ##################################
    # Evalueate after initialization #
    ##################################

    train_lls = []
    valid_lls = []
    test_lls = []
    train_accs = []
    valid_accs = []
    test_accs = []
    train_lls_ref = []
    valid_lls_ref = []
    test_lls_ref = []
    train_accs_ref = []
    valid_accs_ref = []
    test_accs_ref = []
    added_samples = [0]

    def eval_network(do_print=False, no_OA=False):
        if no_OA:
            train_N = train_x_backup.shape[0]
        else:
            train_N = train_x.shape[0]
        valid_N = valid_x.shape[0]
        test_N = test_x.shape[0]
        mixture.eval()
        if no_OA:
            train_ll_before = mixture.eval_loglikelihood_batched(
                train_x_backup, batch_size=batch_size)
        else:
            train_ll_before = mixture.eval_loglikelihood_batched(
                train_x, batch_size=batch_size)
        valid_ll_before = mixture.eval_loglikelihood_batched(
            valid_x, batch_size=batch_size)
        test_ll_before = mixture.eval_loglikelihood_batched(
            test_x, batch_size=batch_size)
        if do_print:
            print()
            print(
                "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
                .format(train_ll_before / train_N, valid_ll_before / valid_N,
                        test_ll_before / test_N))
        if no_OA:
            train_lls_ref.append(train_ll_before / train_N)
            valid_lls_ref.append(valid_ll_before / valid_N)
            test_lls_ref.append(test_ll_before / test_N)
        else:
            train_lls.append(train_ll_before / train_N)
            valid_lls.append(valid_ll_before / valid_N)
            test_lls.append(test_ll_before / test_N)

        ################
        # Experiment 4 #
        ################
        if no_OA:
            train_labelsz = torch.tensor(train_labels_backup).to(
                torch.device(device))
        else:
            train_labelsz = torch.tensor(train_labels).to(torch.device(device))
        valid_labelsz = torch.tensor(valid_labels).to(torch.device(device))
        test_labelsz = torch.tensor(test_labels).to(torch.device(device))

        if no_OA:
            acc_train_before = mixture.eval_accuracy_batched(
                classes, train_x_backup, train_labelsz, batch_size=batch_size)
        else:
            acc_train_before = mixture.eval_accuracy_batched(
                classes, train_x, train_labelsz, batch_size=batch_size)
        acc_valid_before = mixture.eval_accuracy_batched(classes,
                                                         valid_x,
                                                         valid_labelsz,
                                                         batch_size=batch_size)
        acc_test_before = mixture.eval_accuracy_batched(classes,
                                                        test_x,
                                                        test_labelsz,
                                                        batch_size=batch_size)
        if do_print:
            print()
            print(
                "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
                .format(acc_train_before, acc_valid_before, acc_test_before))
        if no_OA:
            train_accs_ref.append(acc_train_before)
            valid_accs_ref.append(acc_valid_before)
            test_accs_ref.append(acc_test_before)
        else:
            train_accs.append(acc_train_before)
            valid_accs.append(acc_valid_before)
            test_accs.append(acc_test_before)
        mixture.train()

    eval_network(do_print=True, no_OA=False)
    eval_network(do_print=False, no_OA=True)

    #####################################################
    # Evaluate the network with different training sets #
    #####################################################

    idx_batches = torch.randperm(online_x.shape[0], device=device).split(20)

    for idx in tqdm(idx_batches):
        online_x_idx = online_x[idx]
        online_labels_idx = [online_labels[i] for i in idx]

        for (einet, c) in zip(einets, classes):
            batch_x = online_x_idx[[l == c for l in online_labels_idx]]
            train_x_backup = torch.cat((train_x_backup, batch_x))
            train_labels_backup += [c for i in batch_x]

        added_samples.append(added_samples[-1] + len(idx))
        eval_network(do_print=False, no_OA=True)

    #####################
    # Online adaptation #
    #####################

    for idx in tqdm(idx_batches):
        online_x_idx = online_x[idx]
        online_labels_idx = [online_labels[i] for i in idx]

        for (einet, c) in zip(einets, classes):
            batch_x = online_x_idx[[l == c for l in online_labels_idx]]
            online_update(einet, batch_x)
            train_x = torch.cat((train_x, batch_x))
            train_labels += [c for i in batch_x]

        eval_network(do_print=False, no_OA=False)

    print()
    print(f'Network size: {num_params} parameters')

    return {
        'train_lls': train_lls,
        'valid_lls': valid_lls,
        'test_lls': test_lls,
        'train_accs': train_accs,
        'valid_accs': valid_accs,
        'test_accs': test_accs,
        'train_lls_ref': train_lls_ref,
        'valid_lls_ref': valid_lls_ref,
        'test_lls_ref': test_lls_ref,
        'train_accs_ref': train_accs_ref,
        'valid_accs_ref': valid_accs_ref,
        'test_accs_ref': test_accs_ref,
        'network_size': num_params,
        'online_samples': added_samples,
    }
Пример #12
0
def run_experiment(settings):
    ############################################################################

    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions = settings.num_repetitions

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    online_em_frequency = settings.online_em_frequency
    online_em_stepsize = settings.online_em_stepsize
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    # Make EinsumNetwork
    ######################################
    if structure == 'poon-domingos':
        pd_delta = [[height / d, width / d] for d in pd_num_pieces]
        graph = Graph.poon_domingos_structure(shape=(height, width),
                                              delta=pd_delta)
    elif structure == 'binary-trees':
        graph = Graph.random_binary_trees(num_var=train_x.shape[1],
                                          depth=depth,
                                          num_repetitions=num_repetitions)
    else:
        raise AssertionError("Unknown Structure")

    args = EinsumNetwork.Args(num_var=train_x.shape[1],
                              num_dims=3 if svhn else 1,
                              num_classes=1,
                              num_sums=K,
                              num_input_distributions=K,
                              exponential_family=exponential_family,
                              exponential_family_args=exponential_family_args,
                              online_em_frequency=online_em_frequency,
                              online_em_stepsize=online_em_stepsize,
                              use_em=True)

    einet = EinsumNetwork.EinsumNetwork(graph, args)
    print(einet)

    init_dict = get_init_dict(einet, train_x)
    einet.initialize(init_dict)
    einet.to(device)

    num_params = EinsumNetwork.eval_size(einet)

    data_dir = '../src/experiments/round5/data/weights_analysis/'
    data_file = os.path.join(data_dir, f"weights_before.json")
    weights = einet.einet_layers[-1].params.data.cpu()
    np.savetxt(data_file, weights[0])

    # Train
    ######################################

    optimizer = torch.optim.SGD(einet.parameters(), lr=SGD_learning_rate)

    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]

    start_time = time.time()

    for epoch_count in range(num_epochs):
        idx_batches = torch.randperm(train_N, device=device).split(batch_size)

        total_loss = 0.0
        for idx in idx_batches:
            batch_x = train_x[idx, :]
            # optimizer.zero_grad()
            outputs = einet.forward(batch_x)
            ll_sample = EinsumNetwork.log_likelihoods(outputs)
            log_likelihood = ll_sample.sum()
            log_likelihood.backward()
            # nll = log_likelihood * -1
            # nll.backward()
            # optimizer.step()

            einet.em_process_batch()
        einet.em_update()

        print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()

    data_dir = '../src/experiments/round5/data/weights_analysis/'
    data_file = os.path.join(data_dir, f"weights_after.json")
    weights = einet.einet_layers[-1].params.data.cpu()
    np.savetxt(data_file, weights[0])
    # exit()

    ################
    # Experiment 1 #
    ################
    einet.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        train_x,
                                                        batch_size=batch_size)
    valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                        valid_x,
                                                        batch_size=batch_size)
    test_ll = EinsumNetwork.eval_loglikelihood_batched(einet,
                                                       test_x,
                                                       batch_size=batch_size)
    print()
    print(
        "Experiment 1: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 2 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    train_x,
                                                    train_labels,
                                                    batch_size=batch_size)
    acc_valid = EinsumNetwork.eval_accuracy_batched(einet,
                                                    classes,
                                                    valid_x,
                                                    valid_labels,
                                                    batch_size=batch_size)
    acc_test = EinsumNetwork.eval_accuracy_batched(einet,
                                                   classes,
                                                   test_x,
                                                   test_labels,
                                                   batch_size=batch_size)
    print()
    print(
        "Experiment 2: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
def run_experiment(settings):
    ############################################################################
    exponential_family = EinsumNetwork.BinomialArray

    K = 1

    # structure = 'poon-domingos'
    structure = 'binary-trees'

    # 'poon-domingos'
    pd_num_pieces = [2]

    # 'binary-trees'
    depth = 1
    num_repetitions = 2

    width = 4

    num_epochs = 2
    batch_size = 3
    online_em_frequency = 1
    online_em_stepsize = 0.05
    SGD_learning_rate = 0.05

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 80}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    iris = datasets.load_iris()
    train_x = iris.data * 10
    train_labels = iris.target

    # print(train_x[0])
    # print(train_labels[0])
    # exit()

    # self generated data
    # train_x = np.array([np.array([1, 1, 1, 1]) for i in range(50)] + [np.array([3, 3, 3, 3]) for i in range(50)] + [np.array([8, 8, 8, 8]) for i in range(50)])
    # train_labels = np.array([0 for i in range(50)] + [1 for i in range(50)] + [2 for i in range(50)])

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        train_x -= .5

    classes = np.unique(train_labels).tolist()

    train_x = torch.from_numpy(train_x).to(torch.device(device))

    # Make EinsumNetwork
    ######################################

    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(num_var=train_x.shape[1],
                                              depth=depth,
                                              num_repetitions=num_repetitions)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            online_em_frequency=online_em_frequency,
            online_em_stepsize=online_em_stepsize)

        einet = EinsumNetwork.EinsumNetwork(graph, args)
        print(einet)

        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)

        einet = einet.float()
        einets.append(einet)

        ps.append(np.count_nonzero(train_labels == c))

    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()

    # Train
    ######################################
    """ Generative training """

    start_time = time.time()

    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]

        train_N = train_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_ll = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :].float()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                log_likelihood.backward()

                einet.em_process_batch()
                total_ll += log_likelihood.detach().item()
            einet.em_update()
            print(
                f'[{epoch_count}]   total log-likelihood: {total_ll/train_N}')

    end_time = time.time()
    """ Discriminative training """
    def discriminative_learning():
        sub_net_parameters = None
        for einet in mixture.einets:
            if sub_net_parameters is None:
                sub_net_parameters = list(einet.parameters())
            else:
                sub_net_parameters += list(einet.parameters())
        sub_net_parameters += list(mixture.parameters())

        optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)
        loss_function = torch.nn.CrossEntropyLoss()

        train_N = train_x.shape[0]

        start_time = time.time()

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0
            for idx in idx_batches:
                batch_x = train_x[idx, :].float()
                optimizer.zero_grad()
                outputs = mixture.forward(batch_x)
                target = torch.tensor([
                    classes.index(train_labels[i]) for i in idx
                ]).to(torch.device(device))
                loss = loss_function(outputs, target)
                loss.backward()
                optimizer.step()
                total_loss += loss.detach().item()

            print(f'[{epoch_count}]   total loss: {total_loss}')

        end_time = time.time()

    ################
    # Experiment 1 #
    ################
    mixture.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(mixture,
                                                        train_x.float(),
                                                        batch_size=1)
    print()
    print("Experiment 1: Log-likelihoods  --- train LL {}".format(train_ll /
                                                                  train_N))

    ################
    # Experiment 2 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))

    acc_train = EinsumNetwork.eval_accuracy_batched(mixture,
                                                    classes,
                                                    train_x.float(),
                                                    train_labels,
                                                    batch_size=1)
    print()
    print("Experiment 2: Classification accuracies  --- train acc {}".format(
        acc_train))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'train_acc': acc_train,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
Пример #14
0
train_x_orig, test_x_orig, valid_x_orig = datasets.load_debd(dataset, dtype='float32')

train_x = train_x_orig
test_x = test_x_orig
valid_x = valid_x_orig

# to torch
train_x = torch.from_numpy(train_x).to(torch.device(device))
valid_x = torch.from_numpy(valid_x).to(torch.device(device))
test_x = torch.from_numpy(test_x).to(torch.device(device))

train_N, num_dims = train_x.shape
valid_N = valid_x.shape[0]
test_N = test_x.shape[0]

graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)

args = EinsumNetwork.Args(
    num_classes=1,
    num_input_distributions=num_input_distributions,
    exponential_family=EinsumNetwork.CategoricalArray,
    exponential_family_args={'K': 2},
    num_sums=num_sums,
    num_var=train_x.shape[1],
    online_em_frequency=1,
    online_em_stepsize=0.05)

einet = EinsumNetwork.EinsumNetwork(graph, args)
einet.initialize()
einet.to(device)
print(einet)
def train(einet, mean, train_x, valid_x, test_x, result_path):
    model_file = os.path.join(result_path, 'einet.mdl')
    graph_file = os.path.join(result_path, 'einet.pc')
    record_file = os.path.join(result_path, 'record.pkl')
    sample_dir = os.path.join(result_path, 'samples')
    utils.mkdir_p(sample_dir)

    record = {
        'train_ll': [],
        'valid_ll': [],
        'test_ll': [],
        'best_validation_ll': None
    }

    for epoch_count in range(num_epochs):

        shuffled_batch = make_shuffled_batch(len(train_x), batch_size)
        for batch_counter, batch_idx in enumerate(shuffled_batch):
            batch = torch.tensor(train_x[batch_idx, :]).to(device).float()
            batch = batch.reshape(batch.shape[0], height * width, 3)
            # we subtract the mean for this cluster -- centered data seems to help EM learning
            # we will re-add the mean to the Gaussian means below
            batch = batch - mean
            batch = batch / 255.

            ll_sample = einet.forward(batch)
            log_likelihood = ll_sample.sum()
            log_likelihood.backward()
            einet.em_process_batch()
        einet.em_update()

        ##### evaluate
        train_ll = eval_ll(einet, mean, train_x, batch_size=batch_size)
        valid_ll = eval_ll(einet, mean, valid_x, batch_size=batch_size)
        test_ll = eval_ll(einet, mean, test_x, batch_size=batch_size)

        ##### store results
        record['train_ll'].append(train_ll)
        record['valid_ll'].append(valid_ll)
        record['test_ll'].append(test_ll)

        pickle.dump(record, open(record_file, 'wb'))

        print("[{}]   train LL {}   valid LL {}   test LL {}".format(
            epoch_count, train_ll, valid_ll, test_ll))

        if record['best_validation_ll'] is None or valid_ll > record[
                'best_validation_ll']:
            record['best_validation_ll'] = valid_ll
            torch.save(einet, model_file)
            Graph.write_gpickle(graph, graph_file)

        if epoch_count % 10 == 0:
            # draw some samples
            samples = einet.sample(num_samples=25,
                                   std_correction=0.0).cpu().numpy()
            samples = samples + mean.detach().cpu().numpy() / 255.
            samples -= samples.min()
            samples /= samples.max()
            samples = samples.reshape(samples.shape[0], height, width, 3)
            img = np.zeros((height * 5 + 40, width * 5 + 40, 3))
            for h in range(5):
                for w in range(5):
                    img[h * (height + 10):h * (height + 10) + height,
                        w * (width + 10):w * (width + 10) +
                        width, :] = samples[h * 5 + w, :]
            img = Image.fromarray(np.round(img * 255.).astype(np.uint8))
            img.save(
                os.path.join(sample_dir, "samples{}.jpg".format(epoch_count)))

    # We subtract the mean for the current cluster from the data (centering it at 0).
    # Here we re-add the mean to the Gaussian means. A hacky solution at the moment...
    einet = torch.load(model_file)
    with torch.no_grad():
        params = einet.einet_layers[0].ef_array.params
        mu2 = params[..., 0:3]**2
        params[..., 3:] -= mu2
        params[..., 3:] = torch.clamp(params[..., 3:],
                                      exponential_family_args['min_var'],
                                      exponential_family_args['max_var'])
        params[..., 0:3] += mean.reshape((width * height, 1, 1, 3)) / 255.
        params[..., 3:] += params[..., 0:3]**2
    torch.save(einet, model_file)
for cluster_n in range(num_clusters):
    train_x = train_x_all[cluster_idx == cluster_n, ...]
    valid_x = valid_x_all[valid_cluster_idx == cluster_n, ...]
    test_x = test_x_all[test_cluster_idx == cluster_n, ...]

    mean = cluster_means[cluster_n, ...]
    mean = mean.reshape(1, height * width, 3)
    mean = torch.tensor(mean, device=device)

    result_path = result_base_path
    result_path = os.path.join(result_path,
                               "num_clusters_{}".format(num_clusters))
    result_path = os.path.join(result_path, "cluster_{}".format(cluster_n))

    graph = Graph.poon_domingos_structure(shape=(height, width),
                                          axes=[1],
                                          delta=[8])

    args = EinsumNetwork.Args(num_var=height * width,
                              num_dims=3,
                              num_classes=1,
                              num_sums=num_sums,
                              num_input_distributions=num_sums,
                              exponential_family=exponential_family,
                              exponential_family_args=exponential_family_args,
                              online_em_frequency=online_em_frequency,
                              online_em_stepsize=online_em_stepsize)

    print()
    print(result_path)
Пример #17
0
def run_experiment(settings):
    ############################################################################
    exponential_family = EinsumNetwork.BinomialArray

    K = 2

    # structure = 'poon-domingos'
    structure = 'binary-trees'

    # 'poon-domingos'
    pd_num_pieces = [2]

    # 'binary-trees'
    depth = 1
    num_repetitions = 2

    width = 4

    num_epochs = 2
    batch_size = 3
    online_em_frequency = 1
    online_em_stepsize = 0.05

    print_weights = False
    print_weights = True

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 80}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    iris = datasets.load_iris()
    train_x = iris.data * 10
    train_labels = iris.target
    
    # self generated data
    # train_x = np.array([np.array([1, 1, 1, 1]) for i in range(50)] + [np.array([3, 3, 3, 3]) for i in range(50)] + [np.array([8, 8, 8, 8]) for i in range(50)])
    # train_labels = [0 for i in range(50)] + [1 for i in range(50)] + [2 for i in range(50)]

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        train_x -= .5

    classes = np.unique(train_labels).tolist()

    train_x = torch.from_numpy(train_x).to(torch.device(device))

    # Make EinsumNetwork
    ######################################
    if structure == 'poon-domingos':
        pd_delta = [[width / d] for d in pd_num_pieces]
        graph = Graph.poon_domingos_structure(shape=(width), delta=pd_delta)
    elif structure == 'binary-trees':
        graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions)
    else:
        raise AssertionError("Unknown Structure")

    args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            online_em_frequency=online_em_frequency,
            online_em_stepsize=online_em_stepsize)

    einet = EinsumNetwork.EinsumNetwork(graph, args)
    print(einet)

    init_dict = get_init_dict(einet, train_x)
    einet.initialize(init_dict)
    einet.to(device)

    einet = einet.float()

    num_params = EinsumNetwork.eval_size(einet)

    # Train
    ######################################

    train_N = train_x.shape[0]

    start_time = time.time()

    for epoch_count in range(num_epochs):
        idx_batches = torch.randperm(train_N, device=device).split(batch_size)

        total_ll = 0.0
        for idx in idx_batches:
            batch_x = train_x[idx, :].float()
            # exit()
            outputs = einet.forward(batch_x)
            ll_sample = EinsumNetwork.log_likelihoods(outputs)
            log_likelihood = ll_sample.sum()
            log_likelihood.backward()

            einet.em_process_batch()
            total_ll += log_likelihood.detach().item()
        einet.em_update()
        print(f'[{epoch_count}]   total log-likelihood: {total_ll/train_N}')

    end_time = time.time()

    ################
    # Experiment 1 #
    ################
    einet.eval()
    train_ll = EinsumNetwork.eval_loglikelihood_batched(einet, train_x.float(), batch_size=batch_size)
    print()
    print("Experiment 1: Log-likelihoods  --- train LL {}".format(
            train_ll / train_N))

    ################
    # Experiment 2 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))

    acc_train = EinsumNetwork.eval_accuracy_batched(einet, classes, train_x.float(), train_labels, batch_size=batch_size)
    print()
    print("Experiment 2: Classification accuracies  --- train acc {}".format(
            acc_train))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    if print_weights:
        for l in einet.einet_layers:
            print()
            if isinstance(l, FactorizedLeafLayer.FactorizedLeafLayer):
                print(l.ef_array.params)
            else:
                print(l.params)

    return {
        'train_ll': train_ll / train_N,
        'train_acc': acc_train,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }
    """
    A simple initializer for normalized sum-weights.
    :return: initial parameters
    """
    params = 0.01 + 0.98 * torch.rand(layer.params_shape)
    with torch.no_grad():
        if layer.params_mask is not None:
            params.data *= layer.params_mask
        params.data = params.data / (params.data.sum(layer.normalization_dims,
                                                     keepdim=True))
    return params


if structure == 'poon-domingos':
    pd_delta = [[height / d, width / d] for d in pd_num_pieces]
    graph = Graph.poon_domingos_structure(shape=(height, width),
                                          delta=pd_delta)
elif structure == 'binary-trees':
    graph = Graph.random_binary_trees(num_var=train_x.shape[1],
                                      depth=depth,
                                      num_repetitions=num_repetitions)
else:
    raise AssertionError("Unknown Structure")

args = EinsumNetwork.Args(num_var=train_x.shape[1],
                          num_dims=1,
                          num_classes=len(classes),
                          num_sums=K,
                          num_input_distributions=K,
                          exponential_family=exponential_family,
                          exponential_family_args=exponential_family_args,
                          use_em=False)
Пример #19
0
def run_experiment(settings):
    ############################################################################
    fashion_mnist = settings.fashion_mnist
    svhn = settings.svhn

    exponential_family = settings.exponential_family

    classes = settings.classes

    K = settings.K

    structure = settings.structure

    # 'poon-domingos'
    pd_num_pieces = settings.pd_num_pieces

    # 'binary-trees'
    depth = settings.depth
    num_repetitions_mixture = settings.num_repetitions_mixture

    width = settings.width
    height = settings.height

    num_epochs = settings.num_epochs
    batch_size = settings.batch_size
    online_em_frequency = settings.online_em_frequency
    online_em_stepsize = settings.online_em_stepsize
    SGD_learning_rate = settings.SGD_learning_rate

    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    # get data
    if fashion_mnist:
        train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist(
        )
    elif svhn:
        train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn(
        )
    else:
        train_x, train_labels, test_x, test_labels = datasets.load_mnist()

    if not exponential_family != EinsumNetwork.NormalArray:
        train_x /= 255.
        test_x /= 255.
        train_x -= .5
        test_x -= .5

    # validation split
    valid_x = train_x[-10000:, :]
    train_x = train_x[:-10000, :]
    valid_labels = train_labels[-10000:]
    train_labels = train_labels[:-10000]

    # valid_x = train_x[-10000:, :]
    # train_x = train_x[-12000:-11000, :]

    # valid_labels = train_labels[-10000:]
    # train_labels = train_labels[-12000:-11000]

    # pick the selected classes
    if classes is not None:
        train_x = train_x[
            np.any(np.stack([train_labels == c for c in classes], 1), 1), :]
        valid_x = valid_x[
            np.any(np.stack([valid_labels == c for c in classes], 1), 1), :]
        test_x = test_x[
            np.any(np.stack([test_labels == c for c in classes], 1), 1), :]

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]
    else:
        classes = np.unique(train_labels).tolist()

        train_labels = [l for l in train_labels if l in classes]
        valid_labels = [l for l in valid_labels if l in classes]
        test_labels = [l for l in test_labels if l in classes]

    train_x = torch.from_numpy(train_x).to(torch.device(device))
    valid_x = torch.from_numpy(valid_x).to(torch.device(device))
    test_x = torch.from_numpy(test_x).to(torch.device(device))

    ######################################
    # Make EinsumNetworks for each class #
    ######################################
    einets = []
    ps = []
    for c in classes:
        if structure == 'poon-domingos':
            pd_delta = [[height / d, width / d] for d in pd_num_pieces]
            graph = Graph.poon_domingos_structure(shape=(height, width),
                                                  delta=pd_delta)
        elif structure == 'binary-trees':
            graph = Graph.random_binary_trees(
                num_var=train_x.shape[1],
                depth=depth,
                num_repetitions=num_repetitions_mixture)
        else:
            raise AssertionError("Unknown Structure")

        args = EinsumNetwork.Args(
            num_var=train_x.shape[1],
            num_dims=3 if svhn else 1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            online_em_frequency=online_em_frequency,
            online_em_stepsize=online_em_stepsize,
            use_em=False)

        einet = EinsumNetwork.EinsumNetwork(graph, args)

        init_dict = get_init_dict(einet,
                                  train_x,
                                  train_labels=train_labels,
                                  einet_class=c)
        einet.initialize(init_dict)
        einet.to(device)
        einets.append(einet)

        # Calculate amount of training samples per class
        ps.append(train_labels.count(c))

        print(f'Einsum network for class {c}:')
        print(einet)

    # normalize ps, construct mixture component
    ps = [p / sum(ps) for p in ps]
    ps = torch.tensor(ps).to(torch.device(device))
    mixture = EinetMixture(ps, einets, classes=classes)

    num_params = mixture.eval_size()
    """Code for weight analysis, section 7.3"""
    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # data_file = os.path.join(data_dir, f"weights_before_{c}.json")
    # weights = {}
    # for (einet, c) in zip(einets, classes):
    #     einet_weights = {}
    #     for i, l in enumerate(reversed(einet.einet_layers)):
    #         if type(l) != FactorizedLeafLayer.FactorizedLeafLayer:
    #             einet_weights[i] = l.reparam(l.params.data)
    #         else:
    #             einet_weights[i] = l.ef_array.reparam(l.ef_array.params.data)
    #     weights[c] = einet_weights

    ##################
    # Training phase #
    ##################

    sub_net_parameters = None
    for einet in mixture.einets:
        if sub_net_parameters is None:
            sub_net_parameters = list(einet.parameters())
        else:
            sub_net_parameters += list(einet.parameters())
    sub_net_parameters += list(mixture.parameters())

    optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate)
    """ Learning each sub Network Generatively """

    start_time = time.time()

    for (einet, c) in zip(einets, classes):
        train_x_c = train_x[[l == c for l in train_labels]]
        valid_x_c = valid_x[[l == c for l in valid_labels]]
        test_x_c = test_x[[l == c for l in test_labels]]

        train_N = train_x_c.shape[0]
        valid_N = valid_x_c.shape[0]
        test_N = test_x_c.shape[0]

        for epoch_count in range(num_epochs):
            idx_batches = torch.randperm(train_N,
                                         device=device).split(batch_size)

            total_loss = 0.0
            for idx in idx_batches:
                batch_x = train_x_c[idx, :]
                optimizer.zero_grad()
                outputs = einet.forward(batch_x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.sum()
                # log_likelihood.backward()
                nll = log_likelihood * -1
                nll.backward()
                optimizer.step()
                total_loss += nll.detach().item()
            #     einet.em_process_batch()
            # einet.em_update()

            print(f'[{epoch_count}]   total loss: {total_loss}')

    end_time = time.time()
    """Code for weight analysis"""
    # data_dir = '../src/experiments/round5/data/weights_analysis/'
    # utils.mkdir_p(data_dir)
    # percentual_change = []
    # percentual_change_formatted = []
    # data_file = os.path.join(data_dir, f"percentual_change.json")
    # for (einet, c) in zip(einets, classes):
    #     einet_change = {}
    #     for i, l in enumerate(reversed(einet.einet_layers)):
    #         if type(l) != FactorizedLeafLayer.FactorizedLeafLayer:
    #             weights_new = l.reparam(l.params.data)
    #         else:
    #             weights_new = l.ef_array.reparam(l.ef_array.params.data)
    #         change = torch.mean(torch.abs(weights[c][i] - weights_new)/weights[c][i]).item()
    #         einet_change[i] = change
    #         if i == 0:
    #             percentual_change_formatted.append(change)

    #     percentual_change.append(einet_change)

    # print(percentual_change)
    # with open(data_file, 'w') as f:
    #     json.dump(percentual_change, f)
    # data_file = os.path.join(data_dir, f"percentual_change_formatted.json")
    # with open(data_file, 'w') as f:
    #     json.dump(percentual_change_formatted, f)

    ################
    # Experiment 3 #
    ################
    train_N = train_x.shape[0]
    valid_N = valid_x.shape[0]
    test_N = test_x.shape[0]
    mixture.eval()
    train_ll = mixture.eval_loglikelihood_batched(train_x,
                                                  batch_size=batch_size,
                                                  skip_reparam=True)
    valid_ll = mixture.eval_loglikelihood_batched(valid_x,
                                                  batch_size=batch_size,
                                                  skip_reparam=True)
    test_ll = mixture.eval_loglikelihood_batched(test_x,
                                                 batch_size=batch_size,
                                                 skip_reparam=True)
    print()
    print(
        "Experiment 3: Log-likelihoods  --- train LL {}   valid LL {}   test LL {}"
        .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N))

    ################
    # Experiment 4 #
    ################
    train_labels = torch.tensor(train_labels).to(torch.device(device))
    valid_labels = torch.tensor(valid_labels).to(torch.device(device))
    test_labels = torch.tensor(test_labels).to(torch.device(device))

    acc_train = mixture.eval_accuracy_batched(classes,
                                              train_x,
                                              train_labels,
                                              batch_size=batch_size,
                                              skip_reparam=True)
    acc_valid = mixture.eval_accuracy_batched(classes,
                                              valid_x,
                                              valid_labels,
                                              batch_size=batch_size,
                                              skip_reparam=True)
    acc_test = mixture.eval_accuracy_batched(classes,
                                             test_x,
                                             test_labels,
                                             batch_size=batch_size,
                                             skip_reparam=True)
    print()
    print(
        "Experiment 4: Classification accuracies  --- train acc {}   valid acc {}   test acc {}"
        .format(acc_train, acc_valid, acc_test))

    print()
    print(f'Network size: {num_params} parameters')
    print(f'Training time: {end_time - start_time}s')

    return {
        'train_ll': train_ll / train_N,
        'valid_ll': valid_ll / valid_N,
        'test_ll': test_ll / test_N,
        'train_acc': acc_train,
        'valid_acc': acc_valid,
        'test_acc': acc_test,
        'network_size': num_params,
        'training_time': end_time - start_time,
    }