Пример #1
0
def make_train_loader(epoch_size=1100000, batch_size=64, shuffle=20000):

    # num_batches = (epoch_size + batch_size - 1) // batch_size

    if True:
        image_transform = torchvision.transforms.Compose([
            torchvision.transforms.RandomResizedCrop(224),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            normalize,
        ])
    else:
        image_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            normalize,
        ])

    dataset = (wds.Dataset(
        trainurls, handler=wds.warn_and_stop,
        length=epoch_size).shuffle(shuffle).decode(
            "pil", handler=wds.warn_and_continue).to_tuple(
                "ppm;jpg;jpeg;png", "cls",
                handler=wds.warn_and_continue).map_tuple(
                    image_transform, identity,
                    handler=wds.warn_and_continue).batched(batch_size))

    loader = data.DataLoader(dataset, batch_size=None, num_workers=4)
    return loader
Пример #2
0
def test_torch_map_dict_decoder():
    def image_decoder(data):
        with io.BytesIO(data) as stream:
            img = Image.open(stream)
            img.load()
            img = img.convert("RGB")
        result = np.asarray(img)
        result = np.array(result.transpose(2, 0, 1))
        return torch.tensor(result) / 255.0

    def mask_decoder(data):
        with io.BytesIO(data) as stream:
            img = Image.open(stream)
            img.load()
            img = img.convert("L")
        result = np.asarray(img)
        return torch.tensor(result)

    ds = (wds.Dataset(test_data).rename(
        image="rgb.png",
        mask="msk.png").map_dict(image=image_decoder,
                                 mask=mask_decoder).to_tuple("image", "mask"))

    image, mask = next(iter(ds))
    assert (image.shape, mask.shape) == ((3, 512, 512), (512, 512))
