def reparameterize(self, mu, var): if self.training: std = var.mul(0.5).exp_() eps = utils.var_or_cuda((std.data.new(std.size()).normal_())) z = eps.mul(std).add_(mu) return z else: return mu
def forward(self, images): means = utils.var_or_cuda( torch.zeros(self.args.num_views, self.args.batch_size, 200)) vars = utils.var_or_cuda( torch.zeros(self.args.num_views, self.args.batch_size, 200)) zs = utils.var_or_cuda( torch.zeros(self.args.num_views, self.args.batch_size, 200)) for i, image in enumerate(images): image = utils.var_or_cuda(image) z_mean, z_log_var = self.single_image_forward(image) zs[i:] = self.reparameterize(z_mean, z_log_var) means[i:] = z_mean vars[i:] = z_log_var #z_mu= self.combine(means) #z_var = self.combine(vars) return self.combine(zs), means, vars
def test_3DVAEGAN(args): # datset define dsets_path = args.input_dir + args.data_dir + "test/" print(dsets_path) dsets = ShapeNetPlusImageDataset(dsets_path, args) dset_loaders = torch.utils.data.DataLoader(dsets, batch_size=args.batch_size, shuffle=True, num_workers=1) # model define E = _E(args) G = _G(args) G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta) E_solver = optim.Adam(E.parameters(), lr=args.g_lr, betas=args.beta) if torch.cuda.is_available(): print("using cuda") G.cuda() E.cuda() pickle_path = "." + args.pickle_dir + '3DVAEGAN' read_pickle(pickle_path, G, G_solver, G, G_solver, E, E_solver) recon_loss_total = 0 for i, (image, model_3d) in enumerate(dset_loaders): X = var_or_cuda(model_3d) image = var_or_cuda(image) z_mu, z_var = E(image) Z_vae = E.reparameterize(z_mu, z_var) G_vae = G(Z_vae) recon_loss = torch.sum(torch.pow((G_vae - X), 2), dim=(1, 2, 3)) print(recon_loss.size()) print("RECON LOSS ITER: ", i, " - ", torch.mean(recon_loss)) recon_loss_total += (recon_loss) samples = G_vae.cpu().data[:8].squeeze().numpy() image_path = args.output_dir + args.image_dir + '3DVAEGAN_test' if not os.path.exists(image_path): os.makedirs(image_path) SavePloat_Voxels(samples, image_path, i)
def test_3DGAN(args): # datset define dsets_path = args.input_dir + args.data_dir + "test/" print(dsets_path) dsets = ShapeNetDataset(dsets_path, args) dset_loaders = torch.utils.data.DataLoader(dsets, batch_size=args.batch_size, shuffle=True, num_workers=1) # model define D = _D(args) G = _G(args) D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta) G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta) if torch.cuda.is_available(): print("using cuda") D.cuda() G.cuda() pickle_path = "." + args.pickle_dir + '3DVAEGAN_MULTIVIEW_MAX' read_pickle(pickle_path, G, G_solver, D, D_solver) recon_loss_total = 0 for i, X in enumerate(dset_loaders): #X = X.view(-1, 1, args.cube_len, args.cube_len, args.cube_len) X = var_or_cuda(X) print(X.size()) Z = generateZ(args) print(Z.size()) fake = G(Z).squeeze() print(fake.size()) recon_loss = torch.sum(torch.pow((fake - X), 2), dim=(1, 2, 3)) print(recon_loss.size()) print("RECON LOSS ITER: ", i, " - ", torch.mean(recon_loss)) recon_loss_total += (recon_loss) samples = fake.cpu().data[:8].squeeze().numpy() image_path = args.output_dir + args.image_dir + '3DVAEGAN_MULTIVIEW_MAX_test' if not os.path.exists(image_path): os.makedirs(image_path) SavePloat_Voxels(samples, image_path, i)
def train(args): hyparam_list = [ ("model", args.model_name), ("cube", args.cube_len), ("bs", args.batch_size), ("g_lr", args.g_lr), ("d_lr", args.d_lr), ("z", args.z_dis), ("bias", args.bias), ("sl", args.soft_label), ] hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list)) log_param = make_hyparam_string(hyparam_dict) print(log_param) # for using tensorboard if args.use_tensorboard: import tensorflow as tf summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir + log_param) def inject_summary(summary_writer, tag, value, step): summary = tf.Summary( value=[tf.Summary.Value(tag=tag, simple_value=value)]) summary_writer.add_summary(summary, global_step=step) inject_summary = inject_summary # datset define dsets_path = args.input_dir + args.data_dir + "train/" print(dsets_path) dsets = ShapeNetDataset(dsets_path, args) dset_loaders = torch.utils.data.DataLoader(dsets, batch_size=args.batch_size, shuffle=True, num_workers=1) # model define D = _D(args) G = _G(args) D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta) G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta) if args.lrsh: D_scheduler = MultiStepLR(D_solver, milestones=[500, 1000]) if torch.cuda.is_available(): print("using cuda") D.cuda() G.cuda() criterion = nn.BCELoss() pickle_path = "." + args.pickle_dir + log_param read_pickle(pickle_path, G, G_solver, D, D_solver) for epoch in range(args.n_epochs): for i, X in enumerate(dset_loaders): X = var_or_cuda(X) if X.size()[0] != int(args.batch_size): #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size))) continue Z = generateZ(args) real_labels = var_or_cuda(torch.ones(args.batch_size)) fake_labels = var_or_cuda(torch.zeros(args.batch_size)) if args.soft_label: real_labels = var_or_cuda( torch.Tensor(args.batch_size).uniform_(0.7, 1.2)) fake_labels = var_or_cuda( torch.Tensor(args.batch_size).uniform_(0, 0.3)) # ============= Train the discriminator =============# d_real = D(X) d_real_loss = criterion(d_real, real_labels) fake = G(Z) d_fake = D(fake) d_fake_loss = criterion(d_fake, fake_labels) d_loss = d_real_loss + d_fake_loss d_real_acu = torch.ge(d_real.squeeze(), 0.5).float() d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float() d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0)) if d_total_acu <= args.d_thresh: D.zero_grad() d_loss.backward() D_solver.step() # =============== Train the generator ===============# Z = generateZ(args) fake = G(Z) d_fake = D(fake) g_loss = criterion(d_fake, real_labels) D.zero_grad() G.zero_grad() g_loss.backward() G_solver.step() # =============== logging each iteration ===============# iteration = str(G_solver.state_dict()['state'][ G_solver.state_dict()['param_groups'][0]['params'][0]]['step']) if args.use_tensorboard: log_save_path = args.output_dir + args.log_dir + log_param if not os.path.exists(log_save_path): os.makedirs(log_save_path) info = { 'loss/loss_D_R': d_real_loss.data[0], 'loss/loss_D_F': d_fake_loss.data[0], 'loss/loss_D': d_loss.data[0], 'loss/loss_G': g_loss.data[0], 'loss/acc_D': d_total_acu.data[0] } for tag, value in info.items(): inject_summary(summary_writer, tag, value, iteration) summary_writer.flush() # =============== each epoch save model or save image ===============# print( 'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}' .format(iteration, d_loss.data[0], g_loss.data[0], d_total_acu.data[0], D_solver.state_dict()['param_groups'][0]["lr"])) if (epoch + 1) % args.image_save_step == 0: samples = fake.cpu().data[:8].squeeze().numpy() image_path = args.output_dir + args.image_dir + log_param if not os.path.exists(image_path): os.makedirs(image_path) SavePloat_Voxels(samples, image_path, iteration) if (epoch + 1) % args.pickle_step == 0: pickle_save_path = args.output_dir + args.pickle_dir + log_param save_new_pickle(pickle_save_path, iteration, G, G_solver, D, D_solver) if args.lrsh: try: D_scheduler.step() except Exception as e: print("fail lr scheduling", e)
def train(args): #for creating the visdom object DEFAULT_PORT = 8097 DEFAULT_HOSTNAME = "http://localhost" viz = Visdom(DEFAULT_HOSTNAME, DEFAULT_PORT, ipv6=False) hyparam_list = [ ("model", args.model_name), ("cube", args.cube_len), ("bs", args.batch_size), ("g_lr", args.g_lr), ("d_lr", args.d_lr), ("z", args.z_dis), ("bias", args.bias), ("sl", args.soft_label), ] hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list)) log_param = make_hyparam_string(hyparam_dict) print(log_param) # for using tensorboard if args.use_tensorboard: import tensorflow as tf summary_writer = tf.summary.FileWriter(args.output_dir + args.log_dir + log_param) def inject_summary(summary_writer, tag, value, step): summary = tf.Summary( value=[tf.Summary.Value(tag=tag, simple_value=value)]) summary_writer.add_summary(summary, global_step=step) inject_summary = inject_summary # datset define dsets_path = args.input_dir + args.data_dir + "train/" print(dsets_path) x_train = np.load("voxels_3DMNIST_16.npy") dataset = x_train.reshape(-1, args.cube_len * args.cube_len * args.cube_len) print(dataset.shape) dset_loaders = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=1) # model define D = _D(args) G = _G(args) D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta) G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta) if torch.cuda.is_available(): print("using cuda") D.cuda() G.cuda() criterion = nn.BCELoss() pickle_path = "." + args.pickle_dir + log_param read_pickle(pickle_path, G, G_solver, D, D_solver) for epoch in range(args.n_epochs): epoch_start_time = time.time() print("epoch %d started" % (epoch)) for i, X in enumerate(dset_loaders): X = var_or_cuda(X) X = X.type(torch.cuda.FloatTensor) if X.size()[0] != int(args.batch_size): #print("batch_size != {} drop last incompatible batch".format(int(args.batch_size))) continue Z = generateZ(args) real_labels = var_or_cuda(torch.ones(args.batch_size)).view( -1, 1, 1, 1, 1) fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view( -1, 1, 1, 1, 1) if args.soft_label: real_labels = var_or_cuda( torch.Tensor(args.batch_size).uniform_(0.9, 1.1)).view( -1, 1, 1, 1, 1) #### #fake_labels = var_or_cuda(torch.Tensor(args.batch_size).uniform_(0, 0.3)).view(-1,1,1,1,1) fake_labels = var_or_cuda(torch.zeros(args.batch_size)).view( -1, 1, 1, 1, 1) ##### # ============= Train the discriminator =============# d_real = D(X) d_real_loss = criterion(d_real, real_labels) fake = G(Z) d_fake = D(fake) d_fake_loss = criterion(d_fake, fake_labels) d_loss = d_real_loss + d_fake_loss d_real_acu = torch.ge(d_real.squeeze(), 0.5).float() d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float() d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0)) #if 1: if d_total_acu <= args.d_thresh: D.zero_grad() d_loss.backward() D_solver.step() # =============== Train the generator ===============# Z = generateZ(args) fake = G(Z) d_fake = D(fake) g_loss = criterion(d_fake, real_labels) D.zero_grad() G.zero_grad() g_loss.backward() G_solver.step() ####### #print(fake.shape) #print(fake.cpu().data[:8].squeeze().numpy().shape) # =============== logging each iteration ===============# iteration = str(G_solver.state_dict()['state'][ G_solver.state_dict()['param_groups'][0]['params'][0]]['step']) #print(type(iteration)) #iteration = str(i) #saving the model and a image each 100 iteration if int(iteration) % 300 == 0: #pickle_save_path = args.output_dir + args.pickle_dir + log_param #save_new_pickle(pickle_save_path, iteration, G, G_solver, D, D_solver) samples = fake.cpu().data[:8].squeeze().numpy() #print(samples.shape) for s in range(8): plotVoxelVisdom(samples[s, ...], viz, "Iteration:{:.4}".format(iteration)) # image_path = args.output_dir + args.image_dir + log_param # if not os.path.exists(image_path): # os.makedirs(image_path) # SavePloat_Voxels(samples, image_path, iteration) # =============== each epoch save model or save image ===============# print( 'Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, D_acu : {:.4}, D_lr : {:.4}' .format(iteration, d_loss.item(), g_loss.item(), d_total_acu.item(), D_solver.state_dict()['param_groups'][0]["lr"])) epoch_end_time = time.time() if (epoch + 1) % args.image_save_step == 0: samples = fake.cpu().data[:8].squeeze().numpy() image_path = args.output_dir + args.image_dir + log_param if not os.path.exists(image_path): os.makedirs(image_path) SavePloat_Voxels(samples, image_path, iteration) if (epoch + 1) % args.pickle_step == 0: pickle_save_path = args.output_dir + args.pickle_dir + log_param save_new_pickle(pickle_save_path, iteration, G, G_solver, D, D_solver) print("epoch time", (epoch_end_time - epoch_start_time) / 60) print("epoch %d ended" % (epoch)) print("################################################")
def train(args): #WSGAN related params lambda_gp = 10 n_critic = 5 hyparam_list = [ ("model", args.model_name), ("cube", args.cube_len), ("bs", args.batch_size), ("g_lr", args.g_lr), ("d_lr", args.d_lr), ("z", args.z_dis), ("bias", args.bias), ] hyparam_dict = OrderedDict(((arg, value) for arg, value in hyparam_list)) log_param = make_hyparam_string(hyparam_dict) print(log_param) #define different paths pickle_path = "." + args.pickle_dir + log_param image_path = args.output_dir + args.image_dir + log_param pickle_save_path = args.output_dir + args.pickle_dir + log_param N = None # None for the whole dataset VOL_SIZE = 64 train_path = pathlib.Path("../Vert_dataset") dataset = VertDataset(train_path, n=N, transform=transforms.Compose( [ResizeTo(VOL_SIZE), transforms.ToTensor()])) print('Number of samples: ', len(dataset)) dset_loaders = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) print('Number of batches: ', len(dset_loaders)) # Build the model D = _D(args) G = _G(args) #Create the solvers D_solver = optim.Adam(D.parameters(), lr=args.d_lr, betas=args.beta) G_solver = optim.Adam(G.parameters(), lr=args.g_lr, betas=args.beta) if torch.cuda.device_count() > 1: D = nn.DataParallel(D) G = nn.DataParallel(G) print("Using {} GPUs".format(torch.cuda.device_count())) D.cuda() G.cuda() elif torch.cuda.is_available(): print("using cuda") D.cuda() G.cuda() #Load checkpoint if available read_pickle(pickle_path, G, G_solver, D, D_solver) G_losses = [] D_losses = [] for epoch in range(args.n_epochs): epoch_start_time = time.time() print("epoch %d started" % (epoch)) for i, X in enumerate(dset_loaders): #print(X.shape) X = X.view(-1, args.cube_len * args.cube_len * args.cube_len) X = var_or_cuda(X) X = X.type(torch.cuda.FloatTensor) Z = generateZ(num_samples=X.size(0), z_size=args.z_size) #Train the critic d_loss, Wasserstein_D, gp = train_critic(X, Z, D, G, D_solver, G_solver) # Train the generator every n_critic steps if i % n_critic == 0: Z = generateZ(num_samples=X.size(0), z_size=args.z_size) g_loss = train_gen(Z, D, G, D_solver, G_solver) #Log each iteration iteration = str(G_solver.state_dict()['state'][ G_solver.state_dict()['param_groups'][0]['params'][0]]['step']) print('Iter-{}; , D_loss : {:.4}, G_loss : {:.4}, WSdistance : {:.4}, GP : {:.4}'.format(iteration, d_loss.item(), \ g_loss.item(), Wasserstein_D.item(), gp.item() )) ## End of epoch epoch_end_time = time.time() #Plot the losses each epoch G_losses.append(g_loss.item()) D_losses.append(d_loss.item()) plot_losess(G_losses, D_losses, epoch) if (epoch + 1) % args.image_save_step == 0: print("Saving voxels") Z = generateZ(num_samples=8, z_size=args.z_size) gen_output = G(Z) samples = gen_output.cpu().data[:8].squeeze().numpy() samples = samples.reshape(-1, args.cube_len, args.cube_len, args.cube_len) Save_Voxels(samples, image_path, iteration) if (epoch + 1) % args.pickle_step == 0: print("Pickeling the model") save_new_pickle(pickle_save_path, iteration, G, G_solver, D, D_solver) print("epoch time", (epoch_end_time - epoch_start_time) / 60) print("epoch %d ended" % (epoch)) print("################################################")