Ejemplo n.º 1
0
    def sample_index(self, batch_size):
        """Sample data indices of mini-batch.

        Args:
            batch_size (int): size of mini-batch
        Returns:
            indices (np.ndarray): indices of dataframe in the current mini-batch
            is_new_epoch (bool): flag for the end of the current epoch

        """
        is_new_epoch = False

        if self.discourse_aware:
            indices = self.indices_buckets.pop(0)
            self._offset += len(indices)
            is_new_epoch = (len(self.indices_buckets) == 0)

        elif self.shuffle_bucket:
            indices = self.indices_buckets.pop(0)
            self._offset += len(indices)
            is_new_epoch = (len(self.indices_buckets) == 0)

            # Shuffle utterances in mini-batch
            indices = random.sample(indices, len(indices))

        else:
            if batch_size is None:
                batch_size = self.batch_size

            # Change batch size dynamically
            min_xlen = self.df[self._offset:self._offset + 1]['xlen'].values[0]
            min_ylen = self.df[self._offset:self._offset + 1]['ylen'].values[0]
            batch_size = set_batch_size(batch_size, min_xlen, min_ylen,
                                        self.dynamic_batching)

            if len(self.indices) > batch_size:
                indices = list(self.df[self._offset:self._offset +
                                       batch_size].index)
                self._offset += len(indices)
            else:
                # Last mini-batch
                indices = self.indices[:]
                self._offset = len(self.df)
                is_new_epoch = True

            # Shuffle utterances in mini-batch
            indices = random.sample(indices, len(indices))

            for i in indices:
                self.indices.remove(i)

        return indices, is_new_epoch
Ejemplo n.º 2
0
    def sample_index(self):
        """Sample data indices of mini-batch.

        Returns:
            indices (np.ndarray): indices of dataframe in the current mini-batch

        """
        if self.discourse_aware or self.longform_xmax > 0 or self.shuffle_bucket:
            indices = self.indices_buckets.pop(0)
            self._offset += len(indices)
            is_new_epoch = (len(self.indices_buckets) == 0)

            if self.shuffle_bucket:
                # Shuffle utterances in mini-batch
                indices = random.sample(indices, len(indices))
        else:
            if self.batch_size_tmp is not None:
                batch_size = self.batch_size_tmp
            else:
                batch_size = self.batch_size

            # Change batch size dynamically
            min_xlen = self.df[self._offset:self._offset + 1]['xlen'].values[0]
            min_ylen = self.df[self._offset:self._offset + 1]['ylen'].values[0]
            batch_size = set_batch_size(batch_size, min_xlen, min_ylen,
                                        self.dynamic_batching)
            is_new_epoch = (len(self.indices) <= batch_size)

            if is_new_epoch:
                # Last mini-batch
                indices = self.indices[:]
                self._offset = len(self.df)
            else:
                indices = list(self.df[self._offset:self._offset +
                                       batch_size].index)
                self._offset += len(indices)

            # Shuffle utterances in mini-batch
            indices = random.sample(indices, len(indices))

            for i in indices:
                self.indices.remove(i)

        return indices, is_new_epoch