Пример #3
0
    def __init__(self,
                 manifest_path: str,
                 tar_filepaths: Union[str, List[str]],
                 shuffle_n: int = 128):
        self._manifest = collections.ASRAudioText(manifest_path,
                                                  parser=parsers.make_parser(
                                                      []),
                                                  index_by_file_id=True)

        if isinstance(tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in tar_filepaths:
                    tar_filepaths = tar_filepaths.replace(bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in tar_filepaths:
                    tar_filepaths = tar_filepaths.replace(bkey, "}")

        self.audio_dataset = (
            wd.Dataset(tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key'))
        self.audio_iter = iter(self.audio_dataset)
Пример #4
0
    def setup(self, stage=None):
        """Downloads the data, parse it and split the data into train, test,
        validation data.

        Args:
            stage: Stage - training or testing
        """

        data_path = self.args.get("train_glob", "/pvc/output/processing")

        train_base_url = data_path + "/train"
        val_base_url = data_path + "/val"
        test_base_url = data_path + "/test"

        train_count = self.get_num_files(train_base_url)
        val_count = self.get_num_files(val_base_url)
        test_count = self.get_num_files(test_base_url)

        train_url = "{}/{}-{}".format(train_base_url, "train",
                                      "{0.." + str(train_count) + "}.tar")
        valid_url = "{}/{}-{}".format(val_base_url, "val",
                                      "{0.." + str(val_count) + "}.tar")
        test_url = "{}/{}-{}".format(test_base_url, "test",
                                     "{0.." + str(test_count) + "}.tar")

        self.train_dataset = (wds.Dataset(
            train_url, handler=wds.warn_and_continue,
            length=40000 // 40).shuffle(100).decode("pil").rename(
                image="ppm;jpg;jpeg;png",
                info="cls").map_dict(image=self.train_transform).to_tuple(
                    "image", "info").batched(40))

        self.valid_dataset = (wds.Dataset(
            valid_url, handler=wds.warn_and_continue,
            length=10000 // 20).shuffle(100).decode("pil").rename(
                image="ppm",
                info="cls").map_dict(image=self.valid_transform).to_tuple(
                    "image", "info").batched(20))

        self.test_dataset = (wds.Dataset(
            test_url, handler=wds.warn_and_continue,
            length=10000 // 20).shuffle(100).decode("pil").rename(
                image="ppm",
                info="cls").map_dict(image=self.valid_transform).to_tuple(
                    "image", "info").batched(20))
Пример #5
0
def KineticsSounds(cfg, split):
    if split == 'train':
        max_idx = 19
    elif split == 'val':
        max_idx = 1
    elif split == 'test':
        max_idx = 2
    dataset_root = cfg.DATASET_ROOT
    if dataset_root.endswith('/'):
        dataset_root = dataset_root[:-1]
    url = f"{dataset_root}/KineticsSounds/shards-{split}/shard-{{000000..{max_idx:06d}}}.tar"
    if cfg.STORAGE_SAS_KEY:
        url += cfg.STORAGE_SAS_KEY

    _decoder = Decoder(cfg, "KineticsSounds", split)
    if split == 'train':
        batch_size = int(cfg.TRAIN.BATCH_SIZE /
                         cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS)
        batch_size = int(batch_size / du.get_world_size())
        length = int(cfg.TRAIN.DATASET_SIZE / du.get_world_size())
        nominal = int(length / batch_size)
    elif split == 'val':
        batch_size = int(cfg.TRAIN.BATCH_SIZE / du.get_world_size())
        length = int(cfg.VAL.DATASET_SIZE / du.get_world_size())
        nominal = int(length / batch_size)
    elif split == 'test':
        batch_size = int(cfg.TEST.BATCH_SIZE / du.get_world_size())
        length = math.ceil(cfg.TEST.DATASET_SIZE / du.get_world_size())
        nominal = math.ceil(length / batch_size)

    wds.filters.batched = wds.filters.Curried(
        partial(wds.filters.batched_, collation_fn=COLLATE_FN["kinetics"]))

    dataset = wds.Dataset(
        url,
        handler=wds.warn_and_continue,
        shard_selection=du.shard_selection,
        length=length,
    )
    if split == 'train':
        dataset = dataset.shuffle(100)
    dataset = (dataset.map_dict(
        handler=wds.warn_and_continue,
        mp4=_decoder.mp4decode,
        json=_decoder.jsondecode,
    ))
    if cfg.DATA_LOADER.NUM_WORKERS > 0:
        length = nominal
    else:
        nominal = length
    dataset = wds.ResizedDataset(
        dataset,
        length=length,
        nominal=nominal,
    )
    return dataset
Пример #6
0
def main(args):
    device = "cuda"

    preproc = tf.Compose([tf.Resize(256), tf.CenterCrop(256), tf.ToTensor()])

    url = '../data_celeba_tar/train_{0..162}.tar'
    dataset = (wds.Dataset(
        url, length=162000 // 32).shuffle(500).decode("pil").to_tuple(
            "input.jpg", "sensitive.cls").map_tuple(preproc,
                                                    identity).batched(32))

    loader = DataLoader(dataset, batch_size=None, num_workers=16)
    #loader = DataLoader(dataset, batch_size=32, shuffle=True)

    model = VQVAE(cout=30)

    checkpoints = f"/scratch/xgitiaux/checkpoint/vqvae/two_q_vqvae_017.pt"

    logger.info(f'Loading checkpoint {checkpoints}')
    checkpoint = torch.load(checkpoints, map_location='cpu')

    new_state_dict = OrderedDict()
    for k, v in checkpoint.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)

    model = model.to(device)

    if torch.cuda.device_count() > 1:
        logger.info(f'Number of gpu is {torch.cuda.device_count()}')
        model = _CustomDataParallel(model)

    entropy_coder = MLP(32 * 32, depth=3, width=256).to(device)

    if torch.cuda.device_count() > 1:
        logger.info(f'Number of gpu is {torch.cuda.device_count()}')
        entropy_coder = _CustomDataParallel(entropy_coder)
        #PixelCNN(ncode=512, channels_in=1).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    poptimizer = optim.Adam(entropy_coder.parameters(), lr=args.lr)
    scheduler = None

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device, entropy_coder,
              poptimizer)
Пример #7
0
def make_val_loader(epoch_size=50000, batch_size=64):
    val_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        normalize,
    ])

    val_dataset = (wds.Dataset(
        valurls, handler=wds.warn_and_stop, prepare_for_worker=False).decode(
            "pil", handler=wds.warn_and_continue).to_tuple(
                "ppm;jpg;jpeg;png", "cls",
                handler=wds.warn_and_continue).map_tuple(
                    val_transform, identity,
                    handler=wds.warn_and_continue).batched(batch_size))
    val_loader = data.DataLoader(val_dataset, batch_size=None, num_workers=4)
    return val_loader
