class TreeGAN(): def __init__(self, args): self.args = args # ------------------------------------------------Dataset---------------------------------------------- # #jz default unifrom=True if args.dataset == 'ShapeNet_v0': class_choice = ['Airplane', 'Car', 'Chair', 'Table'] ratio = [args.ratio_base] * 4 self.data = ShapeNet_v0(root=args.dataset_path, npoints=args.point_num, uniform=None, class_choice=class_choice, ratio=ratio) elif args.dataset == 'ShapeNet_v0_rGAN_Chair': self.data = ShapeNet_v0_rGAN_Chair() else: self.data = BenchmarkDataset(root=args.dataset_path, npoints=args.point_num, uniform=None, class_choice=args.class_choice) # TODO num workers to change back to 4 self.dataLoader = torch.utils.data.DataLoader( self.data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=10) print("Training Dataset : {} prepared.".format(len(self.data))) # ----------------------------------------------------------------------------------------------------- # # -------------------------------------------------Module---------------------------------------------- # self.G = Generator(batch_size=args.batch_size, features=args.G_FEAT, degrees=args.DEGREE, support=args.support, version=0).to(args.device) # import pdb; pdb.set_trace() #jz default features=0.5*args.D_FEAT self.D = Discriminator(batch_size=args.batch_size, features=args.D_FEAT).to(args.device) #jz parallel # self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) self.optimizerG = optim.Adam(self.G.parameters(), lr=args.lr, betas=(0, 0.99)) self.optimizerD = optim.Adam(self.D.parameters(), lr=args.lr, betas=(0, 0.99)) #jz TODO check if think can be speed up via multi-GPU self.GP = GradientPenalty(args.lambdaGP, gamma=1, device=args.device) print("Network prepared.") # ----------------------------------------------------------------------------------------------------- # # ---------------------------------------------Visualization------------------------------------------- # #jz TODO visdom # self.vis = visdom.Visdom(port=args.visdom_port) # assert self.vis.check_connection() # print("Visdom connected.") # ----------------------------------------------------------------------------------------------------- # def run(self, save_ckpt=None, load_ckpt=None, result_path=None): color_num = self.args.visdom_color chunk_size = int(self.args.point_num / color_num) #jz TODO??? colors = np.array([(227, 0, 27), (231, 64, 28), (237, 120, 15), (246, 176, 44), (252, 234, 0), (224, 221, 128), (142, 188, 40), (18, 126, 68), (63, 174, 0), (113, 169, 156), (164, 194, 184), (51, 186, 216), (0, 152, 206), (16, 68, 151), (57, 64, 139), (96, 72, 132), (172, 113, 161), (202, 174, 199), (145, 35, 132), (201, 47, 133), (229, 0, 123), (225, 106, 112), (163, 38, 42), (128, 128, 128)]) colors = colors[np.random.choice(len(colors), color_num, replace=False)] label = torch.stack([ torch.ones(chunk_size).type(torch.LongTensor) * inx for inx in range(1, int(color_num) + 1) ], dim=0).view(-1) epoch_log = 0 loss_log = {'G_loss': [], 'D_loss': []} loss_legend = list(loss_log.keys()) metric = {'FPD': []} if load_ckpt is not None: checkpoint = torch.load(load_ckpt, map_location=self.args.device) self.D.load_state_dict(checkpoint['D_state_dict']) self.G.load_state_dict(checkpoint['G_state_dict']) epoch_log = checkpoint['epoch'] loss_log['G_loss'] = checkpoint['G_loss'] loss_log['D_loss'] = checkpoint['D_loss'] loss_legend = list(loss_log.keys()) metric['FPD'] = checkpoint['FPD'] print("Checkpoint loaded.") for epoch in range(epoch_log, self.args.epochs): epoch_g_loss = [] epoch_d_loss = [] epoch_time = time.time() for _iter, data in enumerate(self.dataLoader): # TODO remove # if _iter > 20: # break # Start Time start_time = time.time() point, _ = data point = point.to(self.args.device) # -------------------- Discriminator -------------------- # tic = time.time() for d_iter in range(self.args.D_iter): self.D.zero_grad() z = torch.randn(self.args.batch_size, 1, 96).to(self.args.device) tree = [z] with torch.no_grad(): fake_point = self.G(tree) D_real = self.D(point) D_realm = D_real.mean() D_fake = self.D(fake_point) D_fakem = D_fake.mean() gp_loss = self.GP(self.D, point.data, fake_point.data) d_loss = -D_realm + D_fakem d_loss_gp = d_loss + gp_loss d_loss_gp.backward() self.optimizerD.step() loss_log['D_loss'].append(d_loss.item()) epoch_d_loss.append(d_loss.item()) toc = time.time() # ---------------------- Generator ---------------------- # self.G.zero_grad() z = torch.randn(self.args.batch_size, 1, 96).to(self.args.device) tree = [z] fake_point = self.G(tree) G_fake = self.D(fake_point) G_fakem = G_fake.mean() g_loss = -G_fakem g_loss.backward() self.optimizerG.step() loss_log['G_loss'].append(g_loss.item()) epoch_g_loss.append(g_loss.item()) tac = time.time() # --------------------- Visualization -------------------- # verbose = None if verbose is not None: print("[Epoch/Iter] ", "{:3} / {:3}".format(epoch, _iter), "[ D_Loss ] ", "{: 7.6f}".format(d_loss), "[ G_Loss ] ", "{: 7.6f}".format(g_loss), "[ Time ] ", "{:4.2f}s".format(time.time() - start_time), "{:4.2f}s".format(toc - tic), "{:4.2f}s".format(tac - toc)) # jz TODO visdom is disabled # if _iter % 10 == 0: # generated_point = self.G.getPointcloud() # plot_X = np.stack([np.arange(len(loss_log[legend])) for legend in loss_legend], 1) # plot_Y = np.stack([np.array(loss_log[legend]) for legend in loss_legend], 1) # self.vis.line(X=plot_X, Y=plot_Y, win=1, # opts={'title': 'TreeGAN Loss', 'legend': loss_legend, 'xlabel': 'Iteration', 'ylabel': 'Loss'}) # self.vis.scatter(X=generated_point[:,torch.LongTensor([2,0,1])], Y=label, win=2, # opts={'title': "Generated Pointcloud", 'markersize': 2, 'markercolor': colors, 'webgl': True}) # if len(metric['FPD']) > 0: # self.vis.line(X=np.arange(len(metric['FPD'])), Y=np.array(metric['FPD']), win=3, # opts={'title': "Frechet Pointcloud Distance", 'legend': ["{} / FPD best : {:.6f}".format(np.argmin(metric['FPD']), np.min(metric['FPD']))]}) # print('Figures are saved.') # ---------------- Epoch everage loss --------------- # d_loss_mean = np.array(epoch_d_loss).mean() g_loss_mean = np.array(epoch_g_loss).mean() print("[Epoch] ", "{:3}".format(epoch), "[ D_Loss ] ", "{: 7.6f}".format(d_loss_mean), "[ G_Loss ] ", "{: 7.6f}".format(g_loss_mean), "[ Time ] ", "{:.2f}s".format(time.time() - epoch_time)) epoch_time = time.time() # ---------------- Frechet Pointcloud Distance --------------- # if epoch % 5 == 0 and not result_path == None: fake_pointclouds = torch.Tensor([]) # jz, adjust for different batch size test_batch_num = int(5000 / self.args.batch_size) for i in range(test_batch_num): # For 5000 samples z = torch.randn(self.args.batch_size, 1, 96).to(self.args.device) tree = [z] with torch.no_grad(): sample = self.G(tree).cpu() fake_pointclouds = torch.cat((fake_pointclouds, sample), dim=0) fpd = calculate_fpd(fake_pointclouds, statistic_save_path=self.args.FPD_path, batch_size=100, dims=1808, device=self.args.device) metric['FPD'].append(fpd) print( '-------------------------[{:4} Epoch] Frechet Pointcloud Distance <<< {:.4f} >>>' .format(epoch, fpd)) class_name = args.class_choice if args.class_choice is not None else 'all' # TODO # torch.save(fake_pointclouds, result_path+str(epoch)+'_'+class_name+'.pt') del fake_pointclouds # ---------------------- Save checkpoint --------------------- # if epoch % 5 == 0 and not save_ckpt == None: torch.save( { 'epoch': epoch, 'D_state_dict': self.D.state_dict(), 'G_state_dict': self.G.state_dict(), 'D_loss': loss_log['D_loss'], 'G_loss': loss_log['G_loss'], 'FPD': metric['FPD'] }, save_ckpt + str(epoch) + '_' + class_name + '.pt')
class TreeGAN(): def __init__(self, args): self.args = args # ------------------------------------------------Dataset---------------------------------------------- # print("Self.args.train=", self.args.train) if self.args.train: self.data = BenchmarkDataset(root=args.dataset_path, npoints=args.point_num, uniform=False, class_choice=args.class_choice) self.dataLoader = torch.utils.data.DataLoader( self.data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4) print("Training Dataset : {} prepared.".format(len(self.data))) # ----------------------------------------------------------------------------------------------------- # # -------------------------------------------------Module---------------------------------------------- # self.G = Generator(batch_size=args.batch_size, features=args.G_FEAT, degrees=args.DEGREE, support=args.support).to(args.device) self.D = Discriminator(batch_size=args.batch_size, features=args.D_FEAT).to(args.device) self.optimizerG = optim.Adam(self.G.parameters(), lr=args.lr, betas=(0, 0.99)) self.optimizerD = optim.Adam(self.D.parameters(), lr=args.lr, betas=(0, 0.99)) self.GP = GradientPenalty(args.lambdaGP, gamma=1, device=args.device) print("Network prepared.") def interpolation(self, load_ckpt=None, save_images=None, save_pts_files=None, epoch=0): if args.train: if not os.path.isdir(os.path.join(save_images, "Matplot_Images")): print("Making a directory!") os.mkdir(os.path.join(save_images, "Matplot_Images")) SAVE_IMAGES = os.path.join(save_images, "Matplot_Images") if not os.path.isdir(os.path.join(save_pts_files, "Points")): os.mkdir(os.path.join(save_pts_files, "Points")) SAVE_PTS_FILES = os.path.join(save_pts_files, "Points") epoch = str(epoch) args_copy = copy.deepcopy(args) args_copy.batch_size = 1 Gen = TreeGAN(args_copy) if not args.train: SAVE_IMAGES = save_images SAVE_PTS_FILES = save_pts_files epoch = '' Gen = self if load_ckpt is not None: checkpoint = torch.load(load_ckpt, map_location=self.args.device) # self.D.load_state_dict(checkpoint['D_state_dict']) Gen.G.load_state_dict(checkpoint['G_state_dict']) print("Checkpoint loaded in interpolation") Gen.G.zero_grad() with torch.no_grad(): alpha = [0, 0.2, 0.4, 0.6, 0.8, 1] #seeds = [10, 40, 80, 100, 120, 140, 160] # Make this an argument? seeds = self.args.seed print("The seed is===", seeds) angles = [90, 120, 210, 270] for s in seeds: np.random.seed(s) z_a, z_b = np.random.normal(size=96), np.random.normal(size=96) fig_size = (30, 30) plt.axis('off') new_f = plt.figure(figsize=fig_size) flag = 1 for row_no, a in enumerate(alpha): z = torch.tensor((1 - a) * z_a + a * z_b, dtype=torch.float32).to(self.args.device) z = z.reshape(1, 1, -1) tree = [z] fake_point = Gen.G(tree).detach() generated_point = Gen.G.getPointcloud().cpu().detach( ).numpy() new_f = visualize_3d(generated_point, fig=new_f, num=flag, angles=angles, row_no=row_no + 1, rows=len(alpha)) flag += len(angles) ## Creating .pts files for each z list_out = generated_point.tolist() if args.train: f_path = os.path.join( SAVE_PTS_FILES, "Epoch_{}_Seed_{}_PC_{}.pts".format( epoch, s, row_no + 1)) else: f_path = os.path.join( SAVE_PTS_FILES, "Seed_{}_PC_{}.pts".format(s, row_no + 1)) #f = open("/storage/TreeGAN_dataset/RS_{}_PC_{}.pts".format(s,row_no+1), "a") f = open(f_path, "a") for line in list_out: Y = " ".join(list(map(str, line))) f.write(Y + "\n") f.close() if args.train: print("Written to Epoch_{}_Seed_{}_PC_{}.pts file". format(epoch, s, row_no + 1)) else: print("Written to Seed_{}_PC_{}.pts file".format( s, row_no + 1)) #### new_f.suptitle('Random Seed={}'.format(s), fontsize=14) #new_f.savefig('/storage/TreeGAN_dataset/new_to_'+str(s)+'.png') if args.train: new_f.savefig(SAVE_IMAGES + '/Epoch_' + epoch + '_' + 'Seed_' + str(s) + '.png') else: new_f.savefig(SAVE_IMAGES + '/' + 'Seed_' + str(s) + '.png') return def run(self, save_ckpt=None, load_ckpt=None, result_path=None): epoch_log = 0 loss_log = {'G_loss': [], 'D_loss': []} loss_legend = list(loss_log.keys()) metric = {'FPD': []} if load_ckpt is not None: checkpoint = torch.load(load_ckpt, map_location=self.args.device) self.D.load_state_dict(checkpoint['D_state_dict']) self.G.load_state_dict(checkpoint['G_state_dict']) epoch_log = checkpoint['epoch'] loss_log['G_loss'] = checkpoint['G_loss'] loss_log['D_loss'] = checkpoint['D_loss'] loss_legend = list(loss_log.keys()) metric['FPD'] = checkpoint['FPD'] print("Checkpoint loaded.") ################# # self.G.zero_grad() # z = torch.randn(self.args.batch_size, 1, 96).to(self.args.device) # tree = [z] # fake_point = self.G(tree) # generated_point = self.G.getPointcloud() # out = generated_point.cpu().detach().numpy() # list_out = out.tolist() # f = open("/storage/TreeGAN_dataset/check_this.pts", "a") # for line in list_out: # Y= " ".join(list(map(str, line))) # f.write(Y + "\n") # f.close() # print("written to file") ################ for epoch in range(epoch_log, self.args.epochs): for _iter, data in enumerate(self.dataLoader): # Start Time start_time = time.time() point = data point = point.to(self.args.device) # -------------------- Discriminator -------------------- # for d_iter in range(self.args.D_iter): self.D.zero_grad() z = torch.randn(self.args.batch_size, 1, 96).to(self.args.device) tree = [z] with torch.no_grad(): fake_point = self.G(tree) # print("fake_point.shape!=", fake_point.shape) D_real = self.D(point) D_realm = D_real.mean() D_fake = self.D(fake_point) D_fakem = D_fake.mean() # print("checking point size", point.data.shape) # print("CHECKING SIZE", fake_point.data.shape) gp_loss = self.GP(self.D, point.data, fake_point.data) d_loss = -D_realm + D_fakem d_loss_gp = d_loss + gp_loss d_loss_gp.backward() self.optimizerD.step() loss_log['D_loss'].append(d_loss.item()) # ---------------------- Generator ---------------------- # self.G.zero_grad() z = torch.randn(self.args.batch_size, 1, 96).to(self.args.device) tree = [z] fake_point = self.G(tree) G_fake = self.D(fake_point) G_fakem = G_fake.mean() g_loss = -G_fakem g_loss.backward() self.optimizerG.step() loss_log['G_loss'].append(g_loss.item()) # --------------------- Visualization -------------------- # print("[Epoch/Iter] ", "{:3} / {:3}".format(epoch, _iter), "[ D_Loss ] ", "{: 7.6f}".format(d_loss), "[ G_Loss ] ", "{: 7.6f}".format(g_loss), "[ Time ] ", "{:4.2f}s".format(time.time() - start_time)) # ---------------- Frechet Pointcloud Distance --------------- # # if epoch % self.args.save_at_epoch == 0 and not result_path == None: # fake_pointclouds = torch.Tensor([]) # for i in range(10): # For 5000 samples # z = torch.randn(self.args.batch_size, 1, 96).to(self.args.device) # tree = [z] # with torch.no_grad(): # sample = self.G(tree).cpu() # fake_pointclouds = torch.cat((fake_pointclouds, sample), dim=0) # fpd = calculate_fpd(fake_pointclouds, statistic_save_path=self.args.FPD_path, batch_size=100, dims=1808, device=self.args.device) # metric['FPD'].append(fpd) # print('[{:4} Epoch] Frechet Pointcloud Distance <<< {:.10f} >>>'.format(epoch, fpd)) # del fake_pointclouds #------------------------------------------------------------------------------- # class_name = args.class_choice if args.class_choice is not None else 'all' # torch.save(fake_pointclouds, result_path+str(epoch)+'_'+class_name+'.pt') # del fake_pointclouds # if epoch % self.args.save_at_epoch == 0: # generated_point = self.G.getPointcloud() # out = generated_point.cpu().detach().numpy() # list_out = out.tolist() # f = open("/storage/TreeGAN_dataset/sample"+str(epoch+1)+".pts", "a") # for line in list_out: # Y= " ".join(list(map(str, line))) # f.write(Y + "\n") # f.close() # print("written to file") # ---------------------- Save checkpoint --------------------- # if (epoch + 1 ) % self.args.save_at_epoch == 0 and not save_ckpt == None: torch.save( { 'epoch': epoch, 'D_state_dict': self.D.state_dict(), 'G_state_dict': self.G.state_dict(), 'D_loss': loss_log['D_loss'], 'G_loss': loss_log['G_loss'], 'FPD': metric['FPD'] }, save_ckpt + str(epoch + 1) + '.pt') print('Checkpoint at {} epoch is saved.'.format(epoch + 1)) # --------------Saving intermediate images and .pts files----------------------# self.interpolation(load_ckpt=save_ckpt + str(epoch + 1) + '.pt', save_images=result_path, save_pts_files=result_path, epoch=epoch + 1)