Exemplo n.º 1
0
def checkpoint_eval(G_net,
                    device,
                    n_samples=5000,
                    batch_size=100,
                    conditional=False,
                    ratio='even',
                    FPD_path=None,
                    class_choices=None):
    """
    an abstraction used during training
    """
    G_net.eval()
    fake_pcs = generate_pcs(G_net,
                            n_pcs=n_samples,
                            batch_size=batch_size,
                            device=device)
    fpd = calculate_fpd(fake_pcs,
                        statistic_save_path=FPD_path,
                        batch_size=100,
                        dims=1808,
                        device=device)
    # print(fpd)
    print(
        '----------------------------------------- Frechet Pointcloud Distance <<< {:.2f} >>>'
        .format(fpd))
Exemplo n.º 2
0
def test(args, mode='FPD', verbose=True):
    '''
    args needed: 
        n_classes, pcs to generate, ratio of each class, class to id dict???
        model pth, , points to save, save pth, npz for the class, 
    '''
    G_net = Generator(features=args.G_FEAT,
                      degrees=args.DEGREE,
                      support=args.support,
                      args=args).to(args.device)
    checkpoint = torch.load(args.model_pathname, map_location=args.device)
    G_net.load_state_dict(checkpoint['G_state_dict'])
    G_net.eval()
    fake_pcs = generate_pcs(G_net,
                            n_pcs=args.n_samples,
                            batch_size=args.batch_size,
                            device=args.device)
    if mode == 'save':
        save_pcs_to_txt(args.save_sample_path, fake_pcs)
    elif mode == 'FPD':
        fpd = calculate_fpd(fake_pcs,
                            statistic_save_path=args.FPD_path,
                            batch_size=100,
                            dims=1808,
                            device=args.device)
        if verbose:
            print('-----FPD: {:3.2f}'.format(fpd))
    elif mode == 'MMD':
        use_EMD = True
        batch_size = 50
        normalize = True
        gt_dataset = CRNShapeNet(args)

        dataLoader = torch.utils.data.DataLoader(gt_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 pin_memory=True,
                                                 num_workers=10)
        gt_data = torch.Tensor([])
        for _iter, data in enumerate(dataLoader):
            point, partial, index = data
            gt_data = torch.cat((gt_data, point), 0)
        ref_pcs = gt_data.detach().cpu().numpy()
        sample_pcs = fake_pcs.detach().cpu().numpy()

        tic = time.time()
        mmd, matched_dists, dist_mat = MMD_batch(sample_pcs,
                                                 ref_pcs,
                                                 batch_size,
                                                 normalize=normalize,
                                                 use_EMD=use_EMD,
                                                 device=args.device)
        toc = time.time()
        if verbose:
            print('-----MMD-EMD: {:5.3f}'.format(mmd * 100))
Exemplo n.º 3
0
def FPD(opt, save_gen=False):
    '''
    NOTE: model is of a certain class now
    args needed: 
        n_classes, pcs to generate, ratio of each class, class to id dict???
        model pth, , points to save, save pth, npz for the class, 
    '''
    # print(' in FPD')
    if not opt.conditional:
        G_net = Generator(batch_size=opt.batch_size,
                          features=opt.G_FEAT,
                          degrees=opt.DEGREE,
                          support=opt.support,
                          version=opt.version).to(device)
    else:
        G_net = ConditionalGenerator_v0(batch_size=opt.batch_size,
                                        features=opt.G_FEAT,
                                        degrees=opt.DEGREE,
                                        support=opt.support,
                                        n_classes=opt.n_classes,
                                        version=opt.version).to(opt.device)
    # print(G_net)
    # print(opt.model_pathname, opt.version)
    checkpoint = torch.load(opt.model_pathname, map_location=device)
    G_net.load_state_dict(checkpoint['G_state_dict'])
    G_net.eval()
    # compute ratio
    # if not conditional, labels are dummy
    fake_pcs, labels = generate_pcs(G_net,
                                    n_pcs=opt.num_samples,
                                    batch_size=opt.batch_size,
                                    conditional=opt.conditional,
                                    device=opt.device,
                                    ratio=opt.conditional_ratio)
    # print('fake_pcs shape,',fake_pcs.shape)

    if save_gen:
        save_pcs_to_txt(opt.gen_path, fake_pcs, labels=labels)
    # TODO check all-chair only scenario
    # opt.FPD_path = './evaluation/pre_statistics_chair.npz'
    fpd = calculate_fpd(fake_pcs,
                        statistic_save_path=opt.FPD_path,
                        batch_size=100,
                        dims=1808,
                        device=opt.device)
    # print('Frechet Pointcloud Distance <<< {:.4f} >>>'.format(fpd))

    return fpd