Пример #8
0
def load_gqn_dataset(name,
                     batch_size,
                     seed=42,
                     shuffle=False,
                     target_transform=transform_viewpoint,
                     max_samples_per_environment=-1):
    dataset_name, split = split_name(name)
    assert dataset_name in _DATASET_INFO, f'Dataset {dataset_name} is not supported'
    assert split in ['test', 'train'], f'Split {split} is not supported'
    dataset_info = _DATASET_INFO[dataset_name]
    size = dataset_info[f'{split}_size']
    url = os.path.join(
        DATASET_PATH, f'{dataset_name}-wd',
        f'{dataset_name}-{split}-{{000001..{size:06d}}}-of-{size:06d}.tar')
    sample_size = dataset_info['max_num_views'] + 1
    environment_size = dataset_info['sequence_size']
    dataset = wds.Dataset(url)
    rng = random.Random(seed)
    dataset.rng = rng
    dataset.reseed_hook = dataset.reseed_rng
    if shuffle:
        dataset.shard_shuffle = wds.dataset.Shuffler(rng)
    dataset = dataset.pipe(
        partial(sample_environment,
                sample_size=sample_size,
                environment_size=environment_size,
                shuffle=shuffle,
                rng=rng,
                max_samples_per_env=max_samples_per_environment))
    if shuffle:
        dataset = dataset.pipe(
            wds.filters.shuffle(10000, rng=rng, initial=1000))
    dataset = dataset.to_tuple('camera.pth', 'image.jpg')
    dataset = dataset.batched(batch_size)
    dataset = dataset.pipe(
        partial(transform_batch, rng=rng,
                target_transform=transform_viewpoint))
    return dataset
Пример #9
0
def distribute_remaining_data(data_path, subset, total_instance, chunks):
    remaining_chunks = len(chunks) % total_instance
    if remaining_chunks == 0:
        remaining_chunks = total_instance
    remaining_start, remaining_end = chunks[-remaining_chunks], chunks[-1]
    chunk_str = f"{remaining_start}..{remaining_end}"
    dataset = wds.Dataset(
        os.path.join(data_path, subset + "-{" + chunk_str + "}.tar"))
    # determine saving format
    with open(os.path.join(data_path, "metadata.json")) as metadata_file:
        metadata = json.load(metadata_file)
        data_format = metadata["format"]
    dataset = decode_webdataset(dataset,
                                data_format,
                                identity,
                                use_bbox_info=True)

    folder = os.path.join(data_path, "distributed",
                          str(total_instance) + "-instances")
    if not os.path.exists(folder):
        os.makedirs(folder)
    distributed_save_path = os.path.join(folder, subset + "-%06d.tar")
    save_remaining_data(dataset, distributed_save_path, total_instance)
