def _dataset_for_stmt_with_extra_test(ds, extra_test, body, init_state): """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" def scan_body(state, iterate): extra_cond = extra_test(*state) new_state = control_flow_ops.cond( extra_cond, lambda: body(iterate, *state), lambda: state) aug_state = new_state, extra_cond # Note: new_state is the actual state of scan; aug_state is its output # (hence the redundancy). return new_state, aug_state def take_while_predicate(new_state, extra_cond): del new_state return extra_cond def reduce_body(old_state, aug_state): del old_state new_state, extra_cond = aug_state del extra_cond return new_state ds = ds.apply(scan_ops.scan(init_state, scan_body)) ds = ds.apply(take_while_ops.take_while(take_while_predicate)) return ds.reduce(init_state, reduce_body)
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, init_state): """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" def scan_body(state, iterate): extra_cond = extra_test(*state) new_state = control_flow_ops.cond( extra_cond, lambda: body(iterate, *state), lambda: state) aug_state = new_state, extra_cond # Note: new_state is the actual state of scan; aug_state is its output # (hence the redundancy). return new_state, aug_state def take_while_predicate(new_state, extra_cond): del new_state return extra_cond def reduce_body(old_state, aug_state): del old_state new_state, extra_cond = aug_state del extra_cond return new_state ds = ds.apply(scan_ops.scan(init_state, scan_body)) ds = ds.apply(take_while_ops.take_while(take_while_predicate)) return ds.reduce(init_state, reduce_body)
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts): """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" # TODO(mdan): Simplify this - following it is extremely difficult. init_state = get_state() aug_init_vars = init_vars, init_state def scan_body(aug_vars, iterate): """The main loop body wrapper. Only calculates the stop condition.""" loop_vars, state = aug_vars def true_fn(): """Main path - stop condition is not set.""" set_state(state) new_vars = body(iterate, *loop_vars) new_state = get_state() _verify_tf_loop_vars( init_vars + init_state, loop_vars + state, new_vars + new_state, basic_symbol_names + composite_symbol_names, opts, check_shapes=False) return new_vars, new_state extra_cond = extra_test(*loop_vars) new_vars, new_state = control_flow_ops.cond( extra_cond, true_fn, lambda: (loop_vars, state), ) scan_outputs = new_vars, new_state, extra_cond # Note: new_aug_vars is the actual state of scan; scan_outputs is its output # (hence the redundancy). # get_state will pull any mutations that body may have made. new_aug_vars = new_vars, new_state return new_aug_vars, scan_outputs def take_while_predicate(unused_loop_vars, unused_state, extra_cond): return extra_cond def reduce_body(unused_aug_vars, scan_outputs): output_aug_vars, output_state, extra_cond = scan_outputs del extra_cond return output_aug_vars, output_state ds = _general_purpose_scan(ds, aug_init_vars, scan_body) ds = ds.apply(take_while_ops.take_while(take_while_predicate)) final_aug_vars = ds.reduce(aug_init_vars, reduce_body) final_vars, final_state = final_aug_vars set_state(final_state) return final_vars
def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds): dataset = dataset_ops.Dataset.range(num_elements).apply( take_while_ops.take_while(lambda x: x < upper_bound)) if out_of_bounds: with self.assertRaises(errors.OutOfRangeError): self.assertDatasetProduces(dataset, np.arange(upper_bound)) else: self.assertDatasetProduces(dataset, np.arange(upper_bound))
def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds): dataset = dataset_ops.Dataset.range(num_elements).apply( take_while_ops.take_while(lambda x: x < upper_bound)) if out_of_bounds: with self.assertRaises(errors.OutOfRangeError): self.assertDatasetProduces(dataset, np.arange(upper_bound)) else: self.assertDatasetProduces(dataset, np.arange(upper_bound))
def testTakeWhileDataset(self, num_elements, window_size): def _predicate_func(elem): return array_ops.shape(elem)[0] > (window_size - 1) take_while = take_while_ops.take_while(_predicate_func) dataset = dataset_ops.Dataset.range(num_elements).batch(window_size) dataset = dataset.apply(take_while).flat_map( dataset_ops.Dataset.from_tensor_slices) expected_num_elements = int(num_elements / window_size) * window_size self.assertDatasetProduces(dataset, np.arange(expected_num_elements))
def testTakeWhileDataset(self, num_elements, window_size): def _predicate_func(elem): return array_ops.shape(elem)[0] > (window_size - 1) take_while = take_while_ops.take_while(_predicate_func) dataset = dataset_ops.Dataset.range(num_elements).batch(window_size) dataset = dataset.apply(take_while).flat_map( dataset_ops.Dataset.from_tensor_slices) expected_num_elements = int(num_elements / window_size) * window_size self.assertDatasetProduces(dataset, np.arange(expected_num_elements))
def testTakewhileDatasetShortCircuit(self, size, index): def _predicate_func(data_elem): return data_elem boolean_array = [True] * size boolean_array[index] = False dataset = dataset_ops.Dataset.from_tensor_slices(boolean_array).apply( take_while_ops.take_while(_predicate_func)) next_element = self.getNext(dataset) for _ in range(index): self.assertTrue(self.evaluate(next_element())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element())
def testTakeWhileDatasetString(self): def not_equal(string): return lambda x: math_ops.not_equal(x, constant_op.constant(string) ) string = ["this", "is", "the", "test", "for", "strings"] dataset = dataset_ops.Dataset.from_tensor_slices(string).apply( take_while_ops.take_while(not_equal("test"))) next_element = self.getNext(dataset) self.assertEqual(b"this", self.evaluate(next_element())) self.assertEqual(b"is", self.evaluate(next_element())) self.assertEqual(b"the", self.evaluate(next_element())) with self.assertRaises(errors.OutOfRangeError): self.assertEqual(b"test", self.evaluate(next_element()))
def testTakeWhileDatasetString(self): def not_equal(string): return lambda x: math_ops.not_equal(x, constant_op.constant(string)) string = ["this", "is", "the", "test", "for", "strings"] dataset = dataset_ops.Dataset.from_tensor_slices(string).apply( take_while_ops.take_while(not_equal("test"))) next_element = self.getNext(dataset) self.assertEqual(b"this", self.evaluate(next_element())) self.assertEqual(b"is", self.evaluate(next_element())) self.assertEqual(b"the", self.evaluate(next_element())) with self.assertRaises(errors.OutOfRangeError): self.assertEqual(b"test", self.evaluate(next_element()))
def testTakewhileDatasetShortCircuit(self, size, index): def _predicate_func(data_elem): return data_elem boolean_array = [True] * size boolean_array[index] = False dataset = dataset_ops.Dataset.from_tensor_slices(boolean_array).apply( take_while_ops.take_while(_predicate_func)) next_element = self.getNext(dataset) for _ in range(index): self.assertTrue(self.evaluate(next_element())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element())
def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, set_state, init_vars): """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" # TODO(mdan): Simplify this - following it is extremely difficult. def scan_body(aug_vars, iterate): """The main loop body wrapper. Only calculates the stop condition.""" loop_vars, state = aug_vars def true_fn(): set_state(state) outputs = body(iterate, *loop_vars) return outputs, get_state() extra_cond = extra_test(*loop_vars) new_vars, new_state = control_flow_ops.cond(extra_cond, true_fn, lambda: (loop_vars, state)) scan_outputs = new_vars, new_state, extra_cond # Note: new_aug_vars is the actual state of scan; scan_outputs is its output # (hence the redundancy). # get_state will pull any mutations that body may have made. new_aug_vars = new_vars, new_state return new_aug_vars, scan_outputs def take_while_predicate(unused_loop_vars, unused_state, extra_cond): return extra_cond def reduce_body(unused_aug_vars, scan_outputs): output_aug_vars, output_state, extra_cond = scan_outputs del extra_cond return output_aug_vars, output_state init_state = get_state() aug_vars = init_vars, init_state ds = ds.apply(scan_ops.scan(aug_vars, scan_body)) ds = ds.apply(take_while_ops.take_while(take_while_predicate)) final_aug_vars = ds.reduce(aug_vars, reduce_body) final_vars, final_state = final_aug_vars set_state(final_state) return final_vars
def testTakeWhileDatasetRange(self, num_elements, upper_bound): dataset = dataset_ops.Dataset.range(num_elements).apply( take_while_ops.take_while(lambda x: x < upper_bound)) self.assertDatasetProduces(dataset, np.arange(min(num_elements, upper_bound)))
def testTakeWhileDatasetWithRepeat(self): dataset = dataset_ops.Dataset.range(10).apply( take_while_ops.take_while(lambda x: x < 2)).repeat(5) self.assertDatasetProduces(dataset, np.tile([0, 1], 5))
def testStatefulTakeWhileNotCheckpointable(self): dataset = dataset_ops.Dataset.range(10) dataset = dataset.apply(take_while_ops.take_while(self._statefulBoolFunc)) self._assertNotCheckpointable(dataset)
def _build_dataset(self, num_elements, upper_bound): return dataset_ops.Dataset.range(num_elements).apply( take_while_ops.take_while(lambda x: x < upper_bound))
def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, symbol_names, opts): """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" # Note: This is easier to follow with the insight that the computations in # a dataset pipeline are transposed (aka fused). # For example, given a pipeline input -> scan -> take_while -> reduce, # and a dataset with input [1, 2, 3], the computations occur in the following # order: # reduce(take_while(scan(1))) # reduce(take_while(scan(2))) # reduce(take_while(scan(3))) init_vars = get_state() _verify_loop_init_vars(init_vars, symbol_names) # Workaround for Dataset.reduce not allowing empty state tensors - create # a dummy state variable that remains unused. # TODO(mdan): reduce should allow and match empty structures. if not init_vars: init_vars = (constant_op.constant(0), ) symbol_names = ('<internal dummy>', ) def dummy_set_state(unused_dummy): pass def dummy_get_state(): return (constant_op.constant(0), ) get_state, set_state = dummy_get_state, dummy_set_state def scan_body(scan_state, scan_inputs): """Main body of the Dataset.scan.""" loop_vars, iterate = scan_state, scan_inputs set_state(loop_vars) def main_path(): body(iterate) new_loop_vars = get_state() _verify_tf_loop_vars(init_vars, loop_vars, new_loop_vars, symbol_names, opts, check_shapes=False) return new_loop_vars if extra_test is not None: extra_cond = extra_test() new_loop_vars = control_flow_ops.cond(extra_cond, main_path, lambda: loop_vars) else: # TODO(mdan): the optimizer should be able to remove an invariant cond? extra_cond = (constant_op.constant(True), ) # dummy value, unused new_loop_vars = main_path() scan_outputs = new_loop_vars, extra_cond new_scan_state = new_loop_vars return new_scan_state, scan_outputs def take_while_predicate(unused_loop_vars, extra_cond): return extra_cond def reduce_body(unused_reduce_state, scan_outputs): output_loop_vars, unused_extra_cond = scan_outputs new_reduce_state = output_loop_vars return new_reduce_state ds = _general_purpose_scan(ds, init_vars, scan_body) if extra_test is not None: ds = ds.apply(take_while_ops.take_while(take_while_predicate)) final_loop_vars = ds.reduce(init_vars, reduce_body) set_state(final_loop_vars)
def testTakeWhileDatasetWithRepeat(self): dataset = dataset_ops.Dataset.range(10).apply( take_while_ops.take_while(lambda x: x < 2)).repeat(5) self.assertDatasetProduces(dataset, np.tile([0, 1], 5))