예제 #1
0
    def __init__(self, cfg, optimizer=None):
        super().__init__()
        self.cfg = cfg
        
        self.pointnet2 = PointNet2(cfg)
        self.img_enc = ImgEncoder(cfg)
        self.displacement_net = DisplacementNet(cfg)

        self.optimizer = None if optimizer is None else optimizer(self.parameters())
        
        # 2D supervision part
        self.projector = Projector(self.cfg)

        # proj loss
        self.proj_loss = ProjectLoss(self.cfg)
        
        # emd loss
        self.emd = EMD()

        if torch.cuda.is_available():
            self.img_enc = torch.nn.DataParallel(self.img_enc, device_ids=cfg.CONST.DEVICE).cuda()
            self.pointnet2 = torch.nn.DataParallel(self.pointnet2, device_ids=cfg.CONST.DEVICE).cuda()
            self.projector = torch.nn.DataParallel(self.projector, device_ids=cfg.CONST.DEVICE).cuda()
            self.proj_loss = torch.nn.DataParallel(self.proj_loss, device_ids=cfg.CONST.DEVICE).cuda()
            self.emd = torch.nn.DataParallel(self.emd, device_ids=cfg.CONST.DEVICE).cuda()
            self.cuda()
예제 #2
0
    def __init__(self, cfg, optimizer, scheduler):
        super().__init__()
        self.cfg = cfg
        
        # Init the generator and refiner
        self.generator = GRAPHX_Generator(cfg=cfg,
                                          in_channels=3, 
                                          in_instances=cfg.GRAPHX.NUM_INIT_POINTS,
                                          use_graphx=cfg.GRAPHX.USE_GRAPHX)

        self.refiner = EdgeConv_Refiner(cfg=cfg, 
                                        num_points=cfg.CONST.NUM_POINTS, 
                                        use_SElayer=True)
        
        self.optimizer = None if optimizer is None else optimizer(self.parameters())
        self.scheduler = None if scheduler or optimizer is None else scheduler(self.optimizer)

        self.emd = EMD()

        if torch.cuda.is_available():
            self.generator = torch.nn.DataParallel(self.generator, device_ids=cfg.CONST.DEVICE).cuda()
            self.refiner = torch.nn.DataParallel(self.refiner, device_ids=cfg.CONST.DEVICE).cuda()
            self.emd = torch.nn.DataParallel(self.emd, device_ids=cfg.CONST.DEVICE).cuda()
            self.cuda()
        
        # Load pretrained generator
        if cfg.REFINE.USE_PRETRAIN_GENERATOR:
            print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.REFINE.GENERATOR_WEIGHTS))
            checkpoint = torch.load(cfg.REFINE.GENERATOR_WEIGHTS)
            self.generator.load_state_dict(checkpoint['net'])
            print('Best Epoch: %d' % (checkpoint['epoch_idx']))
