def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config,
                                      config['arch'],
                                      'experiments',
                                      dirs_to_create=[
                                          'interpolations', 'sphere',
                                          'points_interpolation',
                                          'different_number_points', 'fixed',
                                          'reconstruction', 'sphere_triangles',
                                          'sphere_triangles_interpolation'
                                      ])
    weights_path = get_weights_dir(config)
    epoch = find_latest_epoch(weights_path)

    if not epoch:
        print("Invalid 'weights_path' in configuration")
        exit(1)

    setup_logging(results_dir)
    global log
    log = logging.getLogger('aae')

    if not exists(join(results_dir, 'experiment_config.json')):
        with open(join(results_dir, 'experiment_config.json'), mode='w') as f:
            json.dump(config, f)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    elif dataset_name == 'custom':
        dataset = TxtDataset(root_dir=config['data_dir'],
                             classes=config['classes'],
                             config=config)
    elif dataset_name == 'benchmark':
        dataset = Benchmark(root_dir=config['data_dir'],
                            classes=config['classes'],
                            config=config)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=64,
                                   shuffle=True,
                                   num_workers=8,
                                   drop_last=True,
                                   pin_memory=True,
                                   collate_fn=collate_fn)

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder_visible = aae.VisibleEncoder(config).to(device)
    encoder_pocket = aae.PocketEncoder(config).to(device)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        # from utils.metrics import earth_mover_distance
        # reconstruction_loss = earth_mover_distance
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    log.info("Weights for epoch: %s" % epoch)

    log.info("Loading weights...")
    hyper_network.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_G.pth')))
    encoder_pocket.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_EP.pth')))
    encoder_visible.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_EV.pth')))

    hyper_network.eval()
    encoder_visible.eval()
    encoder_pocket.eval()

    total_loss_eg = 0.0
    total_loss_e = 0.0
    total_loss_kld = 0.0
    x = []

    with torch.no_grad():
        for i, point_data in enumerate(points_dataloader, 1):
            X = point_data['non-visible']
            X = X.to(device, dtype=torch.float)

            # get whole point cloud
            X_whole = point_data['cloud']
            X_whole = X_whole.to(device, dtype=torch.float)

            # get visible point cloud
            X_visible = point_data['visible']
            X_visible = X_visible.to(device, dtype=torch.float)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)
                X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1)
                X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)

            x.append(X)
            codes, mu, logvar = encoder_pocket(X)
            mu_visible = encoder_visible(X_visible)
            target_networks_weights = hyper_network(
                torch.cat((codes, mu_visible), 1))

            X_rec = torch.zeros(X_whole.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)
                target_network_input = generate_points(config=config,
                                                       epoch=epoch,
                                                       size=(X_whole.shape[2],
                                                             X_whole.shape[1]))
                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            loss_e = torch.mean(config['reconstruction_coef'] *
                                reconstruction_loss(
                                    X_whole.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 -
                              logvar).sum()

            loss_eg = loss_e + loss_kld
            total_loss_e += loss_e.item()
            total_loss_kld += loss_kld.item()
            total_loss_eg += loss_eg.item()

        log.info(f'Loss_ALL: {total_loss_eg / i:.4f} '
                 f'Loss_R: {total_loss_e / i:.4f} '
                 f'Loss_E: {total_loss_kld / i:.4f} ')

        # take the lowest possible first dim
        min_dim = min(x, key=lambda X: X.shape[2]).shape[2]
        x = [X[:, :, :min_dim] for X in x]
        x = torch.cat(x)

        if config['experiments']['interpolation']['execute']:
            interpolation(
                x, encoder_pocket, hyper_network, device, results_dir, epoch,
                config['experiments']['interpolation']['amount'],
                config['experiments']['interpolation']['transitions'])

        if config['experiments']['interpolation_between_two_points'][
                'execute']:
            interpolation_between_two_points(
                encoder_pocket, hyper_network, device, x, results_dir, epoch,
                config['experiments']['interpolation_between_two_points']
                ['amount'], config['experiments']
                ['interpolation_between_two_points']['image_points'],
                config['experiments']['interpolation_between_two_points']
                ['transitions'])

        if config['experiments']['reconstruction']['execute']:
            reconstruction(encoder_pocket, hyper_network, device, x,
                           results_dir, epoch,
                           config['experiments']['reconstruction']['amount'])

        if config['experiments']['sphere']['execute']:
            sphere(encoder_pocket, hyper_network, device, x, results_dir,
                   epoch, config['experiments']['sphere']['amount'],
                   config['experiments']['sphere']['image_points'],
                   config['experiments']['sphere']['start'],
                   config['experiments']['sphere']['end'],
                   config['experiments']['sphere']['transitions'])

        if config['experiments']['sphere_triangles']['execute']:
            sphere_triangles(
                encoder_pocket, hyper_network, device, x, results_dir,
                config['experiments']['sphere_triangles']['amount'],
                config['experiments']['sphere_triangles']['method'],
                config['experiments']['sphere_triangles']['depth'],
                config['experiments']['sphere_triangles']['start'],
                config['experiments']['sphere_triangles']['end'],
                config['experiments']['sphere_triangles']['transitions'])

        if config['experiments']['sphere_triangles_interpolation']['execute']:
            sphere_triangles_interpolation(
                encoder_pocket, hyper_network, device, x, results_dir,
                config['experiments']['sphere_triangles_interpolation']
                ['amount'], config['experiments']
                ['sphere_triangles_interpolation']['method'],
                config['experiments']['sphere_triangles_interpolation']
                ['depth'], config['experiments']
                ['sphere_triangles_interpolation']['coefficient'],
                config['experiments']['sphere_triangles_interpolation']
                ['transitions'])

        if config['experiments']['different_number_of_points']['execute']:
            different_number_of_points(
                encoder_pocket, hyper_network, x, device, results_dir, epoch,
                config['experiments']['different_number_of_points']['amount'],
                config['experiments']['different_number_of_points']
                ['image_points'])

        if config['experiments']['fixed']['execute']:
            # get visible element from loader (probably should be done using given object for example using
            # parser

            points_dataloader = DataLoader(dataset,
                                           batch_size=10,
                                           shuffle=True,
                                           num_workers=8,
                                           drop_last=True,
                                           pin_memory=True,
                                           collate_fn=collate_fn)
            X_visible = next(iter(points_dataloader))['visible'].to(
                device, dtype=torch.float)
            X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)

            fixed(hyper_network, encoder_visible, X_visible, device,
                  results_dir, epoch, config['experiments']['fixed']['amount'],
                  config['z_size'] // 2,
                  config['experiments']['fixed']['mean'],
                  config['experiments']['fixed']['std'], (3, 2048),
                  config['experiments']['fixed']['triangulation']['execute'],
                  config['experiments']['fixed']['triangulation']['method'],
                  config['experiments']['fixed']['triangulation']['depth'])