Пример #10
0
    def train_dataloader(self):

        # This "if" statement is the only difference between
        # WebDataset and torchvision.datasets.ImageNet
        if self.imagenet in [None, ""]:
            dataset = (
                wds.Dataset(self.trainurls, handler=wds.warn_and_continue)
                .shuffle(5000)
                .decode("pil", handler=wds.warn_and_continue)
                .to_tuple("ppm;jpg;jpeg;png", "cls")
                .map_tuple(image_transform, identity)
            )
            num_batches = (self.epoch + self.batch_size - 1) // self.batch_size
            dataset = wds.ResizedDataset(dataset, self.epoch, nominal=num_batches)
        else:
            dataset = torchvision.datasets.ImageNet(
                self.imagenet, split="train", transform=image_transform
            )
            dataset = wds.ResizedDataset(dataset, self.epoch)

        loader = data.DataLoader(
            dataset, batch_size=self.batch_size, num_workers=self.num_workers
        )
        return loader
Пример #11
0
import torch
import torchvision
import webdataset as wds
from itertools import islice
import tempfile

url = "testoutput.tar"


def mp4decode(data):
    with tempfile.TemporaryDirectory() as dname:
        with open(dname + "/sample.mp4", "wb") as stream:
            stream.write(data)
        vframes, aframes, info = torchvision.io.read_video(dname +
                                                           "/sample.mp4")
    return vframes, aframes, info


dataset = (wds.Dataset(url).decode().map_dict(mp4=mp4decode))

for sample in islice(dataset, 0, 3):
    print("---")
    print(list(sample.keys()))
    vframes, aframes, info = sample["mp4"]
    print(vframes.shape, aframes.shape)
    print(info)
Пример #12
0
    def __init__(
        self,
        *,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        labels: List[str],
        featurizer,
        shuffle_n: int = 0,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        load_audio: bool = True,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath.split(','),
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.file_occurence = count_occurence(self.collection.mapping)

        self.featurizer = featurizer
        self.trim = trim
        self.load_audio = load_audio

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(
                idx, self.id2label[idx]))

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                    self._filter).map(f=self._build_sample))
Пример #13
0
    def __init__(
        self,
        audio_tar_filepaths,
        manifest_filepath,
        labels,
        batch_size,
        sample_rate=16000,
        int_values=False,
        bos_id=None,
        eos_id=None,
        pad_id=None,
        min_duration=0.1,
        max_duration=None,
        normalize_transcripts=True,
        trim_silence=False,
        shuffle_n=0,
        num_workers=0,
        augmentor: Optional[Union[AudioAugmentor,
                                  Dict[str, Dict[str, Any]]]] = None,
    ):
        super().__init__()
        self._sample_rate = sample_rate

        if augmentor is not None:
            augmentor = _process_augmentations(augmentor)

        self.collection = ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=make_parser(labels=labels,
                               name='en',
                               do_normalize=normalize_transcripts),
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=self._sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)

        self.trim = trim_silence
        self.eos_id = eos_id
        self.bos_id = bos_id

        # Used in creating a sampler (in Actions).
        self._batch_size = batch_size
        self._num_workers = num_workers
        pad_id = 0 if pad_id is None else pad_id
        self.collate_fn = partial(seq_collate_fn, token_pad_value=pad_id)

        # Check for distributed and partition shards accordingly
        if torch.distributed.is_initialized():
            global_rank = torch.distributed.get_rank()
            world_size = torch.distributed.get_world_size()

            if isinstance(audio_tar_filepaths, str):
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if len(audio_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size}).")

            begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
            audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                    self._filter).map(f=self._build_sample))
Пример #14
0
def test_dataset_shuffle_extract():
    ds = wds.Dataset(test_data).shuffle(5).to_tuple("msk.png rgb.png")
    assert count_samples_tuple(ds) == 64