예제 #3
0
    def __init__(self, cfg, in_channels, in_instances, activation=nn.ReLU(), optimizer=None, scheduler=None, use_graphx=True, **kwargs):
        super().__init__()
        self.cfg = cfg
        
        # Graphx
        self.img_enc = CNN18Encoder(in_channels, activation)

        out_features = [block[-2].out_channels for block in self.img_enc.children()]
        self.pc_enc = PointCloudEncoder(3, out_features, cat_pc=True, use_adain=True, use_proj=True, 
                                        activation=activation)
        
        deform_net = PointCloudGraphXDecoder if use_graphx else PointCloudDecoder
        self.pc = deform_net(2 * sum(out_features) + 3, in_instances=in_instances, activation=activation)
        
        self.kwargs = kwargs

        # emd loss
        self.emd = EMD()
        
        if torch.cuda.is_available():
            self.img_enc = torch.nn.DataParallel(self.img_enc, device_ids=cfg.CONST.DEVICE).cuda()
            self.pc_enc = torch.nn.DataParallel(self.pc_enc, device_ids=cfg.CONST.DEVICE).cuda()
            self.pc = torch.nn.DataParallel(self.pc, device_ids=cfg.CONST.DEVICE).cuda()
            self.emd = torch.nn.DataParallel(self.emd, device_ids=cfg.CONST.DEVICE).cuda()
            self.cuda()
    def __init__(self, cfg, in_channels, activation=nn.ReLU(), optimizer=None):
        super().__init__()
        self.cfg = cfg
        
        self.img_enc = CNN18Encoder(in_channels, activation)
        self.transform_pc = TransformPC(cfg)
        self.feature_projection = FeatureProjection(cfg)

        if cfg.UPDATER.PC_ENCODE_MODULE == 'Pointnet++':
            self.pc_encode = PointNet2(cfg)
        elif cfg.UPDATER.PC_ENCODE_MODULE == 'EdgeRes':
            self.pc_encode = EdgeRes(use_SElayer=True)
        
        if cfg.UPDATER.PC_DECODE_MODULE == 'Linear':
            self.displacement_net = LinearDisplacementNet(cfg)
        elif cfg.UPDATER.PC_DECODE_MODULE == 'Graphx':
            self.displacement_net = GraphxDisplacementNet(cfg)
        
        self.optimizer = None if optimizer is None else optimizer(self.parameters())
        
        # emd loss
        self.emd = EMD()

        if torch.cuda.is_available():
            self.img_enc = torch.nn.DataParallel(self.img_enc, device_ids=cfg.CONST.DEVICE).cuda()
            self.transform_pc = torch.nn.DataParallel(self.transform_pc, device_ids=cfg.CONST.DEVICE).cuda()
            self.feature_projection = torch.nn.DataParallel(self.feature_projection, device_ids=cfg.CONST.DEVICE).cuda()
            self.pc_encode = torch.nn.DataParallel(self.pc_encode, device_ids=cfg.CONST.DEVICE).cuda()
            self.displacement_net = torch.nn.DataParallel(self.displacement_net, device_ids=cfg.CONST.DEVICE).cuda()
            self.emd = torch.nn.DataParallel(self.emd, device_ids=cfg.CONST.DEVICE).cuda()
            self.cuda()
    def __init__(self,
                 cfg,
                 optimizer,
                 num_points: int = 2048,
                 use_SElayer: bool = True):
        super().__init__()
        self.cfg = cfg
        self.num_points = num_points
        self.residual = EdgeRes(use_SElayer=use_SElayer)
        self.optimizer = None if optimizer is None else optimizer(
            self.parameters())
        self.emd = EMD()

        if torch.cuda.is_available():
            self.residual = torch.nn.DataParallel(
                self.residual, device_ids=cfg.CONST.DEVICE).cuda()
            self.emd = torch.nn.DataParallel(
                self.emd, device_ids=cfg.CONST.DEVICE).cuda()
            self.cuda()
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config,
                                      config['arch'],
                                      'experiments',
                                      dirs_to_create=[
                                          'interpolations', 'sphere',
                                          'points_interpolation',
                                          'different_number_points', 'fixed',
                                          'reconstruction', 'sphere_triangles',
                                          'sphere_triangles_interpolation'
                                      ])
    weights_path = get_weights_dir(config)
    epoch = find_latest_epoch(weights_path)

    if not epoch:
        print("Invalid 'weights_path' in configuration")
        exit(1)

    setup_logging(results_dir)
    global log
    log = logging.getLogger('aae')

    if not exists(join(results_dir, 'experiment_config.json')):
        with open(join(results_dir, 'experiment_config.json'), mode='w') as f:
            json.dump(config, f)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    elif dataset_name == 'custom':
        dataset = TxtDataset(root_dir=config['data_dir'],
                             classes=config['classes'],
                             config=config)
    elif dataset_name == 'benchmark':
        dataset = Benchmark(root_dir=config['data_dir'],
                            classes=config['classes'],
                            config=config)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=64,
                                   shuffle=True,
                                   num_workers=8,
                                   drop_last=True,
                                   pin_memory=True,
                                   collate_fn=collate_fn)

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder_visible = aae.VisibleEncoder(config).to(device)
    encoder_pocket = aae.PocketEncoder(config).to(device)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        # from utils.metrics import earth_mover_distance
        # reconstruction_loss = earth_mover_distance
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    log.info("Weights for epoch: %s" % epoch)

    log.info("Loading weights...")
    hyper_network.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_G.pth')))
    encoder_pocket.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_EP.pth')))
    encoder_visible.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_EV.pth')))

    hyper_network.eval()
    encoder_visible.eval()
    encoder_pocket.eval()

    total_loss_eg = 0.0
    total_loss_e = 0.0
    total_loss_kld = 0.0
    x = []

    with torch.no_grad():
        for i, point_data in enumerate(points_dataloader, 1):
            X = point_data['non-visible']
            X = X.to(device, dtype=torch.float)

            # get whole point cloud
            X_whole = point_data['cloud']
            X_whole = X_whole.to(device, dtype=torch.float)

            # get visible point cloud
            X_visible = point_data['visible']
            X_visible = X_visible.to(device, dtype=torch.float)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)
                X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1)
                X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)

            x.append(X)
            codes, mu, logvar = encoder_pocket(X)
            mu_visible = encoder_visible(X_visible)
            target_networks_weights = hyper_network(
                torch.cat((codes, mu_visible), 1))

            X_rec = torch.zeros(X_whole.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)
                target_network_input = generate_points(config=config,
                                                       epoch=epoch,
                                                       size=(X_whole.shape[2],
                                                             X_whole.shape[1]))
                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            loss_e = torch.mean(config['reconstruction_coef'] *
                                reconstruction_loss(
                                    X_whole.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 -
                              logvar).sum()

            loss_eg = loss_e + loss_kld
            total_loss_e += loss_e.item()
            total_loss_kld += loss_kld.item()
            total_loss_eg += loss_eg.item()

        log.info(f'Loss_ALL: {total_loss_eg / i:.4f} '
                 f'Loss_R: {total_loss_e / i:.4f} '
                 f'Loss_E: {total_loss_kld / i:.4f} ')

        # take the lowest possible first dim
        min_dim = min(x, key=lambda X: X.shape[2]).shape[2]
        x = [X[:, :, :min_dim] for X in x]
        x = torch.cat(x)

        if config['experiments']['interpolation']['execute']:
            interpolation(
                x, encoder_pocket, hyper_network, device, results_dir, epoch,
                config['experiments']['interpolation']['amount'],
                config['experiments']['interpolation']['transitions'])

        if config['experiments']['interpolation_between_two_points'][
                'execute']:
            interpolation_between_two_points(
                encoder_pocket, hyper_network, device, x, results_dir, epoch,
                config['experiments']['interpolation_between_two_points']
                ['amount'], config['experiments']
                ['interpolation_between_two_points']['image_points'],
                config['experiments']['interpolation_between_two_points']
                ['transitions'])

        if config['experiments']['reconstruction']['execute']:
            reconstruction(encoder_pocket, hyper_network, device, x,
                           results_dir, epoch,
                           config['experiments']['reconstruction']['amount'])

        if config['experiments']['sphere']['execute']:
            sphere(encoder_pocket, hyper_network, device, x, results_dir,
                   epoch, config['experiments']['sphere']['amount'],
                   config['experiments']['sphere']['image_points'],
                   config['experiments']['sphere']['start'],
                   config['experiments']['sphere']['end'],
                   config['experiments']['sphere']['transitions'])

        if config['experiments']['sphere_triangles']['execute']:
            sphere_triangles(
                encoder_pocket, hyper_network, device, x, results_dir,
                config['experiments']['sphere_triangles']['amount'],
                config['experiments']['sphere_triangles']['method'],
                config['experiments']['sphere_triangles']['depth'],
                config['experiments']['sphere_triangles']['start'],
                config['experiments']['sphere_triangles']['end'],
                config['experiments']['sphere_triangles']['transitions'])

        if config['experiments']['sphere_triangles_interpolation']['execute']:
            sphere_triangles_interpolation(
                encoder_pocket, hyper_network, device, x, results_dir,
                config['experiments']['sphere_triangles_interpolation']
                ['amount'], config['experiments']
                ['sphere_triangles_interpolation']['method'],
                config['experiments']['sphere_triangles_interpolation']
                ['depth'], config['experiments']
                ['sphere_triangles_interpolation']['coefficient'],
                config['experiments']['sphere_triangles_interpolation']
                ['transitions'])

        if config['experiments']['different_number_of_points']['execute']:
            different_number_of_points(
                encoder_pocket, hyper_network, x, device, results_dir, epoch,
                config['experiments']['different_number_of_points']['amount'],
                config['experiments']['different_number_of_points']
                ['image_points'])

        if config['experiments']['fixed']['execute']:
            # get visible element from loader (probably should be done using given object for example using
            # parser

            points_dataloader = DataLoader(dataset,
                                           batch_size=10,
                                           shuffle=True,
                                           num_workers=8,
                                           drop_last=True,
                                           pin_memory=True,
                                           collate_fn=collate_fn)
            X_visible = next(iter(points_dataloader))['visible'].to(
                device, dtype=torch.float)
            X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)

            fixed(hyper_network, encoder_visible, X_visible, device,
                  results_dir, epoch, config['experiments']['fixed']['amount'],
                  config['z_size'] // 2,
                  config['experiments']['fixed']['mean'],
                  config['experiments']['fixed']['std'], (3, 2048),
                  config['experiments']['fixed']['triangulation']['execute'],
                  config['experiments']['fixed']['triangulation']['method'],
                  config['experiments']['fixed']['triangulation']['depth'])
