def save_einet(einet, graph, out_path): # save model os.makedirs(out_path, exist_ok=True) graph_file = os.path.join(out_path, "einet.rg") Graph.write_gpickle(graph, graph_file) print("Saved PC graph to {}".format(graph_file)) model_file = os.path.join(out_path, "einet.pth") torch.save(einet, model_file) print("Saved model to {}".format(model_file))
def load_einet_state(model_path, einet_file='einet.pth', graph_file='einet.rg', n_vars=None, n_classes=None, n_sums=None, n_input_dists=None, exp_fam=None, exp_fam_args=None, use_em=None, em_freq=None, em_stepsize=None, graph=None): # reload model einet = None if graph is None: if graph_file: graph_file = os.path.join(model_path, graph_file) graph = Graph.read_gpickle(graph_file) else: raise ValueError(f"Cannot create graph") model_file = os.path.join(model_path, einet_file) einet = make_einet(graph, n_vars=n_vars, n_classes=n_classes, n_sums=n_sums, n_input_dists=n_input_dists, exp_fam=exp_fam, exp_fam_args=exp_fam_args, use_em=use_em, em_freq=em_freq, em_stepsize=em_stepsize) einet.load_state_dict(torch.load(model_file)) print("Loaded model from {}".format(model_file)) return einet, graph
def make_region_graph(structure="poon-domingos", height=None, width=None, pd_pieces=None, depth=None, n_repetitions=None): graph = None if structure == 'poon-domingos': assert pd_pieces is not None pd_delta = [[height / d, width / d] for d in pd_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': n_vars = height * width graph = Graph.random_binary_trees(num_var=n_vars, depth=depth, num_repetitions=n_repetitions) else: raise AssertionError("Unknown Structure") return graph
def load_einet(model_path, einet_file='einet.pth', graph_file='einet.rg'): # reload model einet, graph = None, None model_file = os.path.join(model_path, einet_file) einet = torch.load(model_file) print("Loaded model from {}".format(model_file)) if graph_file: graph_file = os.path.join(model_path, graph_file) graph = Graph.read_gpickle(graph_file) return einet, graph
def __init__(self, graph, args=None): """Make an EinsumNetwork.""" super(EinsumNetwork, self).__init__() check_flag, check_msg = Graph.check_graph(graph) if not check_flag: raise AssertionError(check_msg) self.graph = graph self.args = args if args is not None else Args() if len(Graph.get_roots(self.graph)) != 1: raise AssertionError("Currently only EinNets with single root node supported.") root = Graph.get_roots(self.graph)[0] if tuple(range(self.args.num_var)) != root.scope: raise AssertionError("The graph should be over tuple(range(num_var)).") for node in Graph.get_leaves(self.graph): node.num_dist = self.args.num_input_distributions for node in Graph.get_sums(self.graph): if node is root: node.num_dist = self.args.num_classes else: node.num_dist = self.args.num_sums # Algorithm 1 in the paper -- organize the PC in layers self.graph_layers = Graph.topological_layers(self.graph) # input layer einet_layers = [FactorizedLeafLayer(self.graph_layers[0], self.args.num_var, self.args.num_dims, self.args.exponential_family, self.args.exponential_family_args, use_em=self.args.use_em)] # internal layers for c, layer in enumerate(self.graph_layers[1:]): if c % 2 == 0: # product layer einet_layers.append(EinsumLayer(self.graph, layer, einet_layers, use_em=self.args.use_em)) else: # sum layer # the Mixing layer is only for regions which have multiple partitions as children. multi_sums = [n for n in layer if len(graph.succ[n]) > 1] if multi_sums: einet_layers.append(EinsumMixingLayer(graph, multi_sums, einet_layers[-1], use_em=self.args.use_em)) self.einet_layers = torch.nn.ModuleList(einet_layers) self.em_set_hyperparams(self.args.online_em_frequency, self.args.online_em_stepsize)
def run_experiment(settings): ############################################################################ fashion_mnist = settings.fashion_mnist svhn = settings.svhn exponential_family = settings.exponential_family classes = settings.classes K = settings.K structure = settings.structure # 'poon-domingos' pd_num_pieces = settings.pd_num_pieces # 'binary-trees' depth = settings.depth num_repetitions = settings.num_repetitions width = settings.width height = settings.height num_epochs = settings.num_epochs batch_size = settings.batch_size SGD_learning_rate = settings.SGD_learning_rate ############################################################################ exponential_family_args = None if exponential_family == EinsumNetwork.BinomialArray: exponential_family_args = {'N': 255} if exponential_family == EinsumNetwork.CategoricalArray: exponential_family_args = {'K': 256} if exponential_family == EinsumNetwork.NormalArray: exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} # get data if fashion_mnist: train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist( ) elif svhn: train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn( ) else: train_x, train_labels, test_x, test_labels = datasets.load_mnist() if not exponential_family != EinsumNetwork.NormalArray: train_x /= 255. test_x /= 255. train_x -= .5 test_x -= .5 # validation split valid_x = train_x[-10000:, :] train_x = train_x[:-10000, :] valid_labels = train_labels[-10000:] train_labels = train_labels[:-10000] # pick the selected classes if classes is not None: train_x = train_x[ np.any(np.stack([train_labels == c for c in classes], 1), 1), :] valid_x = valid_x[ np.any(np.stack([valid_labels == c for c in classes], 1), 1), :] test_x = test_x[ np.any(np.stack([test_labels == c for c in classes], 1), 1), :] train_labels = [l for l in train_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] else: classes = np.unique(train_labels).tolist() train_labels = [l for l in train_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] train_x = torch.from_numpy(train_x).to(torch.device(device)) valid_x = torch.from_numpy(valid_x).to(torch.device(device)) test_x = torch.from_numpy(test_x).to(torch.device(device)) ###################################### # Make EinsumNetworks for each class # ###################################### if structure == 'poon-domingos': pd_delta = [[height / d, width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args(num_var=train_x.shape[1], num_dims=3 if svhn else 1, num_classes=len(classes), num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, use_em=False) einet = EinsumNetwork.EinsumNetwork(graph, args) init_dict = get_init_dict(einet, train_x) einet.initialize(init_dict) einet.to(device) print(einet) num_params = EinsumNetwork.eval_size(einet) ################################# # Discriminative training phase # ################################# optimizer = torch.optim.SGD(einet.parameters(), lr=SGD_learning_rate) loss_function = torch.nn.CrossEntropyLoss() train_N = train_x.shape[0] valid_N = valid_x.shape[0] test_N = test_x.shape[0] start_time = time.time() for epoch_count in range(num_epochs): idx_batches = torch.randperm(train_N, device=device).split(batch_size) total_loss = 0 for idx in idx_batches: batch_x = train_x[idx, :] optimizer.zero_grad() outputs = einet.forward(batch_x) target = torch.tensor([ classes.index(train_labels[i]) for i in idx ]).to(torch.device(device)) loss = loss_function(outputs, target) loss.backward() optimizer.step() total_loss += loss.detach().item() print(f'[{epoch_count}] total loss: {total_loss}') end_time = time.time() ################ # Experiment 5 # ################ einet.eval() train_ll = EinsumNetwork.eval_loglikelihood_batched(einet, train_x, batch_size=batch_size) valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet, valid_x, batch_size=batch_size) test_ll = EinsumNetwork.eval_loglikelihood_batched(einet, test_x, batch_size=batch_size) print() print( "Experiment 5: Log-likelihoods --- train LL {} valid LL {} test LL {}" .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N)) ################ # Experiment 6 # ################ train_labels = torch.tensor(train_labels).to(torch.device(device)) valid_labels = torch.tensor(valid_labels).to(torch.device(device)) test_labels = torch.tensor(test_labels).to(torch.device(device)) acc_train = EinsumNetwork.eval_accuracy_batched(einet, classes, train_x, train_labels, batch_size=batch_size) acc_valid = EinsumNetwork.eval_accuracy_batched(einet, classes, valid_x, valid_labels, batch_size=batch_size) acc_test = EinsumNetwork.eval_accuracy_batched(einet, classes, test_x, test_labels, batch_size=batch_size) print() print( "Experiment 6: Classification accuracies --- train acc {} valid acc {} test acc {}" .format(acc_train, acc_valid, acc_test)) print() print(f'Network size: {num_params} parameters') print(f'Training time: {end_time - start_time}s') return { 'train_ll': train_ll / train_N, 'valid_ll': valid_ll / valid_N, 'test_ll': test_ll / test_N, 'train_acc': acc_train, 'valid_acc': acc_valid, 'test_acc': acc_test, 'network_size': num_params, 'training_time': end_time - start_time, }
def construct_network_and_train( num_starts, svhn, exponential_family, classes, K, structure, pd_num_pieces, depth, num_repetitions, width, height, num_epochs, batch_size, online_em_frequency, online_em_stepsize, exponential_family_args, train_x, train_labels, test_x, test_labels, valid_x, valid_labels): train_lls = [] valid_lls = [] test_lls = [] train_accs = [] valid_accs = [] test_accs = [] training_times = [] num_params = None for s in range(num_starts): print(f""" Running start: {s}/{num_starts} """) # train_lls.append(0) # valid_lls.append(0) # test_lls.append(0) # train_accs.append(0) # valid_accs.append(0) # test_accs.append(0) # training_times.append(0) # continue ###################################### # Make EinsumNetworks for each class # ###################################### einets = [] ps = [] for c in classes: if structure == 'poon-domingos': pd_delta = [[height / d, width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees( num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args( num_var=train_x.shape[1], num_dims=3 if svhn else 1, num_classes=1, num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, online_em_frequency=online_em_frequency, online_em_stepsize=online_em_stepsize) einet = EinsumNetwork.EinsumNetwork(graph, args) init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c) einet.initialize(init_dict) einet.to(device) einets.append(einet) # Calculate amount of training samples per class ps.append(train_labels.count(c)) # print(f'Einsum network for class {c}:') # print(einet) # normalize ps, construct mixture component ps = [p / sum(ps) for p in ps] ps = torch.tensor(ps).to(torch.device(device)) mixture = EinetMixture(ps, einets, classes=classes) if num_params == None: num_params = mixture.eval_size() ################## # Training phase # ################## """ Learning each sub Network Generatively """ start_time = time.time() for (einet, c) in zip(einets, classes): train_x_c = train_x[[l == c for l in train_labels]] valid_x_c = valid_x[[l == c for l in valid_labels]] test_x_c = test_x[[l == c for l in test_labels]] train_N = train_x_c.shape[0] valid_N = valid_x_c.shape[0] test_N = test_x_c.shape[0] for epoch_count in range(num_epochs): idx_batches = torch.randperm( train_N, device=device).split(batch_size) total_ll = 0.0 for idx in idx_batches: batch_x = train_x_c[idx, :] outputs = einet.forward(batch_x) ll_sample = EinsumNetwork.log_likelihoods(outputs) log_likelihood = ll_sample.sum() log_likelihood.backward() einet.em_process_batch() total_ll += log_likelihood.detach().item() einet.em_update() print( f'[{epoch_count}] total log-likelihood: {total_ll}') end_time = time.time() ################ # Experiment 3 # ################ train_N = train_x.shape[0] valid_N = valid_x.shape[0] test_N = test_x.shape[0] mixture.eval() train_ll = mixture.eval_loglikelihood_batched( train_x, batch_size=batch_size) valid_ll = mixture.eval_loglikelihood_batched( valid_x, batch_size=batch_size) test_ll = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size) print() print( "Experiment 3: Log-likelihoods --- train LL {} valid LL {} test LL {}" .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N)) ################ # Experiment 4 # ################ train_labels_tensor = torch.tensor(train_labels).to( torch.device(device)) valid_labels_tensor = torch.tensor(valid_labels).to( torch.device(device)) test_labels_tensor = torch.tensor(test_labels).to( torch.device(device)) acc_train = mixture.eval_accuracy_batched(classes, train_x, train_labels_tensor, batch_size=batch_size) acc_valid = mixture.eval_accuracy_batched(classes, valid_x, valid_labels_tensor, batch_size=batch_size) acc_test = mixture.eval_accuracy_batched(classes, test_x, test_labels_tensor, batch_size=batch_size) print() print( "Experiment 4: Classification accuracies --- train acc {} valid acc {} test acc {}" .format(acc_train, acc_valid, acc_test)) print() print(f'Network size: {num_params} parameters') print(f'Training time: {end_time - start_time}s') train_lls.append(train_ll / train_N) valid_lls.append(valid_ll / valid_N) test_lls.append(test_ll / test_N) train_accs.append(acc_train) valid_accs.append(acc_valid) test_accs.append(acc_test) training_times.append(end_time - start_time) return { 'train_ll': train_lls, 'valid_ll': valid_lls, 'test_ll': test_lls, 'train_acc': train_accs, 'valid_acc': valid_accs, 'test_acc': test_accs, 'network_size': num_params, 'training_time': training_times, }
def run_experiment(settings): ############################################################################ fashion_mnist = settings.fashion_mnist svhn = settings.svhn exponential_family = settings.exponential_family classes = settings.classes K = settings.K structure = settings.structure # 'poon-domingos' pd_num_pieces = settings.pd_num_pieces # 'binary-trees' depth = settings.depth num_repetitions = settings.num_repetitions_mixture width = settings.width height = settings.height num_epochs = settings.num_epochs batch_size = settings.batch_size SGD_learning_rate = settings.SGD_learning_rate ############################################################################ exponential_family_args = None if exponential_family == EinsumNetwork.BinomialArray: exponential_family_args = {'N': 255} if exponential_family == EinsumNetwork.CategoricalArray: exponential_family_args = {'K': 256} if exponential_family == EinsumNetwork.NormalArray: exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} # get data if fashion_mnist: train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist() elif svhn: train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn() else: train_x, train_labels, test_x, test_labels = datasets.load_mnist() if not exponential_family != EinsumNetwork.NormalArray: train_x /= 255. test_x /= 255. train_x -= .5 test_x -= .5 # validation split valid_x = train_x[-10000:, :] train_x = train_x[:-10000, :] valid_labels = train_labels[-10000:] train_labels = train_labels[:-10000] # pick the selected classes if classes is not None: train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :] valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :] test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :] train_labels = [l for l in train_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] else: classes = np.unique(train_labels).tolist() train_labels = [l for l in train_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] train_x = torch.from_numpy(train_x).to(torch.device(device)) valid_x = torch.from_numpy(valid_x).to(torch.device(device)) test_x = torch.from_numpy(test_x).to(torch.device(device)) ###################################### # Make EinsumNetworks for each class # ###################################### einets = [] ps = [] for c in classes: if structure == 'poon-domingos': pd_delta = [[height / d, width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args( num_var=train_x.shape[1], num_dims=3 if svhn else 1, num_classes=1, num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, use_em=False) einet = EinsumNetwork.EinsumNetwork(graph, args) init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c) einet.initialize(init_dict) einet.to(device) einets.append(einet) # Calculate amount of training samples per class ps.append(train_labels.count(c)) print(f'Einsum network for class {c}:') print(einet) # normalize ps, construct mixture component ps = [p / sum(ps) for p in ps] ps = torch.tensor(ps).to(torch.device(device)) mixture = EinetMixture(ps, einets, classes=classes) num_params = mixture.eval_size() ################################## # Evalueate after initialization # ################################## train_lls = [] valid_lls = [] test_lls = [] train_accs = [] valid_accs = [] test_accs = [] train_N = train_x.shape[0] valid_N = valid_x.shape[0] test_N = test_x.shape[0] mixture.eval() train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size, skip_reparam=True) valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size, skip_reparam=True) test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size, skip_reparam=True) print() print("Experiment 3: Log-likelihoods --- train LL {} valid LL {} test LL {}".format( train_ll_before / train_N, valid_ll_before / valid_N, test_ll_before / test_N)) train_lls.append(train_ll_before / train_N) valid_lls.append(valid_ll_before / valid_N) test_lls.append(test_ll_before / test_N) ################ # Experiment 4 # ################ train_labelsz = torch.tensor(train_labels).to(torch.device(device)) valid_labelsz = torch.tensor(valid_labels).to(torch.device(device)) test_labelsz = torch.tensor(test_labels).to(torch.device(device)) acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size, skip_reparam=True) acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size, skip_reparam=True) acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size, skip_reparam=True) print() print("Experiment 4: Classification accuracies --- train acc {} valid acc {} test acc {}".format( acc_train_before, acc_valid_before, acc_test_before)) train_accs.append(acc_train_before) valid_accs.append(acc_valid_before) test_accs.append(acc_test_before) mixture.train() ################## # Training phase # ################## """ Learning each sub Network Generatively """ sub_net_parameters = None for einet in mixture.einets: if sub_net_parameters is None: sub_net_parameters = list(einet.parameters()) else: sub_net_parameters += list(einet.parameters()) sub_net_parameters += list(mixture.parameters()) optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate) start_time = time.time() end_time = time.time() for epoch_count in range(num_epochs): for (einet, c) in zip(einets, classes): train_x_c = train_x[[l == c for l in train_labels]] valid_x_c = valid_x[[l == c for l in valid_labels]] test_x_c = test_x[[l == c for l in test_labels]] train_N = train_x_c.shape[0] valid_N = valid_x_c.shape[0] test_N = test_x_c.shape[0] idx_batches = torch.randperm(train_N, device=device).split(batch_size) total_loss = 0.0 for idx in idx_batches: batch_x = train_x_c[idx, :] optimizer.zero_grad() outputs = einet.forward(batch_x) ll_sample = EinsumNetwork.log_likelihoods(outputs) log_likelihood = ll_sample.sum() nll = log_likelihood * -1 nll.backward() optimizer.step() total_loss += nll.detach().item() print(f'[{epoch_count}] total loss: {total_loss}') mixture.eval() train_N = train_x.shape[0] valid_N = valid_x.shape[0] test_N = test_x.shape[0] train_ll = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size) valid_ll = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size) test_ll = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size) train_lls.append(train_ll / train_N) valid_lls.append(valid_ll / valid_N) test_lls.append(test_ll / test_N) train_labelsz = torch.tensor(train_labels).to(torch.device(device)) valid_labelsz = torch.tensor(valid_labels).to(torch.device(device)) test_labelsz = torch.tensor(test_labels).to(torch.device(device)) acc_train = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size) acc_valid = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size) acc_test = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size) train_accs.append(acc_train) valid_accs.append(acc_valid) test_accs.append(acc_test) mixture.train() print() print("Experiment 3: Log-likelihoods --- train LL {} valid LL {} test LL {}".format( train_ll / train_N, valid_ll / valid_N, test_ll / test_N)) print() print("Experiment 4: Classification accuracies --- train acc {} valid acc {} test acc {}".format( acc_train, acc_valid, acc_test)) print(f'Network size: {num_params} parameters') print(f'Training time: {end_time - start_time}s') return { 'train_lls': train_lls, 'valid_lls': valid_lls, 'test_lls': test_lls, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'network_size': num_params, 'training_time': end_time - start_time, }
def new_start(start_train_set, online_offset): ############################################################################ fashion_mnist = settings.fashion_mnist svhn = settings.svhn exponential_family = settings.exponential_family classes = settings.classes K = settings.K structure = settings.structure # 'poon-domingos' pd_num_pieces = settings.pd_num_pieces # 'binary-trees' depth = settings.depth num_repetitions = settings.num_repetitions_mixture width = settings.width height = settings.height num_epochs = settings.num_epochs batch_size = settings.batch_size online_em_frequency = settings.online_em_frequency online_em_stepsize = settings.online_em_stepsize ############################################################################ exponential_family_args = None if exponential_family == EinsumNetwork.BinomialArray: exponential_family_args = {'N': 255} if exponential_family == EinsumNetwork.CategoricalArray: exponential_family_args = {'K': 256} if exponential_family == EinsumNetwork.NormalArray: exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} # get data if fashion_mnist: train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist() elif svhn: train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn() else: train_x, train_labels, test_x, test_labels = datasets.load_mnist() if not exponential_family != EinsumNetwork.NormalArray: train_x /= 255. test_x /= 255. train_x -= .5 test_x -= .5 # validation split valid_x = train_x[-10000:, :] # online_x = train_x[-40000:, :] train_x = train_x[:-(10000+online_offset-start_train_set), :] valid_labels = train_labels[-10000:] # online_labels = train_labels[-40000:] train_labels = train_labels[:-(10000+online_offset-start_train_set)] # # debug setup # valid_x = train_x[-10000:, :] # online_x = train_x[-45000:, :] # train_x = train_x[:-55000, :] # valid_labels = train_labels[-10000:] # online_labels = train_labels[-45000:] # train_labels = train_labels[:-55000] # valid_x = train_x[-10000:, :] # online_x = train_x[-10000:, :] # train_x = train_x[:-20000, :] # valid_labels = train_labels[-10000:] # online_labels = train_labels[-10000:] # train_labels = train_labels[:-20000] # valid_x = train_x[-10000:, :] # online_x = train_x[-20000:, :] # train_x = train_x[:-30000, :] # valid_labels = train_labels[-10000:] # online_labels = train_labels[-20000:] # train_labels = train_labels[:-30000] # pick the selected classes if classes is not None: train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :] # online_x = online_x[np.any(np.stack([online_labels == c for c in classes], 1), 1), :] valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :] test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :] train_labels = [l for l in train_labels if l in classes] # online_labels = [l for l in online_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] else: classes = np.unique(train_labels).tolist() train_labels = [l for l in train_labels if l in classes] # online_labels = [l for l in online_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] train_x = torch.from_numpy(train_x).to(torch.device(device)) # online_x = torch.from_numpy(online_x).to(torch.device(device)) valid_x = torch.from_numpy(valid_x).to(torch.device(device)) test_x = torch.from_numpy(test_x).to(torch.device(device)) ###################################### # Make EinsumNetworks for each class # ###################################### einets = [] ps = [] for c in classes: if structure == 'poon-domingos': pd_delta = [[height / d, width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args( num_var=train_x.shape[1], num_dims=3 if svhn else 1, num_classes=1, num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, online_em_frequency=online_em_frequency, online_em_stepsize=online_em_stepsize) einet = EinsumNetwork.EinsumNetwork(graph, args) init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c) einet.initialize(init_dict) einet.to(device) einets.append(einet) # Calculate amount of training samples per class ps.append(train_labels.count(c)) print(f'Einsum network for class {c}:') print(einet) # normalize ps, construct mixture component ps = [p / sum(ps) for p in ps] ps = torch.tensor(ps).to(torch.device(device)) mixture = EinetMixture(ps, einets, classes=classes) num_params = mixture.eval_size() ################################## # Evalueate after initialization # ################################## train_N = train_x.shape[0] valid_N = valid_x.shape[0] test_N = test_x.shape[0] mixture.eval() train_ll_before = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size) valid_ll_before = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size) test_ll_before = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size) print() print("Experiment 3: Log-likelihoods --- train LL {} valid LL {} test LL {}".format( train_ll_before / train_N, valid_ll_before / valid_N, test_ll_before / test_N)) train_lls.append(train_ll_before / train_N) valid_lls.append(valid_ll_before / valid_N) test_lls.append(test_ll_before / test_N) ################ # Experiment 4 # ################ train_labelsz = torch.tensor(train_labels).to(torch.device(device)) valid_labelsz = torch.tensor(valid_labels).to(torch.device(device)) test_labelsz = torch.tensor(test_labels).to(torch.device(device)) acc_train_before = mixture.eval_accuracy_batched(classes, train_x, train_labelsz, batch_size=batch_size) acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size) acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size) print() print("Experiment 8: Classification accuracies --- train acc {} valid acc {} test acc {}".format( acc_train_before, acc_valid_before, acc_test_before)) train_accs.append(acc_train_before) valid_accs.append(acc_valid_before) test_accs.append(acc_test_before)
train_x = train_x[np.any(np.stack([train_labels == c for c in classes], 1), 1), :] valid_x = valid_x[np.any(np.stack([valid_labels == c for c in classes], 1), 1), :] test_x = test_x[np.any(np.stack([test_labels == c for c in classes], 1), 1), :] train_x = torch.from_numpy(train_x).to(torch.device(device)) valid_x = torch.from_numpy(valid_x).to(torch.device(device)) test_x = torch.from_numpy(test_x).to(torch.device(device)) # Make EinsumNetwork ###################################### if structure == 'poon-domingos': pd_delta = [[height / d, width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': #graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) graph = Graph.random_binary_trees(num_var=num_var, depth=depth, num_repetitions=num_repetitions) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args( num_var=num_var, #train_x.shape[1], num_dims=2 if use_pair else 1, num_classes=1, num_sums=K, num_input_distributions=K, exponential_family=exponential_family,
def run_experiment(settings): ############################################################################ fashion_mnist = settings.fashion_mnist svhn = settings.svhn exponential_family = settings.exponential_family classes = settings.classes K = settings.K structure = settings.structure # 'poon-domingos' pd_num_pieces = settings.pd_num_pieces # 'binary-trees' depth = settings.depth num_repetitions_mixture = settings.num_repetitions_mixture width = settings.width height = settings.height num_epochs = settings.num_epochs batch_size = settings.batch_size SGD_learning_rate = settings.SGD_learning_rate ############################################################################ exponential_family_args = None if exponential_family == EinsumNetwork.BinomialArray: exponential_family_args = {'N': 255} if exponential_family == EinsumNetwork.CategoricalArray: exponential_family_args = {'K': 256} if exponential_family == EinsumNetwork.NormalArray: exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} # get data if fashion_mnist: train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist( ) elif svhn: train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn( ) else: train_x, train_labels, test_x, test_labels = datasets.load_mnist() if not exponential_family != EinsumNetwork.NormalArray: train_x /= 255. test_x /= 255. train_x -= .5 test_x -= .5 # validation split valid_x = train_x[-10000:, :] online_x = train_x[-11000:-10000, :] train_x = train_x[-56000:-11000, :] # init_x = train_x[-13000:-10000, :] valid_labels = train_labels[-10000:] online_labels = train_labels[-11000:-10000] train_labels = train_labels[-56000:-11000] # init_labels = train_labels[-13000:-10000] # full set of training # valid_x = train_x[-10000:, :] # online_x = train_x[-11000:-10000, :] # train_x = train_x[:-10000, :] # valid_labels = train_labels[-10000:] # online_labels = train_labels[-11000:-10000] # train_labels = train_labels[:-10000] # print('train_x:') # print(train_x.shape) # print(train_labels.shape) # print('online_x:') # print(online_x.shape) # print(online_labels.shape) # exit() # pick the selected classes if classes is not None: train_x = train_x[ np.any(np.stack([train_labels == c for c in classes], 1), 1), :] online_x = online_x[ np.any(np.stack([online_labels == c for c in classes], 1), 1), :] # init_x = init_x[np.any(np.stack([init_labels == c for c in classes], 1), 1), :] valid_x = valid_x[ np.any(np.stack([valid_labels == c for c in classes], 1), 1), :] test_x = test_x[ np.any(np.stack([test_labels == c for c in classes], 1), 1), :] train_labels = [l for l in train_labels if l in classes] train_labels_backup = train_labels online_labels = [l for l in online_labels if l in classes] # init_labels = [l for l in init_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] else: classes = np.unique(train_labels).tolist() train_labels = [l for l in train_labels if l in classes] train_labels_backup = train_labels online_labels = [l for l in online_labels if l in classes] # init_labels = [l for l in init_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] train_x = torch.from_numpy(train_x).to(torch.device(device)) train_x_backup = train_x online_x = torch.from_numpy(online_x).to(torch.device(device)) # init_x = torch.from_numpy(init_x).to(torch.device(device)) valid_x = torch.from_numpy(valid_x).to(torch.device(device)) test_x = torch.from_numpy(test_x).to(torch.device(device)) ###################################### # Make EinsumNetworks for each class # ###################################### einets = [] ps = [] for c in classes: if structure == 'poon-domingos': pd_delta = [[height / d, width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees( num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions_mixture) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args( num_var=train_x.shape[1], num_dims=3 if svhn else 1, num_classes=1, num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, use_em=False) einet = EinsumNetwork.EinsumNetwork(graph, args) # init_dict = get_init_dict(einet, init_x, train_labels=init_labels, einet_class=c) init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c) einet.initialize(init_dict) einet.to(device) einets.append(einet) # Calculate amount of training samples per class ps.append(train_labels.count(c)) print(f'Einsum network for class {c}:') print(einet) # normalize ps, construct mixture component ps = [p / sum(ps) for p in ps] ps = torch.tensor(ps).to(torch.device(device)) mixture = EinetMixture(ps, einets, classes=classes) num_params = mixture.eval_size() # data_dir = '../src/experiments/round5/data/weights_analysis/' # utils.mkdir_p(data_dir) # for (einet, c) in zip(einets, classes): # data_file = os.path.join(data_dir, f"weights_before_{c}.json") # weights = einet.einet_layers[-1].params.data.cpu() # np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0]) ################## # Training phase # ################## sub_net_parameters = None for einet in mixture.einets: if sub_net_parameters is None: sub_net_parameters = list(einet.parameters()) else: sub_net_parameters += list(einet.parameters()) sub_net_parameters += list(mixture.parameters()) optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate) start_time = time.time() """ Learning each sub Network Generatively """ for (einet, c) in zip(einets, classes): train_x_c = train_x[[l == c for l in train_labels]] train_N = train_x_c.shape[0] for epoch_count in range(num_epochs): idx_batches = torch.randperm(train_N, device=device).split(batch_size) total_loss = 0.0 for idx in idx_batches: batch_x = train_x_c[idx, :] optimizer.zero_grad() outputs = einet.forward(batch_x) ll_sample = EinsumNetwork.log_likelihoods(outputs) log_likelihood = ll_sample.sum() nll = log_likelihood * -1 nll.backward() optimizer.step() total_loss += nll.detach().item() print(f'[{epoch_count}] total log-likelihood: {total_loss}') # data_dir = '../src/experiments/round5/data/weights_analysis/' # utils.mkdir_p(data_dir) # for (einet, c) in zip(einets, classes): # data_file = os.path.join(data_dir, f"weights_after_{c}.json") # weights = einet.einet_layers[-1].params.data.cpu() # np.savetxt(data_file, einet.einet_layers[-1].reparam(weights)[0]) # exit() ################################## # Evalueate after initialization # ################################## train_lls = [] valid_lls = [] test_lls = [] train_accs = [] valid_accs = [] test_accs = [] train_lls_ref = [] valid_lls_ref = [] test_lls_ref = [] train_accs_ref = [] valid_accs_ref = [] test_accs_ref = [] added_samples = [0] def eval_network(do_print=False, no_OA=False): if no_OA: train_N = train_x_backup.shape[0] else: train_N = train_x.shape[0] valid_N = valid_x.shape[0] test_N = test_x.shape[0] mixture.eval() if no_OA: train_ll_before = mixture.eval_loglikelihood_batched( train_x_backup, batch_size=batch_size) else: train_ll_before = mixture.eval_loglikelihood_batched( train_x, batch_size=batch_size) valid_ll_before = mixture.eval_loglikelihood_batched( valid_x, batch_size=batch_size) test_ll_before = mixture.eval_loglikelihood_batched( test_x, batch_size=batch_size) if do_print: print() print( "Experiment 3: Log-likelihoods --- train LL {} valid LL {} test LL {}" .format(train_ll_before / train_N, valid_ll_before / valid_N, test_ll_before / test_N)) if no_OA: train_lls_ref.append(train_ll_before / train_N) valid_lls_ref.append(valid_ll_before / valid_N) test_lls_ref.append(test_ll_before / test_N) else: train_lls.append(train_ll_before / train_N) valid_lls.append(valid_ll_before / valid_N) test_lls.append(test_ll_before / test_N) ################ # Experiment 4 # ################ if no_OA: train_labelsz = torch.tensor(train_labels_backup).to( torch.device(device)) else: train_labelsz = torch.tensor(train_labels).to(torch.device(device)) valid_labelsz = torch.tensor(valid_labels).to(torch.device(device)) test_labelsz = torch.tensor(test_labels).to(torch.device(device)) if no_OA: acc_train_before = mixture.eval_accuracy_batched( classes, train_x_backup, train_labelsz, batch_size=batch_size) else: acc_train_before = mixture.eval_accuracy_batched( classes, train_x, train_labelsz, batch_size=batch_size) acc_valid_before = mixture.eval_accuracy_batched(classes, valid_x, valid_labelsz, batch_size=batch_size) acc_test_before = mixture.eval_accuracy_batched(classes, test_x, test_labelsz, batch_size=batch_size) if do_print: print() print( "Experiment 4: Classification accuracies --- train acc {} valid acc {} test acc {}" .format(acc_train_before, acc_valid_before, acc_test_before)) if no_OA: train_accs_ref.append(acc_train_before) valid_accs_ref.append(acc_valid_before) test_accs_ref.append(acc_test_before) else: train_accs.append(acc_train_before) valid_accs.append(acc_valid_before) test_accs.append(acc_test_before) mixture.train() eval_network(do_print=True, no_OA=False) eval_network(do_print=False, no_OA=True) ##################################################### # Evaluate the network with different training sets # ##################################################### idx_batches = torch.randperm(online_x.shape[0], device=device).split(20) for idx in tqdm(idx_batches): online_x_idx = online_x[idx] online_labels_idx = [online_labels[i] for i in idx] for (einet, c) in zip(einets, classes): batch_x = online_x_idx[[l == c for l in online_labels_idx]] train_x_backup = torch.cat((train_x_backup, batch_x)) train_labels_backup += [c for i in batch_x] added_samples.append(added_samples[-1] + len(idx)) eval_network(do_print=False, no_OA=True) ##################### # Online adaptation # ##################### for idx in tqdm(idx_batches): online_x_idx = online_x[idx] online_labels_idx = [online_labels[i] for i in idx] for (einet, c) in zip(einets, classes): batch_x = online_x_idx[[l == c for l in online_labels_idx]] online_update(einet, batch_x) train_x = torch.cat((train_x, batch_x)) train_labels += [c for i in batch_x] eval_network(do_print=False, no_OA=False) print() print(f'Network size: {num_params} parameters') return { 'train_lls': train_lls, 'valid_lls': valid_lls, 'test_lls': test_lls, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'train_lls_ref': train_lls_ref, 'valid_lls_ref': valid_lls_ref, 'test_lls_ref': test_lls_ref, 'train_accs_ref': train_accs_ref, 'valid_accs_ref': valid_accs_ref, 'test_accs_ref': test_accs_ref, 'network_size': num_params, 'online_samples': added_samples, }
def run_experiment(settings): ############################################################################ fashion_mnist = settings.fashion_mnist svhn = settings.svhn exponential_family = settings.exponential_family classes = settings.classes K = settings.K structure = settings.structure # 'poon-domingos' pd_num_pieces = settings.pd_num_pieces # 'binary-trees' depth = settings.depth num_repetitions = settings.num_repetitions width = settings.width height = settings.height num_epochs = settings.num_epochs batch_size = settings.batch_size online_em_frequency = settings.online_em_frequency online_em_stepsize = settings.online_em_stepsize SGD_learning_rate = settings.SGD_learning_rate ############################################################################ exponential_family_args = None if exponential_family == EinsumNetwork.BinomialArray: exponential_family_args = {'N': 255} if exponential_family == EinsumNetwork.CategoricalArray: exponential_family_args = {'K': 256} if exponential_family == EinsumNetwork.NormalArray: exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} # get data if fashion_mnist: train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist( ) elif svhn: train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn( ) else: train_x, train_labels, test_x, test_labels = datasets.load_mnist() if not exponential_family != EinsumNetwork.NormalArray: train_x /= 255. test_x /= 255. train_x -= .5 test_x -= .5 # validation split valid_x = train_x[-10000:, :] train_x = train_x[:-10000, :] valid_labels = train_labels[-10000:] train_labels = train_labels[:-10000] # pick the selected classes if classes is not None: train_x = train_x[ np.any(np.stack([train_labels == c for c in classes], 1), 1), :] valid_x = valid_x[ np.any(np.stack([valid_labels == c for c in classes], 1), 1), :] test_x = test_x[ np.any(np.stack([test_labels == c for c in classes], 1), 1), :] train_labels = [l for l in train_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] else: classes = np.unique(train_labels).tolist() train_labels = [l for l in train_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] train_x = torch.from_numpy(train_x).to(torch.device(device)) valid_x = torch.from_numpy(valid_x).to(torch.device(device)) test_x = torch.from_numpy(test_x).to(torch.device(device)) # Make EinsumNetwork ###################################### if structure == 'poon-domingos': pd_delta = [[height / d, width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args(num_var=train_x.shape[1], num_dims=3 if svhn else 1, num_classes=1, num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, online_em_frequency=online_em_frequency, online_em_stepsize=online_em_stepsize, use_em=True) einet = EinsumNetwork.EinsumNetwork(graph, args) print(einet) init_dict = get_init_dict(einet, train_x) einet.initialize(init_dict) einet.to(device) num_params = EinsumNetwork.eval_size(einet) data_dir = '../src/experiments/round5/data/weights_analysis/' data_file = os.path.join(data_dir, f"weights_before.json") weights = einet.einet_layers[-1].params.data.cpu() np.savetxt(data_file, weights[0]) # Train ###################################### optimizer = torch.optim.SGD(einet.parameters(), lr=SGD_learning_rate) train_N = train_x.shape[0] valid_N = valid_x.shape[0] test_N = test_x.shape[0] start_time = time.time() for epoch_count in range(num_epochs): idx_batches = torch.randperm(train_N, device=device).split(batch_size) total_loss = 0.0 for idx in idx_batches: batch_x = train_x[idx, :] # optimizer.zero_grad() outputs = einet.forward(batch_x) ll_sample = EinsumNetwork.log_likelihoods(outputs) log_likelihood = ll_sample.sum() log_likelihood.backward() # nll = log_likelihood * -1 # nll.backward() # optimizer.step() einet.em_process_batch() einet.em_update() print(f'[{epoch_count}] total loss: {total_loss}') end_time = time.time() data_dir = '../src/experiments/round5/data/weights_analysis/' data_file = os.path.join(data_dir, f"weights_after.json") weights = einet.einet_layers[-1].params.data.cpu() np.savetxt(data_file, weights[0]) # exit() ################ # Experiment 1 # ################ einet.eval() train_ll = EinsumNetwork.eval_loglikelihood_batched(einet, train_x, batch_size=batch_size) valid_ll = EinsumNetwork.eval_loglikelihood_batched(einet, valid_x, batch_size=batch_size) test_ll = EinsumNetwork.eval_loglikelihood_batched(einet, test_x, batch_size=batch_size) print() print( "Experiment 1: Log-likelihoods --- train LL {} valid LL {} test LL {}" .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N)) ################ # Experiment 2 # ################ train_labels = torch.tensor(train_labels).to(torch.device(device)) valid_labels = torch.tensor(valid_labels).to(torch.device(device)) test_labels = torch.tensor(test_labels).to(torch.device(device)) acc_train = EinsumNetwork.eval_accuracy_batched(einet, classes, train_x, train_labels, batch_size=batch_size) acc_valid = EinsumNetwork.eval_accuracy_batched(einet, classes, valid_x, valid_labels, batch_size=batch_size) acc_test = EinsumNetwork.eval_accuracy_batched(einet, classes, test_x, test_labels, batch_size=batch_size) print() print( "Experiment 2: Classification accuracies --- train acc {} valid acc {} test acc {}" .format(acc_train, acc_valid, acc_test)) print() print(f'Network size: {num_params} parameters') print(f'Training time: {end_time - start_time}s') return { 'train_ll': train_ll / train_N, 'valid_ll': valid_ll / valid_N, 'test_ll': test_ll / test_N, 'train_acc': acc_train, 'valid_acc': acc_valid, 'test_acc': acc_test, 'network_size': num_params, 'training_time': end_time - start_time, }
def run_experiment(settings): ############################################################################ exponential_family = EinsumNetwork.BinomialArray K = 1 # structure = 'poon-domingos' structure = 'binary-trees' # 'poon-domingos' pd_num_pieces = [2] # 'binary-trees' depth = 1 num_repetitions = 2 width = 4 num_epochs = 2 batch_size = 3 online_em_frequency = 1 online_em_stepsize = 0.05 SGD_learning_rate = 0.05 ############################################################################ exponential_family_args = None if exponential_family == EinsumNetwork.BinomialArray: exponential_family_args = {'N': 80} if exponential_family == EinsumNetwork.CategoricalArray: exponential_family_args = {'K': 256} if exponential_family == EinsumNetwork.NormalArray: exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} iris = datasets.load_iris() train_x = iris.data * 10 train_labels = iris.target # print(train_x[0]) # print(train_labels[0]) # exit() # self generated data # train_x = np.array([np.array([1, 1, 1, 1]) for i in range(50)] + [np.array([3, 3, 3, 3]) for i in range(50)] + [np.array([8, 8, 8, 8]) for i in range(50)]) # train_labels = np.array([0 for i in range(50)] + [1 for i in range(50)] + [2 for i in range(50)]) if not exponential_family != EinsumNetwork.NormalArray: train_x /= 255. train_x -= .5 classes = np.unique(train_labels).tolist() train_x = torch.from_numpy(train_x).to(torch.device(device)) # Make EinsumNetwork ###################################### einets = [] ps = [] for c in classes: if structure == 'poon-domingos': pd_delta = [[width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args( num_var=train_x.shape[1], num_dims=1, num_classes=1, num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, online_em_frequency=online_em_frequency, online_em_stepsize=online_em_stepsize) einet = EinsumNetwork.EinsumNetwork(graph, args) print(einet) init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c) einet.initialize(init_dict) einet.to(device) einet = einet.float() einets.append(einet) ps.append(np.count_nonzero(train_labels == c)) ps = [p / sum(ps) for p in ps] ps = torch.tensor(ps).to(torch.device(device)) mixture = EinetMixture(ps, einets, classes=classes) num_params = mixture.eval_size() # Train ###################################### """ Generative training """ start_time = time.time() for (einet, c) in zip(einets, classes): train_x_c = train_x[[l == c for l in train_labels]] train_N = train_x_c.shape[0] for epoch_count in range(num_epochs): idx_batches = torch.randperm(train_N, device=device).split(batch_size) total_ll = 0.0 for idx in idx_batches: batch_x = train_x_c[idx, :].float() outputs = einet.forward(batch_x) ll_sample = EinsumNetwork.log_likelihoods(outputs) log_likelihood = ll_sample.sum() log_likelihood.backward() einet.em_process_batch() total_ll += log_likelihood.detach().item() einet.em_update() print( f'[{epoch_count}] total log-likelihood: {total_ll/train_N}') end_time = time.time() """ Discriminative training """ def discriminative_learning(): sub_net_parameters = None for einet in mixture.einets: if sub_net_parameters is None: sub_net_parameters = list(einet.parameters()) else: sub_net_parameters += list(einet.parameters()) sub_net_parameters += list(mixture.parameters()) optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate) loss_function = torch.nn.CrossEntropyLoss() train_N = train_x.shape[0] start_time = time.time() for epoch_count in range(num_epochs): idx_batches = torch.randperm(train_N, device=device).split(batch_size) total_loss = 0 for idx in idx_batches: batch_x = train_x[idx, :].float() optimizer.zero_grad() outputs = mixture.forward(batch_x) target = torch.tensor([ classes.index(train_labels[i]) for i in idx ]).to(torch.device(device)) loss = loss_function(outputs, target) loss.backward() optimizer.step() total_loss += loss.detach().item() print(f'[{epoch_count}] total loss: {total_loss}') end_time = time.time() ################ # Experiment 1 # ################ mixture.eval() train_ll = EinsumNetwork.eval_loglikelihood_batched(mixture, train_x.float(), batch_size=1) print() print("Experiment 1: Log-likelihoods --- train LL {}".format(train_ll / train_N)) ################ # Experiment 2 # ################ train_labels = torch.tensor(train_labels).to(torch.device(device)) acc_train = EinsumNetwork.eval_accuracy_batched(mixture, classes, train_x.float(), train_labels, batch_size=1) print() print("Experiment 2: Classification accuracies --- train acc {}".format( acc_train)) print() print(f'Network size: {num_params} parameters') print(f'Training time: {end_time - start_time}s') return { 'train_ll': train_ll / train_N, 'train_acc': acc_train, 'network_size': num_params, 'training_time': end_time - start_time, }
train_x_orig, test_x_orig, valid_x_orig = datasets.load_debd(dataset, dtype='float32') train_x = train_x_orig test_x = test_x_orig valid_x = valid_x_orig # to torch train_x = torch.from_numpy(train_x).to(torch.device(device)) valid_x = torch.from_numpy(valid_x).to(torch.device(device)) test_x = torch.from_numpy(test_x).to(torch.device(device)) train_N, num_dims = train_x.shape valid_N = valid_x.shape[0] test_N = test_x.shape[0] graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) args = EinsumNetwork.Args( num_classes=1, num_input_distributions=num_input_distributions, exponential_family=EinsumNetwork.CategoricalArray, exponential_family_args={'K': 2}, num_sums=num_sums, num_var=train_x.shape[1], online_em_frequency=1, online_em_stepsize=0.05) einet = EinsumNetwork.EinsumNetwork(graph, args) einet.initialize() einet.to(device) print(einet)
def train(einet, mean, train_x, valid_x, test_x, result_path): model_file = os.path.join(result_path, 'einet.mdl') graph_file = os.path.join(result_path, 'einet.pc') record_file = os.path.join(result_path, 'record.pkl') sample_dir = os.path.join(result_path, 'samples') utils.mkdir_p(sample_dir) record = { 'train_ll': [], 'valid_ll': [], 'test_ll': [], 'best_validation_ll': None } for epoch_count in range(num_epochs): shuffled_batch = make_shuffled_batch(len(train_x), batch_size) for batch_counter, batch_idx in enumerate(shuffled_batch): batch = torch.tensor(train_x[batch_idx, :]).to(device).float() batch = batch.reshape(batch.shape[0], height * width, 3) # we subtract the mean for this cluster -- centered data seems to help EM learning # we will re-add the mean to the Gaussian means below batch = batch - mean batch = batch / 255. ll_sample = einet.forward(batch) log_likelihood = ll_sample.sum() log_likelihood.backward() einet.em_process_batch() einet.em_update() ##### evaluate train_ll = eval_ll(einet, mean, train_x, batch_size=batch_size) valid_ll = eval_ll(einet, mean, valid_x, batch_size=batch_size) test_ll = eval_ll(einet, mean, test_x, batch_size=batch_size) ##### store results record['train_ll'].append(train_ll) record['valid_ll'].append(valid_ll) record['test_ll'].append(test_ll) pickle.dump(record, open(record_file, 'wb')) print("[{}] train LL {} valid LL {} test LL {}".format( epoch_count, train_ll, valid_ll, test_ll)) if record['best_validation_ll'] is None or valid_ll > record[ 'best_validation_ll']: record['best_validation_ll'] = valid_ll torch.save(einet, model_file) Graph.write_gpickle(graph, graph_file) if epoch_count % 10 == 0: # draw some samples samples = einet.sample(num_samples=25, std_correction=0.0).cpu().numpy() samples = samples + mean.detach().cpu().numpy() / 255. samples -= samples.min() samples /= samples.max() samples = samples.reshape(samples.shape[0], height, width, 3) img = np.zeros((height * 5 + 40, width * 5 + 40, 3)) for h in range(5): for w in range(5): img[h * (height + 10):h * (height + 10) + height, w * (width + 10):w * (width + 10) + width, :] = samples[h * 5 + w, :] img = Image.fromarray(np.round(img * 255.).astype(np.uint8)) img.save( os.path.join(sample_dir, "samples{}.jpg".format(epoch_count))) # We subtract the mean for the current cluster from the data (centering it at 0). # Here we re-add the mean to the Gaussian means. A hacky solution at the moment... einet = torch.load(model_file) with torch.no_grad(): params = einet.einet_layers[0].ef_array.params mu2 = params[..., 0:3]**2 params[..., 3:] -= mu2 params[..., 3:] = torch.clamp(params[..., 3:], exponential_family_args['min_var'], exponential_family_args['max_var']) params[..., 0:3] += mean.reshape((width * height, 1, 1, 3)) / 255. params[..., 3:] += params[..., 0:3]**2 torch.save(einet, model_file)
for cluster_n in range(num_clusters): train_x = train_x_all[cluster_idx == cluster_n, ...] valid_x = valid_x_all[valid_cluster_idx == cluster_n, ...] test_x = test_x_all[test_cluster_idx == cluster_n, ...] mean = cluster_means[cluster_n, ...] mean = mean.reshape(1, height * width, 3) mean = torch.tensor(mean, device=device) result_path = result_base_path result_path = os.path.join(result_path, "num_clusters_{}".format(num_clusters)) result_path = os.path.join(result_path, "cluster_{}".format(cluster_n)) graph = Graph.poon_domingos_structure(shape=(height, width), axes=[1], delta=[8]) args = EinsumNetwork.Args(num_var=height * width, num_dims=3, num_classes=1, num_sums=num_sums, num_input_distributions=num_sums, exponential_family=exponential_family, exponential_family_args=exponential_family_args, online_em_frequency=online_em_frequency, online_em_stepsize=online_em_stepsize) print() print(result_path)
def run_experiment(settings): ############################################################################ exponential_family = EinsumNetwork.BinomialArray K = 2 # structure = 'poon-domingos' structure = 'binary-trees' # 'poon-domingos' pd_num_pieces = [2] # 'binary-trees' depth = 1 num_repetitions = 2 width = 4 num_epochs = 2 batch_size = 3 online_em_frequency = 1 online_em_stepsize = 0.05 print_weights = False print_weights = True ############################################################################ exponential_family_args = None if exponential_family == EinsumNetwork.BinomialArray: exponential_family_args = {'N': 80} if exponential_family == EinsumNetwork.CategoricalArray: exponential_family_args = {'K': 256} if exponential_family == EinsumNetwork.NormalArray: exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} iris = datasets.load_iris() train_x = iris.data * 10 train_labels = iris.target # self generated data # train_x = np.array([np.array([1, 1, 1, 1]) for i in range(50)] + [np.array([3, 3, 3, 3]) for i in range(50)] + [np.array([8, 8, 8, 8]) for i in range(50)]) # train_labels = [0 for i in range(50)] + [1 for i in range(50)] + [2 for i in range(50)] if not exponential_family != EinsumNetwork.NormalArray: train_x /= 255. train_x -= .5 classes = np.unique(train_labels).tolist() train_x = torch.from_numpy(train_x).to(torch.device(device)) # Make EinsumNetwork ###################################### if structure == 'poon-domingos': pd_delta = [[width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args( num_var=train_x.shape[1], num_dims=1, num_classes=1, num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, online_em_frequency=online_em_frequency, online_em_stepsize=online_em_stepsize) einet = EinsumNetwork.EinsumNetwork(graph, args) print(einet) init_dict = get_init_dict(einet, train_x) einet.initialize(init_dict) einet.to(device) einet = einet.float() num_params = EinsumNetwork.eval_size(einet) # Train ###################################### train_N = train_x.shape[0] start_time = time.time() for epoch_count in range(num_epochs): idx_batches = torch.randperm(train_N, device=device).split(batch_size) total_ll = 0.0 for idx in idx_batches: batch_x = train_x[idx, :].float() # exit() outputs = einet.forward(batch_x) ll_sample = EinsumNetwork.log_likelihoods(outputs) log_likelihood = ll_sample.sum() log_likelihood.backward() einet.em_process_batch() total_ll += log_likelihood.detach().item() einet.em_update() print(f'[{epoch_count}] total log-likelihood: {total_ll/train_N}') end_time = time.time() ################ # Experiment 1 # ################ einet.eval() train_ll = EinsumNetwork.eval_loglikelihood_batched(einet, train_x.float(), batch_size=batch_size) print() print("Experiment 1: Log-likelihoods --- train LL {}".format( train_ll / train_N)) ################ # Experiment 2 # ################ train_labels = torch.tensor(train_labels).to(torch.device(device)) acc_train = EinsumNetwork.eval_accuracy_batched(einet, classes, train_x.float(), train_labels, batch_size=batch_size) print() print("Experiment 2: Classification accuracies --- train acc {}".format( acc_train)) print() print(f'Network size: {num_params} parameters') print(f'Training time: {end_time - start_time}s') if print_weights: for l in einet.einet_layers: print() if isinstance(l, FactorizedLeafLayer.FactorizedLeafLayer): print(l.ef_array.params) else: print(l.params) return { 'train_ll': train_ll / train_N, 'train_acc': acc_train, 'network_size': num_params, 'training_time': end_time - start_time, }
""" A simple initializer for normalized sum-weights. :return: initial parameters """ params = 0.01 + 0.98 * torch.rand(layer.params_shape) with torch.no_grad(): if layer.params_mask is not None: params.data *= layer.params_mask params.data = params.data / (params.data.sum(layer.normalization_dims, keepdim=True)) return params if structure == 'poon-domingos': pd_delta = [[height / d, width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees(num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args(num_var=train_x.shape[1], num_dims=1, num_classes=len(classes), num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, use_em=False)
def run_experiment(settings): ############################################################################ fashion_mnist = settings.fashion_mnist svhn = settings.svhn exponential_family = settings.exponential_family classes = settings.classes K = settings.K structure = settings.structure # 'poon-domingos' pd_num_pieces = settings.pd_num_pieces # 'binary-trees' depth = settings.depth num_repetitions_mixture = settings.num_repetitions_mixture width = settings.width height = settings.height num_epochs = settings.num_epochs batch_size = settings.batch_size online_em_frequency = settings.online_em_frequency online_em_stepsize = settings.online_em_stepsize SGD_learning_rate = settings.SGD_learning_rate ############################################################################ exponential_family_args = None if exponential_family == EinsumNetwork.BinomialArray: exponential_family_args = {'N': 255} if exponential_family == EinsumNetwork.CategoricalArray: exponential_family_args = {'K': 256} if exponential_family == EinsumNetwork.NormalArray: exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1} # get data if fashion_mnist: train_x, train_labels, test_x, test_labels = datasets.load_fashion_mnist( ) elif svhn: train_x, train_labels, test_x, test_labels, extra_x, extra_labels = datasets.load_svhn( ) else: train_x, train_labels, test_x, test_labels = datasets.load_mnist() if not exponential_family != EinsumNetwork.NormalArray: train_x /= 255. test_x /= 255. train_x -= .5 test_x -= .5 # validation split valid_x = train_x[-10000:, :] train_x = train_x[:-10000, :] valid_labels = train_labels[-10000:] train_labels = train_labels[:-10000] # valid_x = train_x[-10000:, :] # train_x = train_x[-12000:-11000, :] # valid_labels = train_labels[-10000:] # train_labels = train_labels[-12000:-11000] # pick the selected classes if classes is not None: train_x = train_x[ np.any(np.stack([train_labels == c for c in classes], 1), 1), :] valid_x = valid_x[ np.any(np.stack([valid_labels == c for c in classes], 1), 1), :] test_x = test_x[ np.any(np.stack([test_labels == c for c in classes], 1), 1), :] train_labels = [l for l in train_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] else: classes = np.unique(train_labels).tolist() train_labels = [l for l in train_labels if l in classes] valid_labels = [l for l in valid_labels if l in classes] test_labels = [l for l in test_labels if l in classes] train_x = torch.from_numpy(train_x).to(torch.device(device)) valid_x = torch.from_numpy(valid_x).to(torch.device(device)) test_x = torch.from_numpy(test_x).to(torch.device(device)) ###################################### # Make EinsumNetworks for each class # ###################################### einets = [] ps = [] for c in classes: if structure == 'poon-domingos': pd_delta = [[height / d, width / d] for d in pd_num_pieces] graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta) elif structure == 'binary-trees': graph = Graph.random_binary_trees( num_var=train_x.shape[1], depth=depth, num_repetitions=num_repetitions_mixture) else: raise AssertionError("Unknown Structure") args = EinsumNetwork.Args( num_var=train_x.shape[1], num_dims=3 if svhn else 1, num_classes=1, num_sums=K, num_input_distributions=K, exponential_family=exponential_family, exponential_family_args=exponential_family_args, online_em_frequency=online_em_frequency, online_em_stepsize=online_em_stepsize, use_em=False) einet = EinsumNetwork.EinsumNetwork(graph, args) init_dict = get_init_dict(einet, train_x, train_labels=train_labels, einet_class=c) einet.initialize(init_dict) einet.to(device) einets.append(einet) # Calculate amount of training samples per class ps.append(train_labels.count(c)) print(f'Einsum network for class {c}:') print(einet) # normalize ps, construct mixture component ps = [p / sum(ps) for p in ps] ps = torch.tensor(ps).to(torch.device(device)) mixture = EinetMixture(ps, einets, classes=classes) num_params = mixture.eval_size() """Code for weight analysis, section 7.3""" # data_dir = '../src/experiments/round5/data/weights_analysis/' # utils.mkdir_p(data_dir) # data_file = os.path.join(data_dir, f"weights_before_{c}.json") # weights = {} # for (einet, c) in zip(einets, classes): # einet_weights = {} # for i, l in enumerate(reversed(einet.einet_layers)): # if type(l) != FactorizedLeafLayer.FactorizedLeafLayer: # einet_weights[i] = l.reparam(l.params.data) # else: # einet_weights[i] = l.ef_array.reparam(l.ef_array.params.data) # weights[c] = einet_weights ################## # Training phase # ################## sub_net_parameters = None for einet in mixture.einets: if sub_net_parameters is None: sub_net_parameters = list(einet.parameters()) else: sub_net_parameters += list(einet.parameters()) sub_net_parameters += list(mixture.parameters()) optimizer = torch.optim.SGD(sub_net_parameters, lr=SGD_learning_rate) """ Learning each sub Network Generatively """ start_time = time.time() for (einet, c) in zip(einets, classes): train_x_c = train_x[[l == c for l in train_labels]] valid_x_c = valid_x[[l == c for l in valid_labels]] test_x_c = test_x[[l == c for l in test_labels]] train_N = train_x_c.shape[0] valid_N = valid_x_c.shape[0] test_N = test_x_c.shape[0] for epoch_count in range(num_epochs): idx_batches = torch.randperm(train_N, device=device).split(batch_size) total_loss = 0.0 for idx in idx_batches: batch_x = train_x_c[idx, :] optimizer.zero_grad() outputs = einet.forward(batch_x) ll_sample = EinsumNetwork.log_likelihoods(outputs) log_likelihood = ll_sample.sum() # log_likelihood.backward() nll = log_likelihood * -1 nll.backward() optimizer.step() total_loss += nll.detach().item() # einet.em_process_batch() # einet.em_update() print(f'[{epoch_count}] total loss: {total_loss}') end_time = time.time() """Code for weight analysis""" # data_dir = '../src/experiments/round5/data/weights_analysis/' # utils.mkdir_p(data_dir) # percentual_change = [] # percentual_change_formatted = [] # data_file = os.path.join(data_dir, f"percentual_change.json") # for (einet, c) in zip(einets, classes): # einet_change = {} # for i, l in enumerate(reversed(einet.einet_layers)): # if type(l) != FactorizedLeafLayer.FactorizedLeafLayer: # weights_new = l.reparam(l.params.data) # else: # weights_new = l.ef_array.reparam(l.ef_array.params.data) # change = torch.mean(torch.abs(weights[c][i] - weights_new)/weights[c][i]).item() # einet_change[i] = change # if i == 0: # percentual_change_formatted.append(change) # percentual_change.append(einet_change) # print(percentual_change) # with open(data_file, 'w') as f: # json.dump(percentual_change, f) # data_file = os.path.join(data_dir, f"percentual_change_formatted.json") # with open(data_file, 'w') as f: # json.dump(percentual_change_formatted, f) ################ # Experiment 3 # ################ train_N = train_x.shape[0] valid_N = valid_x.shape[0] test_N = test_x.shape[0] mixture.eval() train_ll = mixture.eval_loglikelihood_batched(train_x, batch_size=batch_size, skip_reparam=True) valid_ll = mixture.eval_loglikelihood_batched(valid_x, batch_size=batch_size, skip_reparam=True) test_ll = mixture.eval_loglikelihood_batched(test_x, batch_size=batch_size, skip_reparam=True) print() print( "Experiment 3: Log-likelihoods --- train LL {} valid LL {} test LL {}" .format(train_ll / train_N, valid_ll / valid_N, test_ll / test_N)) ################ # Experiment 4 # ################ train_labels = torch.tensor(train_labels).to(torch.device(device)) valid_labels = torch.tensor(valid_labels).to(torch.device(device)) test_labels = torch.tensor(test_labels).to(torch.device(device)) acc_train = mixture.eval_accuracy_batched(classes, train_x, train_labels, batch_size=batch_size, skip_reparam=True) acc_valid = mixture.eval_accuracy_batched(classes, valid_x, valid_labels, batch_size=batch_size, skip_reparam=True) acc_test = mixture.eval_accuracy_batched(classes, test_x, test_labels, batch_size=batch_size, skip_reparam=True) print() print( "Experiment 4: Classification accuracies --- train acc {} valid acc {} test acc {}" .format(acc_train, acc_valid, acc_test)) print() print(f'Network size: {num_params} parameters') print(f'Training time: {end_time - start_time}s') return { 'train_ll': train_ll / train_N, 'valid_ll': valid_ll / valid_N, 'test_ll': test_ll / test_N, 'train_acc': acc_train, 'valid_acc': acc_valid, 'test_acc': acc_test, 'network_size': num_params, 'training_time': end_time - start_time, }