예제 #2
0
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config, 'vae', 'training')
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('vae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')
    metrics_path = join(results_dir, 'metrics')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    elif dataset_name == 'custom':
        dataset = TxtDataset(root_dir=config['data_dir'],
                             classes=config['classes'],
                             config=config)
    elif dataset_name == 'benchmark':
        dataset = Benchmark(root_dir=config['data_dir'],
                            classes=config['classes'],
                            config=config)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True,
                                   pin_memory=True,
                                   collate_fn=collate_fn)

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder_pocket = aae.PocketEncoder(config).to(device)
    encoder_visible = aae.VisibleEncoder(config).to(device)

    hyper_network.apply(weights_init)
    encoder_pocket.apply(weights_init)
    encoder_visible.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        # from utils.metrics import earth_mover_distance
        # reconstruction_loss = earth_mover_distance
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type'])
    e_hn_optimizer = e_hn_optimizer(
        chain(encoder_visible.parameters(), encoder_pocket.parameters(),
              hyper_network.parameters()),
        **config['optimizer']['E_HN']['hyperparams'])

    log.info("Starting epoch: %s" % starting_epoch)
    if starting_epoch > 1:
        log.info("Loading weights...")
        hyper_network.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_G.pth')))
        encoder_pocket.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EP.pth')))
        encoder_visible.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EV.pth')))

        e_hn_optimizer.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EGo.pth')))

        log.info("Loading losses...")
        losses_e = np.load(join(metrics_path,
                                f'{starting_epoch - 1:05}_E.npy')).tolist()
        losses_kld = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_KLD.npy')).tolist()
        losses_eg = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist()
    else:
        log.info("First epoch")
        losses_e = []
        losses_kld = []
        losses_eg = []

    if config['target_network_input']['normalization']['enable']:
        normalization_type = config['target_network_input']['normalization'][
            'type']
        assert normalization_type == 'progressive', 'Invalid normalization type'

    target_network_input = None
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()
        log.debug("Epoch: %s" % epoch)
        hyper_network.train()
        encoder_visible.train()
        encoder_pocket.train()

        total_loss_all = 0.0
        total_loss_r = 0.0
        total_loss_kld = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            # get only visible part of point cloud
            X = point_data['non-visible']
            X = X.to(device, dtype=torch.float)

            # get not visible part of point cloud
            X_visible = point_data['visible']
            X_visible = X_visible.to(device, dtype=torch.float)

            # get whole point cloud
            X_whole = point_data['cloud']
            X_whole = X_whole.to(device, dtype=torch.float)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)
                X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)
                X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1)

            codes, mu, logvar = encoder_pocket(X)
            mu_visible = encoder_visible(X_visible)

            target_networks_weights = hyper_network(
                torch.cat((codes, mu_visible), 1))

            X_rec = torch.zeros(X_whole.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)

                if not config['target_network_input'][
                        'constant'] or target_network_input is None:
                    target_network_input = generate_points(
                        config=config,
                        epoch=epoch,
                        size=(X_whole.shape[2], X_whole.shape[1]))

                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            loss_r = torch.mean(config['reconstruction_coef'] *
                                reconstruction_loss(
                                    X_whole.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 -
                              logvar).sum()

            loss_all = loss_r + loss_kld
            e_hn_optimizer.zero_grad()
            encoder_visible.zero_grad()
            encoder_pocket.zero_grad()
            hyper_network.zero_grad()

            loss_all.backward()
            e_hn_optimizer.step()

            total_loss_r += loss_r.item()
            total_loss_kld += loss_kld.item()
            total_loss_all += loss_all.item()

        log.info(f'[{epoch}/{config["max_epochs"]}] '
                 f'Loss_ALL: {total_loss_all / i:.4f} '
                 f'Loss_R: {total_loss_r / i:.4f} '
                 f'Loss_E: {total_loss_kld / i:.4f} '
                 f'Time: {datetime.now() - start_epoch_time}')

        losses_e.append(total_loss_r)
        losses_kld.append(total_loss_kld)
        losses_eg.append(total_loss_all)

        #
        # Save intermediate results
        #
        X = X.cpu().numpy()
        X_whole = X_whole.cpu().numpy()
        X_rec = X_rec.detach().cpu().numpy()

        if epoch % config['save_frequency'] == 0:
            for k in range(min(5, X_rec.shape[0])):
                fig = plot_3d_point_cloud(X_rec[k][0],
                                          X_rec[k][1],
                                          X_rec[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples',
                         f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X_whole[k][0],
                                          X_whole[k][1],
                                          X_whole[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X[k][0],
                                          X[k][1],
                                          X[k][2],
                                          in_u_sphere=True,
                                          show=False)
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_visible.png'))
                plt.close(fig)

        if config['clean_weights_dir']:
            log.debug('Cleaning weights path: %s' % weights_path)
            shutil.rmtree(weights_path, ignore_errors=True)
            os.makedirs(weights_path, exist_ok=True)

        if epoch % config['save_frequency'] == 0:
            log.debug('Saving data...')

            torch.save(hyper_network.state_dict(),
                       join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(encoder_visible.state_dict(),
                       join(weights_path, f'{epoch:05}_EV.pth'))
            torch.save(encoder_pocket.state_dict(),
                       join(weights_path, f'{epoch:05}_EP.pth'))
            torch.save(e_hn_optimizer.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))

            np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e))
            np.save(join(metrics_path, f'{epoch:05}_KLD'),
                    np.array(losses_kld))
            np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
예제 #3
0
def all_metrics(config, weights_path, device, epoch, jsd_value):
    from utils.metrics import compute_all_metrics
    print('All metrics')
    if epoch is None:
        print('Finding latest epoch...')
        epoch = find_latest_epoch(weights_path)
        print(f'Epoch: {epoch}')

    if jsd_value is not None:
        print(f'Best Epoch selected via mimnimal JSD: {epoch}')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'],
                                  split='test')
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    classes_selected = ('all' if not config['classes'] else ','.join(
        config['classes']))
    print(
        f'Test dataset. Selected {classes_selected} classes. Loaded {len(dataset)} '
        f'samples.')

    distribution = config['metrics']['distribution']
    assert distribution in ['normal', 'beta'
                            ], 'Invalid distribution. Choose normal or beta'

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)

    hyper_network.eval()

    data_loader = DataLoader(dataset,
                             batch_size=32,
                             shuffle=False,
                             num_workers=4,
                             drop_last=False,
                             pin_memory=True)

    hyper_network.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_G.pth')))

    result = {}
    size = 0
    start_clock = datetime.now()
    for point_data in data_loader:

        X, _ = point_data
        X = X.to(device)

        # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
        if X.size(-1) == 3:
            X.transpose_(X.dim() - 2, X.dim() - 1)

        with torch.no_grad():
            noise = torch.zeros(X.shape[0], config['z_size']).to(device)
            if distribution == 'normal':
                noise.normal_(config['metrics']['normal_mu'],
                              config['metrics']['normal_std'])
            elif distribution == 'beta':
                noise_np = np.random.beta(config['metrics']['beta_a'],
                                          config['metrics']['beta_b'],
                                          noise.shape)
                noise = torch.tensor(noise_np).float().round().to(device)

            target_networks_weights = hyper_network(noise)

            X_rec = torch.zeros(X.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)

                target_network_input = generate_points(config=config,
                                                       epoch=epoch,
                                                       size=(X.shape[2],
                                                             X.shape[1]))

                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            for k, v in compute_all_metrics(
                    torch.transpose(X, 1, 2).contiguous(),
                    torch.transpose(X_rec, 1, 2).contiguous(),
                    X.shape[0]).items():
                result[k] = (size * result.get(k, 0.0) +
                             X.shape[0] * v.item()) / (size + X.shape[0])

        size += X.shape[0]

    result['jsd'] = jsd_value
    print(f'Time: {datetime.now() - start_clock}')
    print(f'Result:')
    for k, v in result.items():
        print(f'{k}: {v}')
