Example #1
0
class TreeGAN():
    def __init__(self, args):
        self.args = args
        # ------------------------------------------------Dataset---------------------------------------------- #
        #jz default unifrom=True
        if args.dataset == 'ShapeNet_v0':
            class_choice = ['Airplane', 'Car', 'Chair', 'Table']
            ratio = [args.ratio_base] * 4
            self.data = ShapeNet_v0(root=args.dataset_path,
                                    npoints=args.point_num,
                                    uniform=None,
                                    class_choice=class_choice,
                                    ratio=ratio)
        elif args.dataset == 'ShapeNet_v0_rGAN_Chair':
            self.data = ShapeNet_v0_rGAN_Chair()
        else:
            self.data = BenchmarkDataset(root=args.dataset_path,
                                         npoints=args.point_num,
                                         uniform=None,
                                         class_choice=args.class_choice)
        # TODO num workers to change back to 4
        self.dataLoader = torch.utils.data.DataLoader(
            self.data,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=10)
        print("Training Dataset : {} prepared.".format(len(self.data)))
        # ----------------------------------------------------------------------------------------------------- #

        # -------------------------------------------------Module---------------------------------------------- #
        self.G = Generator(batch_size=args.batch_size,
                           features=args.G_FEAT,
                           degrees=args.DEGREE,
                           support=args.support,
                           version=0).to(args.device)
        # import pdb; pdb.set_trace()
        #jz default features=0.5*args.D_FEAT
        self.D = Discriminator(batch_size=args.batch_size,
                               features=args.D_FEAT).to(args.device)
        #jz parallel
        # self.G = nn.DataParallel(self.G)
        self.D = nn.DataParallel(self.D)

        self.optimizerG = optim.Adam(self.G.parameters(),
                                     lr=args.lr,
                                     betas=(0, 0.99))
        self.optimizerD = optim.Adam(self.D.parameters(),
                                     lr=args.lr,
                                     betas=(0, 0.99))
        #jz TODO check if think can be speed up via multi-GPU
        self.GP = GradientPenalty(args.lambdaGP, gamma=1, device=args.device)
        print("Network prepared.")
        # ----------------------------------------------------------------------------------------------------- #

        # ---------------------------------------------Visualization------------------------------------------- #
        #jz TODO visdom
        # self.vis = visdom.Visdom(port=args.visdom_port)
        # assert self.vis.check_connection()
        # print("Visdom connected.")
        # ----------------------------------------------------------------------------------------------------- #

    def run(self, save_ckpt=None, load_ckpt=None, result_path=None):
        color_num = self.args.visdom_color
        chunk_size = int(self.args.point_num / color_num)
        #jz TODO???
        colors = np.array([(227, 0, 27), (231, 64, 28), (237, 120, 15),
                           (246, 176, 44), (252, 234, 0), (224, 221, 128),
                           (142, 188, 40), (18, 126, 68), (63, 174, 0),
                           (113, 169, 156), (164, 194, 184), (51, 186, 216),
                           (0, 152, 206), (16, 68, 151), (57, 64, 139),
                           (96, 72, 132), (172, 113, 161), (202, 174, 199),
                           (145, 35, 132), (201, 47, 133), (229, 0, 123),
                           (225, 106, 112), (163, 38, 42), (128, 128, 128)])
        colors = colors[np.random.choice(len(colors), color_num,
                                         replace=False)]
        label = torch.stack([
            torch.ones(chunk_size).type(torch.LongTensor) * inx
            for inx in range(1,
                             int(color_num) + 1)
        ],
                            dim=0).view(-1)

        epoch_log = 0

        loss_log = {'G_loss': [], 'D_loss': []}
        loss_legend = list(loss_log.keys())

        metric = {'FPD': []}
        if load_ckpt is not None:
            checkpoint = torch.load(load_ckpt, map_location=self.args.device)
            self.D.load_state_dict(checkpoint['D_state_dict'])
            self.G.load_state_dict(checkpoint['G_state_dict'])

            epoch_log = checkpoint['epoch']

            loss_log['G_loss'] = checkpoint['G_loss']
            loss_log['D_loss'] = checkpoint['D_loss']
            loss_legend = list(loss_log.keys())

            metric['FPD'] = checkpoint['FPD']

            print("Checkpoint loaded.")

        for epoch in range(epoch_log, self.args.epochs):
            epoch_g_loss = []
            epoch_d_loss = []
            epoch_time = time.time()
            for _iter, data in enumerate(self.dataLoader):
                # TODO remove
                # if _iter > 20:
                #     break
                # Start Time
                start_time = time.time()
                point, _ = data
                point = point.to(self.args.device)

                # -------------------- Discriminator -------------------- #
                tic = time.time()
                for d_iter in range(self.args.D_iter):
                    self.D.zero_grad()

                    z = torch.randn(self.args.batch_size, 1,
                                    96).to(self.args.device)
                    tree = [z]

                    with torch.no_grad():
                        fake_point = self.G(tree)

                    D_real = self.D(point)
                    D_realm = D_real.mean()

                    D_fake = self.D(fake_point)
                    D_fakem = D_fake.mean()

                    gp_loss = self.GP(self.D, point.data, fake_point.data)

                    d_loss = -D_realm + D_fakem
                    d_loss_gp = d_loss + gp_loss
                    d_loss_gp.backward()
                    self.optimizerD.step()

                loss_log['D_loss'].append(d_loss.item())
                epoch_d_loss.append(d_loss.item())
                toc = time.time()
                # ---------------------- Generator ---------------------- #
                self.G.zero_grad()

                z = torch.randn(self.args.batch_size, 1,
                                96).to(self.args.device)
                tree = [z]

                fake_point = self.G(tree)
                G_fake = self.D(fake_point)
                G_fakem = G_fake.mean()

                g_loss = -G_fakem
                g_loss.backward()
                self.optimizerG.step()

                loss_log['G_loss'].append(g_loss.item())
                epoch_g_loss.append(g_loss.item())
                tac = time.time()
                # --------------------- Visualization -------------------- #
                verbose = None
                if verbose is not None:
                    print("[Epoch/Iter] ", "{:3} / {:3}".format(epoch, _iter),
                          "[ D_Loss ] ", "{: 7.6f}".format(d_loss),
                          "[ G_Loss ] ", "{: 7.6f}".format(g_loss),
                          "[ Time ] ",
                          "{:4.2f}s".format(time.time() - start_time),
                          "{:4.2f}s".format(toc - tic),
                          "{:4.2f}s".format(tac - toc))

                # jz TODO visdom is disabled
                # if _iter % 10 == 0:
                #     generated_point = self.G.getPointcloud()
                #     plot_X = np.stack([np.arange(len(loss_log[legend])) for legend in loss_legend], 1)
                #     plot_Y = np.stack([np.array(loss_log[legend]) for legend in loss_legend], 1)

                #     self.vis.line(X=plot_X, Y=plot_Y, win=1,
                #                   opts={'title': 'TreeGAN Loss', 'legend': loss_legend, 'xlabel': 'Iteration', 'ylabel': 'Loss'})

                #     self.vis.scatter(X=generated_point[:,torch.LongTensor([2,0,1])], Y=label, win=2,
                #                      opts={'title': "Generated Pointcloud", 'markersize': 2, 'markercolor': colors, 'webgl': True})

                #     if len(metric['FPD']) > 0:
                #         self.vis.line(X=np.arange(len(metric['FPD'])), Y=np.array(metric['FPD']), win=3,
                #                       opts={'title': "Frechet Pointcloud Distance", 'legend': ["{} / FPD best : {:.6f}".format(np.argmin(metric['FPD']), np.min(metric['FPD']))]})

                #     print('Figures are saved.')
            # ---------------- Epoch everage loss   --------------- #
            d_loss_mean = np.array(epoch_d_loss).mean()
            g_loss_mean = np.array(epoch_g_loss).mean()

            print("[Epoch] ", "{:3}".format(epoch), "[ D_Loss ] ",
                  "{: 7.6f}".format(d_loss_mean), "[ G_Loss ] ",
                  "{: 7.6f}".format(g_loss_mean), "[ Time ] ",
                  "{:.2f}s".format(time.time() - epoch_time))
            epoch_time = time.time()
            # ---------------- Frechet Pointcloud Distance --------------- #
            if epoch % 5 == 0 and not result_path == None:
                fake_pointclouds = torch.Tensor([])
                # jz, adjust for different batch size
                test_batch_num = int(5000 / self.args.batch_size)
                for i in range(test_batch_num):  # For 5000 samples
                    z = torch.randn(self.args.batch_size, 1,
                                    96).to(self.args.device)
                    tree = [z]
                    with torch.no_grad():
                        sample = self.G(tree).cpu()
                    fake_pointclouds = torch.cat((fake_pointclouds, sample),
                                                 dim=0)

                fpd = calculate_fpd(fake_pointclouds,
                                    statistic_save_path=self.args.FPD_path,
                                    batch_size=100,
                                    dims=1808,
                                    device=self.args.device)
                metric['FPD'].append(fpd)
                print(
                    '-------------------------[{:4} Epoch] Frechet Pointcloud Distance <<< {:.4f} >>>'
                    .format(epoch, fpd))

                class_name = args.class_choice if args.class_choice is not None else 'all'
                # TODO
                # torch.save(fake_pointclouds, result_path+str(epoch)+'_'+class_name+'.pt')
                del fake_pointclouds

            # ---------------------- Save checkpoint --------------------- #
            if epoch % 5 == 0 and not save_ckpt == None:
                torch.save(
                    {
                        'epoch': epoch,
                        'D_state_dict': self.D.state_dict(),
                        'G_state_dict': self.G.state_dict(),
                        'D_loss': loss_log['D_loss'],
                        'G_loss': loss_log['G_loss'],
                        'FPD': metric['FPD']
                    }, save_ckpt + str(epoch) + '_' + class_name + '.pt')