예제 #7
0
def main(config):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger(__name__)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True, pin_memory=True)

    #
    # Models
    #
    arch = import_module(f"models.{config['arch']}")
    G = arch.Generator(config).to(device)
    E = arch.Encoder(config).to(device)
    D = arch.Discriminator(config).to(device)

    G.apply(weights_init)
    E.apply(weights_init)
    D.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')
    #
    # Float Tensors
    #
    distribution = config['distribution'].lower()
    if distribution == 'bernoulli':
        p = torch.tensor(config['p']).to(device)
        sampler = Bernoulli(probs=p)
        fixed_noise = sampler.sample(torch.Size([config['batch_size'],
                                                 config['z_size']]))
    elif distribution == 'beta':
        fixed_noise_np = np.random.beta(config['z_beta_a'],
                                        config['z_beta_b'],
                                        size=(config['batch_size'],
                                              config['z_size']))
        fixed_noise = torch.tensor(fixed_noise_np).float().to(device)
    else:
        raise ValueError('Invalid distribution for binaray model.')

    #
    # Optimizers
    #
    EG_optim = getattr(optim, config['optimizer']['EG']['type'])
    EG_optim = EG_optim(chain(E.parameters(), G.parameters()),
                        **config['optimizer']['EG']['hyperparams'])

    D_optim = getattr(optim, config['optimizer']['D']['type'])
    D_optim = D_optim(D.parameters(),
                      **config['optimizer']['D']['hyperparams'])

    if starting_epoch > 1:
        G.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_G.pth')))
        E.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_E.pth')))
        D.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_D.pth')))

        D_optim.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_Do.pth')))

        EG_optim.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_EGo.pth')))

    loss_d_tot, loss_gp_tot, loss_e_tot, loss_g_tot = [], [], [], []
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()

        G.train()
        E.train()
        D.train()

        total_loss_eg = 0.0
        total_loss_d = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            log.debug('-' * 20)

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            codes = E(X)
            if distribution == 'bernoulli':
                noise = sampler.sample(fixed_noise.size())
            elif distribution == 'beta':
                noise_np = np.random.beta(config['z_beta_a'],
                                          config['z_beta_b'],
                                          size=(config['batch_size'],
                                                config['z_size']))
                noise = torch.tensor(noise_np).float().to(device)
            synth_logit = D(codes)
            real_logit = D(noise)
            loss_d = torch.mean(synth_logit) - torch.mean(real_logit)
            loss_d_tot.append(loss_d)

            # Gradient Penalty
            alpha = torch.rand(config['batch_size'], 1).to(device)
            differences = codes - noise
            interpolates = noise + alpha * differences
            disc_interpolates = D(interpolates)

            gradients = grad(
                outputs=disc_interpolates,
                inputs=interpolates,
                grad_outputs=torch.ones_like(disc_interpolates).to(device),
                create_graph=True,
                retain_graph=True,
                only_inputs=True)[0]
            slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1))
            gradient_penalty = ((slopes - 1) ** 2).mean()
            loss_gp = config['gp_lambda'] * gradient_penalty
            loss_gp_tot.append(loss_gp)
            ###

            loss_d += loss_gp

            D_optim.zero_grad()
            D.zero_grad()

            loss_d.backward(retain_graph=True)
            total_loss_d += loss_d.item()
            D_optim.step()

            # EG part of training
            X_rec = G(codes)

            loss_e = torch.mean(
                config['reconstruction_coef'] *
                reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))
            loss_e_tot.append(loss_e)

            synth_logit = D(codes)

            loss_g = -torch.mean(synth_logit)
            loss_g_tot.append(loss_g)

            loss_eg = loss_e + loss_g
            EG_optim.zero_grad()
            E.zero_grad()
            G.zero_grad()

            loss_eg.backward()
            total_loss_eg += loss_eg.item()
            EG_optim.step()

            log.debug(f'[{epoch}: ({i})] '
                      f'Loss_D: {loss_d.item():.4f} '
                      f'(GP: {loss_gp.item(): .4f}) '
                      f'Loss_EG: {loss_eg.item():.4f} '
                      f'(REC: {loss_e.item(): .4f}) '
                      f'Time: {datetime.now() - start_epoch_time}')

        log.debug(
            f'[{epoch}/{config["max_epochs"]}] '
            f'Loss_D: {total_loss_d / i:.4f} '
            f'Loss_EG: {total_loss_eg / i:.4f} '
            f'Time: {datetime.now() - start_epoch_time}'
        )

        #
        # Save intermediate results
        #
        G.eval()
        E.eval()
        D.eval()
        with torch.no_grad():
            fake = G(fixed_noise).data.cpu().numpy()
            X_rec = G(E(X)).data.cpu().numpy()
            X = X.data.cpu().numpy()

        plt.figure(figsize=(16, 9))
        plt.plot(loss_d_tot, 'r-', label="loss_d")
        plt.plot(loss_gp_tot, 'g-', label="loss_gp")
        plt.plot(loss_e_tot, 'b-', label="loss_e")
        plt.plot(loss_g_tot, 'k-', label="loss_g")
        plt.legend()
        plt.xlabel("Batch number")
        plt.xlabel("Loss value")
        plt.savefig(
            join(results_dir, 'samples', f'loss_plot.png'))
        plt.close()

        for k in range(5):
            fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2],
                                      in_u_sphere=True, show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples', f'{epoch:05}_{k}_real.png'))
            plt.close(fig)

        for k in range(5):
            fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                      in_u_sphere=True, show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png'))
            plt.close(fig)

        for k in range(5):
            fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2],
                                      in_u_sphere=True, show=False,
                                      title=str(epoch))
            fig.savefig(join(results_dir, 'samples',
                             f'{epoch:05}_{k}_reconstructed.png'))
            plt.close(fig)

        if epoch % config['save_frequency'] == 0:
            torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(D.state_dict(), join(weights_path, f'{epoch:05}_D.pth'))
            torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))

            torch.save(EG_optim.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))

            torch.save(D_optim.state_dict(),
                       join(weights_path, f'{epoch:05}_Do.pth'))
