def subprocess_fn(rank, args, temp_dir):
    dnnlib.util.Logger(should_flush=True)

    # Init torch.distributed.
    if args.num_gpus > 1:
        init_file = os.path.abspath(
            os.path.join(temp_dir, '.torch_distributed_init'))
        if os.name == 'nt':
            init_method = 'file:///' + init_file.replace('\\', '/')
            torch.distributed.init_process_group(backend='gloo',
                                                 init_method=init_method,
                                                 rank=rank,
                                                 world_size=args.num_gpus)
        else:
            init_method = f'file://{init_file}'
            torch.distributed.init_process_group(backend='nccl',
                                                 init_method=init_method,
                                                 rank=rank,
                                                 world_size=args.num_gpus)

    # Init torch_utils.
    sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
    training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
    if rank != 0 or not args.verbose:
        custom_ops.verbosity = 'none'

    # Print network summary.
    device = torch.device('cuda', rank)
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
    if rank == 0 and args.verbose:
        z = torch.empty([1, G.z_dim], device=device)
        c = torch.empty([1, G.c_dim], device=device)
        misc.print_module_summary(G, [z, c])

    # Calculate each metric.
    for metric in args.metrics:
        if rank == 0 and args.verbose:
            print(f'Calculating {metric}...')
        progress = metric_utils.ProgressMonitor(verbose=args.verbose)
        result_dict = metric_main.calc_metric(
            metric=metric,
            G=G,
            dataset_kwargs=args.dataset_kwargs,
            num_gpus=args.num_gpus,
            rank=rank,
            device=device,
            progress=progress)
        if rank == 0:
            metric_main.report_metric(result_dict,
                                      run_dir=args.run_dir,
                                      snapshot_pkl=args.network_pkl)
        if rank == 0 and args.verbose:
            print()

    # Done.
    if rank == 0 and args.verbose:
        print('Exiting...')
Beispiel #2
0
def evaluate(Gs, snapshot_pkl, metrics, eval_images_num, dataset_args, num_gpus, rank, device, log, 
        logger = None, run_dir = None, print_progress = True):
    for metric in metrics:
        result_dict = metric_main.compute_metric(metric = metric, max_items = eval_images_num, 
            G = Gs, dataset_args = dataset_args, num_gpus = num_gpus, rank = rank, device = device,
            progress = metric_utils.ProgressMonitor(verbose = log))
        if log:
            metric_main.report_metric(result_dict, run_dir = run_dir, snapshot_pkl = snapshot_pkl)
        if logger is not None:
            logger.metrics.update(result_dict.results)