Example #2
0
class TreeGAN():
    def __init__(self, args):
        self.args = args
        # ------------------------------------------------Dataset---------------------------------------------- #
        print("Self.args.train=", self.args.train)

        if self.args.train:
            self.data = BenchmarkDataset(root=args.dataset_path,
                                         npoints=args.point_num,
                                         uniform=False,
                                         class_choice=args.class_choice)

            self.dataLoader = torch.utils.data.DataLoader(
                self.data,
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=True,
                num_workers=4)
            print("Training Dataset : {} prepared.".format(len(self.data)))
        # ----------------------------------------------------------------------------------------------------- #

        # -------------------------------------------------Module---------------------------------------------- #
        self.G = Generator(batch_size=args.batch_size,
                           features=args.G_FEAT,
                           degrees=args.DEGREE,
                           support=args.support).to(args.device)
        self.D = Discriminator(batch_size=args.batch_size,
                               features=args.D_FEAT).to(args.device)

        self.optimizerG = optim.Adam(self.G.parameters(),
                                     lr=args.lr,
                                     betas=(0, 0.99))
        self.optimizerD = optim.Adam(self.D.parameters(),
                                     lr=args.lr,
                                     betas=(0, 0.99))

        self.GP = GradientPenalty(args.lambdaGP, gamma=1, device=args.device)
        print("Network prepared.")

    def interpolation(self,
                      load_ckpt=None,
                      save_images=None,
                      save_pts_files=None,
                      epoch=0):

        if args.train:
            if not os.path.isdir(os.path.join(save_images, "Matplot_Images")):
                print("Making a directory!")
                os.mkdir(os.path.join(save_images, "Matplot_Images"))
            SAVE_IMAGES = os.path.join(save_images, "Matplot_Images")
            if not os.path.isdir(os.path.join(save_pts_files, "Points")):
                os.mkdir(os.path.join(save_pts_files, "Points"))
            SAVE_PTS_FILES = os.path.join(save_pts_files, "Points")
            epoch = str(epoch)
            args_copy = copy.deepcopy(args)

            args_copy.batch_size = 1

            Gen = TreeGAN(args_copy)

        if not args.train:
            SAVE_IMAGES = save_images
            SAVE_PTS_FILES = save_pts_files
            epoch = ''
            Gen = self

        if load_ckpt is not None:

            checkpoint = torch.load(load_ckpt, map_location=self.args.device)
            # self.D.load_state_dict(checkpoint['D_state_dict'])
            Gen.G.load_state_dict(checkpoint['G_state_dict'])

            print("Checkpoint loaded in interpolation")

        Gen.G.zero_grad()
        with torch.no_grad():

            alpha = [0, 0.2, 0.4, 0.6, 0.8, 1]
            #seeds = [10, 40, 80, 100, 120, 140, 160]    # Make this an argument?
            seeds = self.args.seed
            print("The seed is===", seeds)
            angles = [90, 120, 210, 270]
            for s in seeds:
                np.random.seed(s)
                z_a, z_b = np.random.normal(size=96), np.random.normal(size=96)

                fig_size = (30, 30)
                plt.axis('off')
                new_f = plt.figure(figsize=fig_size)
                flag = 1
                for row_no, a in enumerate(alpha):
                    z = torch.tensor((1 - a) * z_a + a * z_b,
                                     dtype=torch.float32).to(self.args.device)
                    z = z.reshape(1, 1, -1)

                    tree = [z]
                    fake_point = Gen.G(tree).detach()
                    generated_point = Gen.G.getPointcloud().cpu().detach(
                    ).numpy()
                    new_f = visualize_3d(generated_point,
                                         fig=new_f,
                                         num=flag,
                                         angles=angles,
                                         row_no=row_no + 1,
                                         rows=len(alpha))
                    flag += len(angles)

                    ## Creating .pts files for each z

                    list_out = generated_point.tolist()

                    if args.train:
                        f_path = os.path.join(
                            SAVE_PTS_FILES,
                            "Epoch_{}_Seed_{}_PC_{}.pts".format(
                                epoch, s, row_no + 1))
                    else:
                        f_path = os.path.join(
                            SAVE_PTS_FILES,
                            "Seed_{}_PC_{}.pts".format(s, row_no + 1))

                    #f = open("/storage/TreeGAN_dataset/RS_{}_PC_{}.pts".format(s,row_no+1), "a")
                    f = open(f_path, "a")
                    for line in list_out:
                        Y = " ".join(list(map(str, line)))
                        f.write(Y + "\n")

                    f.close()
                    if args.train:
                        print("Written to Epoch_{}_Seed_{}_PC_{}.pts file".
                              format(epoch, s, row_no + 1))
                    else:
                        print("Written to Seed_{}_PC_{}.pts file".format(
                            s, row_no + 1))

                    ####
                new_f.suptitle('Random Seed={}'.format(s), fontsize=14)
                #new_f.savefig('/storage/TreeGAN_dataset/new_to_'+str(s)+'.png')
                if args.train:
                    new_f.savefig(SAVE_IMAGES + '/Epoch_' + epoch + '_' +
                                  'Seed_' + str(s) + '.png')
                else:
                    new_f.savefig(SAVE_IMAGES + '/' + 'Seed_' + str(s) +
                                  '.png')
        return

    def run(self, save_ckpt=None, load_ckpt=None, result_path=None):

        epoch_log = 0

        loss_log = {'G_loss': [], 'D_loss': []}
        loss_legend = list(loss_log.keys())

        metric = {'FPD': []}
        if load_ckpt is not None:

            checkpoint = torch.load(load_ckpt, map_location=self.args.device)

            self.D.load_state_dict(checkpoint['D_state_dict'])
            self.G.load_state_dict(checkpoint['G_state_dict'])

            epoch_log = checkpoint['epoch']

            loss_log['G_loss'] = checkpoint['G_loss']
            loss_log['D_loss'] = checkpoint['D_loss']
            loss_legend = list(loss_log.keys())

            metric['FPD'] = checkpoint['FPD']

            print("Checkpoint loaded.")

            #################

