Beispiel #1
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'Evaluate ProgressiveGAN w/ Hessian Penalty on Various Metrics')

    # Model/Dataset Parameters:
    parser.add_argument(
        '--model',
        required=True,
        help=
        'Either the number of experiment in results directory or a path to a .pkl checkpoint.'
    )
    parser.add_argument('--num_gpus',
                        type=int,
                        required=True,
                        help='Number of GPUs to evaluate with')
    parser.add_argument(
        '--snapshot_kimg',
        default='latest',
        help='network-snapshot-<snapshot_kimg>.pkl to evaluate')
    parser.add_argument(
        '--dataset',
        type=str,
        default='edges_and_shoes',
        help=
        'Name of TFRecords directory in datasets/ folder to run metrics with.')
    parser.add_argument('--resolution',
                        type=int,
                        default=128,
                        help='Resolution of real data to evaluate with.')
    parser.add_argument('--metrics',
                        nargs='+',
                        default=['FID', 'PPL'],
                        help='Metrics to run. Must specify at least one')

    opt = parser.parse_args()
    opt = EasyDict(vars(opt))
    assert opt.num_gpus in [1, 2, 4, 8]
    if os.path.isdir(
            opt.model
    ):  # If you pass a directory organized as DIRECTORY/dataset/model.pkl, it will iterate:git
        paths = sorted(glob(f'{opt.model}/*/*.pkl'))
        for path in paths:
            opt.model = path
            if 'clevr_simple' in path or 'clevr_u' in path:
                opt.dataset = 'clevr_simple'
            elif 'clevr_complex' in path:
                opt.dataset = 'clevr_two_obj'
            elif 'edgeshoes' in path:
                opt.dataset = 'edges_and_shoes'
            elif 'clevr_1fov' in path:
                opt.dataset = 'clevr_1fov'
            else:
                print(f'Couldn\'t find dataset for {path}')
                raise NotImplementedError
            run(opt)
    else:
        run(opt)