Пример #15
0
    def __init__(
        self,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        parser: Callable,
        sample_rate: int,
        int_values: bool = False,
        augmentor: Optional[
            'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
        shuffle_n: int = 0,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        add_misc: bool = False,
        pad_id: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.collection = collections.ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=parser,
            min_duration=min_duration,
            max_duration=max_duration,
            max_number=max_utts,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.pad_id = pad_id
        self._add_misc = add_misc

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                    self._filter).map(f=self._build_sample))
Пример #16
0
def _file_iter_to_line_iter(jsonl_iter):
    for jsonl in jsonl_iter:
        lines = jsonl["jsonl"].split(b"\n")
        for line in lines:
            if not line:
                continue
            json_line = json.loads(line)
            json_line["binary"] = jsonl["__key__"]
            yield json_line


if __name__ == "__main__":
    print(sys.argv[1])
    urls = sorted(glob.glob(sys.argv[1]))
    dataset = wds.Dataset(urls).pipe(_file_iter_to_line_iter)
    dataset = torch.utils.data.DataLoader(dataset,
                                          num_workers=8,
                                          batch_size=None)
    uniq_code = set()
    uniq_binary = set()
    token_len = []
    num_vars = []

    def tokenlen(example):
        return len(example["code_tokens"])

    def num_var(example):
        return len(example["source"])

    def name(example):
Пример #17
0
def main(args):
    device = "cuda"

    preproc = tf.Compose([tf.Resize(256), tf.CenterCrop(256), tf.ToTensor()])

    url = '../data_celeba_tar/train_{0..162}.tar'
    dataset = (wds.Dataset(
        url, length=162000 // 16).shuffle(500).decode("pil").to_tuple(
            "input.jpg", "sensitive.cls").map_tuple(preproc,
                                                    identity).batched(16))
    url = '../data_vae'
    dataset = CelebA(url)

    loader = DataLoader(dataset, batch_size=64, num_workers=16, drop_last=True)
    #loader = DataLoader(dataset, batch_size=32, shuffle=True)

    model = VQVAE(cout=30).to(device)

    if torch.cuda.device_count() > 1:
        logger.info(f'Number of gpu is {torch.cuda.device_count()}')
        model = _CustomDataParallel(model)

    entropy_coder = PixelSNAIL(
        [32, 32],
        512,
        64,
        5,
        2,
        2,
        64,
        n_out_res_block=0,
    ).to(device)

    # entropy_coder_bottom = PixelSNAIL(
    #     [64, 64],
    #     512,
    #     64,
    #     5,
    #     2,
    #     2,
    #     64,
    #     n_out_res_block=0,
    #     n_cond_res_block=2,
    #     cond_res_channel=64,
    #     attention=False
    # ).to(device)

    entropy_coder = MLP(32 * 32, depth=3, width=256).to(device)

    if torch.cuda.device_count() > 1:
        logger.info(f'Number of gpu is {torch.cuda.device_count()}')
        entropy_coder = _CustomDataParallel(entropy_coder)
        #entropy_coder_bottom = _CustomDataParallel(entropy_coder_bottom)
        #PixelCNN(ncode=512, channels_in=1).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    poptimizer = optim.Adam(
        list(entropy_coder.parameters()
             ),  #+ list(entropy_coder_bottom.parameters()),
        lr=args.lr)
    scheduler = None

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device, entropy_coder,
              entropy_coder, poptimizer)

        os.makedirs("/scratch/xgitiaux/checkpoint/vqvae", exist_ok=True)
        torch.save(
            model.state_dict(),
            f"/scratch/xgitiaux/checkpoint/vqvae/two_q_vqvae_{str(i + 1).zfill(3)}.pt"
        )
Пример #18
0
def sample_batch():
    ds = (wds.Dataset(test_data).map(semsegment_decoder).rename(
        image="rgb.png", mask="msk.png").to_tuple("image", "mask").batched(bs))
    sample = next(iter(ds))
    return sample[0]
Пример #19
0
def test_rename():
    ds = wds.Dataset(test_data).rename(image="rgb.png", mask="msk.png")
    sample = next(iter(ds))
    assert set(sample.keys()) == {"image", "mask"}