Exemplo n.º 4
0
    def run(self, save_ckpt=None, load_ckpt=None, result_path=None):
        color_num = self.args.visdom_color
        chunk_size = int(self.args.point_num / color_num)
        #jz TODO???
        colors = np.array([(227, 0, 27), (231, 64, 28), (237, 120, 15),
                           (246, 176, 44), (252, 234, 0), (224, 221, 128),
                           (142, 188, 40), (18, 126, 68), (63, 174, 0),
                           (113, 169, 156), (164, 194, 184), (51, 186, 216),
                           (0, 152, 206), (16, 68, 151), (57, 64, 139),
                           (96, 72, 132), (172, 113, 161), (202, 174, 199),
                           (145, 35, 132), (201, 47, 133), (229, 0, 123),
                           (225, 106, 112), (163, 38, 42), (128, 128, 128)])
        colors = colors[np.random.choice(len(colors), color_num,
                                         replace=False)]
        label = torch.stack([
            torch.ones(chunk_size).type(torch.LongTensor) * inx
            for inx in range(1,
                             int(color_num) + 1)
        ],
                            dim=0).view(-1)

        epoch_log = 0

        loss_log = {'G_loss': [], 'D_loss': []}
        loss_legend = list(loss_log.keys())

        metric = {'FPD': []}
        if load_ckpt is not None:
            checkpoint = torch.load(load_ckpt, map_location=self.args.device)
            self.D.load_state_dict(checkpoint['D_state_dict'])
            self.G.load_state_dict(checkpoint['G_state_dict'])

            epoch_log = checkpoint['epoch']

            loss_log['G_loss'] = checkpoint['G_loss']
            loss_log['D_loss'] = checkpoint['D_loss']
            loss_legend = list(loss_log.keys())

            metric['FPD'] = checkpoint['FPD']

            print("Checkpoint loaded.")

        for epoch in range(epoch_log, self.args.epochs):
            epoch_g_loss = []
            epoch_d_loss = []
            epoch_time = time.time()
            for _iter, data in enumerate(self.dataLoader):
                # TODO remove
                # if _iter > 20:
                #     break
                # Start Time
                start_time = time.time()
                point, _ = data
                point = point.to(self.args.device)

                # -------------------- Discriminator -------------------- #
                tic = time.time()
                for d_iter in range(self.args.D_iter):
                    self.D.zero_grad()

                    z = torch.randn(self.args.batch_size, 1,
                                    96).to(self.args.device)
                    tree = [z]

                    with torch.no_grad():
                        fake_point = self.G(tree)

                    D_real = self.D(point)
                    D_realm = D_real.mean()

                    D_fake = self.D(fake_point)
                    D_fakem = D_fake.mean()

                    gp_loss = self.GP(self.D, point.data, fake_point.data)

                    d_loss = -D_realm + D_fakem
                    d_loss_gp = d_loss + gp_loss
                    d_loss_gp.backward()
                    self.optimizerD.step()

                loss_log['D_loss'].append(d_loss.item())
                epoch_d_loss.append(d_loss.item())
                toc = time.time()
                # ---------------------- Generator ---------------------- #
                self.G.zero_grad()

                z = torch.randn(self.args.batch_size, 1,
                                96).to(self.args.device)
                tree = [z]

                fake_point = self.G(tree)
                G_fake = self.D(fake_point)
                G_fakem = G_fake.mean()

                g_loss = -G_fakem
                g_loss.backward()
                self.optimizerG.step()

                loss_log['G_loss'].append(g_loss.item())
                epoch_g_loss.append(g_loss.item())
                tac = time.time()
                # --------------------- Visualization -------------------- #
                verbose = None
                if verbose is not None:
                    print("[Epoch/Iter] ", "{:3} / {:3}".format(epoch, _iter),
                          "[ D_Loss ] ", "{: 7.6f}".format(d_loss),
                          "[ G_Loss ] ", "{: 7.6f}".format(g_loss),
                          "[ Time ] ",
                          "{:4.2f}s".format(time.time() - start_time),
                          "{:4.2f}s".format(toc - tic),
                          "{:4.2f}s".format(tac - toc))

                # jz TODO visdom is disabled
                # if _iter % 10 == 0:
                #     generated_point = self.G.getPointcloud()
                #     plot_X = np.stack([np.arange(len(loss_log[legend])) for legend in loss_legend], 1)
                #     plot_Y = np.stack([np.array(loss_log[legend]) for legend in loss_legend], 1)

                #     self.vis.line(X=plot_X, Y=plot_Y, win=1,
                #                   opts={'title': 'TreeGAN Loss', 'legend': loss_legend, 'xlabel': 'Iteration', 'ylabel': 'Loss'})

                #     self.vis.scatter(X=generated_point[:,torch.LongTensor([2,0,1])], Y=label, win=2,
                #                      opts={'title': "Generated Pointcloud", 'markersize': 2, 'markercolor': colors, 'webgl': True})

                #     if len(metric['FPD']) > 0:
                #         self.vis.line(X=np.arange(len(metric['FPD'])), Y=np.array(metric['FPD']), win=3,
                #                       opts={'title': "Frechet Pointcloud Distance", 'legend': ["{} / FPD best : {:.6f}".format(np.argmin(metric['FPD']), np.min(metric['FPD']))]})

                #     print('Figures are saved.')
            # ---------------- Epoch everage loss   --------------- #
            d_loss_mean = np.array(epoch_d_loss).mean()
            g_loss_mean = np.array(epoch_g_loss).mean()

            print("[Epoch] ", "{:3}".format(epoch), "[ D_Loss ] ",
                  "{: 7.6f}".format(d_loss_mean), "[ G_Loss ] ",
                  "{: 7.6f}".format(g_loss_mean), "[ Time ] ",
                  "{:.2f}s".format(time.time() - epoch_time))
            epoch_time = time.time()
            # ---------------- Frechet Pointcloud Distance --------------- #
            if epoch % 5 == 0 and not result_path == None:
                fake_pointclouds = torch.Tensor([])
                # jz, adjust for different batch size
                test_batch_num = int(5000 / self.args.batch_size)
                for i in range(test_batch_num):  # For 5000 samples
                    z = torch.randn(self.args.batch_size, 1,
                                    96).to(self.args.device)
                    tree = [z]
                    with torch.no_grad():
                        sample = self.G(tree).cpu()
                    fake_pointclouds = torch.cat((fake_pointclouds, sample),
                                                 dim=0)

                fpd = calculate_fpd(fake_pointclouds,
                                    statistic_save_path=self.args.FPD_path,
                                    batch_size=100,
                                    dims=1808,
                                    device=self.args.device)
                metric['FPD'].append(fpd)
                print(
                    '-------------------------[{:4} Epoch] Frechet Pointcloud Distance <<< {:.4f} >>>'
                    .format(epoch, fpd))

                class_name = args.class_choice if args.class_choice is not None else 'all'
                # TODO
                # torch.save(fake_pointclouds, result_path+str(epoch)+'_'+class_name+'.pt')
                del fake_pointclouds

            # ---------------------- Save checkpoint --------------------- #
            if epoch % 5 == 0 and not save_ckpt == None:
                torch.save(
                    {
                        'epoch': epoch,
                        'D_state_dict': self.D.state_dict(),
                        'G_state_dict': self.G.state_dict(),
                        'D_loss': loss_log['D_loss'],
                        'G_loss': loss_log['G_loss'],
                        'FPD': metric['FPD']
                    }, save_ckpt + str(epoch) + '_' + class_name + '.pt')
