コード例 #1
0
ファイル: sampler.py プロジェクト: mbencherif/neural_sp
    def reset(self, batch_size=None, epoch=None):
        """Reset data counter and offset.

            Args:
                batch_size (int): size of mini-batch
                epoch (int): current epoch

        """
        if batch_size is None:
            batch_size = self.batch_size

        self._offset = 0

        if self.discourse_aware:
            self.indices_buckets = discourse_bucketing(self.df, batch_size)
        elif self.longform_xmax > 0:
            self.indices_buckets = longform_bucketing(self.df, batch_size,
                                                      self.longform_xmax)
        elif self.shuffle_bucket:
            self.indices_buckets = shuffle_bucketing(self.df,
                                                     batch_size,
                                                     self.dynamic_batching,
                                                     seed=epoch)
        else:
            self.indices = list(self.df.index)
            self.batch_size_tmp = batch_size
コード例 #2
0
    def __init__(self,
                 df,
                 batch_size,
                 dynamic_batching,
                 shuffle_bucket,
                 discourse_aware,
                 sort_stop_epoch,
                 df_sub1=None,
                 df_sub2=None,
                 longform_max_n_frames=0):
        """Custom BatchSampler.

        Args:

            df (pandas.DataFrame): dataframe for the main task
            batch_size (int): size of mini-batch
            dynamic_batching (bool): change batch size dynamically in training
            shuffle_bucket (bool): gather similar length of utterances and shuffle them
            discourse_aware (bool): sort in the discourse order
            sort_stop_epoch (int): After sort_stop_epoch, training will revert
                back to a random order
            df_sub1 (pandas.DataFrame): dataframe for the first sub task
            df_sub2 (pandas.DataFrame): dataframe for the second sub task
            longform_max_n_frames (int): maximum input length for long-form evaluation

        """
        # super(BatchSampler, self).__init__()
        # sampler, batch_size, drop_last

        self.df = df
        self.df_sub1 = df_sub1
        self.df_sub2 = df_sub2
        self.batch_size = batch_size

        self.dynamic_batching = dynamic_batching
        self.shuffle_bucket = shuffle_bucket
        self.sort_stop_epoch = sort_stop_epoch
        self.discourse_aware = discourse_aware
        self.longform_max_n_frames = longform_max_n_frames

        self._offset = 0

        if discourse_aware:
            self.indices_buckets = discourse_bucketing(self.df, batch_size)
            self._iteration = len(self.indices_buckets)
        elif longform_max_n_frames > 0:
            self.indices_buckets = longform_bucketing(self.df, batch_size,
                                                      longform_max_n_frames)
            self._iteration = len(self.indices_buckets)
        elif shuffle_bucket:
            self.indices_buckets = shuffle_bucketing(self.df, batch_size,
                                                     self.dynamic_batching)
            self._iteration = len(self.indices_buckets)
        else:
            self.indices = list(self.df.index)
            # calculate #iteration in advance
            self.calculate_iteration()
コード例 #3
0
ファイル: sampler.py プロジェクト: mbencherif/neural_sp
    def __init__(self,
                 df,
                 batch_size,
                 dynamic_batching,
                 shuffle_bucket,
                 discourse_aware,
                 sort_stop_epoch,
                 longform_max_n_frames=0,
                 seed=1):
        """Custom BatchSampler.

        Args:

            df (pandas.DataFrame): dataframe for the main task
            batch_size (int): size of mini-batch
            dynamic_batching (bool): change batch size dynamically in training
            shuffle_bucket (bool): gather similar length of utterances and shuffle them
            discourse_aware (bool): sort in the discourse order
            sort_stop_epoch (int): After sort_stop_epoch, training will revert
                back to a random order
            longform_max_n_frames (int): maximum input length for long-form evaluation

        """
        random.seed(seed)
        np.random.seed(seed)

        self.df = df
        self.batch_size = batch_size
        self.batch_size_tmp = None

        self.dynamic_batching = dynamic_batching
        self.shuffle_bucket = shuffle_bucket
        self.sort_stop_epoch = sort_stop_epoch
        self.discourse_aware = discourse_aware
        self.longform_xmax = longform_max_n_frames

        self._offset = 0
        # NOTE: epoch should not be counted in BatchSampler

        if discourse_aware:
            self.indices_buckets = discourse_bucketing(df, batch_size)
            self._iteration = len(self.indices_buckets)
        elif longform_max_n_frames > 0:
            self.indices_buckets = longform_bucketing(df, batch_size,
                                                      longform_max_n_frames)
            self._iteration = len(self.indices_buckets)
        elif shuffle_bucket:
            self.indices_buckets = shuffle_bucketing(df, batch_size,
                                                     self.dynamic_batching,
                                                     seed)
            self._iteration = len(self.indices_buckets)
        else:
            self.indices = list(df.index)
            # calculate #iteration in advance
            self.calculate_iteration()