Пример #20
0
def test_slice():
    ds = wds.Dataset(test_data).slice(10)
    assert count_samples_tuple(ds) == 10
Пример #21
0
def test_dataset_pipe_cat():
    ds = wds.Dataset(f"pipe:cat {test_data}").shuffle(5).to_tuple(
        "msk.png rgb.png")
    assert count_samples_tuple(ds) == 64
Пример #22
0
def test_multi():
    for k in [1, 4, 17]:
        urls = [f"pipe:cat {local_data} # {i}" for i in range(k)]
        ds = wds.Dataset(urls).decode().shuffle(5).to_tuple("png;jpg cls")
        mds = multi.MultiDataset(ds, workers=4)
        assert count_samples_tuple(mds) == 47 * k
Пример #23
0
    def __init__(
        self,
        text_tar_filepaths: str,
        metadata_path: str,
        encoder_tokenizer: str,
        decoder_tokenizer: str,
        shuffle_n: int = 1,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        reverse_lang_direction: bool = False,
    ):
        super(TarredTranslationDataset, self).__init__()

        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer
        self.reverse_lang_direction = reverse_lang_direction
        self.src_pad_id = encoder_tokenizer.pad_id
        self.tgt_pad_id = decoder_tokenizer.pad_id

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        self.metadata = metadata

        if isinstance(text_tar_filepaths, str):
            # Replace '(', '[', '<' and '_OP_' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "{")

            # Replace ')', ']', '>' and '_CL_' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "}")

        if isinstance(text_tar_filepaths, str):
            # Brace expand
            text_tar_filepaths = list(
                braceexpand.braceexpand(text_tar_filepaths))

        if shard_strategy == 'scatter':
            logging.info(
                "All tarred dataset shards will be scattered evenly across all nodes."
            )
            if len(text_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size}).")
            begin_idx = (len(text_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(text_tar_filepaths) // world_size)
            logging.info('Begin Index : %d' % (begin_idx))
            logging.info('End Index : %d' % (end_idx))
            text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx]
            logging.info(
                "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                global_rank, begin_idx, end_idx)

        elif shard_strategy == 'replicate':
            logging.info(
                "All tarred dataset shards will be replicated across all nodes."
            )

        else:
            raise ValueError(
                f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
            )

        self.tarpath = text_tar_filepaths

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(text_tar_filepaths).shuffle(shuffle_n).rename(
                pkl='pkl',
                key='__key__').to_tuple('pkl',
                                        'key').map(f=self._build_sample))
Пример #24
0
    def __init__(
        self,
        text_tar_filepaths: str,
        metadata_path: str,
        tokenizer,
        max_seq_length: int = 512,
        batch_step: int = None,
        shuffle_n: int = 1,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        super(TarredL2RLanguageModelingDataset, self).__init__()

        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.batch_step = batch_step or self.max_seq_length

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        self.metadata = metadata

        if isinstance(text_tar_filepaths, str):
            # Replace '(', '[', '<' and '_OP_' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "{")

            # Replace ')', ']', '>' and '_CL_' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "}")

        if shard_strategy == 'scatter':
            logging.info(
                "All tarred dataset shards will be scattered evenly across all nodes."
            )

            if len(text_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size}).")

            begin_idx = (len(text_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(text_tar_filepaths) // world_size)
            text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx]
            logging.info(
                "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                global_rank, begin_idx, end_idx)

        elif shard_strategy == 'replicate':
            logging.info(
                "All tarred dataset shards will be replicated across all nodes."
            )

        else:
            raise ValueError(
                f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
            )

        self.tarpath = text_tar_filepaths

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(text_tar_filepaths).shuffle(shuffle_n).rename(
                npy='npy',
                key='__key__').to_tuple('npy',
                                        'key').map(f=self._build_sample))
Пример #25
0
def test_dataset():
    ds = wds.Dataset(test_data)
    assert count_samples_tuple(ds) == 64