示例#1
0
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, init_state):
  """Overload of _dataset_for_stmt with early stopping. See for_stmt."""

  def scan_body(state, iterate):
    extra_cond = extra_test(*state)
    new_state = control_flow_ops.cond(
        extra_cond, lambda: body(iterate, *state), lambda: state)
    aug_state = new_state, extra_cond
    # Note: new_state is the actual state of scan; aug_state is its output
    # (hence the redundancy).
    return new_state, aug_state

  def take_while_predicate(new_state, extra_cond):
    del new_state
    return extra_cond

  def reduce_body(old_state, aug_state):
    del old_state
    new_state, extra_cond = aug_state
    del extra_cond
    return new_state

  ds = ds.apply(scan_ops.scan(init_state, scan_body))
  ds = ds.apply(take_while_ops.take_while(take_while_predicate))
  return ds.reduce(init_state, reduce_body)
示例#2
0
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, init_state):
  """Overload of _dataset_for_stmt with early stopping. See for_stmt."""

  def scan_body(state, iterate):
    extra_cond = extra_test(*state)
    new_state = control_flow_ops.cond(
        extra_cond, lambda: body(iterate, *state), lambda: state)
    aug_state = new_state, extra_cond
    # Note: new_state is the actual state of scan; aug_state is its output
    # (hence the redundancy).
    return new_state, aug_state

  def take_while_predicate(new_state, extra_cond):
    del new_state
    return extra_cond

  def reduce_body(old_state, aug_state):
    del old_state
    new_state, extra_cond = aug_state
    del extra_cond
    return new_state

  ds = ds.apply(scan_ops.scan(init_state, scan_body))
  ds = ds.apply(take_while_ops.take_while(take_while_predicate))
  return ds.reduce(init_state, reduce_body)
示例#3
0
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
                                      set_state, init_vars, basic_symbol_names,
                                      composite_symbol_names, opts):
  """Overload of _dataset_for_stmt with early stopping. See for_stmt."""

  # TODO(mdan): Simplify this - following it is extremely difficult.

  init_state = get_state()
  aug_init_vars = init_vars, init_state

  def scan_body(aug_vars, iterate):
    """The main loop body wrapper. Only calculates the stop condition."""
    loop_vars, state = aug_vars

    def true_fn():
      """Main path - stop condition is not set."""
      set_state(state)
      new_vars = body(iterate, *loop_vars)
      new_state = get_state()
      _verify_tf_loop_vars(
          init_vars + init_state,
          loop_vars + state,
          new_vars + new_state,
          basic_symbol_names + composite_symbol_names,
          opts,
          check_shapes=False)
      return new_vars, new_state

    extra_cond = extra_test(*loop_vars)
    new_vars, new_state = control_flow_ops.cond(
        extra_cond,
        true_fn,
        lambda: (loop_vars, state),
    )

    scan_outputs = new_vars, new_state, extra_cond
    # Note: new_aug_vars is the actual state of scan; scan_outputs is its output
    # (hence the redundancy).
    # get_state will pull any mutations that body may have made.
    new_aug_vars = new_vars, new_state
    return new_aug_vars, scan_outputs

  def take_while_predicate(unused_loop_vars, unused_state, extra_cond):
    return extra_cond

  def reduce_body(unused_aug_vars, scan_outputs):
    output_aug_vars, output_state, extra_cond = scan_outputs
    del extra_cond
    return output_aug_vars, output_state

  ds = _general_purpose_scan(ds, aug_init_vars, scan_body)
  ds = ds.apply(take_while_ops.take_while(take_while_predicate))
  final_aug_vars = ds.reduce(aug_init_vars, reduce_body)
  final_vars, final_state = final_aug_vars
  set_state(final_state)
  return final_vars
示例#4
0
  def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds):
    dataset = dataset_ops.Dataset.range(num_elements).apply(
        take_while_ops.take_while(lambda x: x < upper_bound))

    if out_of_bounds:
      with self.assertRaises(errors.OutOfRangeError):
        self.assertDatasetProduces(dataset, np.arange(upper_bound))

    else:
      self.assertDatasetProduces(dataset, np.arange(upper_bound))
示例#5
0
    def testTakeWhileDatasetRange(self, num_elements, upper_bound,
                                  out_of_bounds):
        dataset = dataset_ops.Dataset.range(num_elements).apply(
            take_while_ops.take_while(lambda x: x < upper_bound))

        if out_of_bounds:
            with self.assertRaises(errors.OutOfRangeError):
                self.assertDatasetProduces(dataset, np.arange(upper_bound))

        else:
            self.assertDatasetProduces(dataset, np.arange(upper_bound))