예제 #4
0
def jsd(config, weights_path, device):
    print(
        'Evaluating Jensen-Shannon divergences on validation set on all saved epochs.'
    )

    # Find all epochs that have saved model weights
    e_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_E\.pth')
    g_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_G\.pth')
    epochs = sorted(e_epochs.intersection(g_epochs))

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'],
                                  split='valid')
    else:
        raise ValueError(
            f'Invalid dataset name. Expected `shapenet`. Got: `{dataset_name}`'
        )

    classes_selected = ('all' if not config['classes'] else ','.join(
        config['classes']))
    print(
        f'Valid dataset. Selected {classes_selected} classes. Loaded {len(dataset)} '
        f'samples.')

    distribution = config['metrics']['distribution']
    assert distribution in ['normal', 'beta'
                            ], 'Invalid distribution. Choose normal or beta'

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)

    hyper_network.eval()

    num_samples = len(dataset.point_clouds_names_valid)
    data_loader = DataLoader(dataset,
                             batch_size=num_samples,
                             shuffle=False,
                             num_workers=4,
                             drop_last=False,
                             pin_memory=True)

    X, _ = next(iter(data_loader))
    X = X.to(device)

    # We take 3 times as many samples as there are in test data in order to
    # perform JSD calculation in the same manner as in the reference publication
    noise = torch.zeros(3 * X.shape[0], config['z_size']).to(device)

    results = {}

    n_last_epochs = config['metrics'].get('jsd_how_many_last_epochs', -1)
    epochs = epochs[-n_last_epochs:] if n_last_epochs > 0 else epochs
    print(f'Testing epochs: {epochs}')

    for epoch in reversed(epochs):
        try:
            hyper_network.load_state_dict(
                torch.load(join(weights_path, f'{epoch:05}_G.pth')))

            start_clock = datetime.now()

            # We average JSD computation from 3 independent trials.
            js_results = []
            for _ in range(3):
                if distribution == 'normal':
                    noise.normal_(config['metrics']['normal_mu'],
                                  config['metrics']['normal_std'])
                elif distribution == 'beta':
                    noise_np = np.random.beta(config['metrics']['beta_a'],
                                              config['metrics']['beta_b'],
                                              noise.shape)
                    noise = torch.tensor(noise_np).float().round().to(device)

                with torch.no_grad():
                    target_networks_weights = hyper_network(noise)

                    X_rec = torch.zeros(3 * X.shape[0], X.shape[1],
                                        X.shape[2]).to(device)

                    # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
                    if X_rec.size(-1) == 3:
                        X_rec.transpose_(X_rec.dim() - 2, X_rec.dim() - 1)

                    for j, target_network_weights in enumerate(
                            target_networks_weights):
                        target_network = aae.TargetNetwork(
                            config, target_network_weights).to(device)

                        target_network_input = generate_points(
                            config=config,
                            epoch=epoch,
                            size=(X_rec.shape[2], X_rec.shape[1]))

                        X_rec[j] = torch.transpose(
                            target_network(target_network_input.to(device)), 0,
                            1)

                jsd = jsd_between_point_cloud_sets(
                    X.cpu().numpy(),
                    torch.transpose(X_rec, 1, 2).cpu().numpy())
                js_results.append(jsd)

            js_result = np.mean(js_results)
            print(f'Epoch: {epoch} JSD: {js_result: .6f} '
                  f'Time: {datetime.now() - start_clock}')
            results[epoch] = js_result
        except KeyboardInterrupt:
            print(f'Interrupted during epoch: {epoch}')
            break

    results = pd.DataFrame.from_dict(results, orient='index', columns=['jsd'])
    print(f"Minimum JSD at epoch {results.idxmin()['jsd']}: "
          f"{results.min()['jsd']: .6f}")

    return results.idxmin()['jsd'], results.min()['jsd']