#             self.G.zero_grad()

#             z = torch.randn(self.args.batch_size, 1, 96).to(self.args.device)
#             tree = [z]

#             fake_point = self.G(tree)
#             generated_point = self.G.getPointcloud()

#             out = generated_point.cpu().detach().numpy()

#             list_out = out.tolist()
#             f = open("/storage/TreeGAN_dataset/check_this.pts", "a")
#             for line in list_out:
#                 Y= " ".join(list(map(str, line)))
#                 f.write(Y + "\n")

#             f.close()

#             print("written to file")

################

        for epoch in range(epoch_log, self.args.epochs):
            for _iter, data in enumerate(self.dataLoader):
                # Start Time
                start_time = time.time()

                point = data
                point = point.to(self.args.device)

                # -------------------- Discriminator -------------------- #
                for d_iter in range(self.args.D_iter):
                    self.D.zero_grad()

                    z = torch.randn(self.args.batch_size, 1,
                                    96).to(self.args.device)
                    tree = [z]

                    with torch.no_grad():
                        fake_point = self.G(tree)

#                     print("fake_point.shape!=", fake_point.shape)

                    D_real = self.D(point)
                    D_realm = D_real.mean()

                    D_fake = self.D(fake_point)
                    D_fakem = D_fake.mean()

                    #                     print("checking point size", point.data.shape)
                    #                     print("CHECKING SIZE", fake_point.data.shape)

                    gp_loss = self.GP(self.D, point.data, fake_point.data)

                    d_loss = -D_realm + D_fakem
                    d_loss_gp = d_loss + gp_loss
                    d_loss_gp.backward()
                    self.optimizerD.step()

                loss_log['D_loss'].append(d_loss.item())

                # ---------------------- Generator ---------------------- #
                self.G.zero_grad()

                z = torch.randn(self.args.batch_size, 1,
                                96).to(self.args.device)
                tree = [z]

                fake_point = self.G(tree)
                G_fake = self.D(fake_point)
                G_fakem = G_fake.mean()

                g_loss = -G_fakem
                g_loss.backward()
                self.optimizerG.step()

                loss_log['G_loss'].append(g_loss.item())

                # --------------------- Visualization -------------------- #

                print("[Epoch/Iter] ", "{:3} / {:3}".format(epoch, _iter),
                      "[ D_Loss ] ", "{: 7.6f}".format(d_loss), "[ G_Loss ] ",
                      "{: 7.6f}".format(g_loss), "[ Time ] ",
                      "{:4.2f}s".format(time.time() - start_time))

            # ---------------- Frechet Pointcloud Distance --------------- #