예제 #8
0
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config, 'vae', 'training')
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('vae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')
    metrics_path = join(results_dir, 'metrics')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    elif dataset_name == 'custom':
        dataset = TxtDataset(root_dir=config['data_dir'],
                             classes=config['classes'],
                             config=config)
    elif dataset_name == 'benchmark':
        dataset = Benchmark(root_dir=config['data_dir'],
                            classes=config['classes'],
                            config=config)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True,
                                   pin_memory=True,
                                   collate_fn=collate_fn)

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder_pocket = aae.PocketEncoder(config).to(device)
    encoder_visible = aae.VisibleEncoder(config).to(device)

    hyper_network.apply(weights_init)
    encoder_pocket.apply(weights_init)
    encoder_visible.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        # from utils.metrics import earth_mover_distance
        # reconstruction_loss = earth_mover_distance
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type'])
    e_hn_optimizer = e_hn_optimizer(
        chain(encoder_visible.parameters(), encoder_pocket.parameters(),
              hyper_network.parameters()),
        **config['optimizer']['E_HN']['hyperparams'])

    log.info("Starting epoch: %s" % starting_epoch)
    if starting_epoch > 1:
        log.info("Loading weights...")
        hyper_network.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_G.pth')))
        encoder_pocket.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EP.pth')))
        encoder_visible.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EV.pth')))

        e_hn_optimizer.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EGo.pth')))

        log.info("Loading losses...")
        losses_e = np.load(join(metrics_path,
                                f'{starting_epoch - 1:05}_E.npy')).tolist()
        losses_kld = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_KLD.npy')).tolist()
        losses_eg = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist()
    else:
        log.info("First epoch")
        losses_e = []
        losses_kld = []
        losses_eg = []

    if config['target_network_input']['normalization']['enable']:
        normalization_type = config['target_network_input']['normalization'][
            'type']
        assert normalization_type == 'progressive', 'Invalid normalization type'

    target_network_input = None
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()
        log.debug("Epoch: %s" % epoch)
        hyper_network.train()
        encoder_visible.train()
        encoder_pocket.train()

        total_loss_all = 0.0
        total_loss_r = 0.0
        total_loss_kld = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            # get only visible part of point cloud
            X = point_data['non-visible']
            X = X.to(device, dtype=torch.float)

            # get not visible part of point cloud
            X_visible = point_data['visible']
            X_visible = X_visible.to(device, dtype=torch.float)

            # get whole point cloud
            X_whole = point_data['cloud']
            X_whole = X_whole.to(device, dtype=torch.float)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)
                X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)
                X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1)

            codes, mu, logvar = encoder_pocket(X)
            mu_visible = encoder_visible(X_visible)

            target_networks_weights = hyper_network(
                torch.cat((codes, mu_visible), 1))

            X_rec = torch.zeros(X_whole.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)

                if not config['target_network_input'][
                        'constant'] or target_network_input is None:
                    target_network_input = generate_points(
                        config=config,
                        epoch=epoch,
                        size=(X_whole.shape[2], X_whole.shape[1]))

                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            loss_r = torch.mean(config['reconstruction_coef'] *
                                reconstruction_loss(
                                    X_whole.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 -
                              logvar).sum()

            loss_all = loss_r + loss_kld
            e_hn_optimizer.zero_grad()
            encoder_visible.zero_grad()
            encoder_pocket.zero_grad()
            hyper_network.zero_grad()

            loss_all.backward()
            e_hn_optimizer.step()

            total_loss_r += loss_r.item()
            total_loss_kld += loss_kld.item()
            total_loss_all += loss_all.item()

        log.info(f'[{epoch}/{config["max_epochs"]}] '
                 f'Loss_ALL: {total_loss_all / i:.4f} '
                 f'Loss_R: {total_loss_r / i:.4f} '
                 f'Loss_E: {total_loss_kld / i:.4f} '
                 f'Time: {datetime.now() - start_epoch_time}')

        losses_e.append(total_loss_r)
        losses_kld.append(total_loss_kld)
        losses_eg.append(total_loss_all)

        #
        # Save intermediate results
        #
        X = X.cpu().numpy()
        X_whole = X_whole.cpu().numpy()
        X_rec = X_rec.detach().cpu().numpy()

        if epoch % config['save_frequency'] == 0:
            for k in range(min(5, X_rec.shape[0])):
                fig = plot_3d_point_cloud(X_rec[k][0],
                                          X_rec[k][1],
                                          X_rec[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples',
                         f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X_whole[k][0],
                                          X_whole[k][1],
                                          X_whole[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X[k][0],
                                          X[k][1],
                                          X[k][2],
                                          in_u_sphere=True,
                                          show=False)
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_visible.png'))
                plt.close(fig)

        if config['clean_weights_dir']:
            log.debug('Cleaning weights path: %s' % weights_path)
            shutil.rmtree(weights_path, ignore_errors=True)
            os.makedirs(weights_path, exist_ok=True)

        if epoch % config['save_frequency'] == 0:
            log.debug('Saving data...')

            torch.save(hyper_network.state_dict(),
                       join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(encoder_visible.state_dict(),
                       join(weights_path, f'{epoch:05}_EV.pth'))
            torch.save(encoder_pocket.state_dict(),
                       join(weights_path, f'{epoch:05}_EP.pth'))
            torch.save(e_hn_optimizer.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))

            np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e))
            np.save(join(metrics_path, f'{epoch:05}_KLD'),
                    np.array(losses_kld))
            np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
예제 #9
0
def main(config):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger(__name__)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True,
                                   pin_memory=True)

    #
    # Models
    #
    arch = import_module(f"model.architectures.{config['arch']}")
    G = arch.Generator(config).to(device)
    E = arch.Encoder(config).to(device)

    G.apply(weights_init)
    E.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    EG_optim = getattr(optim, config['optimizer']['EG']['type'])
    EG_optim = EG_optim(chain(E.parameters(), G.parameters()),
                        **config['optimizer']['EG']['hyperparams'])

    if starting_epoch > 1:
        G.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch-1:05}_G.pth')))
        E.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch-1:05}_E.pth')))

        EG_optim.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch-1:05}_EGo.pth')))

    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()

        G.train()
        E.train()

        total_loss = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            log.debug('-' * 20)

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            X_rec = G(E(X))

            loss = torch.mean(config['reconstruction_coef'] *
                              reconstruction_loss(
                                  X.permute(0, 2, 1) + 0.5,
                                  X_rec.permute(0, 2, 1) + 0.5))

            EG_optim.zero_grad()
            E.zero_grad()
            G.zero_grad()

            loss.backward()
            total_loss += loss.item()
            EG_optim.step()

            log.debug(f'[{epoch}: ({i})] '
                      f'Loss: {loss.item():.4f} '
                      f'Time: {datetime.now() - start_epoch_time}')

        log.debug(f'[{epoch}/{config["max_epochs"]}] '
                  f'Loss: {total_loss / i:.4f} '
                  f'Time: {datetime.now() - start_epoch_time}')

        #
        # Save intermediate results
        #
        G.eval()
        E.eval()
        with torch.no_grad():
            X_rec = G(E(X)).data.cpu().numpy()

        for k in range(5):
            fig = plot_3d_point_cloud(X[k][0],
                                      X[k][1],
                                      X[k][2],
                                      in_u_sphere=True,
                                      show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples', f'{epoch:05}_{k}_real.png'))
            plt.close(fig)

        for k in range(5):
            fig = plot_3d_point_cloud(X_rec[k][0],
                                      X_rec[k][1],
                                      X_rec[k][2],
                                      in_u_sphere=True,
                                      show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples',
                     f'{epoch:05}_{k}_reconstructed.png'))
            plt.close(fig)

        if epoch % config['save_frequency'] == 0:
            torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))

            torch.save(EG_optim.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))