コード例 #4
0
ファイル: asr.py プロジェクト: qwjaskzxl/neural_sp
    def _reset(self, batch_size=None):
        """Reset data counter and offset.

            Args:
                batch_size (int): size of mini-batch

        """
        if batch_size is None:
            batch_size = self.batch_size

        if self.discourse_aware:
            self.indices_buckets = discourse_bucketing(self.df, batch_size)
        elif self.shuffle_bucket:
            self.indices_buckets = shuffle_bucketing(self.df, batch_size,
                                                     self.dynamic_batching)
        else:
            self.indices = list(self.df.index)
        self._offset = 0
コード例 #5
0
ファイル: sampler.py プロジェクト: ishine/neural_sp
    def reset(self, batch_size=None, batch_size_type=None, epoch=0):
        """Reset data counter and offset.

            Args:
                batch_size (int): size of mini-batch
                epoch (int): current epoch

        """
        if batch_size is None:
            batch_size = self.batch_size
        if batch_size_type is None:
            batch_size_type = self.batch_size_type

        self._offset = 0

        if self.shuffle_bucket:
            self.indices_buckets = shuffle_bucketing(
                self.df,
                batch_size,
                batch_size_type,
                self.dynamic_batching,
                seed=self.seed + epoch,
                num_replicas=self.num_replicas)
        elif self.discourse_aware:
            self.indices_buckets = discourse_bucketing(self.df, batch_size)
        elif self.longform_xmax > 0:
            self.indices_buckets = longform_bucketing(self.df, batch_size,
                                                      self.longform_xmax)
        else:
            self.indices_buckets = sort_bucketing(
                self.df,
                batch_size,
                batch_size_type,
                self.dynamic_batching,
                num_replicas=self.num_replicas)
        self._iteration = len(self.indices_buckets)
コード例 #6
0
ファイル: sampler.py プロジェクト: ishine/neural_sp
    def __init__(self,
                 dataset,
                 distributed,
                 batch_size,
                 batch_size_type,
                 dynamic_batching,
                 shuffle_bucket,
                 discourse_aware,
                 longform_max_n_frames=0,
                 seed=1,
                 resume_epoch=0):
        """Custom BatchSampler.

        Args:
            dataset (Dataset): pytorch Dataset class
            batch_size (int): size of mini-batch
            batch_size_type (str): type of batch size counting
            dynamic_batching (bool): change batch size dynamically in training
            shuffle_bucket (bool): gather similar length of utterances and shuffle them
            discourse_aware (bool): sort in the discourse order
            longform_max_n_frames (int): maximum input length for long-form evaluation
            seed (int): seed for randomization
            resume_epoch (int): epoch to resume training

        """
        if distributed:
            super().__init__(dataset=dataset,
                             num_replicas=dist.get_world_size(),
                             rank=dist.get_rank())
        else:
            self.rank = 0
            self.num_replicas = 1
            self.total_size = len(dataset.df.index) * self.num_replicas

        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)

        self.df = dataset.df
        self.batch_size = batch_size * self.num_replicas
        if self.num_replicas > 1 and self.rank == 0:
            logger.info(
                f"Batch size is automatically increased from {batch_size} to {self.batch_size}."
            )
        self.batch_size_type = batch_size_type
        self.dynamic_batching = dynamic_batching
        self.shuffle_bucket = shuffle_bucket
        self.discourse_aware = discourse_aware
        self.longform_xmax = longform_max_n_frames

        self._offset = 0
        # NOTE: epoch should not be counted in BatchSampler

        if shuffle_bucket:
            self.indices_buckets = shuffle_bucketing(
                self.df,
                self.batch_size,
                batch_size_type,
                self.dynamic_batching,
                seed=seed + resume_epoch,
                num_replicas=self.num_replicas)
        elif discourse_aware:
            assert distributed
            self.indices_buckets = discourse_bucketing(self.df,
                                                       self.batch_size)
        elif longform_max_n_frames > 0:
            assert not distributed
            self.indices_buckets = longform_bucketing(self.df, self.batch_size,
                                                      longform_max_n_frames)
        else:
            self.indices_buckets = sort_bucketing(
                self.df,
                self.batch_size,
                batch_size_type,
                self.dynamic_batching,
                num_replicas=self.num_replicas)
        self._iteration = len(self.indices_buckets)