#             if epoch % self.args.save_at_epoch == 0 and not result_path == None:
#                  fake_pointclouds = torch.Tensor([])
#                  for i in range(10): # For 5000 samples
#                      z = torch.randn(self.args.batch_size, 1, 96).to(self.args.device)
#                      tree = [z]
#                      with torch.no_grad():
#                          sample = self.G(tree).cpu()
#                      fake_pointclouds = torch.cat((fake_pointclouds, sample), dim=0)

#                  fpd = calculate_fpd(fake_pointclouds, statistic_save_path=self.args.FPD_path, batch_size=100, dims=1808, device=self.args.device)
#                  metric['FPD'].append(fpd)
#                  print('[{:4} Epoch] Frechet Pointcloud Distance <<< {:.10f} >>>'.format(epoch, fpd))

#                  del fake_pointclouds
#-------------------------------------------------------------------------------
#                 class_name = args.class_choice if args.class_choice is not None else 'all'

#                 torch.save(fake_pointclouds, result_path+str(epoch)+'_'+class_name+'.pt')
#                 del fake_pointclouds

#             if epoch % self.args.save_at_epoch == 0:
#                 generated_point = self.G.getPointcloud()

#                 out = generated_point.cpu().detach().numpy()

#                 list_out = out.tolist()
#                 f = open("/storage/TreeGAN_dataset/sample"+str(epoch+1)+".pts", "a")
#                 for line in list_out:
#                     Y= " ".join(list(map(str, line)))
#                     f.write(Y + "\n")

#                 f.close()

#                 print("written to file")

# ---------------------- Save checkpoint --------------------- #
            if (epoch + 1
                ) % self.args.save_at_epoch == 0 and not save_ckpt == None:
                torch.save(
                    {
                        'epoch': epoch,
                        'D_state_dict': self.D.state_dict(),
                        'G_state_dict': self.G.state_dict(),
                        'D_loss': loss_log['D_loss'],
                        'G_loss': loss_log['G_loss'],
                        'FPD': metric['FPD']
                    }, save_ckpt + str(epoch + 1) + '.pt')

                print('Checkpoint at {} epoch is saved.'.format(epoch + 1))

                # --------------Saving intermediate images and .pts files----------------------#

                self.interpolation(load_ckpt=save_ckpt + str(epoch + 1) +
                                   '.pt',
                                   save_images=result_path,
                                   save_pts_files=result_path,
                                   epoch=epoch + 1)