Example #1
0
    def setup_dataloader(self):
        args = self.args
        config = self.config
        ljspeech_dataset = LJSpeech(args.data)

        valid_set, train_set = dataset.split(ljspeech_dataset,
                                             config.data.valid_size)
        batch_fn = LJSpeechCollector(padding_idx=config.data.padding_idx)

        if not self.parallel:
            self.train_loader = DataLoader(train_set,
                                           batch_size=config.data.batch_size,
                                           shuffle=True,
                                           drop_last=True,
                                           collate_fn=batch_fn)
        else:
            sampler = DistributedBatchSampler(
                train_set,
                batch_size=config.data.batch_size,
                shuffle=True,
                drop_last=True)
            self.train_loader = DataLoader(train_set,
                                           batch_sampler=sampler,
                                           collate_fn=batch_fn)

        self.valid_loader = DataLoader(valid_set,
                                       batch_size=config.data.batch_size,
                                       shuffle=False,
                                       drop_last=False,
                                       collate_fn=batch_fn)
    def setup_dataloader(self):
        config = self.config
        args = self.args

        ljspeech_dataset = LJSpeech(args.data)
        valid_set, train_set = dataset.split(ljspeech_dataset,
                                             config.data.valid_size)

        # convolutional net's causal padding size
        context_size = config.model.n_stack \
                      * sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \
                      + 1
        context_frames = context_size // config.data.hop_length

        # frames used to compute loss
        frames_per_second = config.data.sample_rate // config.data.hop_length
        train_clip_frames = math.ceil(config.data.train_clip_seconds *
                                      frames_per_second)

        num_frames = train_clip_frames + context_frames
        batch_fn = LJSpeechClipCollector(num_frames, config.data.hop_length)
        if not self.parallel:
            train_loader = DataLoader(train_set,
                                      batch_size=config.data.batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      collate_fn=batch_fn)
        else:
            sampler = DistributedBatchSampler(
                train_set,
                batch_size=config.data.batch_size,
                shuffle=True,
                drop_last=True)
            train_loader = DataLoader(train_set,
                                      batch_sampler=sampler,
                                      collate_fn=batch_fn)

        valid_batch_fn = LJSpeechCollector()
        valid_loader = DataLoader(valid_set,
                                  batch_size=1,
                                  collate_fn=valid_batch_fn)

        self.train_loader = train_loader
        self.valid_loader = valid_loader
Example #3
0
    def setup_dataloader(self):
        args = self.args
        config = self.config

        ljspeech_dataset = LJSpeech(args.data)
        transform = Transform(config.data.mel_start_value,
                              config.data.mel_end_value)
        ljspeech_dataset = dataset.TransformDataset(ljspeech_dataset,
                                                    transform)
        valid_set, train_set = dataset.split(ljspeech_dataset,
                                             config.data.valid_size)
        batch_fn = LJSpeechCollector(padding_idx=config.data.padding_idx)

        if not self.parallel:
            train_loader = DataLoader(train_set,
                                      batch_size=config.data.batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      collate_fn=batch_fn)
        else:
            sampler = DistributedBatchSampler(
                train_set,
                batch_size=config.data.batch_size,
                num_replicas=dist.get_world_size(),
                rank=dist.get_rank(),
                shuffle=True,
                drop_last=True)
            train_loader = DataLoader(train_set,
                                      batch_sampler=sampler,
                                      collate_fn=batch_fn)

        valid_loader = DataLoader(valid_set,
                                  batch_size=config.data.batch_size,
                                  collate_fn=batch_fn)

        self.train_loader = train_loader
        self.valid_loader = valid_loader
Example #4
0
    def setup_dataloader(self):
        config = self.config
        args = self.args

        ljspeech_dataset = LJSpeech(args.data)
        valid_set, train_set = dataset.split(ljspeech_dataset,
                                             config.data.valid_size)

        batch_fn = LJSpeechClipCollector(config.data.clip_frames,
                                         config.data.hop_length)

        if not self.parallel:
            train_loader = DataLoader(train_set,
                                      batch_size=config.data.batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      collate_fn=batch_fn)
        else:
            sampler = DistributedBatchSampler(
                train_set,
                batch_size=config.data.batch_size,
                num_replicas=dist.get_world_size(),
                rank=dist.get_rank(),
                shuffle=True,
                drop_last=True)
            train_loader = DataLoader(train_set,
                                      batch_sampler=sampler,
                                      collate_fn=batch_fn)

        valid_batch_fn = LJSpeechCollector()
        valid_loader = DataLoader(valid_set,
                                  batch_size=1,
                                  collate_fn=valid_batch_fn)

        self.train_loader = train_loader
        self.valid_loader = valid_loader