def _process_dataset(self, dataset, hparams, data_spec): name_prefix = PairedTextData._get_name_prefix( hparams["source_dataset"], hparams["target_dataset"]) tran_fn, data_spec = self._make_processor(hparams["source_dataset"], hparams["target_dataset"], data_spec, name_prefix=name_prefix) num_parallel_calls = hparams["num_parallel_calls"] dataset = dataset.map(lambda *args: tran_fn(dsutils.maybe_tuple(args)), num_parallel_calls=num_parallel_calls) # Filters by length src_length_name = dsutils._connect_name( data_spec.name_prefix[0], data_spec.decoder[0].length_tensor_name) tgt_length_name = dsutils._connect_name( data_spec.name_prefix[1], data_spec.decoder[1].length_tensor_name) filter_fn = self._make_length_filter(hparams["source_dataset"], hparams["target_dataset"], src_length_name, tgt_length_name, data_spec.decoder[0], data_spec.decoder[1]) if filter_fn: dataset = dataset.filter(filter_fn) # Truncates data count dataset = dataset.take(hparams["max_dataset_size"]) return dataset, data_spec
def _process_dataset(self, dataset, hparams, data_spec): name_prefix = self._get_name_prefix(hparams["datasets"]) # pylint: disable=attribute-defined-outside-init self._name_to_id = {v: k for k, v in enumerate(name_prefix)} tran_fn, data_spec = self._make_processor( hparams["datasets"], data_spec, name_prefix) num_parallel_calls = hparams["num_parallel_calls"] dataset = dataset.map( lambda *args: tran_fn(dsutils.maybe_tuple(args)), num_parallel_calls=num_parallel_calls) # Filters by length def _get_length_name(i): if not _is_text_data(hparams["datasets"][i]["data_type"]): return None name = dsutils._connect_name( data_spec.name_prefix[i], data_spec.decoder[i].length_tensor_name) return name filter_fn = self._make_length_filter( hparams["datasets"], [_get_length_name(i) for i in range(len(hparams["datasets"]))], data_spec.decoder) if filter_fn: dataset = dataset.filter(filter_fn) # Truncates data count dataset = dataset.take(hparams["max_dataset_size"]) return dataset, data_spec
def _process_dataset(self, dataset, hparams, data_spec): chained_tran, data_spec = self._make_processor( hparams["dataset"], data_spec, name_prefix=hparams["dataset"]["data_name"]) num_parallel_calls = hparams["num_parallel_calls"] dataset = dataset.map( lambda *args: chained_tran(dsutils.maybe_tuple(args)), num_parallel_calls=num_parallel_calls) # Truncates data count dataset = dataset.take(hparams["max_dataset_size"]) return dataset, data_spec
def _make_batch(dataset, hparams, element_length_func, padded_shapes=None, padding_values=None): dataset = dataset.repeat(hparams.num_epochs) batch_size = hparams["batch_size"] bucket_boundaries = hparams["bucket_boundaries"] if padded_shapes is None: padded_shapes = dataset.output_shapes if len(bucket_boundaries) == 0: if hparams["allow_smaller_final_batch"]: dataset = dataset.padded_batch(batch_size, padded_shapes, padding_values=padding_values) else: dataset = dataset.apply( tf.contrib.data.padded_batch_and_drop_remainder( batch_size, padded_shapes, padding_values=padding_values)) else: bucket_batch_size = hparams["bucket_batch_sizes"] if bucket_batch_size is None: bucket_batch_size = [batch_size] * (len(bucket_boundaries) + 1) dataset = dataset.apply( tf.contrib.data.bucket_by_sequence_length( element_length_func, bucket_boundaries, bucket_batch_size, padded_shapes=padded_shapes, padding_values=padding_values)) if not hparams["allow_smaller_final_batch"]: if len(set(bucket_batch_size)) > 1: raise ValueError( "Batch size of every bucket must be the same if " "smaller final batch is not allowed.") batch_size = bucket_batch_size[0] filter_fn = dsutils._make_smaller_batch_filter_fn(batch_size) dataset = dataset.filter( lambda *args: filter_fn(dsutils.maybe_tuple(args))) return dataset
def _process_dataset(self, dataset, hparams, data_spec): chained_tran, data_spec = self._make_processor( hparams["dataset"], data_spec, name_prefix=hparams["dataset"]["data_name"]) num_parallel_calls = hparams["num_parallel_calls"] dataset = dataset.map( lambda *args: chained_tran(dsutils.maybe_tuple(args)), num_parallel_calls=num_parallel_calls) # Filters by length length_name = dsutils._connect_name( data_spec.name_prefix, data_spec.decoder.length_tensor_name) filter_fn = self._make_length_filter(hparams["dataset"], length_name, data_spec.decoder) if filter_fn: dataset = dataset.filter(filter_fn) # Truncates data count dataset = dataset.take(hparams["max_dataset_size"]) return dataset, data_spec