Example #1
0
def get_data(config, split):
    data = CachedNTUGems()
    indices = get_split_indices(data)[split]
    data = SubDataset(data, indices)

    # apply generic preprocessing (in particular resizing and masking)
    data = ProcessedDataset(
        data, TargetProcessing(config.get("spatial_size", 256), 1))
    # crop images
    data = ProcessedDataset(data, center_crop)
    # add required labels
    data = ExtraLabelsDataset(data, add_labels)
    return data
Example #2
0
    def __init__(self, root, show_bar=False):
        er = EvalReader(root)

        if "model_output.csv" not in root:
            csv_path = os.path.join(root, "model_output.csv")
        else:
            csv_path = root
            root = os.path.dirname(root)

        self.labels = load_labels(os.path.join(root, "model_outputs"))

        # Capture the case that only labels have been written out
        try:
            csv_data = CsvDataset(csv_path, comment="#")
            self.data = ProcessedDataset(csv_data, er)
        except pd.errors.EmptyDataError as e:
            exemplar_labels = self.labels[sorted(self.labels.keys())[0]]
            self.data = EmptyDataset(len(exemplar_labels), self.labels)

        self.append_labels = True
Example #3
0
    if 'model_output.csv' in root:
        root = root[:-len('model_output.csv')]
    savename_score = os.path.join(save_dir, 'score.txt')
    #savename_std = os.path.join(save_dir, 'std.txt')

    fid_score = np.array(fids).mean()
    #fid_std = np.array(fids).std()
    with open(savename_score, 'w+') as f:
        f.write(str(fid_score))
    #with open(savename_std, 'w+') as f:
    #    f.write(str(fid_std))
    print('\nFID SCORE: {:.2f}'.format(fid_score))
    return {"scalars": {"fid": fid_score}}


if __name__ == "__main__":
    from edflow.debug import DebugDataset
    from edflow.data.dataset import ProcessedDataset

    D1 = DebugDataset(size=100)
    D2 = DebugDataset(size=100)

    P = lambda *args, **kwargs: {'image': np.ones([256, 256, 3])}

    D1 = ProcessedDataset(D1, P)
    D2 = ProcessedDataset(D2, P)

    print(D1[0])

    fid('.', D1, D2, {})