예제 #5
0
def minimum_matching_distance(config, weights_path, device):
    from utils.metrics import EMD_CD
    print('Minimum Matching Distance (MMD) Test split')
    epoch = find_latest_epoch(weights_path)
    print(f'Last Epoch: {epoch}')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'],
                                  split='test')
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    classes_selected = ('all' if not config['classes'] else ','.join(
        config['classes']))
    print(
        f'Test dataset. Selected {classes_selected} classes. Loaded {len(dataset)} '
        f'samples.')

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder = aae.Encoder(config).to(device)

    hyper_network.eval()
    encoder.eval()

    num_samples = len(dataset.point_clouds_names_test)
    data_loader = DataLoader(dataset,
                             batch_size=num_samples,
                             shuffle=False,
                             num_workers=4,
                             drop_last=False,
                             pin_memory=True)

    encoder.load_state_dict(torch.load(join(weights_path,
                                            f'{epoch:05}_E.pth')))
    hyper_network.load_state_dict(
        torch.load(join(weights_path, f'{epoch:05}_G.pth')))

    result = {}

    for point_data in data_loader:

        X, _ = point_data
        X = X.to(device)

        # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
        if X.size(-1) == 3:
            X.transpose_(X.dim() - 2, X.dim() - 1)

        with torch.no_grad():
            z_a, _, _ = encoder(X)
            target_networks_weights = hyper_network(z_a)

            X_rec = torch.zeros(X.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)

                target_network_input = generate_points(config=config,
                                                       epoch=epoch,
                                                       size=(X.shape[2],
                                                             X.shape[1]))

                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            for k, v in EMD_CD(
                    torch.transpose(X, 1, 2).contiguous(),
                    torch.transpose(X_rec, 1, 2).contiguous(),
                    X.shape[0]).items():
                result[k] = result.get(k, 0.0) + v.item()

    print(result)