Exemplo n.º 5
0
    def run(self, save_ckpt=None, load_ckpt=None, result_path=None):

        loss_log = {'G_loss': [], 'D_loss': []}
        loss_legend = list(loss_log.keys())

        metric = {'FPD': []}
        epoch_log = 0
        if load_ckpt is not None:
            checkpoint = torch.load(load_ckpt, map_location=self.opt.device)
            self.D.load_state_dict(checkpoint['D_state_dict'])
            self.G.load_state_dict(checkpoint['G_state_dict'])

            epoch_log = checkpoint['epoch']

            loss_log['G_loss'] = checkpoint['G_loss']
            loss_log['D_loss'] = checkpoint['D_loss']
            loss_legend = list(loss_log.keys())

            metric['FPD'] = checkpoint['FPD']

            print("Checkpoint loaded.")

        for epoch in range(epoch_log, self.opt.epochs):
            epoch_g_loss = []
            epoch_d_loss = []
            epoch_time = time.time()
            for _iter, data in enumerate(self.dataLoader):
                # NOTE change opt.batchsize into labels.shape[0], due to the last batch
                # Start Time
                # if _iter < 270:
                #     continue
                # # TODO remove
                # if _iter > 10:
                #     break

                start_time = time.time()
                point, labels = data
                # TODO to work on treeGCN batch size issue.
                if labels.shape[0] != self.opt.batch_size:
                    continue
                point = point.to(self.opt.device)
                labels = labels.to(self.opt.device)
                labels_onehot = torch.FloatTensor(
                    labels.shape[0], opt.n_classes).to(self.opt.device)
                labels_onehot.zero_()
                labels_onehot.scatter_(1, labels, 1)
                labels_onehot.unsqueeze_(1)
                # -------------------- Discriminator -------------------- #
                tic = time.time()
                for d_iter in range(self.opt.D_iter):
                    self.D.zero_grad()
                    # in tree-gan: normal distribution with mean 0, variance 1.
                    # in r-gan: mean 0 , sigma 0.2. link https://github.com/optas/latent_3d_points/blob/master/notebooks/train_raw_gan.ipynb
                    # in GCN GAN: sigma 0.2 https://github.com/diegovalsesia/GraphCNN-GAN/blob/master/gconv_up_aggr_code/main.py
                    z = torch.randn(labels.shape[0], 1, 96).to(self.opt.device)
                    gen_labels = torch.from_numpy(
                        np.random.randint(0, opt.n_classes,
                                          labels.shape[0]).reshape(-1, 1)).to(
                                              self.opt.device)
                    gen_labels_onehot = torch.FloatTensor(
                        labels.shape[0], opt.n_classes).to(self.opt.device)
                    gen_labels_onehot.zero_()
                    gen_labels_onehot.scatter_(1, gen_labels, 1)
                    gen_labels_onehot.unsqueeze_(1)

                    # NOTE: type may not be compatible
                    # import pdb; pdb.set_trace()
                    # print('iter and  update d',_iter,d_iter)
                    tree = [z]

                    with torch.no_grad():
                        fake_point = self.G(tree, gen_labels_onehot)

                    D_real = self.D(point, labels_onehot)
                    D_realm = D_real.mean()

                    D_fake = self.D(fake_point, gen_labels_onehot)
                    D_fakem = D_fake.mean()

                    # TODO try remove gp_loss and see how (tried, loss explode)
                    gp_loss = self.GP(self.D,
                                      point.data,
                                      fake_point.data,
                                      conditional=True,
                                      yreal=labels_onehot)

                    d_loss = -D_realm + D_fakem
                    d_loss_gp = d_loss + gp_loss
                    d_loss_gp.backward()
                    self.optimizerD.step()

                loss_log['D_loss'].append(d_loss.item())
                epoch_d_loss.append(d_loss.item())
                toc = time.time()
                # ---------------------- Generator ---------------------- #
                self.G.zero_grad()

                z = torch.randn(labels.shape[0], 1, 96).to(self.opt.device)
                gen_labels = torch.from_numpy(
                    np.random.randint(0,
                                      opt.n_classes, labels.shape[0]).reshape(
                                          -1, 1)).to(self.opt.device)
                gen_labels_onehot = torch.FloatTensor(
                    labels.shape[0], opt.n_classes).to(self.opt.device)
                gen_labels_onehot.zero_()
                gen_labels_onehot.scatter_(1, gen_labels, 1)
                gen_labels_onehot.unsqueeze_(1)

                tree = [z]

                fake_point = self.G(tree, gen_labels_onehot)
                G_fake = self.D(fake_point, gen_labels_onehot)
                G_fakem = G_fake.mean()

                g_loss = -G_fakem
                g_loss.backward()
                self.optimizerG.step()

                loss_log['G_loss'].append(g_loss.item())
                epoch_g_loss.append(g_loss.item())
                tac = time.time()
                # --------------------- Visualization -------------------- #
                verbose = None
                if verbose is not None:
                    print("[Epoch/Iter] ", "{:3} / {:3}".format(epoch, _iter),
                          "[ D_Loss ] ", "{: 7.6f}".format(d_loss),
                          "[ G_Loss ] ", "{: 7.6f}".format(g_loss),
                          "[ Time ] ",
                          "{:4.2f}s".format(time.time() - start_time),
                          "{:4.2f}s".format(toc - tic),
                          "{:4.2f}s".format(tac - toc))

            # ---------------- Epoch everage loss   --------------- #
            d_loss_mean = np.array(epoch_d_loss).mean()
            g_loss_mean = np.array(epoch_g_loss).mean()

            print("[Epoch] ", "{:3}".format(epoch), "[ D_Loss ] ",
                  "{: 7.3f}".format(d_loss_mean), "[ G_Loss ] ",
                  "{: 7.3f}".format(g_loss_mean), "[ Time ] ",
                  "{:.2f}s".format(time.time() - epoch_time))
            epoch_time = time.time()
            # ---------------- Frechet Pointcloud Distance --------------- #
            if epoch % 5 == 0 and not result_path == None:
                fake_pointclouds = torch.Tensor([])
                # jz, adjust for different batch size
                # TODO change back to 5000
                test_batch_num = int(2000 / self.opt.batch_size)
                for i in range(test_batch_num):  # For 5000 samples
                    # print(i)
                    z = torch.randn(self.opt.batch_size, 1,
                                    96).to(self.opt.device)
                    gen_labels = torch.from_numpy(
                        np.random.randint(0, opt.n_classes,
                                          self.opt.batch_size).reshape(
                                              -1, 1)).to(self.opt.device)
                    gen_labels_onehot = torch.FloatTensor(
                        self.opt.batch_size, opt.n_classes).to(self.opt.device)
                    gen_labels_onehot.zero_()
                    gen_labels_onehot.scatter_(1, gen_labels, 1)
                    gen_labels_onehot.unsqueeze_(1)
                    tree = [z]
                    with torch.no_grad():
                        sample = self.G(tree, gen_labels_onehot).cpu()
                    fake_pointclouds = torch.cat((fake_pointclouds, sample),
                                                 dim=0)

                fpd = calculate_fpd(fake_pointclouds,
                                    statistic_save_path=self.opt.FPD_path,
                                    batch_size=100,
                                    dims=1808,
                                    device=self.opt.device)
                metric['FPD'].append(fpd)
                print(
                    '-------------------------[{:4} Epoch] Frechet Pointcloud Distance <<< {:.2f} >>>'
                    .format(epoch, fpd))

                class_name = opt.class_choice if opt.class_choice is not None else 'all'
                # TODO to enable save fake points
                # torch.save(fake_pointclouds, result_path+str(epoch)+'_'+class_name+'.pt')
                del fake_pointclouds

            # ---------------------- Save checkpoint --------------------- #
            class_name = opt.class_choice if opt.class_choice is not None else 'all'
            if epoch % 1 == 0 and not save_ckpt == None:
                torch.save(
                    {
                        'epoch': epoch,
                        'D_state_dict': self.D.state_dict(),
                        'G_state_dict': self.G.state_dict(),
                        'D_loss': loss_log['D_loss'],
                        'G_loss': loss_log['G_loss'],
                        'FPD': metric['FPD']
                    }, save_ckpt + str(epoch) + '_' + class_name + '.pt')