Exemple #1
0
    def test_json_write_and_load(self):
        dict_to_write = dict(a=1, b=2, c=3)
        output_file = os.path.join(self.log_dir, 'output.json')
        common.write_to_json(dict_to_write, output_file)
        check = common.load_from_json(output_file)

        assert dict_to_write == check
Exemple #2
0
    def _log_params(self, params):
        """
        Takes the argument options to save into a json file.
        """
        params_file = os.path.join(self.log_dir, 'params.json')

        # Check for discrepancy with previous training config.
        if os.path.exists(params_file):
            check = common.load_from_json(params_file)

            if params != check:
                diffs = []
                for k in params:
                    if k in check and params[k] != check[k]:
                        diffs.append('{}: Expected {} but got {}.'.format(
                            k, check[k], params[k]))

                diff_string = '\n'.join(diffs)
                raise ValueError(
                    "Current hyperparameter configuration is different from previously:\n{}"
                    .format(diff_string))

        common.write_to_json(params, params_file)
Exemple #3
0
def evaluate(metric,
             netG,
             log_dir,
             evaluate_range=None,
             evaluate_step=None,
             num_runs=3,
             start_seed=0,
             overwrite=False,
             write_to_json=True,
             device=None,
             **kwargs):
    """
    Evaluates a generator over several runs.

    Args:
        metric (str): The name of the metric for evaluation.
        netG (Module): Torch generator model to evaluate.
        log_dir (str): The path to the log directory.
        evaluate_range (tuple): The 3 valued tuple for defining a for loop.
        evaluate_step (int): The specific checkpoint to load. Used in place of evaluate_range.
        device (str): Device identifier to use for computation.
        num_runs (int): The number of runs to compute FID for each checkpoint.
        start_seed (int): Starting random seed to use.
        write_to_json (bool): If True, writes to an output json file in log_dir.
        overwrite (bool): If True, then overwrites previous metric score.

    Returns:
        None
    """
    # Check evaluation range/steps
    if evaluate_range and evaluate_step or not (evaluate_step
                                                or evaluate_range):
        raise ValueError(
            "Only one of evaluate_step or evaluate_range can be defined.")

    if evaluate_range:
        if (type(evaluate_range) != tuple
                or not all(map(lambda x: type(x) == int, evaluate_range))
                or not len(evaluate_range) == 3):
            raise ValueError(
                "evaluate_range must be a tuple of ints (start, end, step).")

    # Check metric arguments
    if metric == 'kid':
        if 'num_samples' not in kwargs:
            raise ValueError(
                "num_samples must be provided for KID computation.")

        output_file = os.path.join(
            log_dir, 'kid_{}k.json'.format(kwargs['num_samples'] // 1000))

    elif metric == 'fid':
        if 'num_real_samples' not in kwargs or 'num_fake_samples' not in kwargs:
            raise ValueError(
                "num_real_samples and num_fake_samples must be provided for FID computation."
            )

        output_file = os.path.join(
            log_dir,
            'fid_{}k_{}k.json'.format(kwargs['num_real_samples'] // 1000,
                                      kwargs['num_fake_samples'] // 1000))

    elif metric == 'inception_score':
        if 'num_samples' not in kwargs:
            raise ValueError(
                "num_samples must be provided for IS computation.")

        output_file = os.path.join(
            log_dir,
            'inception_score_{}k.json'.format(kwargs['num_samples'] // 1000))

    else:
        choices = ['fid', 'kid', 'inception_score']
        raise ValueError("Invalid metric {} selected. Choose from {}.".format(
            metric, choices))

    # Check checkpoint dir
    ckpt_dir = os.path.join(log_dir, 'checkpoints', 'netG')
    if not os.path.exists(ckpt_dir):
        raise ValueError(
            "Checkpoint directory {} cannot be found in log_dir.".format(
                ckpt_dir))

    # Check device
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Setup output file
    if os.path.exists(output_file):
        scores_dict = common.load_from_json(output_file)
        scores_dict = dict([(int(k), v) for k, v in scores_dict.items()])

    else:
        scores_dict = {}

    # Decide naming convention
    names_dict = {
        'fid': 'FID',
        'inception_score': 'Inception Score',
        'kid': 'KID',
    }

    # # Set output file and restore if available.
    # if metric == 'fid':
    #     output_file = os.path.join(
    #         log_dir,
    #         'fid_{}k_{}k.json'.format(kwargs['num_real_samples'] // 1000,
    #                                   kwargs['num_fake_samples'] // 1000))

    # elif metric == 'inception_score':
    #     output_file = os.path.join(
    #         log_dir,
    #         'inception_score_{}k.json'.format(kwargs['num_samples'] // 1000))

    # elif metric == 'kid':
    #     output_file = os.path.join(
    #         log_dir, 'kid_{}k.json'.format(
    #             kwargs['num_samples'] // 1000))

    # if os.path.exists(output_file):
    #     scores_dict = common.load_from_json(output_file)
    #     scores_dict = dict([(int(k), v) for k, v in scores_dict.items()])

    # else:
    #     scores_dict = {}

    # Evaluate across a range
    start, end, interval = evaluate_range or (evaluate_step, evaluate_step,
                                              evaluate_step)
    for step in range(start, end + 1, interval):
        # Skip computed scores
        if step in scores_dict and write_to_json and not overwrite:
            print("INFO: {} at step {} has been computed. Skipping...".format(
                names_dict[metric], step))
            continue

        # Load and restore the model checkpoint
        ckpt_file = os.path.join(ckpt_dir, 'netG_{}_steps.pth'.format(step))
        if not os.path.exists(ckpt_file):
            print("INFO: Checkpoint at step {} does not exist. Skipping...".
                  format(step))
            continue
        netG.restore_checkpoint(ckpt_file=ckpt_file, optimizer=None)

        # Compute score for each seed
        scores = []
        for seed in range(start_seed, start_seed + num_runs):
            print("INFO: Computing {} in memory...".format(names_dict[metric]))

            # Obtain only the raw score without var
            if metric == "fid":
                score = compute_fid.fid_score(netG=netG,
                                              seed=seed,
                                              device=device,
                                              split='train',
                                              log_dir=log_dir,
                                              **kwargs
                                              )

            elif metric == "inception_score":
                score, _ = compute_is.inception_score(netG=netG,
                                                      seed=seed,
                                                      device=device,
                                                      log_dir=log_dir,
                                                      **kwargs)

            elif metric == "kid":
                score, _ = compute_kid.kid_score(netG=netG,
                                                 device=device,
                                                 seed=seed,
                                                 split='train',
                                                 log_dir=log_dir,
                                                 **kwargs)

            scores.append(score)
            print("INFO: {} (step {}) [seed {}]: {}".format(
                names_dict[metric], step, seed, score))

        scores_dict[step] = scores

        # Save scores every step
        if write_to_json:
            common.write_to_json(scores_dict, output_file)

    # Print the scores in order
    for step in range(start, end + 1, interval):
        if step in scores_dict:
            scores = scores_dict[step]
            mean = np.mean(scores)
            std = np.std(scores)

            print("INFO: {} (step {}): {} (± {}) ".format(
                names_dict[metric], step, mean, std))

    # Save to output file
    if write_to_json:
        common.write_to_json(scores_dict, output_file)

    print("INFO: {} Evaluation completed!".format(names_dict[metric]))

    return scores_dict
def evaluate_pr(netG,
                log_dir,
                evaluate_range=None,
                evaluate_step=None,
                num_runs=1,
                start_seed=0,
                overwrite=False,
                write_to_json=True,
                device=None,
                **kwargs):
    """
    Computes precision and recall.

    Args:
        netG (Module): Torch generator model to evaluate.
        log_dir (str): The path to the log directory.
        evaluate_range (tuple): The 3 valued tuple for defining a for loop.
        evaluate_step (int): The specific checkpoint to load. Used in place of evaluate_range.
        device (str): Device identifier to use for computation.
        num_runs (int): The number of runs to compute FID for each checkpoint.
        start_seed (int): Starting random seed to use.
        write_to_json (bool): If True, writes to an output json file in log_dir.
        overwrite (bool): If True, then overwrites previous metric score.

    Returns:
        dictionary: precision, recall score dictionary.
    """
    # Check evaluation range/steps
    if evaluate_range and evaluate_step or not (evaluate_step
                                                or evaluate_range):
        raise ValueError(
            "Only one of evaluate_step or evaluate_range can be defined.")

    if evaluate_range:
        if (type(evaluate_range) != tuple
                or not all(map(lambda x: type(x) == int, evaluate_range))
                or not len(evaluate_range) == 3):
            raise ValueError(
                "evaluate_range must be a tuple of ints (start, end, step).")

    output_log_dir = log_dir / 'evaluate' / f'step-{evaluate_step}'
    output_log_dir.mkdir(parents=True, exist_ok=True)
    # Check metric arguments
    if 'num_real_samples' not in kwargs or 'num_fake_samples' not in kwargs:
        raise ValueError(
            "num_real_samples and num_fake_samples must be provided for PR computation."
        )

    output_file = os.path.join(
        output_log_dir,
        'pr_{}k_{}k.json'.format(kwargs['num_real_samples'] // 1000,
                                 kwargs['num_fake_samples'] // 1000))

    # Check checkpoint dir
    ckpt_dir = os.path.join(log_dir, 'checkpoints', 'netG')
    if not os.path.exists(ckpt_dir):
        raise ValueError(
            "Checkpoint directory {} cannot be found in log_dir.".format(
                ckpt_dir))

    # Check device
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Setup output file
    if os.path.exists(output_file):
        scores_dict = common.load_from_json(output_file)
        scores_dict = dict([(int(k), v) for k, v in scores_dict.items()])

    else:
        scores_dict = {}

    # Evaluate across a range
    start, end, interval = evaluate_range or (evaluate_step, evaluate_step,
                                              evaluate_step)
    for step in range(start, end + 1, interval):
        # Skip computed scores
        # if step in scores_dict and write_to_json and not overwrite:
        #     print("INFO: PR at step {} has been computed. Skipping...".format(step))
        #     continue

        # Load and restore the model checkpoint
        ckpt_file = os.path.join(ckpt_dir, 'netG_{}_steps.pth'.format(step))
        if not os.path.exists(ckpt_file):
            print("INFO: Checkpoint at step {} does not exist. Skipping...".
                  format(step))
            continue
        netG.restore_checkpoint(ckpt_file=ckpt_file, optimizer=None)

        # Compute score for each seed
        scores = defaultdict(list)
        for seed in range(start_seed, start_seed + num_runs):
            print("INFO: Computing PR in memory...")

            # Obtain only the raw score without var
            score = pr_score(netG=netG,
                             seed=seed,
                             device=device,
                             log_dir=log_dir,
                             **kwargs)
            for key in score:
                scores[key].append(score[key])
                print("INFO: {} (step {}) [seed {}]: {}".format(
                    key, step, seed, score[key]))

        scores_dict[step] = scores

        # Save scores every step
        if write_to_json:
            common.write_to_json(scores_dict, output_file)

    # Print the scores in order
    for step in range(start, end + 1, interval):
        if step in scores_dict:
            for key in scores_dict[step]:
                scores = scores_dict[step][key]
                mean = np.mean(scores)
                std = np.std(scores)

                print("INFO: {} (step {}): {} (± {}) ".format(
                    key, step, mean, std))

    # Save to output file
    if write_to_json:
        common.write_to_json(scores_dict, output_file)

    print("INFO: PR Evaluation completed!")

    return scores_dict
def evaluate_custom(metric,
                    netG,
                    log_dir,
                    num_runs=1,
                    start_seed=0,
                    overwrite=False,
                    write_to_json=True,
                    device=None,
                    **kwargs):

    output_log_dir = log_dir / 'evaluate' / f'custom'
    output_log_dir.mkdir(parents=True, exist_ok=True)
    # Check metric arguments
    if metric == 'kid':
        if 'num_samples' not in kwargs:
            raise ValueError(
                "num_samples must be provided for KID computation.")

        output_file = os.path.join(
            output_log_dir,
            'kid_{}k.json'.format(kwargs['num_samples'] // 1000))

    elif metric == 'fid':
        if 'num_real_samples' not in kwargs or 'num_fake_samples' not in kwargs:
            raise ValueError(
                "num_real_samples and num_fake_samples must be provided for FID computation."
            )

        output_file = os.path.join(
            output_log_dir,
            'fid_{}k_{}k.json'.format(kwargs['num_real_samples'] // 1000,
                                      kwargs['num_fake_samples'] // 1000))

    elif metric == 'inception_score':
        if 'num_samples' not in kwargs:
            raise ValueError(
                "num_samples must be provided for IS computation.")

        output_file = os.path.join(
            output_log_dir,
            'inception_score_{}k.json'.format(kwargs['num_samples'] // 1000))

    else:
        choices = ['fid', 'kid', 'inception_score']
        raise ValueError("Invalid metric {} selected. Choose from {}.".format(
            metric, choices))

    # Check device
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Setup output file
    if os.path.exists(output_file):
        scores_dict = common.load_from_json(output_file)
        scores_dict = dict([(int(k), v) for k, v in scores_dict.items()])

    else:
        scores_dict = {}

    # Decide naming convention
    names_dict = {
        'fid': 'FID',
        'inception_score': 'Inception Score',
        'kid': 'KID',
    }

    step = 0
    # Compute score for each seed
    scores = []
    for seed in range(start_seed, start_seed + num_runs):
        print("INFO: Computing {} in memory...".format(names_dict[metric]))

        # Obtain only the raw score without var
        if metric == "fid":
            score = compute_fid.fid_score(netG=netG,
                                          seed=seed,
                                          device=device,
                                          log_dir=log_dir,
                                          **kwargs)

        elif metric == "inception_score":
            score, _ = compute_is.inception_score(netG=netG,
                                                  seed=seed,
                                                  device=device,
                                                  log_dir=log_dir,
                                                  **kwargs)

        elif metric == "kid":
            score, _ = compute_kid.kid_score(netG=netG,
                                             device=device,
                                             seed=seed,
                                             log_dir=log_dir,
                                             **kwargs)

        scores.append(score)
        print("INFO: {} (step {}) [seed {}]: {}".format(
            names_dict[metric], step, seed, score))

    scores_dict[step] = scores

    # Save scores every step
    if write_to_json:
        common.write_to_json(scores_dict, output_file)

    # Print the scores in order
    if step in scores_dict:
        scores = scores_dict[step]
        mean = np.mean(scores)
        std = np.std(scores)

        print("INFO: {} (step {}): {} (± {}) ".format(names_dict[metric], step,
                                                      mean, std))

    # Save to output file
    if write_to_json:
        common.write_to_json(scores_dict, output_file)

    print("INFO: {} Evaluation completed!".format(names_dict[metric]))

    return scores_dict