예제 #6
0
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config, 'aae', 'training')
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('aae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')
    metrics_path = join(results_dir, 'metrics')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   pin_memory=True)

    pointnet = config.get('pointnet', False)
    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)

    if pointnet:
        from models.pointnet import PointNet
        encoder = PointNet(config).to(device)
        # PointNet initializes it's own weights during instance creation
    else:
        encoder = aae.Encoder(config).to(device)
        encoder.apply(weights_init)

    discriminator = aae.Discriminator(config).to(device)

    hyper_network.apply(weights_init)
    discriminator.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        if pointnet:
            from utils.metrics import chamfer_distance
            reconstruction_loss = chamfer_distance
        else:
            from losses.champfer_loss import ChamferLoss
            reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from utils.metrics import earth_mover_distance
        reconstruction_loss = earth_mover_distance
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type'])
    e_hn_optimizer = e_hn_optimizer(chain(encoder.parameters(), hyper_network.parameters()),
                                    **config['optimizer']['E_HN']['hyperparams'])

    discriminator_optimizer = getattr(optim, config['optimizer']['D']['type'])
    discriminator_optimizer = discriminator_optimizer(discriminator.parameters(),
                                                      **config['optimizer']['D']['hyperparams'])

    log.info("Starting epoch: %s" % starting_epoch)
    if starting_epoch > 1:
        log.info("Loading weights...")
        hyper_network.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_G.pth')))
        encoder.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_E.pth')))
        discriminator.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_D.pth')))

        e_hn_optimizer.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_EGo.pth')))

        discriminator_optimizer.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_Do.pth')))

        log.info("Loading losses...")
        losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist()
        losses_g = np.load(join(metrics_path, f'{starting_epoch - 1:05}_G.npy')).tolist()
        losses_eg = np.load(join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist()
        losses_d = np.load(join(metrics_path, f'{starting_epoch - 1:05}_D.npy')).tolist()
    else:
        log.info("First epoch")
        losses_e = []
        losses_g = []
        losses_eg = []
        losses_d = []

    normalize_points = config['target_network_input']['normalization']['enable']
    if normalize_points:
        normalization_type = config['target_network_input']['normalization']['type']
        assert normalization_type == 'progressive', 'Invalid normalization type'

    target_network_input = None
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()
        log.debug("Epoch: %s" % epoch)
        hyper_network.train()
        encoder.train()
        discriminator.train()

        total_loss_all = 0.0
        total_loss_reconstruction = 0.0
        total_loss_encoder = 0.0
        total_loss_discriminator = 0.0
        total_loss_regularization = 0.0
        for i, point_data in enumerate(points_dataloader, 1):

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            if pointnet:
                _, feature_transform, codes = encoder(X)
            else:
                codes, _, _ = encoder(X)

            # discriminator training
            noise = torch.empty(codes.shape[0], config['z_size']).normal_(mean=config['normal_mu'],
                                                                          std=config['normal_std']).to(device)
            synth_logit = discriminator(codes)
            real_logit = discriminator(noise)
            if config.get('wasserstein', True):
                loss_discriminator = torch.mean(synth_logit) - torch.mean(real_logit)

                alpha = torch.rand(codes.shape[0], 1).to(device)
                differences = codes - noise
                interpolates = noise + alpha * differences
                disc_interpolates = discriminator(interpolates)

                # gradient_penalty_function
                gradients = grad(
                    outputs=disc_interpolates,
                    inputs=interpolates,
                    grad_outputs=torch.ones_like(disc_interpolates).to(device),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)[0]
                slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1))
                gradient_penalty = ((slopes - 1) ** 2).mean()
                loss_gp = config['gradient_penalty_coef'] * gradient_penalty
                loss_discriminator += loss_gp
            else:
                # An alternative is a = -1, b = 1 iff c = 0
                a = 0.0
                b = 1.0
                loss_discriminator = 0.5 * ((real_logit - b)**2 + (synth_logit - a)**2)

            discriminator_optimizer.zero_grad()
            discriminator.zero_grad()

            loss_discriminator.backward(retain_graph=True)
            total_loss_discriminator += loss_discriminator.item()
            discriminator_optimizer.step()

            # hyper network training
            target_networks_weights = hyper_network(codes)

            X_rec = torch.zeros(X.shape).to(device)
            for j, target_network_weights in enumerate(target_networks_weights):
                target_network = aae.TargetNetwork(config, target_network_weights).to(device)

                if not config['target_network_input']['constant'] or target_network_input is None:
                    target_network_input = generate_points(config=config, epoch=epoch, size=(X.shape[2], X.shape[1]))

                X_rec[j] = torch.transpose(target_network(target_network_input.to(device)), 0, 1)

            if pointnet:
                loss_reconstruction = config['reconstruction_coef'] * \
                                      reconstruction_loss(torch.transpose(X, 1, 2).contiguous(),
                                                          torch.transpose(X_rec, 1, 2).contiguous(),
                                                          batch_size=X.shape[0]).mean()
            else:
                loss_reconstruction = torch.mean(
                    config['reconstruction_coef'] *
                    reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                        X_rec.permute(0, 2, 1) + 0.5))

            # encoder training
            synth_logit = discriminator(codes)
            if config.get('wasserstein', True):
                loss_encoder = -torch.mean(synth_logit)
            else:
                # An alternative is c = 0 iff a = -1, b = 1
                c = 1.0
                loss_encoder = 0.5 * (synth_logit - c)**2

            if pointnet:
                regularization_loss = config['feature_regularization_coef'] * \
                                      feature_transform_regularization(feature_transform).mean()
                loss_all = loss_reconstruction + loss_encoder + regularization_loss
            else:
                loss_all = loss_reconstruction + loss_encoder

            e_hn_optimizer.zero_grad()
            encoder.zero_grad()
            hyper_network.zero_grad()

            loss_all.backward()
            e_hn_optimizer.step()

            total_loss_reconstruction += loss_reconstruction.item()
            total_loss_encoder += loss_encoder.item()
            total_loss_all += loss_all.item()

            if pointnet:
                total_loss_regularization += regularization_loss.item()

        log.info(
            f'[{epoch}/{config["max_epochs"]}] '
            f'Total_Loss: {total_loss_all / i:.4f} '
            f'Loss_R: {total_loss_reconstruction / i:.4f} '
            f'Loss_E: {total_loss_encoder / i:.4f} '
            f'Loss_D: {total_loss_discriminator / i:.4f} '
            f'Time: {datetime.now() - start_epoch_time}'
        )

        if pointnet:
            log.info(f'Loss_Regularization: {total_loss_regularization / i:.4f}')

        losses_e.append(total_loss_reconstruction)
        losses_g.append(total_loss_encoder)
        losses_eg.append(total_loss_all)
        losses_d.append(total_loss_discriminator)

        #
        # Save intermediate results
        #
        if epoch % config['save_samples_frequency'] == 0:
            log.debug('Saving samples...')

            X = X.cpu().numpy()
            X_rec = X_rec.detach().cpu().numpy()

            for k in range(min(5, X_rec.shape[0])):
                fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False,
                                          title=str(epoch))
                fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False)
                fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

        if config['clean_weights_dir']:
            log.debug('Cleaning weights path: %s' % weights_path)
            shutil.rmtree(weights_path, ignore_errors=True)
            os.makedirs(weights_path, exist_ok=True)

        if epoch % config['save_weights_frequency'] == 0:
            log.debug('Saving weights and losses...')

            torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(encoder.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))
            torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))
            torch.save(discriminator.state_dict(), join(weights_path, f'{epoch:05}_D.pth'))
            torch.save(discriminator_optimizer.state_dict(), join(weights_path, f'{epoch:05}_Do.pth'))

            np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e))
            np.save(join(metrics_path, f'{epoch:05}_G'), np.array(losses_g))
            np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
            np.save(join(metrics_path, f'{epoch:05}_D'), np.array(losses_d))