예제 #10
0
def main(config):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('vae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True, pin_memory=True)

    #
    # Models
    #
    arch = import_module(f"models.{config['arch']}")
    G = arch.Generator(config).to(device)
    E = arch.Encoder(config).to(device)

    G.apply(weights_init)
    E.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    elif config['reconstruction_loss'].lower() == 'cramer_wold':
        from losses.cramer_wold import CWSample
        reconstruction_loss = CWSample().to(device)
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')
    #
    # Float Tensors
    #
    fixed_noise = torch.FloatTensor(config['batch_size'], config['z_size'], 1)
    fixed_noise.normal_(mean=0, std=0.2)
    std_assumed = torch.tensor(0.2)

    fixed_noise = fixed_noise.to(device)
    std_assumed = std_assumed.to(device)

    #
    # Optimizers
    #
    EG_optim = getattr(optim, config['optimizer']['EG']['type'])
    EG_optim = EG_optim(chain(E.parameters(), G.parameters()),
                        **config['optimizer']['EG']['hyperparams'])

    if starting_epoch > 1:
        G.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_G.pth')))
        E.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_E.pth')))

        EG_optim.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_EGo.pth')))

    losses = []

    with trange(starting_epoch, config['max_epochs'] + 1) as t:
        for epoch in t:
            start_epoch_time = datetime.now()

            G.train()
            E.train()

            total_loss = 0.0
            losses_eg = []
            losses_e = []
            losses_kld = []

            for i, point_data in enumerate(points_dataloader, 1):
                # log.debug('-' * 20)

                X, _ = point_data
                X = X.to(device)

                # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
                if X.size(-1) == 3:
                    X.transpose_(X.dim() - 2, X.dim() - 1)

                codes, mu, logvar = E(X)
                X_rec = G(codes)

                loss_e = torch.mean(
                    # config['reconstruction_coef'] *
                    reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                        X_rec.permute(0, 2, 1) + 0.5))

                loss_kld = config['reconstruction_coef'] * cw_distance(mu)

                # loss_kld = -0.5 * torch.mean(
                #     1 - 2.0 * torch.log(std_assumed) + logvar -
                #     (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2))

                loss_eg = loss_e + loss_kld
                EG_optim.zero_grad()
                E.zero_grad()
                G.zero_grad()

                loss_eg.backward()
                total_loss += loss_eg.item()
                EG_optim.step()

                losses_e.append(loss_e.item())
                losses_kld.append(loss_kld.item())
                losses_eg.append(loss_eg.item())

                # log.debug

                t.set_description(
                    f'[{epoch}: ({i})] '
                    f'Loss_EG: {loss_eg.item():.4f} '
                    f'(REC: {loss_e.item(): .4f}'
                    f' KLD: {loss_kld.item(): .4f})'
                    f' Time: {datetime.now() - start_epoch_time}'
                )

            t.set_description(
                f'[{epoch}/{config["max_epochs"]}] '
                f'Loss_G: {total_loss / i:.4f} '
                f'Time: {datetime.now() - start_epoch_time}'
            )

            losses.append([
                np.mean(losses_e),
                np.mean(losses_kld),
                np.mean(losses_eg)
            ])

            #
            # Save intermediate results
            #
            G.eval()
            E.eval()
            with torch.no_grad():
                fake = G(fixed_noise).data.cpu().numpy()
                codes, _, _ = E(X)
                X_rec = G(codes).data.cpu().numpy()

            X_numpy = X.cpu().numpy()
            for k in range(5):
                fig = plot_3d_point_cloud(X_numpy[k][0], X_numpy[k][1], X_numpy[k][2],
                                          in_u_sphere=True, show=False)
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

            for k in range(5):
                fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                          in_u_sphere=True, show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples', 'fixed', f'{epoch:05}_{k}_fixed.png'))
                plt.close(fig)

            for k in range(5):
                fig = plot_3d_point_cloud(X_rec[k][0],
                                          X_rec[k][1],
                                          X_rec[k][2],
                                          in_u_sphere=True, show=False)
                fig.savefig(join(results_dir, 'samples',
                                 f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

            if epoch % config['save_frequency'] == 0:
                df = pd.DataFrame(losses, columns=['loss_e', 'loss_kld', 'loss_eg'])
                df.to_json(join(results_dir, 'losses', 'losses.json'))
                fig = df.plot.line().get_figure()
                fig.savefig(join(results_dir, 'losses', f'{epoch:05}_{k}.png'))

                torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
                torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))

                torch.save(EG_optim.state_dict(),
                           join(weights_path, f'{epoch:05}_EGo.pth'))
