def calc_jsd_valid(model, config, prior_std=1.0, split='valid'):
    model.eval()
    device = cuda_setup(config['cuda'], config['gpu'])
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'],
                                  split=split)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    classes_selected = ('all' if not config['classes'] else ','.join(
        config['classes']))
    num_samples = len(dataset.point_clouds_names_valid)
    data_loader = DataLoader(dataset,
                             batch_size=num_samples,
                             shuffle=False,
                             num_workers=4,
                             drop_last=False,
                             pin_memory=True)
    # We take 3 times as many samples as there are in test data in order to
    # perform JSD calculation in the same manner as in the reference publication

    x, _ = next(iter(data_loader))
    x = x.to(device)

    # We average JSD computation from 3 independent trials.
    js_results = []
    for _ in range(3):
        noise = prior_std * torch.randn(3 * num_samples, model.zdim)
        noise = noise.to(device)

        with torch.no_grad():
            x_g = model.decode(noise)
        if x_g.shape[-2:] == (3, 2048):
            x_g.transpose_(1, 2)

        jsd = jsd_between_point_cloud_sets(x, x_g, voxels=28)
        js_results.append(jsd)
    js_result = np.mean(js_results)
    return js_result
def main(eval_config):
    # Load hyperparameters as they were during training
    train_results_path = join(eval_config['results_root'], eval_config['arch'],
                              eval_config['experiment_name'])
    with open(join(train_results_path, 'config.json')) as f:
        train_config = json.load(f)

    random.seed(train_config['seed'])
    torch.manual_seed(train_config['seed'])
    torch.cuda.manual_seed_all(train_config['seed'])

    setup_logging(join(train_results_path, 'results'))
    log = logging.getLogger(__name__)

    log.debug('Evaluating JensenShannon divergences on validation set on all '
              'saved epochs.')

    weights_path = join(train_results_path, 'weights')

    # Find all epochs that have saved model weights
    e_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_E\.pth')
    g_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_G\.pth')
    epochs = sorted(e_epochs.intersection(g_epochs))
    log.debug(f'Testing epochs: {epochs}')

    device = cuda_setup(eval_config['cuda'], eval_config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    #
    # Dataset
    #
    dataset_name = train_config['dataset'].lower()
    if dataset_name == 'shapenet':
        dataset = ShapeNetDataset(root_dir=train_config['data_dir'],
                                  classes=train_config['classes'],
                                  split='valid')
    elif dataset_name == 'faust':
        from datasets.dfaust import DFaustDataset
        dataset = DFaustDataset(root_dir=train_config['data_dir'],
                                classes=train_config['classes'],
                                split='valid')
    elif dataset_name == 'mcgill':
        from datasets.mcgill import McGillDataset
        dataset = McGillDataset(root_dir=train_config['data_dir'],
                                classes=train_config['classes'],
                                split='valid')
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    classes_selected = ('all' if not train_config['classes'] else ','.join(
        train_config['classes']))
    log.debug(f'Selected {classes_selected} classes. Loaded {len(dataset)} '
              f'samples.')

    if 'distribution' in train_config:
        distribution = train_config['distribution']
    elif 'distribution' in eval_config:
        distribution = eval_config['distribution']
    else:
        log.warning(
            'No distribution type specified. Assumed normal = N(0, 0.2)')
        distribution = 'normal'

    #
    # Models
    #
    arch = import_module(f"model.architectures.{train_config['arch']}")
    E = arch.Encoder(train_config).to(device)
    G = arch.Generator(train_config).to(device)

    E.eval()
    G.eval()

    num_samples = len(dataset.point_clouds_names_valid)
    data_loader = DataLoader(dataset,
                             batch_size=num_samples,
                             shuffle=False,
                             num_workers=4,
                             drop_last=False,
                             pin_memory=True)

    # We take 3 times as many samples as there are in test data in order to
    # perform JSD calculation in the same manner as in the reference publication
    noise = torch.FloatTensor(3 * num_samples, train_config['z_size'], 1)
    noise = noise.to(device)

    X, _ = next(iter(data_loader))
    X = X.to(device)

    results = {}

    for epoch in reversed(epochs):
        try:
            E.load_state_dict(
                torch.load(join(weights_path, f'{epoch:05}_E.pth')))
            G.load_state_dict(
                torch.load(join(weights_path, f'{epoch:05}_G.pth')))

            start_clock = datetime.now()

            # We average JSD computation from 3 independet trials.
            js_results = []
            for _ in range(3):
                if distribution == 'normal':
                    noise.normal_(0, 0.2)
                elif distribution == 'beta':
                    noise_np = np.random.beta(train_config['z_beta_a'],
                                              train_config['z_beta_b'],
                                              noise.shape)
                    noise = torch.tensor(noise_np).float().round().to(device)

                with torch.no_grad():
                    X_g = G(noise)
                if X_g.shape[-2:] == (3, 2048):
                    X_g.transpose_(1, 2)

                jsd = jsd_between_point_cloud_sets(X, X_g, voxels=28)
                js_results.append(jsd)

            js_result = np.mean(js_results)
            log.debug(f'Epoch: {epoch} JSD: {js_result: .6f} '
                      f'Time: {datetime.now() - start_clock}')
            results[epoch] = js_result
        except KeyboardInterrupt:
            log.debug(f'Interrupted during epoch: {epoch}')
            break

    results = pd.DataFrame.from_dict(results, orient='index', columns=['jsd'])
    log.debug(f"Minimum JSD at epoch {results.idxmin()['jsd']}: "
              f"{results.min()['jsd']: .6f}")