示例#6
0
    def testTakeWhileDataset(self, num_elements, window_size):
        def _predicate_func(elem):
            return array_ops.shape(elem)[0] > (window_size - 1)

        take_while = take_while_ops.take_while(_predicate_func)

        dataset = dataset_ops.Dataset.range(num_elements).batch(window_size)
        dataset = dataset.apply(take_while).flat_map(
            dataset_ops.Dataset.from_tensor_slices)

        expected_num_elements = int(num_elements / window_size) * window_size
        self.assertDatasetProduces(dataset, np.arange(expected_num_elements))
示例#7
0
  def testTakeWhileDataset(self, num_elements, window_size):

    def _predicate_func(elem):
      return array_ops.shape(elem)[0] > (window_size - 1)

    take_while = take_while_ops.take_while(_predicate_func)

    dataset = dataset_ops.Dataset.range(num_elements).batch(window_size)
    dataset = dataset.apply(take_while).flat_map(
        dataset_ops.Dataset.from_tensor_slices)

    expected_num_elements = int(num_elements / window_size) * window_size
    self.assertDatasetProduces(dataset, np.arange(expected_num_elements))
示例#8
0
    def testTakewhileDatasetShortCircuit(self, size, index):
        def _predicate_func(data_elem):
            return data_elem

        boolean_array = [True] * size
        boolean_array[index] = False
        dataset = dataset_ops.Dataset.from_tensor_slices(boolean_array).apply(
            take_while_ops.take_while(_predicate_func))

        next_element = self.getNext(dataset)

        for _ in range(index):
            self.assertTrue(self.evaluate(next_element()))

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())
示例#9
0
    def testTakeWhileDatasetString(self):
        def not_equal(string):
            return lambda x: math_ops.not_equal(x, constant_op.constant(string)
                                                )

        string = ["this", "is", "the", "test", "for", "strings"]
        dataset = dataset_ops.Dataset.from_tensor_slices(string).apply(
            take_while_ops.take_while(not_equal("test")))

        next_element = self.getNext(dataset)
        self.assertEqual(b"this", self.evaluate(next_element()))
        self.assertEqual(b"is", self.evaluate(next_element()))
        self.assertEqual(b"the", self.evaluate(next_element()))

        with self.assertRaises(errors.OutOfRangeError):
            self.assertEqual(b"test", self.evaluate(next_element()))
示例#10
0
  def testTakeWhileDatasetString(self):

    def not_equal(string):
      return lambda x: math_ops.not_equal(x, constant_op.constant(string))

    string = ["this", "is", "the", "test", "for", "strings"]
    dataset = dataset_ops.Dataset.from_tensor_slices(string).apply(
        take_while_ops.take_while(not_equal("test")))

    next_element = self.getNext(dataset)
    self.assertEqual(b"this", self.evaluate(next_element()))
    self.assertEqual(b"is", self.evaluate(next_element()))
    self.assertEqual(b"the", self.evaluate(next_element()))

    with self.assertRaises(errors.OutOfRangeError):
      self.assertEqual(b"test", self.evaluate(next_element()))
示例#11
0
  def testTakewhileDatasetShortCircuit(self, size, index):

    def _predicate_func(data_elem):
      return data_elem

    boolean_array = [True] * size
    boolean_array[index] = False
    dataset = dataset_ops.Dataset.from_tensor_slices(boolean_array).apply(
        take_while_ops.take_while(_predicate_func))

    next_element = self.getNext(dataset)

    for _ in range(index):
      self.assertTrue(self.evaluate(next_element()))

    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_element())
示例#12
0
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
                                      set_state, init_vars):
    """Overload of _dataset_for_stmt with early stopping. See for_stmt."""

    # TODO(mdan): Simplify this - following it is extremely difficult.

    def scan_body(aug_vars, iterate):
        """The main loop body wrapper. Only calculates the stop condition."""
        loop_vars, state = aug_vars

        def true_fn():
            set_state(state)
            outputs = body(iterate, *loop_vars)
            return outputs, get_state()

        extra_cond = extra_test(*loop_vars)
        new_vars, new_state = control_flow_ops.cond(extra_cond, true_fn,
                                                    lambda: (loop_vars, state))

        scan_outputs = new_vars, new_state, extra_cond
        # Note: new_aug_vars is the actual state of scan; scan_outputs is its output
        # (hence the redundancy).
        # get_state will pull any mutations that body may have made.
        new_aug_vars = new_vars, new_state
        return new_aug_vars, scan_outputs

    def take_while_predicate(unused_loop_vars, unused_state, extra_cond):
        return extra_cond

    def reduce_body(unused_aug_vars, scan_outputs):
        output_aug_vars, output_state, extra_cond = scan_outputs
        del extra_cond
        return output_aug_vars, output_state

    init_state = get_state()
    aug_vars = init_vars, init_state
    ds = ds.apply(scan_ops.scan(aug_vars, scan_body))
    ds = ds.apply(take_while_ops.take_while(take_while_predicate))
    final_aug_vars = ds.reduce(aug_vars, reduce_body)
    final_vars, final_state = final_aug_vars
    set_state(final_state)
    return final_vars