예제 #11
0
def test_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data augmentation
    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])
    dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TEST,
                                           test_transforms),
        batch_size=cfg.CONST.BATCH_SIZE,
        num_workers=1,
        pin_memory=True,
        shuffle=False)

    # Set up networks
    # Set up networks
    # The parameters here need to be set in cfg
    net = Pixel2Pointcloud(
        3,
        cfg.GRAPHX.NUM_INIT_POINTS,
        optimizer=lambda x: torch.optim.Adam(
            x, lr=cfg.TRAIN.LEARNING_RATE, weight_decay=cfg.TRAIN.WEIGHT_DECAY
        ),
        scheduler=lambda x: MultiStepLR(
            x, milestones=cfg.TRAIN.MILESTONES, gamma=cfg.TRAIN.GAMMA),
        use_graphx=cfg.GRAPHX.USE_GRAPHX)

    view_encoder = Encoder(cfg)

    azi_classes, ele_classes = int(360 / cfg.CONST.BIN_SIZE), int(
        180 / cfg.CONST.BIN_SIZE)
    view_estimater = ViewEstimater(cfg,
                                   azi_classes=azi_classes,
                                   ele_classes=ele_classes)

    if torch.cuda.is_available():
        net = torch.nn.DataParallel(net).cuda()
        view_encoder = torch.nn.DataParallel(view_encoder).cuda()
        view_estimater = torch.nn.DataParallel(view_estimater).cuda()

    # Load weight
    # Load weight for encoder, decoder
    print('[INFO] %s Loading reconstruction weights from %s ...' %
          (dt.now(), cfg.TEST.RECONSTRUCTION_WEIGHTS))
    rec_checkpoint = torch.load(cfg.TEST.RECONSTRUCTION_WEIGHTS)
    net.load_state_dict(rec_checkpoint['net'])
    print('[INFO] Best reconstruction result at epoch %d ...' %
          rec_checkpoint['epoch_idx'])

    # Load weight for view encoder
    print('[INFO] %s Loading view estimation weights from %s ...' %
          (dt.now(), cfg.TEST.VIEW_ENCODER_WEIGHTS))
    view_enc_checkpoint = torch.load(cfg.TEST.VIEW_ENCODER_WEIGHTS)
    view_encoder.load_state_dict(view_enc_checkpoint['encoder_state_dict'])
    print('[INFO] Best view encode result at epoch %d ...' %
          view_enc_checkpoint['epoch_idx'])

    # Load weight for view estimater
    print('[INFO] %s Loading view estimation weights from %s ...' %
          (dt.now(), cfg.TEST.VIEW_ESTIMATION_WEIGHTS))
    view_est_checkpoint = torch.load(cfg.TEST.VIEW_ESTIMATION_WEIGHTS)
    view_estimater.load_state_dict(
        view_est_checkpoint['view_estimator_state_dict'])
    print('[INFO] Best view estimation result at epoch %d ...' %
          view_est_checkpoint['epoch_idx'])

    # Set up loss functions
    emd = EMD().cuda()
    cd = ChamferLoss().cuda()

    # Batch average meterics
    cd_distances = utils.network_utils.AverageMeter()
    emd_distances = utils.network_utils.AverageMeter()
    pointwise_emd_distances = utils.network_utils.AverageMeter()

    test_preds = torch.zeros([1, 3], dtype=torch.float).cuda()
    test_ground_truth_views = torch.zeros([1, 3], dtype=torch.long).cuda()

    # Switch models to evaluation mode
    net.eval()
    view_encoder.eval()
    view_estimater.eval()

    n_batches = len(test_data_loader)

    # Testing loop
    for sample_idx, (taxonomy_names, sample_names, rendering_images,
                     init_point_clouds, ground_truth_point_clouds,
                     ground_truth_views) in enumerate(test_data_loader):
        with torch.no_grad():
            # Only one image per sample
            rendering_images = torch.squeeze(rendering_images, 1)

            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)
            init_point_clouds = utils.network_utils.var_or_cuda(
                init_point_clouds)
            ground_truth_views = utils.network_utils.var_or_cuda(
                ground_truth_views)
            ground_truth_point_clouds = utils.network_utils.var_or_cuda(
                ground_truth_point_clouds)

            #=================================================#
            #           Test the encoder, decoder             #
            #=================================================#
            emd_loss, generated_point_clouds = net.module.loss(
                rendering_images, init_point_clouds, ground_truth_point_clouds,
                'mean')

            # Compute CD, EMD
            cd_distance = cd(generated_point_clouds, ground_truth_point_clouds
                             ) / cfg.CONST.BATCH_SIZE / cfg.CONST.NUM_POINTS
            emd_distance = emd_loss
            pointwise_emd_distance = emd_loss / cfg.CONST.NUM_POINTS

            #=================================================#
            #          Test the view estimater                #
            #=================================================#
            # output[0]:prediction of azi class
            # output[1]:prediction of ele class
            # output[2]:prediction of azi regression
            # output[3]:prediction of ele regression
            vgg_features, _ = view_encoder(rendering_images)
            output = view_estimater(vgg_features)

            #=================================================#
            #              Get predict view                   #
            #=================================================#
            preds_cls = utils.view_pred_utils.get_pred_from_cls_output(
                [output[0], output[1]])

            preds = []
            for n in range(len(preds_cls)):
                pred_delta = output[n + 2]
                delta_value = pred_delta[torch.arange(pred_delta.size(0)),
                                         preds_cls[n].long()].tanh() / 2
                preds.append((preds_cls[n].float() + delta_value + 0.5) *
                             cfg.CONST.BIN_SIZE)
            # In this experiment, expect all the object has a fixed rotaion angle 0 and with diferent view point
            # Add the zero inplane term for the rotation acc calculation
            zero_inplane = torch.zeros_like(preds[0])
            # Add fixed inplane rotation to  pred view
            test_pred = torch.cat(
                (preds[0].unsqueeze(1), preds[1].unsqueeze(1),
                 zero_inplane.unsqueeze(1)), 1)
            # Add fixed inplane rotation to ground_truth_views
            zero_inplane = zero_inplane.unsqueeze(1)
            ground_truth_views = torch.cat(
                (ground_truth_views, zero_inplane.long()), 1)

            # Append loss and accuracy to average metrics
            cd_distances.update(cd_distance.item())
            emd_distances.update(emd_distance.item())
            pointwise_emd_distances.update(pointwise_emd_distance.item())

            # concatenate results and labels for view estimation
            test_preds = torch.cat((test_preds, test_pred), 0)
            test_ground_truth_views = torch.cat(
                (test_ground_truth_views, ground_truth_views), 0)

            print(
                "Test on [%d/%d] data, CD: %.4f Point EMD: %.4f Total EMD %.4f"
                % (sample_idx + 1, n_batches, cd_distance.item(),
                   pointwise_emd_distance.item(), emd_distance.item()))

    test_preds = test_preds[1:, :]
    test_ground_truth_views = test_ground_truth_views[1:, :]

    # calculate the rotation errors between prediction and ground truth
    test_errs = utils.rotation_eval.rotation_err(
        test_preds, test_ground_truth_views.float()).cpu().numpy()
    Acc = 100. * np.mean(test_errs <= 30)
    Med = np.median(test_errs)

    # print result
    print("Reconstruction result:")
    print("CD result: ", cd_distances.avg)
    print("Pointwise EMD result: ", pointwise_emd_distances.avg)
    print("Total EMD result", emd_distances.avg)
    print("View estimation result:")
    print('Med_Err is %.2f, and Acc_pi/6 is %.2f \n \n' % (Med, Acc))
    logname = cfg.TEST.RESULT_PATH
    with open(logname, 'a') as f:
        f.write('Reconstruction result: \n')
        f.write("CD result: %.8f \n" % cd_distances.avg)
        f.write("Pointwise EMD result: %.8f \n" % pointwise_emd_distances.avg)
        f.write("Total EMD result: %.8f \n" % emd_distances.avg)
        f.write('View estimation result: \n')
        f.write('Med_Err is %.2f, and Acc_pi/6 is %.2f \n \n' % (Med, Acc))