# sys.exit() else: os.mkdir('./figs/' + name) os.mkdir('./losses/' + name) os.mkdir('./models/' + name) del onlydirs f = open("args/" + name + ".txt", "w+") f.write(str(locals())) f.close() #Change to True !! X = MNISTGraphDataset(num_hits, train=TRAIN, num=NUM, intensities=INTENSITIES, mnist8m=MNIST8M) X_loaded = DataLoader(X, shuffle=True, batch_size=batch_size) if (LOAD_MODEL): start_epoch = 255 G = torch.load("models/" + name + "/G_" + str(start_epoch) + ".pt") D = torch.load("models/" + name + "/D_" + str(start_epoch) + ".pt") else: start_epoch = 0 G = Simple_GRU(node_size, fe_out_size, gru_hidden_size, gru_num_layers, num_iters,
gru_hidden_size = 100 gru_num_layers = 3 dropout = 0.3 batch_size = 1024 num_thresholded = 100 gen_in_dim = 100 lr = 0.00005 lr_disc = 0.0001 lr_gen = 0.00005 num_critic = 1 weight_clipping_limit = 1 torch.manual_seed(4) #Change to True !! X = MNISTGraphDataset(num_thresholded, train=True) X_loaded = DataLoader(X, shuffle=True, batch_size=batch_size) name = "22_wgan" if (LOAD_MODEL): start_epoch = 10 G = torch.load("models/" + name + "_G_" + str(start_epoch) + ".pt") D = torch.load("models/" + name + "_D_" + str(start_epoch) + ".pt") else: start_epoch = 0 G = Simple_GRU(input_size, output_size, gen_in_dim, gru_hidden_size, gru_num_layers, dropout, batch_size).cuda() D = Critic((num_thresholded, input_size), dropout, batch_size, wgan=True).cuda()
def main(args): args = init(args) def pf(data): return data.y == args.num pre_filter = pf if args.num != -1 else None print("loading data") if (args.sparse_mnist): X = MNISTGraphDataset(args.dataset_path, args.num_hits, train=args.train, num=args.num) X_loaded = DataLoader(X, shuffle=True, batch_size=args.batch_size, pin_memory=True) else: if (args.gcnn): X = MNISTSuperpixels(args.dir_path, train=args.train, pre_transform=T.Cartesian(), pre_filter=pre_filter) X_loaded = tgDataLoader(X, shuffle=True, batch_size=args.batch_size) else: X = SuperpixelsDataset(args.dataset_path, args.num_hits, train=args.train, num=args.num) X_loaded = DataLoader(X, shuffle=True, batch_size=args.batch_size, pin_memory=True) print("loaded data") # model if (args.load_model): G = torch.load(args.model_path + args.name + "/G_" + str(args.start_epoch) + ".pt", map_location=args.device) D = torch.load(args.model_path + args.name + "/D_" + str(args.start_epoch) + ".pt", map_location=args.device) else: # G = Graph_Generator(args.node_feat_size, args.fe_hidden_size, args.fe_out_size, args.fn_hidden_size, args.fn_num_layers, args.mp_iters_gen, args.num_hits, args.gen_dropout, args.leaky_relu_alpha, hidden_node_size=args.hidden_node_size, int_diffs=args.int_diffs, pos_diffs=args.pos_diffs, gru=args.gru, batch_norm=args.batch_norm, device=device).to(args.device) if (args.gcnn): G = GaussianGenerator(args=deepcopy(args)).to(args.device) D = MoNet(args=deepcopy(args)).to(args.device) # D = Gaussian_Discriminator(args.node_feat_size, args.fe_hidden_size, args.fe_out_size, args.mp_hidden_size, args.mp_num_layers, args.num_iters, args.num_hits, args.dropout, args.leaky_relu_alpha, kernel_size=args.kernel_size, hidden_node_size=args.hidden_node_size, int_diffs=args.int_diffs, gru=GRU, batch_norm=args.batch_norm, device=device).to(args.device) else: # D = Graph_Discriminator(args.node_feat_size, args.fe_hidden_size, args.fe_out_size, args.fn_hidden_size, args.fn_num_layers, args.mp_iters_disc, args.num_hits, args.disc_dropout, args.leaky_relu_alpha, hidden_node_size=args.hidden_node_size, wgan=args.wgan, int_diffs=args.int_diffs, pos_diffs=args.pos_diffs, gru=args.gru, batch_norm=args.batch_norm, device=device).to(args.device) print("Generator") G = Graph_GAN(gen=True, args=deepcopy(args)).to(args.device) print("Discriminator") D = Graph_GAN(gen=False, args=deepcopy(args)).to(args.device) print("Models loaded") # optimizer if args.spectral_norm_gen: G_params = filter(lambda p: p.requires_grad, G.parameters()) else: G_params = G.parameters() if args.spectral_norm_gen: D_params = filter(lambda p: p.requires_grad, D.parameters()) else: D_params = D.parameters() if (args.optimizer == 'rmsprop'): G_optimizer = optim.RMSprop(G_params, lr=args.lr_gen) D_optimizer = optim.RMSprop(D_params, lr=args.lr_disc) elif (args.optimizer == 'adadelta'): G_optimizer = optim.Adadelta(G_params, lr=args.lr_gen) D_optimizer = optim.Adadelta(D_params, lr=args.lr_disc) elif (args.optimizer == 'acgd'): optimizer = ACGD(max_params=G_params, min_params=D_params, lr_max=args.lr_gen, lr_min=args.lr_disc, device=args.device) elif (args.optimizer == 'adam' or args.optimizer == 'None'): G_optimizer = optim.Adam(G_params, lr=args.lr_gen, weight_decay=5e-4, betas=(args.beta1, args.beta2)) D_optimizer = optim.Adam(D_params, lr=args.lr_disc, weight_decay=5e-4, betas=(args.beta1, args.beta2)) if (args.load_model): try: if (not args.optimizer == 'acgd'): G_optimizer.load_state_dict( torch.load(args.model_path + args.name + "/G_optim_" + str(args.start_epoch) + ".pt", map_location=args.device)) D_optimizer.load_state_dict( torch.load(args.model_path + args.name + "/D_optim_" + str(args.start_epoch) + ".pt", map_location=args.device)) else: optimizer.load_state_dict( torch.load(args.model_path + args.name + "/optim_" + str(args.start_epoch) + ".pt", map_location=args.device)) except: print("Error loading optimizer") print("optimizers loaded") if args.fid: C, mu2, sigma2 = evaluation.load(args, X_loaded) normal_dist = Normal( torch.tensor(0.).to(args.device), torch.tensor(args.sd).to(args.device)) lns = args.latent_node_size if args.latent_node_size else args.hidden_node_size args.noise_file_name = "num_samples_" + str( args.num_samples) + "_num_nodes_" + str( args.num_hits) + "_latent_node_size_" + str(lns) + "_sd_" + str( args.sd) + ".pt" if args.gcnn: args.noise_file_name = "gcnn_" + args.noise_file_name noise_file_names = listdir(args.noise_path) if args.noise_file_name not in noise_file_names: if (args.gcnn): torch.save( normal_dist.sample( (args.num_samples * 5, 2 + args.channels[0])), args.noise_path + args.noise_file_name) else: torch.save( normal_dist.sample((args.num_samples, args.num_hits, lns)), args.noise_path + args.noise_file_name) losses = {} if (args.load_model): try: losses['D'] = np.loadtxt(args.losses_path + args.name + "/" + "D.txt").tolist()[:args.start_epoch] losses['Dr'] = np.loadtxt(args.losses_path + args.name + "/" + "Dr.txt").tolist()[:args.start_epoch] losses['Df'] = np.loadtxt(args.losses_path + args.name + "/" + "Df.txt").tolist()[:args.start_epoch] losses['G'] = np.loadtxt(args.losses_path + args.name + "/" + "G.txt").tolist()[:args.start_epoch] if args.fid: losses['fid'] = np.loadtxt( args.losses_path + args.name + "/" + "fid.txt").tolist()[:args.start_epoch] if (args.gp): losses['gp'] = np.loadtxt(args.losses_path + args.name + "/" + "gp.txt").tolist()[:args.start_epoch] except: print("couldn't load losses") losses['D'] = [] losses['Dr'] = [] losses['Df'] = [] losses['G'] = [] if args.fid: losses['fid'] = [] if (args.gp): losses['gp'] = [] else: losses['D'] = [] losses['Dr'] = [] losses['Df'] = [] losses['G'] = [] if args.fid: losses['fid'] = [] if (args.gp): losses['gp'] = [] Y_real = torch.ones(args.batch_size, 1).to(args.device) Y_fake = torch.zeros(args.batch_size, 1).to(args.device) def train_D(data, gen_data=None, unrolled=False): if args.debug: print("dtrain") D.train() D_optimizer.zero_grad() run_batch_size = data.shape[0] if not args.gcnn else data.y.shape[0] if gen_data is None: gen_data = utils.gen(args, G, normal_dist, run_batch_size) if (args.gcnn): gen_data = utils.convert_to_batch(args, gen_data, run_batch_size) if args.augment: p = args.aug_prob if not args.adaptive_prob else losses['p'][-1] data = augment.augment(args, data, p) gen_data = augment.augment(args, gen_data, p) D_real_output = D(data.clone()) D_fake_output = D(gen_data) D_loss, D_loss_items = utils.calc_D_loss(args, D, data, gen_data, D_real_output, D_fake_output, run_batch_size, Y_real, Y_fake) D_loss.backward(create_graph=unrolled) D_optimizer.step() return D_loss_items def train_G(data): if args.debug: print("gtrain") G.train() G_optimizer.zero_grad() gen_data = utils.gen(args, G, normal_dist, args.batch_size) if (args.gcnn): gen_data = utils.convert_to_batch(args, gen_data, args.batch_size) if args.augment: p = args.aug_prob if not args.adaptive_prob else losses['p'][-1] gen_data = augment.augment(args, gen_data, p) if (args.unrolled_steps > 0): D_backup = deepcopy(D) for i in range(args.unrolled_steps - 1): train_D(data, gen_data=gen_data, unrolled=True) D_fake_output = D(gen_data) G_loss = utils.calc_G_loss(args, D_fake_output, Y_real) G_loss.backward() G_optimizer.step() if (args.unrolled_steps > 0): D.load(D_backup) return G_loss.item() def train_acgd(data): if args.debug: print("acgd train") D.train() G.train() optimizer.zero_grad() run_batch_size = data.shape[0] if not args.gcnn else data.y.shape[0] gen_data = utils.gen(args, G, normal_dist, run_batch_size) if (args.gcnn): gen_data = utils.convert_to_batch(args, gen_data, run_batch_size) if args.augment: p = args.aug_prob if not args.adaptive_prob else losses['p'][-1] data = utils.rand_translate(args, data, p) gen_data = utils.rand_translate(args, gen_data, p) D_real_output = D(data.clone()) D_fake_output = D(gen_data) D_loss, D_loss_items = utils.calc_D_loss(args, D, data, gen_data, D_real_output, D_fake_output, run_batch_size) optimizer.step(loss=D_loss) G.eval() with torch.no_grad(): G_loss = utils.calc_G_loss(args, D_fake_output) return D_loss_items, G_loss.item() def train(): k = 0 temp_ng = args.num_gen if (args.fid): losses['fid'].append( evaluation.get_fid(args, C, G, normal_dist, mu2, sigma2)) if (args.save_zero): save_outputs.save_sample_outputs(args, D, G, normal_dist, args.name, 0, losses) for i in range(args.start_epoch, args.num_epochs): print("Epoch %d %s" % ((i + 1), args.name)) Dr_loss = 0 Df_loss = 0 G_loss = 0 D_loss = 0 gp_loss = 0 lenX = len(X_loaded) for batch_ndx, data in tqdm(enumerate(X_loaded), total=lenX): data = data.to(args.device) if (args.gcnn): data.pos = (data.pos - 14) / 28 row, col = data.edge_index data.edge_attr = (data.pos[col] - data.pos[row]) / (2 * args.cutoff) + 0.5 if (not args.optimizer == 'acgd'): if (args.num_critic > 1): D_loss_items = train_D(data) D_loss += D_loss_items['D'] Dr_loss += D_loss_items['Dr'] Df_loss += D_loss_items['Df'] if (args.gp): gp_loss += D_loss_items['gp'] if ((batch_ndx - 1) % args.num_critic == 0): G_loss += train_G(data) else: if (batch_ndx == 0 or (batch_ndx - 1) % args.num_gen == 0): D_loss_items = train_D(data) D_loss += D_loss_items['D'] Dr_loss += D_loss_items['Dr'] Df_loss += D_loss_items['Df'] if (args.gp): gp_loss += D_loss_items['gp'] G_loss += train_G(data) else: D_loss_items, G_loss_item = train_acgd(data) D_loss += D_loss_items['D'] Dr_loss += D_loss_items['Dr'] Df_loss += D_loss_items['Df'] G_loss += G_loss_item if args.bottleneck: if (batch_ndx == 10): return losses['D'].append(D_loss / (lenX / args.num_gen)) losses['Dr'].append(Dr_loss / (lenX / args.num_gen)) losses['Df'].append(Df_loss / (lenX / args.num_gen)) losses['G'].append(G_loss / (lenX / args.num_critic)) if (args.gp): losses['gp'].append(gp_loss / (lenX / args.num_gen)) print("d loss: " + str(losses['D'][-1])) print("g loss: " + str(losses['G'][-1])) print("dr loss: " + str(losses['Dr'][-1])) print("df loss: " + str(losses['Df'][-1])) if (args.gp): print("gp loss: " + str(losses['gp'][-1])) gloss = losses['G'][-1] drloss = losses['Dr'][-1] dfloss = losses['Df'][-1] dloss = (drloss + dfloss) / 2 if (args.bgm): if (i > 20 and gloss > dloss + args.bag): print("num gen upping to 10") args.num_gen = 10 else: print("num gen normal") args.num_gen = temp_ng elif (args.gom): if (i > 20 and gloss > dloss + args.bag): print("G loss too high - training G only") j = 0 print("starting g loss: " + str(gloss)) print("starting d loss: " + str(dloss)) while (gloss > dloss + args.bag * 0.5): print(j) gloss = 0 for l in tqdm(range(lenX)): gloss += train_G() gloss /= lenX print("g loss: " + str(gloss)) print("d loss: " + str(dloss)) losses['D'].append(dloss * 2) losses['Dr'].append(drloss) losses['Df'].append(dfloss) losses['G'].append(gloss) if (j % 5 == 0): save_outputs.save_sample_outputs(args, D, G, normal_dist, args.name, i + 1, losses, k=k, j=j) j += 1 k += 1 elif (args.rd): if (i > 20 and gloss > dloss + args.bag): print("gloss too high, resetting D params") D.reset_params() if ((i + 1) % 5 == 0): optimizers = optimizer if args.optimizer == 'acgd' else ( D_optimizer, G_optimizer) save_outputs.save_models(args, D, G, optimizers, args.name, i + 1) if (args.fid and (i + 1) % 1 == 0): losses['fid'].append( evaluation.get_fid(args, C, G, normal_dist, mu2, sigma2)) if ((i + 1) % 5 == 0): save_outputs.save_sample_outputs(args, D, G, normal_dist, args.name, i + 1, losses) train()
def main(args): args = init_dirs(args) pt = T.Cartesian() if args.cartesian else T.Polar() if args.dataset == 'sp': train_dataset = MNISTSuperpixels(args.dataset_path, True, pre_transform=pt) test_dataset = MNISTSuperpixels(args.dataset_path, False, pre_transform=pt) train_loader = tgDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) test_loader = tgDataLoader(test_dataset, batch_size=args.batch_size) elif args.dataset == 'sm': train_dataset = MNISTGraphDataset(args.dataset_path, args.num_hits, train=True) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, pin_memory=True) test_dataset = MNISTGraphDataset(args.dataset_path, args.num_hits, train=False) test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, pin_memory=True) if(args.load_model): C = torch.load(args.model_path + args.name + "/C_" + str(args.start_epoch) + ".pt").to(device) else: C = MoNet(args.kernel_size).to(device) C_optimizer = torch.optim.Adam(C.parameters(), lr=args.lr, weight_decay=args.weight_decay) if(args.scheduler): C_scheduler = torch.optim.lr_scheduler.StepLR(C_optimizer, args.decay_step, gamma=args.lr_decay) train_losses = [] test_losses = [] def plot_losses(epoch, train_losses, test_losses): fig = plt.figure() ax1 = fig.add_subplot(1, 2, 1) ax1.plot(train_losses) ax1.set_title('training') ax2 = fig.add_subplot(1, 2, 2) ax2.plot(test_losses) ax2.set_title('testing') plt.savefig(args.losses_path + args.name + "/" + str(epoch) + ".png") plt.close() def save_model(epoch): torch.save(C, args.model_path + args.name + "/C_" + str(epoch) + ".pt") def train_C(data, y): C.train() C_optimizer.zero_grad() output = C(data) # nll_loss takes class labels as target, so one-hot encoding is not needed C_loss = F.nll_loss(output, y) C_loss.backward() C_optimizer.step() return C_loss.item() def test(epoch): C.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data in test_loader: if args.dataset == 'sp': output = C(data.to(device)) y = data.y.to(device) elif args.dataset == 'sm': output = C(tg_transform(args, data[0].to(device))) y = data[1].to(device) test_loss += F.nll_loss(output, y, size_average=False).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).sum() test_loss /= len(test_loader.dataset) test_losses.append(test_loss) print('test') f = open(args.out_path + args.name + '.txt', 'a') print(args.out_path + args.name + '.txt') s = "After {} epochs, on test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(epoch, test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)) print(s) f.write(s) f.close() for i in range(args.start_epoch, args.num_epochs): print("Epoch %d %s" % ((i + 1), args.name)) C_loss = 0 test(i) for batch_ndx, data in tqdm(enumerate(train_loader), total=len(train_loader)): if args.dataset == 'sp': C_loss += train_C(data.to(device), data.y.to(device)) elif args.dataset == 'sm': C_loss += train_C(tg_transform(args, data[0].to(device)), data[1].to(device)) train_losses.append(C_loss / len(train_loader)) if(args.scheduler): C_scheduler.step() if((i + 1) % 10 == 0): save_model(i + 1) plot_losses(i + 1, train_losses, test_losses) test(args.num_epochs)
import setGPU import torch import torchvision import torch.nn as nn from torch.optim import Adam from gcn import GCN_classifier from graph_dataset_mnist import MNISTGraphDataset batch_size = 128 num_thresholded = 100 transforms = torchvision.transforms.Compose( [torchvision.transforms.ToTensor()]) X_test = MNISTGraphDataset(num_thresholded, train=False) # X_train = MNISTGraphDataset(num_thresholded, train=True) # X_train_loaded = torch.utils.data.DataLoader(X_train, shuffle=True, batch_size=batch_size) X_test_loaded = torch.utils.data.DataLoader(X_test, shuffle=False, batch_size=batch_size) model = GCN_classifier(3, 256, 10, 0.3) loss_func = nn.CrossEntropyLoss() optimizer = Adam(model.parameters(), lr=0.001) for e in range(5): print(e) running_loss = 0 for i, data in enumerate(X_test_loaded):