def BuildDataSource(self, data_source_from_file_pattern_fn): """Read and return input batch. Args: data_source_from_file_pattern_fn: a function to read and return input batch from a string file_pattern Returns: A NestedMap containing: data: a tuple of tf.Tensor or `.NestedMap` of tf.Tensor Raises: ValueError: inconsistent sizes between boundaries and datasource_params, specification of unsupported datasources, or out of order boundaries. """ p = self.params if len(p.datasource_params) != len(p.boundaries) + 1: raise ValueError( 'Expected p.datasource_params to have one more entry than ' 'p.boundaries. Found %d datasource_params, and %d boundaries' % (len(p.datasource_params), len(p.boundaries))) for ds_p in p.datasource_params: if 'bprop_variable_filters' in ds_p: if any(filter for filter in ds_p.bprop_variable_filters): raise ValueError( 'CurriculumDataSource does not support distinct ' 'bprop_variable_filters per stage.') for idx in range(len(p.boundaries) - 1): if p.boundaries[idx] > p.boundaries[idx + 1]: raise ValueError( 'Expected p.boundaries to monotonically increase, but ' 'found %d > %d at position %d' % (p.boundaries[idx], p.boundaries[idx + 1], idx)) global_step = py_utils.GetGlobalStep() datasources = [ds_p.Instantiate() for ds_p in p.datasource_params] def GetDatasourceFn(idx): def DatasourceFn(): datasource = datasources[idx].BuildDataSource( data_source_from_file_pattern_fn) datasource.pop('bprop_variable_filters', None) return datasource return DatasourceFn cases = [] for idx in range(len(p.boundaries)): cases.append((tf.less( global_step, tf.constant(p.boundaries[idx], dtype=global_step.dtype)), GetDatasourceFn(idx))) ret = tf.case(cases, default=GetDatasourceFn(-1)) ret.bprop_variable_filters = p.bprop_variable_filters return ret
def GetNext(self): p = self.params global_step = py_utils.GetGlobalStep() cases = [] for idx in range(len(p.boundaries)): cases.append((tf.less( global_step, tf.constant(p.boundaries[idx], dtype=global_step.dtype)), self.sub[idx].GetNext)) return tf.case(cases, default=self.sub[-1].GetNext)