def save_einet(einet, graph, out_path): # save model os.makedirs(out_path, exist_ok=True) graph_file = os.path.join(out_path, "einet.rg") Graph.write_gpickle(graph, graph_file) print("Saved PC graph to {}".format(graph_file)) model_file = os.path.join(out_path, "einet.pth") torch.save(einet, model_file) print("Saved model to {}".format(model_file))
#################### # save and re-load # #################### # evaluate log-likelihoods einet.eval() train_ll_before = EinsumNetwork.eval_loglikelihood_batched( einet, train_x, batch_size=batch_size) valid_ll_before = EinsumNetwork.eval_loglikelihood_batched( einet, valid_x, batch_size=batch_size) test_ll_before = EinsumNetwork.eval_loglikelihood_batched( einet, test_x, batch_size=batch_size) # save model graph_file = os.path.join(model_dir, "einet.pc") Graph.write_gpickle(graph, graph_file) print("Saved PC graph to {}".format(graph_file)) model_file = os.path.join(model_dir, "einet.mdl") torch.save(einet, model_file) print("Saved model to {}".format(model_file)) del einet # reload model einet = torch.load(model_file) print("Loaded model from {}".format(model_file)) # evaluate log-likelihoods on re-loaded model train_ll = EinsumNetwork.eval_loglikelihood_batched(einet, train_x, batch_size=batch_size)
def train(einet, mean, train_x, valid_x, test_x, result_path): model_file = os.path.join(result_path, 'einet.mdl') graph_file = os.path.join(result_path, 'einet.pc') record_file = os.path.join(result_path, 'record.pkl') sample_dir = os.path.join(result_path, 'samples') utils.mkdir_p(sample_dir) record = { 'train_ll': [], 'valid_ll': [], 'test_ll': [], 'best_validation_ll': None } for epoch_count in range(num_epochs): shuffled_batch = make_shuffled_batch(len(train_x), batch_size) for batch_counter, batch_idx in enumerate(shuffled_batch): batch = torch.tensor(train_x[batch_idx, :]).to(device).float() batch = batch.reshape(batch.shape[0], height * width, 3) # we subtract the mean for this cluster -- centered data seems to help EM learning # we will re-add the mean to the Gaussian means below batch = batch - mean batch = batch / 255. ll_sample = einet.forward(batch) log_likelihood = ll_sample.sum() log_likelihood.backward() einet.em_process_batch() einet.em_update() ##### evaluate train_ll = eval_ll(einet, mean, train_x, batch_size=batch_size) valid_ll = eval_ll(einet, mean, valid_x, batch_size=batch_size) test_ll = eval_ll(einet, mean, test_x, batch_size=batch_size) ##### store results record['train_ll'].append(train_ll) record['valid_ll'].append(valid_ll) record['test_ll'].append(test_ll) pickle.dump(record, open(record_file, 'wb')) print("[{}] train LL {} valid LL {} test LL {}".format( epoch_count, train_ll, valid_ll, test_ll)) if record['best_validation_ll'] is None or valid_ll > record[ 'best_validation_ll']: record['best_validation_ll'] = valid_ll torch.save(einet, model_file) Graph.write_gpickle(graph, graph_file) if epoch_count % 10 == 0: # draw some samples samples = einet.sample(num_samples=25, std_correction=0.0).cpu().numpy() samples = samples + mean.detach().cpu().numpy() / 255. samples -= samples.min() samples /= samples.max() samples = samples.reshape(samples.shape[0], height, width, 3) img = np.zeros((height * 5 + 40, width * 5 + 40, 3)) for h in range(5): for w in range(5): img[h * (height + 10):h * (height + 10) + height, w * (width + 10):w * (width + 10) + width, :] = samples[h * 5 + w, :] img = Image.fromarray(np.round(img * 255.).astype(np.uint8)) img.save( os.path.join(sample_dir, "samples{}.jpg".format(epoch_count))) # We subtract the mean for the current cluster from the data (centering it at 0). # Here we re-add the mean to the Gaussian means. A hacky solution at the moment... einet = torch.load(model_file) with torch.no_grad(): params = einet.einet_layers[0].ef_array.params mu2 = params[..., 0:3]**2 params[..., 3:] -= mu2 params[..., 3:] = torch.clamp(params[..., 3:], exponential_family_args['min_var'], exponential_family_args['max_var']) params[..., 0:3] += mean.reshape((width * height, 1, 1, 3)) / 255. params[..., 3:] += params[..., 0:3]**2 torch.save(einet, model_file)