示例#13
0
    def testTakeWhileDatasetRange(self, num_elements, upper_bound):
        dataset = dataset_ops.Dataset.range(num_elements).apply(
            take_while_ops.take_while(lambda x: x < upper_bound))

        self.assertDatasetProduces(dataset,
                                   np.arange(min(num_elements, upper_bound)))
示例#14
0
 def testTakeWhileDatasetWithRepeat(self):
     dataset = dataset_ops.Dataset.range(10).apply(
         take_while_ops.take_while(lambda x: x < 2)).repeat(5)
     self.assertDatasetProduces(dataset, np.tile([0, 1], 5))
示例#15
0
 def testStatefulTakeWhileNotCheckpointable(self):
   dataset = dataset_ops.Dataset.range(10)
   dataset = dataset.apply(take_while_ops.take_while(self._statefulBoolFunc))
   self._assertNotCheckpointable(dataset)
 def _build_dataset(self, num_elements, upper_bound):
   return dataset_ops.Dataset.range(num_elements).apply(
       take_while_ops.take_while(lambda x: x < upper_bound))
示例#17
0
def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state,
                         symbol_names, opts):
    """Overload of _dataset_for_stmt with early stopping. See for_stmt."""
    # Note: This is easier to follow with the insight that the computations in
    # a dataset pipeline are transposed (aka fused).
    # For example, given a pipeline input -> scan -> take_while -> reduce,
    # and a dataset with input [1, 2, 3], the computations occur in the following
    # order:
    #  reduce(take_while(scan(1)))
    #  reduce(take_while(scan(2)))
    #  reduce(take_while(scan(3)))

    init_vars = get_state()
    _verify_loop_init_vars(init_vars, symbol_names)

    # Workaround for Dataset.reduce not allowing empty state tensors - create
    # a dummy state variable that remains unused.
    # TODO(mdan): reduce should allow and match empty structures.
    if not init_vars:
        init_vars = (constant_op.constant(0), )
        symbol_names = ('<internal dummy>', )

        def dummy_set_state(unused_dummy):
            pass

        def dummy_get_state():
            return (constant_op.constant(0), )

        get_state, set_state = dummy_get_state, dummy_set_state

    def scan_body(scan_state, scan_inputs):
        """Main body of the Dataset.scan."""
        loop_vars, iterate = scan_state, scan_inputs
        set_state(loop_vars)

        def main_path():
            body(iterate)
            new_loop_vars = get_state()
            _verify_tf_loop_vars(init_vars,
                                 loop_vars,
                                 new_loop_vars,
                                 symbol_names,
                                 opts,
                                 check_shapes=False)
            return new_loop_vars

        if extra_test is not None:
            extra_cond = extra_test()
            new_loop_vars = control_flow_ops.cond(extra_cond, main_path,
                                                  lambda: loop_vars)
        else:
            # TODO(mdan): the optimizer should be able to remove an invariant cond?
            extra_cond = (constant_op.constant(True), )  # dummy value, unused
            new_loop_vars = main_path()

        scan_outputs = new_loop_vars, extra_cond
        new_scan_state = new_loop_vars
        return new_scan_state, scan_outputs

    def take_while_predicate(unused_loop_vars, extra_cond):
        return extra_cond

    def reduce_body(unused_reduce_state, scan_outputs):
        output_loop_vars, unused_extra_cond = scan_outputs
        new_reduce_state = output_loop_vars
        return new_reduce_state

    ds = _general_purpose_scan(ds, init_vars, scan_body)
    if extra_test is not None:
        ds = ds.apply(take_while_ops.take_while(take_while_predicate))
    final_loop_vars = ds.reduce(init_vars, reduce_body)
    set_state(final_loop_vars)
示例#18
0
 def testTakeWhileDatasetWithRepeat(self):
   dataset = dataset_ops.Dataset.range(10).apply(
       take_while_ops.take_while(lambda x: x < 2)).repeat(5)
   self.assertDatasetProduces(dataset, np.tile([0, 1], 5))