Example #1
0
    def BuildDataSource(self, data_source_from_file_pattern_fn):
        """Read and return input batch.

    Args:
      data_source_from_file_pattern_fn: a function to read and return input
        batch from a string file_pattern

    Returns:
      A NestedMap containing:
        data: a tuple of tf.Tensor or `.NestedMap` of tf.Tensor

    Raises:
      ValueError: inconsistent sizes between boundaries and datasource_params,
      specification of unsupported datasources, or out of order boundaries.
    """
        p = self.params

        if len(p.datasource_params) != len(p.boundaries) + 1:
            raise ValueError(
                'Expected p.datasource_params to have one more entry than '
                'p.boundaries. Found %d datasource_params, and %d boundaries' %
                (len(p.datasource_params), len(p.boundaries)))

        for ds_p in p.datasource_params:
            if 'bprop_variable_filters' in ds_p:
                if any(filter for filter in ds_p.bprop_variable_filters):
                    raise ValueError(
                        'CurriculumDataSource does not support distinct '
                        'bprop_variable_filters per stage.')

        for idx in range(len(p.boundaries) - 1):
            if p.boundaries[idx] > p.boundaries[idx + 1]:
                raise ValueError(
                    'Expected p.boundaries to monotonically increase, but '
                    'found %d > %d at position %d' %
                    (p.boundaries[idx], p.boundaries[idx + 1], idx))

        global_step = py_utils.GetGlobalStep()
        datasources = [ds_p.Instantiate() for ds_p in p.datasource_params]

        def GetDatasourceFn(idx):
            def DatasourceFn():
                datasource = datasources[idx].BuildDataSource(
                    data_source_from_file_pattern_fn)
                datasource.pop('bprop_variable_filters', None)
                return datasource

            return DatasourceFn

        cases = []
        for idx in range(len(p.boundaries)):
            cases.append((tf.less(
                global_step,
                tf.constant(p.boundaries[idx],
                            dtype=global_step.dtype)), GetDatasourceFn(idx)))

        ret = tf.case(cases, default=GetDatasourceFn(-1))
        ret.bprop_variable_filters = p.bprop_variable_filters
        return ret
Example #2
0
    def GetNext(self):
        p = self.params

        global_step = py_utils.GetGlobalStep()

        cases = []
        for idx in range(len(p.boundaries)):
            cases.append((tf.less(
                global_step,
                tf.constant(p.boundaries[idx],
                            dtype=global_step.dtype)), self.sub[idx].GetNext))

        return tf.case(cases, default=self.sub[-1].GetNext)