def _single_threaded_test_session(self):
     # TODO (andreasst): figure out why SDCALinearRegressor needs a single id:753 gh:754
     # threaded session to pass in tsan mode but SDCALogisticClassifier does not.
     config = config_pb2.ConfigProto(inter_op_parallelism_threads=1,
                                     intra_op_parallelism_threads=1)
     return self.test_session(config=config)
示例#2
0
def model_builder(features,
                  labels,
                  mode,
                  params,
                  config,
                  output_type=ModelBuilderOutputType.MODEL_FN_OPS):
    """Multi-machine batch gradient descent tree model.

  Args:
    features: `Tensor` or `dict` of `Tensor` objects.
    labels: Labels used to train on.
    mode: Mode we are in. (TRAIN/EVAL/INFER)
    params: A dict of hyperparameters.
      The following hyperparameters are expected:
      * head: A `Head` instance.
      * learner_config: A config for the learner.
      * feature_columns: An iterable containing all the feature columns used by
          the model.
      * examples_per_layer: Number of examples to accumulate before growing a
          layer. It can also be a function that computes the number of examples
          based on the depth of the layer that's being built.
      * weight_column_name: The name of weight column.
      * center_bias: Whether a separate tree should be created for first fitting
          the bias.
      * override_global_step_value: If after the training is done, global step
        value must be reset to this value. This is particularly useful for hyper
        parameter tuning, which can't recognize early stopping due to the number
        of trees. If None, no override of global step will happen.
    config: `RunConfig` of the estimator.
    output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
      (new interface).

  Returns:
    A `ModelFnOps` object.
  Raises:
    ValueError: if inputs are not valid.
  """
    head = params["head"]
    learner_config = params["learner_config"]
    examples_per_layer = params["examples_per_layer"]
    feature_columns = params["feature_columns"]
    weight_column_name = params["weight_column_name"]
    num_trees = params["num_trees"]
    use_core_libs = params["use_core_libs"]
    logits_modifier_function = params["logits_modifier_function"]
    output_leaf_index = params["output_leaf_index"]
    override_global_step_value = params.get("override_global_step_value", None)
    num_quantiles = params["num_quantiles"]

    if features is None:
        raise ValueError("At least one feature must be specified.")

    if config is None:
        raise ValueError("Missing estimator RunConfig.")
    if config.session_config is not None:
        session_config = config.session_config
        session_config.allow_soft_placement = True
    else:
        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    config = config.replace(session_config=session_config)

    center_bias = params["center_bias"]

    if isinstance(features, ops.Tensor):
        features = {features.name: features}

    # Make a shallow copy of features to ensure downstream usage
    # is unaffected by modifications in the model function.
    training_features = copy.copy(features)
    training_features.pop(weight_column_name, None)
    global_step = training_util.get_global_step()

    initial_ensemble = ""
    if learner_config.each_tree_start.nodes:
        if learner_config.each_tree_start_num_layers <= 0:
            raise ValueError("You must provide each_tree_start_num_layers.")
        num_layers = learner_config.each_tree_start_num_layers
        initial_ensemble = """
             trees { %s }
             tree_weights: 0.1
             tree_metadata {
              num_tree_weight_updates: 1
              num_layers_grown: %d
              is_finalized: false
             }
             """ % (text_format.MessageToString(
            learner_config.each_tree_start), num_layers)
        tree_ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig()
        text_format.Merge(initial_ensemble, tree_ensemble_proto)
        initial_ensemble = tree_ensemble_proto.SerializeToString()

    with ops.device(global_step.device):
        ensemble_handle = model_ops.tree_ensemble_variable(
            stamp_token=0,
            tree_ensemble_config=initial_ensemble,  # Initialize the ensemble.
            name="ensemble_model")

    # Create GBDT model.
    gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
        is_chief=config.is_chief,
        num_ps_replicas=config.num_ps_replicas,
        ensemble_handle=ensemble_handle,
        center_bias=center_bias,
        examples_per_layer=examples_per_layer,
        learner_config=learner_config,
        feature_columns=feature_columns,
        logits_dimension=head.logits_dimension,
        features=training_features,
        use_core_columns=use_core_libs,
        output_leaf_index=output_leaf_index,
        num_quantiles=num_quantiles)
    with ops.name_scope("gbdt", "gbdt_optimizer"):
        predictions_dict = gbdt_model.predict(mode)
        logits = predictions_dict["predictions"]
        if logits_modifier_function:
            logits = logits_modifier_function(logits, features, mode)

        def _train_op_fn(loss):
            """Returns the op to optimize the loss."""
            update_op = gbdt_model.train(loss, predictions_dict, labels)
            with ops.control_dependencies(
                [update_op]), (ops.colocate_with(global_step)):
                update_op = state_ops.assign_add(global_step, 1).op
                return update_op

    create_estimator_spec_op = getattr(head, "create_estimator_spec", None)

    training_hooks = []
    if num_trees:
        if center_bias:
            num_trees += 1

        finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor(
        )
        training_hooks.append(
            trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
                                          finalized_trees,
                                          override_global_step_value))

    if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
        if use_core_libs and callable(create_estimator_spec_op):
            model_fn_ops = head.create_estimator_spec(features=features,
                                                      mode=mode,
                                                      labels=labels,
                                                      train_op_fn=_train_op_fn,
                                                      logits=logits)
            model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
                model_fn_ops)
        else:
            model_fn_ops = head.create_model_fn_ops(features=features,
                                                    mode=mode,
                                                    labels=labels,
                                                    train_op_fn=_train_op_fn,
                                                    logits=logits)

        if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
            model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
                gbdt_batch.LEAF_INDEX]

        model_fn_ops.training_hooks.extend(training_hooks)
        return model_fn_ops
    elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
        assert callable(create_estimator_spec_op)
        estimator_spec = head.create_estimator_spec(features=features,
                                                    mode=mode,
                                                    labels=labels,
                                                    train_op_fn=_train_op_fn,
                                                    logits=logits)

        if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
            estimator_spec.predictions[
                gbdt_batch.LEAF_INDEX] = predictions_dict[
                    gbdt_batch.LEAF_INDEX]

        estimator_spec = estimator_spec._replace(
            training_hooks=training_hooks +
            list(estimator_spec.training_hooks))
        return estimator_spec

    return model_fn_ops
示例#3
0
    @function.defun
    def train():
      v = resource_variable_ops.ResourceVariable(1.0)
      grad = backprop.implicit_grad(loss)(v)
      optimizer.apply_gradients(grad)
      return v.read_value()

    value = train()
    self.assertEqual(value.numpy(), -1.0)

  def testOptimizerInDefunWithCapturedVariable(self):
    v = resource_variable_ops.ResourceVariable(1.0)
    def loss():
      return v**2

    optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

    @function.defun
    def train():
      grad = backprop.implicit_grad(loss)()
      optimizer.apply_gradients(grad)

    train()
    self.assertEqual(v.numpy(), -1.0)


if __name__ == '__main__':
  ops.enable_eager_execution(
      config=config_pb2.ConfigProto(device_count={'CPU': 3}))
  test.main()
    def _benchmark_series(self, label, series, benchmark_id):
        """Runs benchmark the given series."""

        # Decides a proper number of iterations according to the inputs.
        def compute_num_iters(map_num_calls, inter_op, element_size,
                              batch_size):
            return 1024 // ((element_size * batch_size) // min(
                12 if map_num_calls == dataset_ops.AUTOTUNE else map_num_calls,
                inter_op))

        # Makes the dataset based on the inputs.
        def make_dataset(map_num_calls, element_size, batch_size,
                         batch_num_calls, apply_fusion):
            k = 1024 * 1024
            x = constant_op.constant(np.random.rand(element_size, 4 * k))
            y = constant_op.constant(np.random.rand(4 * k, 1))
            dataset = dataset_ops.Dataset.range(1000000000000).map(lambda _:
                                                                   (x, y))
            dataset = dataset.map(math_ops.matmul,
                                  num_parallel_calls=map_num_calls)
            dataset = dataset.batch(batch_size=batch_size,
                                    num_parallel_calls=batch_num_calls)
            options = options_lib.Options()
            options.experimental_optimization.apply_default_optimizations = False
            options.experimental_optimization.map_and_batch_fusion = apply_fusion
            dataset = dataset.with_options(options)
            return dataset

        # Makes the name of the dataset based on the inputs.
        def make_name(label, map_num_calls, inter_op, element_size, batch_size,
                      batch_num_calls, apply_fusion):
            map_num_calls_str = ("autotuned"
                                 if map_num_calls == dataset_ops.AUTOTUNE else
                                 str(map_num_calls))
            batch_num_calls_str = (
                "autotuned" if batch_num_calls == dataset_ops.AUTOTUNE else
                str(1 if batch_num_calls is None else batch_num_calls))
            name_str = (
                "%s_id_%s_map_num_calls_%s_batch_num_calls_%s_inter_op_%d"
                "_elem_size_%d_batch_size_%d")
            name = (name_str % (
                "fused" if apply_fusion else "chained",
                hashlib.sha1((label).encode("utf-8")).hexdigest()[:8],
                map_num_calls_str,
                batch_num_calls_str,
                inter_op,
                element_size,
                batch_size,
            ))
            return name

        for (map_num_calls, inter_op, element_size, batch_size,
             batch_num_calls, apply_fusion) in series:
            num_iters = compute_num_iters(map_num_calls, inter_op,
                                          element_size, batch_size)
            dataset = make_dataset(map_num_calls, element_size, batch_size,
                                   batch_num_calls, apply_fusion)
            name = make_name(label, map_num_calls, inter_op, element_size,
                             batch_size, batch_num_calls, apply_fusion)

            session_config = config_pb2.ConfigProto(
                inter_op_parallelism_threads=inter_op,
                use_per_session_threads=True)

            self.run_and_report_benchmark(
                dataset=dataset,
                iters=num_iters,
                num_elements=batch_size,
                warmup=True,
                extras={
                    "model_name":
                    "map_and_batch.benchmark.%d" % benchmark_id,
                    "parameters":
                    "%d.%d.%d.%d.%d.%s" %
                    (map_num_calls, inter_op, element_size, batch_size,
                     batch_num_calls, apply_fusion),
                },
                session_config=session_config,
                name=name)
class MirroredVariableUpdateTest(test.TestCase):
    # The following tests check assign, assign_add and assign_sub on Mirrored
    # variables in tower and cross tower context.
    config = config_pb2.ConfigProto()
    config.allow_soft_placement = True

    def _skip_eager_if_gpus_less_than(self, num_gpus):
        if context.num_gpus() < num_gpus and context.executing_eagerly():
            self.skipTest(
                "Enough GPUs not available for this test in eager mode.")

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testAssignMirroredVarTowerContextWithoutAggregationType(self):
        # Test that we always have an aggregation type set on the mirrored variable
        # if we assign to it in tower mode.
        self._skip_eager_if_gpus_less_than(1)

        def var_fn():
            v = variable_scope.variable(1.0, name="foo")
            return v

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            mirrored_var = dist.call_for_each_tower(var_fn,
                                                    run_concurrently=False)
            self.assertIsInstance(mirrored_var, values.MirroredVariable)
            self.evaluate(variables.global_variables_initializer())

            def model_fn():
                return mirrored_var.assign(5.0)

            with self.assertRaisesRegexp(
                    ValueError,
                    "You must specify an aggregation method to update a "
                    "MirroredVariable in Tower Context."):
                self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testAssignMirroredVarTowerContextWithSum(self):
        # Test that we don't reduce a non-per-device value with the "sum"
        # aggregation type.
        self._skip_eager_if_gpus_less_than(1)

        def var_fn():
            v = variable_scope.variable(
                1.0,
                name="foo",
                aggregation=variable_scope.VariableAggregation.SUM)
            return v

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            mirrored_var = dist.call_for_each_tower(var_fn,
                                                    run_concurrently=False)
            self.assertIsInstance(mirrored_var, values.MirroredVariable)
            self.evaluate(variables.global_variables_initializer())

            def model_fn():
                return mirrored_var.assign(5.0)

            with self.assertRaisesRegexp(
                    ValueError,
                    "A non PerDevice value cannot be reduced with the given "
                    "aggregation."):
                self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testAssignMirroredVarCrossTowerContext(self):
        self._skip_eager_if_gpus_less_than(1)

        def var_fn():
            return variable_scope.variable(1.0, name="foo")

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            mirrored_var = dist.call_for_each_tower(var_fn,
                                                    run_concurrently=False)
            self.assertIsInstance(mirrored_var, values.MirroredVariable)
            self.evaluate(variables.global_variables_initializer())
            self.assertEquals(1.0, self.evaluate(mirrored_var))
            mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
            self.assertEquals(6.0, mirrored_var_result)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testAssignMirroredVarTowerContext(self):
        self._skip_eager_if_gpus_less_than(1)

        def var_fn():
            return variable_scope.variable(
                1.0,
                name="foo",
                aggregation=variable_scope.VariableAggregation.MEAN)

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            mirrored_var = dist.call_for_each_tower(var_fn,
                                                    run_concurrently=False)
            self.assertIsInstance(mirrored_var, values.MirroredVariable)
            self.evaluate(variables.global_variables_initializer())
            self.assertEquals(1.0, self.evaluate(mirrored_var))

            def model_fn():
                value = math_ops.cast(
                    distribute_lib.get_tower_context().tower_id,
                    mirrored_var.dtype)
                return mirrored_var.assign(value)

            self.evaluate(
                dist.unwrap(
                    dist.call_for_each_tower(model_fn,
                                             run_concurrently=False)))
            self.assertEquals(0.5, self.evaluate(mirrored_var))

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testAssignAddMirroredVarCrossTowerContext(self):
        self._skip_eager_if_gpus_less_than(1)

        def var_fn():
            return variable_scope.variable(1.0, name="foo")

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            mirrored_var = dist.call_for_each_tower(var_fn,
                                                    run_concurrently=False)
            self.assertIsInstance(mirrored_var, values.MirroredVariable)
            self.evaluate(variables.global_variables_initializer())
            self.assertEquals(1.0, self.evaluate(mirrored_var))
            mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0))
            self.assertEquals(7.0, mirrored_var_result)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testAssignAddMirroredVarTowerContext(self):
        self._skip_eager_if_gpus_less_than(1)

        def var_fn():
            return variable_scope.variable(
                1.0,
                name="foo",
                aggregation=variable_scope.VariableAggregation.MEAN)

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            mirrored_var = dist.call_for_each_tower(var_fn,
                                                    run_concurrently=False)
            self.assertIsInstance(mirrored_var, values.MirroredVariable)
            self.evaluate(variables.global_variables_initializer())
            self.assertEquals(1.0, self.evaluate(mirrored_var))

            def model_fn():
                value = math_ops.cast(
                    distribute_lib.get_tower_context().tower_id,
                    mirrored_var.dtype)
                return mirrored_var.assign_add(value)

            self.evaluate(
                dist.unwrap(
                    dist.call_for_each_tower(model_fn,
                                             run_concurrently=False)))
            self.assertEquals(1.5, self.evaluate(mirrored_var))

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testAssignSubMirroredVarCrossTowerContext(self):
        self._skip_eager_if_gpus_less_than(1)

        def var_fn():
            return variable_scope.variable(5.0, name="foo")

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            mirrored_var = dist.call_for_each_tower(var_fn,
                                                    run_concurrently=False)
            self.assertIsInstance(mirrored_var, values.MirroredVariable)
            self.evaluate(variables.global_variables_initializer())
            self.assertEquals(5.0, self.evaluate(mirrored_var))
            mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
            self.assertEquals(3.0, mirrored_var_result)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testAssignSubMirroredVarTowerContext(self):
        self._skip_eager_if_gpus_less_than(1)

        def var_fn():
            return variable_scope.variable(
                5.0,
                name="foo",
                aggregation=variable_scope.VariableAggregation.MEAN)

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            mirrored_var = dist.call_for_each_tower(var_fn,
                                                    run_concurrently=False)
            self.assertIsInstance(mirrored_var, values.MirroredVariable)
            self.evaluate(variables.global_variables_initializer())
            self.assertEquals(5.0, self.evaluate(mirrored_var))

            def model_fn():
                value = math_ops.cast(
                    distribute_lib.get_tower_context().tower_id,
                    mirrored_var.dtype)
                return mirrored_var.assign_sub(value)

            self.evaluate(
                dist.unwrap(
                    dist.call_for_each_tower(model_fn,
                                             run_concurrently=False)))
            self.assertEquals(4.5, self.evaluate(mirrored_var))
 def _test_session(self, target):
     config = config_pb2.ConfigProto(allow_soft_placement=True)
     config.graph_options.optimizer_options.opt_level = -1
     with session.Session(graph=None, config=config, target=target) as sess:
         yield sess
def optimize_graph(graph,
                   signature_def,
                   output_graph,
                   tf_version,
                   quantization_dtype=None,
                   skip_op_check=False,
                   strip_debug_ops=False):
    """Takes a Python Graph object and optimizes the graph.

  Args:
    graph: The frozen graph to optimize.
    signature_def: the SignatureDef of the inference graph.
    output_graph: The location of the output graph.
    tf_version: Tensorflow version of the input graph.
    quantization_dtype: An optional numpy dtype to quantize weights to for
      compression. Only np.uint8 and np.uint16 are supported.
    skip_op_check: Bool whether to skip the op check.
    strip_debug_ops: Bool whether to strip debug ops.
  """
    fuse_prelu.register_prelu_func(graph)

    # Add a collection 'train_op' so that Grappler knows the outputs.
    for _, output in signature_def.outputs.items():
        name = output.name.split(':')[0]
        graph.add_to_collection('train_op', graph.get_operation_by_name(name))

    graph_def = graph.as_graph_def()

    unsupported = validate(graph_def.node, skip_op_check, strip_debug_ops)
    if unsupported:
        raise ValueError('Unsupported Ops in the model before optimization\n' +
                         ', '.join(unsupported))

    # first pass of grappler optimization, this is needed for batch norm folding.
    config = config_pb2.ConfigProto()
    rewriter_config = config.graph_options.rewrite_options
    rewriter_config.optimizers[:] = [
        'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning',
        'constfold', 'arithmetic', 'dependency'
    ]
    if strip_debug_ops:
        rewriter_config.optimizers.insert(0, 'debug_stripper')

    optimized_graph = _run_grappler(config, graph_def, graph, signature_def)

    # batch norm folding
    optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph)

    # set the device to CPU for all Conv2d nodes, since grappler remap optimizer
    # only support FusedConv2D for CPU.
    for node in optimized_graph.node:
        if node.op == 'Conv2D':
            node.device = '/device:CPU:0'

    # rerun grappler to fuse conv2d
    config.graph_options.rewrite_options.optimizers[:] = [
        'remap', 'constfold', 'arithmetic', 'dependency'
    ]

    optimized_graph = _run_grappler(config, optimized_graph, graph,
                                    signature_def)
    optimized_graph = _remove_unused_control_flow_inputs(optimized_graph)

    # Because TF break the Prelu op into 6 ops, for performance we are
    # fusing those ops into a single prelu
    optimized_graph = fuse_prelu.fuse_ops_for_prelu(optimized_graph)

    # Since the grappler remap optimizer doe snot support prelu as the activation
    # function for _FusedConv2D op, we are doing it manually here.
    optimized_graph = fuse_prelu.fuse_prelu_with_fused_conv2d(optimized_graph)

    unsupported = validate(optimized_graph.node, skip_op_check,
                           strip_debug_ops)
    if unsupported:
        raise ValueError('Unsupported Ops in the model after optimization\n' +
                         ', '.join(unsupported))

    extract_weights(optimized_graph, output_graph, tf_version, signature_def,
                    quantization_dtype)
    return optimize_graph
示例#8
0
 def _no_rewrite_session_config(self):
   rewriter_config = rewriter_config_pb2.RewriterConfig(
       pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
   graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
   return config_pb2.ConfigProto(graph_options=graph_options)
 def _GetConfigProto(self):
     """Get ConfigProto for session creation."""
     config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
         allow_growth=True))
     return config
示例#10
0
def run_distribute_coordinator(worker_fn,
                               strategy,
                               eval_fn=None,
                               eval_strategy=None,
                               mode=CoordinatorMode.STANDALONE_CLIENT,
                               cluster_spec=None,
                               task_type=None,
                               task_id=None,
                               session_config=None,
                               rpc_layer="grpc"):
    """Runs the coordinator for distributed TensorFlow.

  This function runs a split coordinator for distributed TensorFlow in its
  default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
  specifying server addresses and their roles in a cluster, this coordinator
  will figure out how to set them up, give the underlying function the right
  targets for master sessions via a scope object and coordinate their training.
  The cluster consisting of standard servers needs to be brought up either with
  the standard server binary or with a binary running distribute coordinator
  with `task_type` set to non-client type which will then turn into standard
  servers.

  In addition to be the distribute coordinator, this is also the source of
  configurations for each job in the distributed training. As there are multiple
  ways to configure a distributed TensorFlow cluster, its context object
  provides these configurations so that users or higher-level APIs don't have to
  figure out the configuration for each job by themselves.

  In the between-graph replicated training, this coordinator will create
  multiple threads and each calls the `worker_fn` which is supposed to create
  its own graph and connect to one worker master given by its context object. In
  the in-graph replicated training, it has only one thread calling this
  `worker_fn`.

  Another mode is the INDEPENDENT_WORKER mode where each server runs a
  distribute coordinator which will start a standard server and optionally runs
  `worker_fn` depending whether it is between-graph training or in-graph
  replicated training.

  The `strategy` object is expected to be a DistributionStrategy object which
  has implemented methods needed by distributed coordinator such as
  `configure(session_config, cluster_spec, task_type, task_id)` which configures
  the strategy object for a specific task and `experimental_should_init`
  property which instructs the distribute coordinator whether to run init ops
  for a task. The distribute coordinator will make a copy of the `strategy`
  object, call its `configure` method and pass it to `worker_fn` as an argument.

  The `worker_fn` defines the training logic and is called under its own
  worker context which can be accessed to via `get_current_worker_context`. A
  worker context provides access to configurations for each task, e.g. the
  task_type, task_id, master target and so on. Since `worker_fn` will be called
  in a thread and possibly multiple times, caller should be careful when it
  accesses global data. For example, it is unsafe to define flags in a
  `worker_fn` or to define different environment variables for different
  `worker_fn`s.

  The `worker_fn` for the between-graph replication is defined as if there is
  only one worker corresponding to the `worker_fn` and possibly ps jobs. For
  example, when training with parameter servers, it assigns variables to
  parameter servers and all other operations to that worker. In the in-graph
  replication case, the `worker_fn` has to define operations for all worker
  jobs. Using a distribution strategy can simplify the `worker_fn` by not having
  to worry about the replication and device assignment of variables and
  operations.

  This method is intended to be invoked by high-level APIs so that users don't
  have to explictly call it to run this coordinator. For those who don't use
  high-level APIs, to change a program to use this coordinator, wrap everything
  in a the program after global data definitions such as commandline flag
  definition into the `worker_fn` and get task-specific configurations from
  the worker context.

  The `cluster_spec` can be either passed by the argument or parsed from the
  "TF_CONFIG" environment variable. Example of a TF_CONFIG:
  ```
    cluster = {'chief': ['host0:2222'],
               'ps': ['host1:2222', 'host2:2222'],
               'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
    os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster})
  ```

  If `cluster_spec` is not given in any format, it becomes local training and
  this coordinator will connect to a local session.

  For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
  will be created to call `eval_fn` with its `task_type` set to "evaluator". If
  `eval_fn` is not defined, fall back to `worker_fn`. This implies that
  evaluation will be done on a single machine if there is an "evaluator" task.
  If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the
  `worker_fn` for how to do evaluation.

  Args:
    worker_fn: the function to be called. The function should accept a
      `strategy` object and will be given access to a context object via a
      context manager scope.
    strategy: a DistributionStrategy object specifying whether it should
      run between-graph replicated training or not, whether to run init ops,
      etc. This object will also be configured given `session_config`,
      `cluster_spec`, `task_type` and `task_id`.
    eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
      in but a "evaluator" task is found in the `cluster_spec`, the `worker_fn`
      will be used for this task.
    eval_strategy: optional DistributionStrategy object for "evaluator" task.
    mode: in which mode this distribute coordinator runs.
    cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
      in a cluster. If not set or empty, fall back to local training.
    task_type: the current task type, optional if this is a client.
    task_id: the current task id, optional if this is a client.
    session_config: an optional `tf.compat.v1.ConfigProto` object which will be
      passed to `strategy`'s `configure` method and used to create a session.
    rpc_layer: optional string, the protocol for RPC, e.g. "grpc".

  Raises:
    ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or
      a ClusterSpec.

  Returns:
    In the client job, return the value returned by `worker_fn` if
    it is in-graph replication or INDEPENDENT_WORKER mode; return None
    otherwise.
  """
    tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
    if not cluster_spec:
        cluster_spec = tf_config.get("cluster", {})
        task_env = tf_config.get("task", {})
        if task_env:
            task_type = task_env.get("type", task_type)
            task_id = int(task_env.get("index", task_id))

    if cluster_spec:
        cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
        # TODO(yuefengz): validate cluster_spec.

    rpc_layer = tf_config.get("rpc_layer", rpc_layer)
    environment = tf_config.get("environment", None)

    # Setting the session config is necessary for some strategies such as
    # CollectiveAllReduceStrategy.
    session_config = session_config or config_pb2.ConfigProto(
        allow_soft_placement=True)

    if cluster_spec:
        logging.info(
            "Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
            "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r",
            mode, cluster_spec.as_dict(), task_type, task_id, environment,
            rpc_layer)

    if not cluster_spec:
        # `mode` is ignored in the local case.
        logging.info("Running local Distribute Coordinator.")
        _run_single_worker(worker_fn, strategy, None, None, None,
                           session_config, rpc_layer)
        if eval_fn:
            _run_single_worker(eval_fn, eval_strategy, None, None, None,
                               session_config, rpc_layer)
        else:
            logging.warning(
                "Skipped evaluation since `eval_fn` is not passed in.")
    elif mode == CoordinatorMode.STANDALONE_CLIENT:
        if not eval_fn:
            logging.warning(
                "`eval_fn` is not passed in. The `worker_fn` will be "
                "used if an \"evaluator\" task exists in the cluster.")
        eval_fn = eval_fn or worker_fn
        if not eval_strategy:
            logging.warning(
                "`eval_strategy` is not passed in. No distribution "
                "strategy will be used for evaluation.")

        # The client must know the cluster but servers in the cluster don't have to
        # know the client.
        if task_type in [_TaskType.CLIENT, None]:
            if strategy.extended.experimental_between_graph:
                return _run_between_graph_client(worker_fn, strategy, eval_fn,
                                                 eval_strategy, cluster_spec,
                                                 session_config, rpc_layer)
            else:
                return _run_in_graph_client(worker_fn, strategy, eval_fn,
                                            eval_strategy, cluster_spec,
                                            session_config, rpc_layer)
        else:
            # If not a client job, run the standard server.
            _configure_session_config_for_std_servers(strategy, eval_strategy,
                                                      session_config,
                                                      cluster_spec, task_type,
                                                      task_id)
            server = _run_std_server(cluster_spec=cluster_spec,
                                     task_type=task_type,
                                     task_id=task_id,
                                     session_config=session_config,
                                     rpc_layer=rpc_layer,
                                     environment=environment)
            server.join()
    else:
        if mode != CoordinatorMode.INDEPENDENT_WORKER:
            raise ValueError("Unexpected coordinator mode: %r" % mode)

        if not eval_fn:
            logging.warning(
                "`eval_fn` is not passed in. The `worker_fn` will be "
                "used if an \"evaluator\" task exists in the cluster.")
        eval_fn = eval_fn or worker_fn
        if not eval_strategy:
            logging.warning(
                "`eval_strategy` is not passed in. No distribution "
                "strategy will be used for evaluation.")

        # Every one starts a standard server, get session config from `configure`
        # method.
        _configure_session_config_for_std_servers(strategy, eval_strategy,
                                                  session_config, cluster_spec,
                                                  task_type, task_id)

        if not getattr(strategy.extended, "_std_server_started", False):
            # Right now, with eager mode, context is configured with a std server at
            # the very beginning while with graph mode the std server is started when
            # distribute coordinator is called. We should consolidate these two paths.
            server = _run_std_server(cluster_spec=cluster_spec,
                                     task_type=task_type,
                                     task_id=task_id,
                                     session_config=session_config,
                                     rpc_layer=rpc_layer,
                                     environment=environment)
        if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
            if strategy.extended.experimental_between_graph:
                # All jobs run `worker_fn` if between-graph.
                return _run_single_worker(worker_fn, strategy, cluster_spec,
                                          task_type, task_id, session_config,
                                          rpc_layer)
            else:
                # Only one node runs `worker_fn` if in-graph.
                context = _WorkerContext(strategy, cluster_spec, task_type,
                                         task_id)
                if context.is_chief:
                    return _run_single_worker(worker_fn, strategy,
                                              cluster_spec, None, None,
                                              session_config, rpc_layer)
                else:
                    server.join()
        elif task_type == _TaskType.EVALUATOR:
            return _run_single_worker(eval_fn, eval_strategy, cluster_spec,
                                      task_type, task_id, session_config,
                                      rpc_layer)
        else:
            if task_type != _TaskType.PS:
                raise ValueError("Unexpected task_type: %r" % task_type)
            server.join()
示例#11
0
 def testContextConfig(self):
   if not context.context().num_gpus():
     self.skipTest('No GPUs found')
   ctx = context.Context(config=config_pb2.ConfigProto(
       device_count={'GPU': 0}))
   self.assertEquals(0, ctx.num_gpus())
示例#12
0
def run_standard_tensorflow_server(session_config=None):
    """Starts a standard TensorFlow server.

  This method parses configurations from "TF_CONFIG" environment variable and
  starts a TensorFlow server. The "TF_CONFIG" is typically a json string and
  must have information of the cluster and the role of the server in the
  cluster. One example is:

  TF_CONFIG='{
      "cluster": {
          "worker": ["host1:2222", "host2:2222", "host3:2222"],
          "ps": ["host4:2222", "host5:2222"]
      },
      "task": {"type": "worker", "index": 1}
  }'

  This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster
  and the current role is worker 1.

  Valid task types are "chief", "worker", "ps" and "evaluator" and you can have
  at most one "chief" and at most one "evaluator".

  An optional key-value can be specified is "rpc_layer". The default value is
  "grpc".

  Args:
    session_config: an optional `tf.compat.v1.ConfigProto` object. Users can
      pass in the session config object to configure server-local devices.

  Returns:
    a `tf.distribute.Server` object which has already been started.

  Raises:
    ValueError: if the "TF_CONFIG" environment is not complete.
  """
    tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
    if "cluster" not in tf_config:
        raise ValueError("\"cluster\" is not found in TF_CONFIG.")
    cluster_spec = multi_worker_util.normalize_cluster_spec(
        tf_config["cluster"])
    if "task" not in tf_config:
        raise ValueError("\"task\" is not found in TF_CONFIG.")
    task_env = tf_config["task"]
    if "type" not in task_env:
        raise ValueError(
            "\"task_type\" is not found in the `task` part of TF_CONFIG.")
    task_type = task_env["type"]
    task_id = int(task_env.get("index", 0))

    rpc_layer = tf_config.get("rpc_layer", "grpc")

    session_config = session_config or config_pb2.ConfigProto()
    # Set the collective group leader for collective ops to initialize collective
    # ops when server starts.
    if "chief" in cluster_spec.jobs:
        session_config.experimental.collective_group_leader = (
            "/job:chief/replica:0/task:0")
    else:
        if "worker" not in cluster_spec.jobs:
            raise ValueError(
                "You must have `chief` or `worker` jobs in the `cluster_spec`."
            )
        session_config.experimental.collective_group_leader = (
            "/job:worker/replica:0/task:0")

    server = _run_std_server(cluster_spec=cluster_spec,
                             task_type=task_type,
                             task_id=task_id,
                             session_config=session_config,
                             rpc_layer=rpc_layer)
    server.start()
    return server
示例#13
0
def initialize_tpu_system(cluster_resolver=None):
    """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster. If called
    inside tf.function, it returns the serialized topology object instead.

  Raises:
    RuntimeError: If running inside a tf.function.
    NotFoundError: If no TPU devices found in eager mode.
  """

    # Deallocate all TPU buffers by clearing out eager context caches and
    # triggering garbage collection to avoid keeping invalid tpu buffer around
    # after reinitialized tpu system.
    logging.info("Deallocate tpu buffers before initializing tpu system.")
    context.context()._clear_caches()  # pylint: disable=protected-access
    context.context().clear_kernel_cache()
    gc.collect()

    job = None
    if cluster_resolver is None:
        # If no cluster resolver is specified, and running eagerly, execute the init
        # ops in the current device scope.
        if context.executing_eagerly():
            curr_device = device.DeviceSpec.from_string(
                context.context().device_name)
            if curr_device.job is not None:
                job = "{}/replica:0/task:0".format(curr_device.job)

        cluster_resolver = TPUClusterResolver("")
    assert isinstance(cluster_resolver, TPUClusterResolver)

    tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
    if tpu_name in _INITIALIZED_TPU_SYSTEMS:
        logging.warning(
            "TPU system %s has already been initialized. "
            "Reinitializing the TPU can cause previously created "
            "variables on TPU to be lost.", tpu_name)

    logging.info("Initializing the TPU system: %s", tpu_name)

    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    if tpu_name not in _LOCAL_MASTERS:
        # Explicitly place the tpu.initialize_system in the first worker to
        # avoid the output node match multiple devices error.
        job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

    if context.executing_eagerly():

        @function.defun
        def _tpu_init_fn():
            # In TF1, we usually close chips when compilation fails to clear the data
            # in infeed. In TF2, we don't need to do this because infeed is no longer
            # used, so user can recover from TPU compilation failures more smoothly.
            # Same for the cancellation of a TPU excution.
            return tpu.initialize_system(
                job=job,
                compilation_failure_closes_chips=False,
                tpu_cancellation_closes_chips=False)

        # The TPU_SYSTEM device must match the device used in tpu.initialize_system
        # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
        # devices available.
        try:
            with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
                output = _tpu_init_fn()
            context.async_wait()
        except errors.InvalidArgumentError as e:
            raise errors.NotFoundError(
                None, None,
                "TPUs not found in the cluster. Failed in initialization: " +
                str(e))

        # Clear out the eager context caches since the memory is invalid now.
        context.context()._initialize_logical_devices()  # pylint: disable=protected-access

        serialized_topology = output.numpy()
    elif not ops.executing_eagerly_outside_functions():
        master = cluster_resolver.master()
        cluster_spec = cluster_resolver.cluster_spec()

        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        if cluster_spec:
            session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                serialized_topology = sess.run(tpu.initialize_system())
    else:
        with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
            serialized_topology = tpu.initialize_system(
                job=job, compilation_failure_closes_chips=False)
            # If initialize_tpu_system is called inside tf.function, we only return
            # the serialized topology object as the tf.tpu.Topology object has to be
            # constructed in eager mode.
            return serialized_topology

    logging.info("Finished initializing TPU system.")
    tpu_topology = topology.Topology(serialized=serialized_topology)
    cluster_resolver.set_tpu_topology(serialized_topology)
    _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

    return tpu_topology
示例#14
0
def shutdown_tpu_system(cluster_resolver=None):
    """Shuts down the TPU devices.

  This will clear all caches, even those that are maintained through sequential
  calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
  cache.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution or if run in a
        tf.function.
  """
    job = None
    if cluster_resolver is None:
        # If no cluster resolver is specified, and running eagerly, execute the init
        # ops in the current device scope.
        if context.executing_eagerly():
            curr_device = device.DeviceSpec.from_string(
                context.context().device_name)
            if curr_device.job is not None:
                job = "{}/replica:0/task:0".format(curr_device.job)

        cluster_resolver = TPUClusterResolver("")
    assert isinstance(cluster_resolver, TPUClusterResolver)

    tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
    if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
        logging.warning(
            "You are shutting down a TPU system %s that has not been "
            "initialized." % tpu_name)

    logging.info("Shutting down the TPU system: %s", tpu_name)

    if context.executing_eagerly():
        # This function looks as it is for the following non-intuitive reasons.
        # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
        # DistributedTPURewritePass. This pass actually adds real ops that
        # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
        # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
        if tpu_name not in _LOCAL_MASTERS:
            # Explicitly place the tpu.shutdown_system in the first worker to
            # avoid the output node match multiple devices error.
            job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

        @function.defun
        def _tpu_shutdown_fn():
            tpu.shutdown_system(job=job)

        # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
        # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
        # devices available.
        with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
            _tpu_shutdown_fn()

        # Clear out the eager context caches since the memory is invalid now.
        logging.info("Clearing out eager caches")
        context.context()._clear_caches()  # pylint: disable=protected-access
        context.context().clear_kernel_cache()
    elif not ops.executing_eagerly_outside_functions():
        master = cluster_resolver.master()
        cluster_spec = cluster_resolver.cluster_spec()

        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        if cluster_spec:
            session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                sess.run(tpu.shutdown_system())
    else:
        raise RuntimeError(
            "initialize_tpu_system is not supported within "
            "tf.functions.  You should call initialize_tpu_system outside of your tf.function. "
        )

    logging.info("Finished shutting down TPU system.")
    if tpu_name in _INITIALIZED_TPU_SYSTEMS:
        del _INITIALIZED_TPU_SYSTEMS[tpu_name]
from __future__ import division
from __future__ import print_function

from tensorflow_recommenders_addons.dynamic_embedding.python.ops import math_ops as de_math

from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test

default_config = config_pb2.ConfigProto(
    allow_soft_placement=False,
    gpu_options=config_pb2.GPUOptions(allow_growth=True))

use_gpu = test_util.is_gpu_available()


class SparseSemgentReductionOpsTest(object):
    def forward_compute(self, data, indices, segment_ids, num_segments=None):
        result = de_math.sparse_segment_sum(data,
                                            indices,
                                            segment_ids,
                                            num_segments=num_segments)
        expected = math_ops.sparse_segment_sum(data,
                                               indices,
                                               segment_ids,
                                               num_segments=num_segments)
示例#16
0
        def benchmark(series):

            for num_calls, inter_op, element_size, batch_size, num_steps in series:
                dataset = dataset_ops.Dataset.from_tensors(
                    np.random.randint(100, size=element_size)).repeat().map(
                        lambda x: x, num_parallel_calls=num_calls).batch(
                            batch_size=batch_size)
                iterator = dataset.make_one_shot_iterator()
                get_next = iterator.get_next()

                fused_dataset = dataset_ops.Dataset.from_tensors(
                    np.random.randint(100,
                                      size=element_size)).repeat(None).apply(
                                          batching.map_and_batch(
                                              lambda x: x,
                                              num_parallel_calls=num_calls,
                                              batch_size=batch_size))
                fused_iterator = fused_dataset.make_one_shot_iterator()
                fused_get_next = fused_iterator.get_next()

                fused_deltas = []
                with session.Session(config=config_pb2.ConfigProto(
                        inter_op_parallelism_threads=inter_op)) as sess:

                    for _ in range(5):
                        sess.run(fused_get_next)
                    for _ in range(num_iters):
                        start = time.time()
                        for _ in range(num_steps):
                            sess.run(fused_get_next)
                        end = time.time()
                        fused_deltas.append(end - start)

                chained_deltas = []
                with session.Session(config=config_pb2.ConfigProto(
                        inter_op_parallelism_threads=inter_op)) as sess:
                    for _ in range(5):
                        sess.run(get_next)
                    for _ in range(num_iters):
                        start = time.time()
                        for _ in range(num_steps):
                            sess.run(get_next)
                        end = time.time()
                        chained_deltas.append(end - start)

                chained_wall_time = np.median(chained_deltas) / num_iters
                fused_wall_time = np.median(fused_deltas) / num_iters
                print(
                    "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
                    "element size: %d, chained wall time: %f, fused wall time: %f"
                    % (batch_size, num_calls, inter_op, element_size,
                       chained_wall_time, fused_wall_time))

                self.report_benchmark(
                    iters=num_iters,
                    wall_time=chained_wall_time,
                    name=
                    "chained_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
                    % (batch_size, num_calls, inter_op, element_size))

                self.report_benchmark(
                    iters=num_iters,
                    wall_time=fused_wall_time,
                    name=
                    "fused_batch_size_%d_num_calls_%d_inter_op_%d_elem_size_%d"
                    % (batch_size, num_calls, inter_op, element_size))
示例#17
0
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import gradient_descent

# Global config for grappler setting that is used for graph mode test.
_rewrites = rewriter_config_pb2.RewriterConfig()
_rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF
_customer_optimizer = _rewrites.custom_optimizers.add()
_customer_optimizer.name = 'ExperimentalImplementationSelector'
_rewrites.min_graph_nodes = -1
_graph_options = config_pb2.GraphOptions(rewrite_options=_rewrites)
_config = config_pb2.ConfigProto(graph_options=_graph_options)


@keras_parameterized.run_all_keras_modes(config=_config)
class UnifiedLSTMTest(keras_parameterized.TestCase):
    @parameterized.named_parameters(
        ('non_tan_activation', 'relu', 'sigmoid', 0, False, True),
        ('non_sigmoid_recur_activation', 'tanh', 'relu', 0, False, True),
        ('use_recurrent_dropout', 'tanh', 'sigmoid', 0.1, False, True),
        ('unroll', 'tanh', 'sigmoid', 0, True, True),
        ('not_use_bias', 'tanh', 'sigmoid', 0, False, False),
    )
    def test_could_use_defun_backend(self, activation, recurrent_activation,
                                     recurrent_dropout, unroll, use_bias):
        layer = keras.layers.UnifiedLSTM(
            1,
示例#18
0
    def __init__(self,
                 master=None,
                 num_cores=0,
                 log_device_placement=False,
                 gpu_memory_fraction=1,
                 tf_random_seed=None,
                 save_summary_steps=100,
                 save_checkpoints_secs=_USE_DEFAULT,
                 save_checkpoints_steps=None,
                 keep_checkpoint_max=5,
                 keep_checkpoint_every_n_hours=10000,
                 log_step_count_steps=100,
                 evaluation_master='',
                 model_dir=None,
                 session_config=None):
        """Constructor.

    The superclass `ClusterConfig` may set properties like `cluster_spec`,
    `is_chief`, `master` (if `None` in the args), `num_ps_replicas`, `task_id`,
    and `task_type` based on the `TF_CONFIG` environment variable. See
    `ClusterConfig` for more details.

    N.B.: If `save_checkpoints_steps` or `save_checkpoints_secs` is set,
    `keep_checkpoint_max` might need to be adjusted accordingly, especially in
    distributed training. For example, setting `save_checkpoints_secs` as 60
    without adjusting `keep_checkpoint_max` (defaults to 5) leads to situation
    that checkpoint would be garbage collected after 5 minutes. In distributed
    training, the evaluation job starts asynchronously and might fail to load or
    find the checkpoint due to race condition.

    Args:
      master: TensorFlow master. Defaults to empty string for local.
      num_cores: Number of cores to be used. If 0, the system picks an
        appropriate number (default: 0).
      log_device_placement: Log the op placement to devices (default: False).
      gpu_memory_fraction: Fraction of GPU memory used by the process on
        each GPU uniformly on the same machine.
      tf_random_seed: Random seed for TensorFlow initializers.
        Setting this value allows consistency between reruns.
      save_summary_steps: Save summaries every this many steps.
      save_checkpoints_secs: Save checkpoints every this many seconds. Can not
          be specified with `save_checkpoints_steps`.
      save_checkpoints_steps: Save checkpoints every this many steps. Can not be
          specified with `save_checkpoints_secs`.
      keep_checkpoint_max: The maximum number of recent checkpoint files to
        keep. As new files are created, older files are deleted. If None or 0,
        all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent
        checkpoint files are kept.)
      keep_checkpoint_every_n_hours: Number of hours between each checkpoint
        to be saved. The default value of 10,000 hours effectively disables
        the feature.
      log_step_count_steps: The frequency, in number of global steps, that the
        global step/sec will be logged during training.
      evaluation_master: the master on which to perform evaluation.
      model_dir: directory where model parameters, graph etc are saved. If
        `None`, will use `model_dir` property in `TF_CONFIG` environment
        variable. If both are set, must have same value. If both are `None`, see
        `Estimator` about where the model will be saved.
      session_config: a ConfigProto used to set session parameters, or None.
        Note - using this argument, it is easy to provide settings which break
        otherwise perfectly good models. Use with care.
    """
        super(RunConfig, self).__init__(master=master,
                                        evaluation_master=evaluation_master)

        gpu_options = config_pb2.GPUOptions(
            per_process_gpu_memory_fraction=gpu_memory_fraction)
        self._tf_config = config_pb2.ConfigProto(
            log_device_placement=log_device_placement,
            inter_op_parallelism_threads=num_cores,
            intra_op_parallelism_threads=num_cores,
            gpu_options=gpu_options)

        self._tf_random_seed = tf_random_seed
        self._save_summary_steps = save_summary_steps
        self._save_checkpoints_secs = save_checkpoints_secs
        self._log_step_count_steps = log_step_count_steps
        self._session_config = session_config
        if save_checkpoints_secs == RunConfig._USE_DEFAULT:
            if save_checkpoints_steps is None:
                self._save_checkpoints_secs = 600
            else:
                self._save_checkpoints_secs = None
        self._save_checkpoints_steps = save_checkpoints_steps

        # TODO (weiho): Remove these after ModelFn refactoring, when users can id:746 gh:747
        # create Scaffold and Saver in their model_fn to set these.
        self._keep_checkpoint_max = keep_checkpoint_max
        self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
        self._model_dir = _get_model_dir(model_dir)
示例#19
0
 def testContextConfig(self):
     ctx = context.Context(config=config_pb2.ConfigProto(
         device_count={'GPU': 0}))
     self.assertEquals(0, ctx.num_gpus())
  def _Run(self, is_training, use_trt, batch_size, num_epochs, model_dir):
    """Train or evaluate the model.

    Args:
      is_training: whether to train or evaluate the model. In training mode,
        quantization will be simulated where the quantize_and_dequantize_v2 are
        placed.
      use_trt: if true, use TRT INT8 mode for evaluation, which will perform
        real quantization. Otherwise use native TensorFlow which will perform
        simulated quantization. Ignored if is_training is True.
      batch_size: batch size.
      num_epochs: how many epochs to train. Ignored if is_training is False.
      model_dir: where to save or load checkpoint.

    Returns:
      The Estimator evaluation result.
    """
    # Get dataset
    train_data, test_data = mnist.load_data()

    def _PreprocessFn(x, y):
      x = math_ops.cast(x, dtypes.float32)
      x = array_ops.expand_dims(x, axis=2)
      x = 2.0 * (x / 255.0) - 1.0
      y = math_ops.cast(y, dtypes.int32)
      return x, y

    def _EvalInputFn():
      mnist_x, mnist_y = test_data
      dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y))
      dataset = dataset.apply(
          data.experimental.map_and_batch(
              map_func=_PreprocessFn,
              batch_size=batch_size,
              num_parallel_calls=8))
      dataset = dataset.repeat(count=1)
      iterator = dataset.make_one_shot_iterator()
      features, labels = iterator.get_next()
      return features, labels

    def _TrainInputFn():
      mnist_x, mnist_y = train_data
      dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y))
      dataset = dataset.shuffle(2 * len(mnist_x))
      dataset = dataset.apply(
          data.experimental.map_and_batch(
              map_func=_PreprocessFn,
              batch_size=batch_size,
              num_parallel_calls=8))
      dataset = dataset.repeat(count=num_epochs)
      iterator = dataset.make_one_shot_iterator()
      features, labels = iterator.get_next()
      return features, labels

    def _ModelFn(features, labels, mode):
      if is_training:
        logits_out = self._BuildGraph(features)
      else:
        graph_def = self._GetGraphDef(use_trt, batch_size, model_dir)
        logits_out = importer.import_graph_def(
            graph_def,
            input_map={INPUT_NODE_NAME: features},
            return_elements=[OUTPUT_NODE_NAME + ':0'],
            name='')[0]

      loss = losses.sparse_softmax_cross_entropy(
          labels=labels, logits=logits_out)
      summary.scalar('loss', loss)

      classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out')
      accuracy = metrics.accuracy(
          labels=labels, predictions=classes_out, name='acc_op')
      summary.scalar('accuracy', accuracy[1])

      if mode == ModeKeys.EVAL:
        return EstimatorSpec(
            mode, loss=loss, eval_metric_ops={'accuracy': accuracy})
      elif mode == ModeKeys.TRAIN:
        optimizer = AdamOptimizer(learning_rate=1e-2)
        train_op = optimizer.minimize(loss, global_step=get_global_step())
        return EstimatorSpec(mode, loss=loss, train_op=train_op)

    config_proto = config_pb2.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    estimator = Estimator(
        model_fn=_ModelFn,
        model_dir=model_dir if is_training else None,
        config=RunConfig(session_config=config_proto))

    if is_training:
      estimator.train(_TrainInputFn)
    results = estimator.evaluate(_EvalInputFn)
    logging.info('accuracy: %s', str(results['accuracy']))
    return results
  def __init__(self,
               master=None,
               num_cores=0,
               log_device_placement=False,
               gpu_memory_fraction=1,
               tf_random_seed=None,
               save_summary_steps=100,
               save_checkpoints_secs=_USE_DEFAULT,
               save_checkpoints_steps=None,
               keep_checkpoint_max=5,
               keep_checkpoint_every_n_hours=10000,
               evaluation_master='',
               model_dir=None):
    """Constructor.

    Note that the superclass `ClusterConfig` may set properties like
    `cluster_spec`, `is_chief`, `master` (if `None` in the args),
    `num_ps_replicas`, `task_id`, and `task_type` based on the `TF_CONFIG`
    environment variable. See `ClusterConfig` for more details.

    Args:
      master: TensorFlow master. Defaults to empty string for local.
      num_cores: Number of cores to be used. If 0, the system picks an
        appropriate number (default: 0).
      log_device_placement: Log the op placement to devices (default: False).
      gpu_memory_fraction: Fraction of GPU memory used by the process on
        each GPU uniformly on the same machine.
      tf_random_seed: Random seed for TensorFlow initializers.
        Setting this value allows consistency between reruns.
      save_summary_steps: Save summaries every this many steps.
      save_checkpoints_secs: Save checkpoints every this many seconds. Can not
          be specified with `save_checkpoints_steps`.
      save_checkpoints_steps: Save checkpoints every this many steps. Can not be
          specified with `save_checkpoints_secs`.
      keep_checkpoint_max: The maximum number of recent checkpoint files to
        keep. As new files are created, older files are deleted. If None or 0,
        all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent
        checkpoint files are kept.)
      keep_checkpoint_every_n_hours: Number of hours between each checkpoint
        to be saved. The default value of 10,000 hours effectively disables
        the feature.
      evaluation_master: the master on which to perform evaluation.
      model_dir: directory where model parameters, graph etc are saved. If
        `None`, see `Estimator` about where the model will be saved.
    """
    super(RunConfig, self).__init__(
        master=master, evaluation_master=evaluation_master)

    gpu_options = config_pb2.GPUOptions(
        per_process_gpu_memory_fraction=gpu_memory_fraction)
    self._tf_config = config_pb2.ConfigProto(
        log_device_placement=log_device_placement,
        inter_op_parallelism_threads=num_cores,
        intra_op_parallelism_threads=num_cores,
        gpu_options=gpu_options)

    self._tf_random_seed = tf_random_seed
    self._save_summary_steps = save_summary_steps
    self._save_checkpoints_secs = save_checkpoints_secs
    if save_checkpoints_secs == RunConfig._USE_DEFAULT:
      if save_checkpoints_steps is None:
        self._save_checkpoints_secs = 600
      else:
        self._save_checkpoints_secs = None
    self._save_checkpoints_steps = save_checkpoints_steps

    # TODO(weiho): Remove these after ModelFn refactoring, when users can
    # create Scaffold and Saver in their model_fn to set these.
    self._keep_checkpoint_max = keep_checkpoint_max
    self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
    self._model_dir = model_dir
    def testClusterSpecPropagationThreeServers2Graphs(self):
        """Boots 3 servers, creates 2 sessions, ensures appropriate operations.

    We create 2 clusterspecs:
     1. server2 as the master, server1 as a worker
     2. server2 as the master, server3 as a worker

    We ensure that variables on the workers are independent.
    """
        server1 = server_lib.Server.create_local_server()
        server2 = server_lib.Server.create_local_server()
        server3 = server_lib.Server.create_local_server()
        cluster_def1 = cluster_pb2.ClusterDef()
        job1 = cluster_def1.job.add()
        job1.name = 'worker1'
        job1.tasks[0] = server2.target[len('grpc://'):]
        job1.tasks[1] = server1.target[len('grpc://'):]

        cluster_def2 = cluster_pb2.ClusterDef()
        job2 = cluster_def2.job.add()
        job2.name = 'worker2'
        job2.tasks[0] = server2.target[len('grpc://'):]
        job2.tasks[1] = server3.target[len('grpc://'):]

        config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
        config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)

        with ops.Graph().as_default() as g1:
            with ops.device('/job:worker1/task:1'):
                var1 = variables.Variable(array_ops.zeros([2]), name='var1')
                update_op1 = state_ops.assign_add(var1,
                                                  array_ops.ones([2]),
                                                  name='var1_assign_add')
                init1 = variables.global_variables_initializer()

        with ops.Graph().as_default() as g2:
            with ops.device('/job:worker2/task:1'):
                var2 = variables.Variable(array_ops.zeros([2]), name='var2')
                update_op2 = state_ops.assign_add(var2,
                                                  array_ops.ones([2]),
                                                  name='var2_assign_add')
                init2 = variables.global_variables_initializer()

        sess1 = session.Session(server2.target, graph=g1, config=config1)
        sess2 = session.Session(server2.target, graph=g2, config=config2)

        init1.run(session=sess1)
        init2.run(session=sess2)

        expected_zeros = np.zeros([2])
        expected_ones = np.ones([2])

        self.assertAllEqual(expected_zeros, sess1.run(var1))
        self.assertAllEqual(expected_zeros, sess2.run(var2))

        self.assertAllEqual(expected_ones, sess1.run(update_op1))
        self.assertAllEqual(expected_ones, sess1.run(var1))
        self.assertAllEqual(expected_zeros, sess2.run(var2))
        self.assertAllEqual(expected_ones, sess2.run(update_op2))
        self.assertAllEqual(expected_ones + expected_ones,
                            sess1.run(update_op1))
        self.assertAllEqual(expected_ones, sess2.run(var2))
        self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1))
class MirroredStrategyVariableCreationTest(test.TestCase):

    config = config_pb2.ConfigProto()
    config.allow_soft_placement = True

    def _skip_eager_if_gpus_less_than(self, num_gpus):
        if context.num_gpus() < num_gpus and context.executing_eagerly():
            self.skipTest(
                "Enough GPUs not available for this test in eager mode.")

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testSingleVariable(self):
        self._skip_eager_if_gpus_less_than(1)

        def model_fn():
            # This variable should be created only once across the threads because of
            # special variable_creator functions used by `dist.call_for_each_tower`.
            v = variable_scope.variable(1.0, name="foo")
            distribute_lib.get_tower_context().merge_call(lambda _: _)
            return v

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            result = dist.call_for_each_tower(model_fn, run_concurrently=False)
            self.assertIsInstance(result, values.MirroredVariable)
            self.assertEquals("foo:0", result.name)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testUnnamedVariable(self):
        self._skip_eager_if_gpus_less_than(1)

        def model_fn():
            v = variable_scope.variable(1.0)
            distribute_lib.get_tower_context().merge_call(lambda _: _)
            return v

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            result = dist.call_for_each_tower(model_fn, run_concurrently=False)
            self.assertIsInstance(result, values.MirroredVariable)
            # Default name of "Variable" will be used.
            self.assertEquals("Variable:0", result.name)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testMultipleVariables(self):
        self._skip_eager_if_gpus_less_than(1)

        def model_fn():
            vs = []
            for i in range(5):
                vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
            distribute_lib.get_tower_context().merge_call(lambda _: _)
            return vs

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            result = dist.call_for_each_tower(model_fn, run_concurrently=False)
            for i, v in enumerate(result):
                self.assertIsInstance(v, values.MirroredVariable)
                self.assertEquals("foo" + str(i) + ":0", v.name)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testMultipleVariablesWithSameCanonicalName(self):
        self._skip_eager_if_gpus_less_than(1)

        def model_fn():
            vs = []
            vs.append(variable_scope.variable(1.0, name="foo/bar"))
            vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
            vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
            vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
            distribute_lib.get_tower_context().merge_call(lambda _: _)
            return vs

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            result = dist.call_for_each_tower(model_fn, run_concurrently=False)
            for v in result:
                self.assertIsInstance(v, values.MirroredVariable)
            self.assertEquals(4, len(result))
            self.assertEquals("foo/bar:0", result[0].name)
            self.assertEquals("foo_1/bar:0", result[1].name)
            self.assertEquals("foo_1/bar_1:0", result[2].name)
            self.assertEquals("foo/bar_1:0", result[3].name)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testVariableWithSameCanonicalNameAcrossThreads(self):
        self._skip_eager_if_gpus_less_than(1)

        def model_fn(device_id):
            v = variable_scope.variable(1.0, name="foo_" + str(device_id))
            distribute_lib.get_tower_context().merge_call(lambda _: _)
            return v

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            result = dist.call_for_each_tower(model_fn,
                                              dist.worker_device_index,
                                              run_concurrently=False)
            self.assertIsInstance(result, values.MirroredVariable)
            # The resulting mirrored variable will use the name from the first device.
            self.assertEquals("foo_0:0", result.name)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testWithLayers(self):
        self._skip_eager_if_gpus_less_than(1)

        def model_fn(features):
            with variable_scope.variable_scope("common"):
                layer1 = core.Dense(1)
                layer1(features)
                layer2 = core.Dense(1)
                layer2(features)
                # This will pause the current thread, and execute the other thread.
                distribute_lib.get_tower_context().merge_call(lambda _: _)
                layer3 = core.Dense(1)
                layer3(features)
                return [(layer1.kernel, layer1.bias),
                        (layer2.kernel, layer2.bias),
                        (layer3.kernel, layer3.bias)]

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])
        features = dist.distribute_dataset(
            lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(
                10)).make_one_shot_iterator().get_next()

        with dist.scope():
            result = dist.call_for_each_tower(model_fn,
                                              features,
                                              run_concurrently=False)
            suffixes = ["", "_1", "_2"]
            for (kernel, bias), suffix in zip(result, suffixes):
                self.assertIsInstance(kernel, values.MirroredVariable)
                self.assertEquals("common/dense" + suffix + "/kernel:0",
                                  kernel.name)
                self.assertIsInstance(bias, values.MirroredVariable)
                self.assertEquals("common/dense" + suffix + "/bias:0",
                                  bias.name)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testWithVariableAndVariableScope(self):
        self._skip_eager_if_gpus_less_than(1)

        def model_fn():
            v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
            with variable_scope.variable_scope("common"):
                v1 = variable_scope.variable(1.0, name="var1")
                # This will pause the current thread, and execute the other thread.
                distribute_lib.get_tower_context().merge_call(lambda _: _)
                v2 = variable_scope.variable(
                    1.0,
                    name="var2",
                    synchronization=variable_scope.VariableSynchronization.
                    ON_READ,
                    aggregation=variable_scope.VariableAggregation.SUM)
                v3 = variable_scope.variable(
                    1.0,
                    name="var3",
                    synchronization=variable_scope.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variable_scope.VariableAggregation.MEAN)

            return v0, v1, v2, v3

        devices = ["/device:CPU:0", "/device:GPU:0"]
        dist = mirrored_strategy.MirroredStrategy(devices)
        with dist.scope():
            v = variable_scope.variable(1.0, name="var-main0")
            self.assertEquals("var-main0:0", v.name)

            result = dist.call_for_each_tower(model_fn, run_concurrently=False)
            self.assertEquals(4, len(result))
            v0, v1, v2, v3 = result
            self.assertIsInstance(v0, values.MirroredVariable)
            self.assertEquals("var0:0", v0.name)
            self.assertIsInstance(v1, values.MirroredVariable)
            self.assertEquals("common/var1:0", v1.name)
            self.assertIsInstance(v2, values.TowerLocalVariable)
            self.assertEquals("common/var2:0", v2.name)
            self.assertEquals(variable_scope.VariableAggregation.SUM,
                              v2.aggregation)
            self.assertIsInstance(v3, values.MirroredVariable)
            self.assertEquals("common/var3:0", v3.name)
            self.assertEquals(variable_scope.VariableAggregation.MEAN,
                              v3.aggregation)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testWithGetVariableAndVariableScope(self):
        self._skip_eager_if_gpus_less_than(1)

        def model_fn():
            v0 = variable_scope.get_variable("var0", [1])
            with variable_scope.variable_scope("common"):
                v1 = variable_scope.get_variable("var1", [1])
                # This will pause the current thread, and execute the other thread.
                distribute_lib.get_tower_context().merge_call(lambda _: _)
                v2 = variable_scope.get_variable(
                    "var2", [1],
                    synchronization=variable_scope.VariableSynchronization.
                    ON_READ,
                    aggregation=variable_scope.VariableAggregation.SUM)
                v3 = variable_scope.get_variable(
                    "var3", [1],
                    synchronization=variable_scope.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variable_scope.VariableAggregation.MEAN)

            return v0, v1, v2, v3

        devices = ["/device:CPU:0", "/device:GPU:0"]
        dist = mirrored_strategy.MirroredStrategy(devices)
        with dist.scope():
            with variable_scope.variable_scope("main"):
                v = variable_scope.get_variable("var-main0", [1])
                self.assertEquals("main/var-main0:0", v.name)

                result = dist.call_for_each_tower(model_fn,
                                                  run_concurrently=False)
                self.assertEquals(4, len(result))
                v0, v1, v2, v3 = result
                self.assertIsInstance(v0, values.MirroredVariable)
                self.assertEquals("main/var0:0", v0.name)
                self.assertIsInstance(v1, values.MirroredVariable)
                self.assertEquals("main/common/var1:0", v1.name)
                self.assertIsInstance(v2, values.TowerLocalVariable)
                self.assertEquals("main/common/var2:0", v2.name)
                self.assertEquals(variable_scope.VariableAggregation.SUM,
                                  v2.aggregation)
                self.assertIsInstance(v3, values.MirroredVariable)
                self.assertEquals("main/common/var3:0", v3.name)
                self.assertEquals(variable_scope.VariableAggregation.MEAN,
                                  v3.aggregation)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testNoneSynchronizationWithGetVariable(self):
        self._skip_eager_if_gpus_less_than(1)
        devices = ["/device:CPU:0", "/device:GPU:0"]
        dist = mirrored_strategy.MirroredStrategy(devices)
        with dist.scope():
            with self.assertRaisesRegexp(
                    ValueError, "`NONE` variable synchronization mode is not "
                    "supported with `Mirrored` distribution strategy. Please change "
                    "the `synchronization` for variable: v"):
                variable_scope.get_variable("v", [1],
                                            synchronization=variable_scope.
                                            VariableSynchronization.NONE)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testNoneSynchronizationWithVariable(self):
        self._skip_eager_if_gpus_less_than(1)
        devices = ["/device:CPU:0", "/device:GPU:0"]
        dist = mirrored_strategy.MirroredStrategy(devices)
        with dist.scope():
            with self.assertRaisesRegexp(
                    ValueError, "`NONE` variable synchronization mode is not "
                    "supported with `Mirrored` distribution strategy. Please change "
                    "the `synchronization` for variable: v"):
                variable_scope.variable(1.0,
                                        name="v",
                                        synchronization=variable_scope.
                                        VariableSynchronization.NONE)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testInvalidSynchronizationWithVariable(self):
        self._skip_eager_if_gpus_less_than(1)
        devices = ["/device:CPU:0", "/device:GPU:0"]
        dist = mirrored_strategy.MirroredStrategy(devices)
        with dist.scope():
            with self.assertRaisesRegexp(
                    ValueError,
                    "Invalid variable synchronization mode: Invalid for "
                    "variable: v"):
                variable_scope.variable(1.0,
                                        name="v",
                                        synchronization="Invalid")

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testInvalidAggregationWithGetVariable(self):
        self._skip_eager_if_gpus_less_than(1)
        devices = ["/device:CPU:0", "/device:GPU:0"]
        dist = mirrored_strategy.MirroredStrategy(devices)
        with dist.scope():
            with self.assertRaisesRegexp(
                    ValueError,
                    "Invalid variable aggregation mode: invalid for "
                    "variable: v"):
                variable_scope.get_variable("v", [1],
                                            synchronization=variable_scope.
                                            VariableSynchronization.ON_WRITE,
                                            aggregation="invalid")

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testInvalidAggregationWithVariable(self):
        self._skip_eager_if_gpus_less_than(1)
        devices = ["/device:CPU:0", "/device:GPU:0"]
        dist = mirrored_strategy.MirroredStrategy(devices)
        with dist.scope():
            with self.assertRaisesRegexp(
                    ValueError,
                    "Invalid variable aggregation mode: invalid for "
                    "variable: v"):
                variable_scope.variable(1.0,
                                        name="v",
                                        synchronization=variable_scope.
                                        VariableSynchronization.ON_WRITE,
                                        aggregation="invalid")

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testThreeDevices(self):
        self._skip_eager_if_gpus_less_than(2)

        def model_fn():
            v = variable_scope.variable(1.0, name="foo")
            distribute_lib.get_tower_context().merge_call(lambda _: _)
            return v

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"])

        with dist.scope():
            result = dist.call_for_each_tower(model_fn, run_concurrently=False)
            self.assertIsInstance(result, values.MirroredVariable)
            self.assertEquals("foo:0", result.name)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testNonMatchingVariableCreation(self):
        self._skip_eager_if_gpus_less_than(1)

        def model_fn(name):
            v = variable_scope.variable(1.0, name=name)
            distribute_lib.get_tower_context().merge_call(lambda _: _)
            return v

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            names = values.DistributedValues({
                "/device:CPU:0": "foo",
                "/device:GPU:0": "bar"
            })
            with self.assertRaises(RuntimeError):
                _ = dist.call_for_each_tower(model_fn,
                                             names,
                                             run_concurrently=False)

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testTowerLocalVariable(self):
        self._skip_eager_if_gpus_less_than(1)

        all_v_sum = {}
        all_v_mean = {}
        components_sum = {}
        components_mean = {}

        def model_fn(device_id):
            v_sum = variable_scope.variable(
                1.0,
                synchronization=variable_scope.VariableSynchronization.ON_READ,
                aggregation=variable_scope.VariableAggregation.SUM)
            v_mean = variable_scope.variable(
                4.0,
                synchronization=variable_scope.VariableSynchronization.ON_READ,
                aggregation=variable_scope.VariableAggregation.MEAN)
            self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
            self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
            updates = [
                v_sum.assign_add(2.0 + device_id),
                v_mean.assign(6.0 * device_id)
            ]
            all_v_sum[device_id] = v_sum
            all_v_mean[device_id] = v_mean
            c_sum = v_sum.get()
            c_mean = v_mean.get()
            components_sum[device_id] = c_sum
            components_mean[device_id] = c_mean
            self.assertIsNot(v_sum, c_sum)
            self.assertIsNot(v_mean, c_mean)
            return updates, v_sum, v_mean, c_sum, c_mean

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with dist.scope():
            # Create "sum" and "mean" versions of TowerLocalVariables.
            ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
                dist.call_for_each_tower(model_fn,
                                         dist.worker_device_index,
                                         run_concurrently=False))
            # Should see the same wrapping instance in all towers.
            self.assertIs(all_v_sum[0], ret_v_sum)
            self.assertIs(all_v_mean[0], ret_v_mean)
            self.assertIs(all_v_sum[0], all_v_sum[1])
            self.assertIs(all_v_mean[0], all_v_mean[1])

            # Regroup should recover the same wrapper.
            self.assertIs(ret_v_sum, regrouped_sum)
            self.assertIs(ret_v_mean, regrouped_mean)
            self.assertIsNot(components_sum[0], components_sum[1])
            self.assertIsNot(components_mean[0], components_mean[1])

            # Apply updates
            self.evaluate(variables.global_variables_initializer())
            self.evaluate([y for x in ret_ops for y in dist.unwrap(x)])
            expected_sum = 0.0
            expected_mean = 0.0
            for i, d in enumerate(dist.worker_devices):
                # Should see different values on different devices.
                v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
                v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
                expected = i + 3.0
                self.assertEqual(expected, v_sum_value)
                expected_sum += expected
                expected = i * 6.0
                self.assertEqual(expected, v_mean_value)
                expected_mean += expected
            expected_mean /= len(dist.worker_devices)

            # Without get(device), should return the value you get by
            # applying the reduction across all towers (whether you use
            # read_var(), get(), or nothing).
            self.assertEqual(expected_sum,
                             self.evaluate(dist.read_var(ret_v_sum)))
            self.assertEqual(expected_mean,
                             self.evaluate(dist.read_var(ret_v_mean)))
            self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
            self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
            self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
            self.assertEqual(expected_mean, self.evaluate(ret_v_mean))

    # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
    # testing this in eager mode.

    def testNameScope(self):
        def model_fn():
            with ops.name_scope("foo"):
                a = constant_op.constant(1.0, name="a")
                distribute_lib.get_tower_context().merge_call(lambda _: _)
                b = constant_op.constant(1.0, name="b")
            return a, b

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with context.graph_mode(), dist.scope():
            with ops.name_scope("main"):
                result = dist.call_for_each_tower(model_fn,
                                                  run_concurrently=False)
                self.assertEquals(2, len(result))
                for v, name in zip(result, ["a", "b"]):
                    self.assertIsInstance(v, values.DistributedValues)
                    v0, v1 = dist.unwrap(v)
                    self.assertEquals("main/foo/" + name + ":0", v0.name)
                    self.assertEquals("main/tower_1/foo/" + name + ":0",
                                      v1.name)

    def testWithDefaultName(self):
        def model_fn():
            with ops.name_scope(None, "foo"):
                a = constant_op.constant(1.0, name="a")
                distribute_lib.get_tower_context().merge_call(lambda _: _)
                b = constant_op.constant(2.0, name="b")
            return a, b

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with context.graph_mode(), dist.scope():
            result = dist.call_for_each_tower(model_fn, run_concurrently=False)
            self.assertEquals(2, len(result))
            for v, name in zip(result, ["a", "b"]):
                self.assertIsInstance(v, values.DistributedValues)
                v0, v1 = dist.unwrap(v)
                self.assertEquals("foo/" + name + ":0", v0.name)
                self.assertEquals("tower_1/foo/" + name + ":0", v1.name)

    # variable_scope.variable() respects name scopes when creating
    # variables. On the other hand variable_scope.get_variable() ignores name
    # scopes when creating variables. We test both methods of creating variables
    # to make sure that we have the same variable names in both cases.
    def testNameScopeWithVariable(self):
        def in_cross_tower(_):
            c = variable_scope.variable(1.0, name="c")
            return c

        def model_fn():
            b = variable_scope.variable(1.0, name="b")
            with ops.name_scope("foo"):
                c = distribute_lib.get_tower_context().merge_call(
                    in_cross_tower)
            return b, c

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with context.graph_mode(), dist.scope():
            with ops.name_scope("main"):
                a = variable_scope.variable(1.0, name="a")
                result = dist.call_for_each_tower(model_fn,
                                                  run_concurrently=False)
            result_b = result[0]
            result_c = result[1]
            self.assertIsInstance(result_b, values.DistributedValues)
            self.assertIsInstance(result_c, values.DistributedValues)
            a0, a1 = dist.unwrap(a)
            b0, b1 = dist.unwrap(result_b)
            c0, c1 = dist.unwrap(result_c)
            self.assertEquals("main/a:0", a0.name)
            self.assertEquals("main/a/replica_1:0", a1.name)
            self.assertEquals("main/b:0", b0.name)
            self.assertEquals("main/b/replica_1:0", b1.name)
            self.assertEquals("main/foo/c:0", c0.name)
            self.assertEquals("main/foo/c/replica_1:0", c1.name)

    def testNameScopeWithGetVariable(self):
        def in_cross_tower(_):
            c = variable_scope.get_variable("c", [1])
            return c

        def model_fn():
            b = variable_scope.get_variable("b", [1])
            with ops.name_scope("foo"):
                c = distribute_lib.get_tower_context().merge_call(
                    in_cross_tower)
            return b, c

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with context.graph_mode(), dist.scope():
            with ops.name_scope("main"):
                a = variable_scope.get_variable("a", [1])
                result = dist.call_for_each_tower(model_fn,
                                                  run_concurrently=False)
            result_b = result[0]
            result_c = result[1]
            self.assertIsInstance(result_b, values.DistributedValues)
            self.assertIsInstance(result_c, values.DistributedValues)
            a0, a1 = dist.unwrap(a)
            b0, b1 = dist.unwrap(result_b)
            c0, c1 = dist.unwrap(result_c)
            self.assertEquals("a:0", a0.name)
            self.assertEquals("a/replica_1:0", a1.name)
            self.assertEquals("b:0", b0.name)
            self.assertEquals("b/replica_1:0", b1.name)
            self.assertEquals("c:0", c0.name)
            self.assertEquals("c/replica_1:0", c1.name)

    def testDynamicRnnVariables(self):
        def model_fn():
            inputs = constant_op.constant(2 *
                                          [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
            cell_fw = rnn_cell_impl.LSTMCell(300)
            cell_bw = rnn_cell_impl.LSTMCell(300)
            (outputs, _) = rnn.bidirectional_dynamic_rnn(cell_fw,
                                                         cell_bw,
                                                         inputs,
                                                         dtype=dtypes.float32)
            return outputs

        dist = mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0", "/device:CPU:0"])

        with context.graph_mode(), dist.scope():
            result = dist.call_for_each_tower(model_fn, run_concurrently=False)
            # Two variables are created by the RNN layer.
            self.assertEquals(2, len(result))
            for v in result:
                self.assertIsInstance(v, values.DistributedValues)
                _, v1 = dist.unwrap(v)
                self.assertStartsWith(v1.name, "tower_1/")

    @test_util.run_in_graph_and_eager_modes(config=config)
    def testTowerLocalVariableUpdate(self):
        with context.graph_mode():

            def model_fn():
                v_sum = variable_scope.variable(
                    1.0,
                    synchronization=variable_scope.VariableSynchronization.
                    ON_READ,
                    aggregation=variable_scope.VariableAggregation.SUM)
                self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
                return v_sum

            dist = mirrored_strategy.MirroredStrategy(
                ["/device:GPU:0", "/device:GPU:1"])

            def update(var, value):
                return var.assign(value)

            with dist.scope():
                ret_v_sum = dist.call_for_each_tower(model_fn,
                                                     run_concurrently=False)
                update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0))

                # Initialize variables.
                self.evaluate(variables.global_variables_initializer())
                # Assert that the aggregated value of the tower local vars is the sum of
                # the individual values before running the update ops.
                self.assertEquals(
                    1.0,
                    self.evaluate(
                        ret_v_sum.get(dist._devices[0]).read_value()))
                self.assertEquals(2.0, self.evaluate(ret_v_sum))

                # Apply updates.
                self.evaluate(update_ops)
                # Assert that the aggregated value of the tower local vars is the sum of
                # the individual values after running the update ops.
                self.assertEquals(
                    5.0,
                    self.evaluate(
                        ret_v_sum.get(dist._devices[0]).read_value()))
                self.assertEquals(10.0, self.evaluate(ret_v_sum))
    def testClusterSpecPropagationThreeServersOneCluster(self):
        """Boots 3 servers, ensures appropriate communication across workers.

    Additionally, in this cluster, we ensure the master is not the 0-th worker.

    Note: this test only uses one session.
    """
        server1 = server_lib.Server.create_local_server()
        server2 = server_lib.Server.create_local_server()
        server3 = server_lib.Server.create_local_server()
        cluster_def = cluster_pb2.ClusterDef()
        job = cluster_def.job.add()
        job.name = 'worker'
        job.tasks[0] = server3.target[len('grpc://'):]
        job.tasks[1] = server2.target[len('grpc://'):]
        job.tasks[2] = server1.target[len('grpc://'):]
        config = config_pb2.ConfigProto(cluster_def=cluster_def)

        # Add ops to the devices in non-linear order.

        with ops.device('/job:worker/task:1'):
            feed1 = array_ops.placeholder(dtypes.float32, shape=(2))
            const1 = constant_op.constant(2.0)
            mul1 = const1 * feed1

        with ops.device('/job:worker/task:2'):
            feed2 = array_ops.placeholder(dtypes.float32, shape=(2))
            const2 = constant_op.constant(2.0)
            mul2 = const2 * feed2

        with ops.device('/job:worker/task:0'):
            feed0 = array_ops.placeholder(dtypes.float32, shape=(2))
            const0 = constant_op.constant(2.0)
            mul0 = const0 * feed0

        sum_op = mul0 + mul1 + mul2

        ones = np.ones([2])
        run_options = config_pb2.RunOptions(
            trace_level=config_pb2.RunOptions.FULL_TRACE)
        run_metadata = config_pb2.RunMetadata()

        # Run!
        with session.Session(server1.target, config=config) as sess:
            output = sess.run(sum_op,
                              options=run_options,
                              run_metadata=run_metadata,
                              feed_dict={
                                  feed1: ones,
                                  feed2: ones,
                                  feed0: ones
                              })
            self.assertAllEqual(6 * ones, output)

            self.assertEqual(
                3,
                len([
                    dev_stats.device
                    for dev_stats in run_metadata.step_stats.dev_stats
                    for node_stats in dev_stats.node_stats
                    if '/job:worker/replica:0/task:' in dev_stats.device
                    and node_stats.node_name.startswith('Const')
                ]), run_metadata)
示例#25
0
      reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3])
      self.assertAllEqual(100 * [12.0], reduced)

  def testAsFunctionInput(self):
    with self.test_scope():

      @function.defun
      def f(x):
        return math_ops.reduce_sum(x, axis=2)

      tensor = constant_op.constant(100 * [[[10.0, 2.0]]])
      reduced = f(tensor)
      self.assertAllEqual(100 * [[12.0]], reduced)

  def testAsFunctionOutput(self):
    with self.test_scope():

      @function.defun
      def f(x):
        return x * constant_op.constant(100 * [[[10.0, 2.0]]])

      y = f(3)
      reduced = math_ops.reduce_sum(y, axis=2)
      self.assertAllEqual(100 * [[36.0]], reduced)


if __name__ == '__main__':
  ops.enable_eager_execution(
      config=config_pb2.ConfigProto(log_device_placement=True))
  googletest.main()
示例#26
0
    def _test_minimize_loss_graph(self,
                                  d,
                                  soft_placement=False,
                                  learning_rate=0.2):
        config = config_pb2.ConfigProto()
        config.allow_soft_placement = soft_placement
        config.gpu_options.per_process_gpu_memory_fraction = 0.3
        with context.graph_mode(), \
             ops.Graph().as_default(), \
             self.cached_session(config=config) as sess, \
             d.scope():
            kernel = create_variable_like_keras_layer(name="kernel",
                                                      shape=(1, 1),
                                                      dtype=dtypes.float32)

            def loss(x):
                y = array_ops.reshape(gen_math_ops.mat_mul(x, kernel),
                                      []) - array_ops.identity(1.)
                return y * y

            grad_fn = backprop.implicit_grad(loss)

            def update(v, g):
                return v.assign_sub(learning_rate * g)

            one = array_ops.identity([[1.]])

            def step():
                """Perform one optimization step."""
                # Run forward & backward to get gradients, variables list.
                g_v = d.extended.call_for_each_replica(grad_fn, args=(one, ))

                # Update the variables using the gradients and the update() function.
                before_list = []
                after_list = []
                for g, v in g_v:
                    fetched = d.extended.read_var(v)
                    before_list.append(fetched)
                    with ops.control_dependencies([fetched]):
                        g = d.extended.reduce_to(reduce_util.ReduceOp.SUM,
                                                 g,
                                                 destinations=v)
                        with ops.control_dependencies(
                                d.extended.update(v,
                                                  update,
                                                  args=(g, ),
                                                  group=False)):
                            after_list.append(d.extended.read_var(v))
                return before_list, after_list

            before_out, after_out = step()
            variables.global_variables_initializer().run()
            for i in range(10):
                b, a = sess.run((before_out, after_out))
                if i == 0:
                    before, = b
                after, = a

            error_before = abs(before - 1)
            error_after = abs(after - 1)
            # Error should go down
            self.assertLess(error_after, error_before)
示例#27
0
    def __init__(self, model_fn, model_dir=None, config=None, params=None):
        """Constructs an `Estimator` instance.

    Args:
      model_fn: Model function. Follows the signature:

        * Args:

          * `features`: This is the first item returned from the `input_fn`
                 passed to `train`, 'evaluate`, and `predict`. This should be a
                 single `Tensor` or `dict` of same.
          * `labels`: This is the second item returned from the `input_fn`
                 passed to `train`, 'evaluate`, and `predict`. This should be a
                 single `Tensor` or `dict` of same (for multi-head models). If
                 mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If
                 the `model_fn`'s signature does not accept `mode`, the
                 `model_fn` must still be able to handle `labels=None`.
          * `mode`: Optional. Specifies if this training, evaluation or
                 prediction. See `ModeKeys`.
          * `params`: Optional `dict` of hyperparameters.  Will receive what
                 is passed to Estimator in `params` parameter. This allows
                 to configure Estimators from hyper parameter tuning.
          * `config`: Optional configuration object. Will receive what is passed
                 to Estimator in `config` parameter, or the default `config`.
                 Allows updating things in your model_fn based on configuration
                 such as `num_ps_replicas`, or `model_dir`.

        * Returns:
          `EstimatorSpec`

      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model. If `None`, the model_dir in
        `config` will be used if set. If both are set, they must be same. If
        both are `None`, a temporary directory will be used.
      config: Configuration object.
      params: `dict` of hyper parameters that will be passed into `model_fn`.
              Keys are names of parameters, values are basic python types.

    Raises:
      ValueError: parameters of `model_fn` don't match `params`.
      ValueError: if this is called via a subclass and if that class overrides
        a member of `Estimator`.
    """
        Estimator._assert_members_are_not_overridden(self)

        if config is None:
            self._config = run_config.RunConfig()
            logging.info('Using default config.')
        else:
            if not isinstance(config, run_config.RunConfig):
                raise ValueError(
                    'config must be an instance of RunConfig, but provided %s.'
                    % config)
            self._config = config

        # Model directory.
        if (model_dir is not None) and (self._config.model_dir is not None):
            if model_dir != self._config.model_dir:
                # pylint: disable=g-doc-exception
                raise ValueError(
                    "model_dir are set both in constructor and RunConfig, but with "
                    "different values. In constructor: '{}', in RunConfig: "
                    "'{}' ".format(model_dir, self._config.model_dir))
                # pylint: enable=g-doc-exception

        self._model_dir = model_dir or self._config.model_dir
        if self._model_dir is None:
            self._model_dir = tempfile.mkdtemp()
            logging.warning('Using temporary folder as model directory: %s',
                            self._model_dir)
        if self._config.model_dir is None:
            self._config = self._config.replace(model_dir=self._model_dir)
        logging.info('Using config: %s', str(vars(self._config)))

        if self._config.session_config is None:
            self._session_config = config_pb2.ConfigProto(
                allow_soft_placement=True)
        else:
            self._session_config = self._config.session_config

        self._device_fn = _get_replica_device_setter(self._config)

        if model_fn is None:
            raise ValueError('model_fn must be provided to Estimator.')
        _verify_model_fn_args(model_fn, params)
        self._model_fn = model_fn
        self._params = params or {}
示例#28
0
    def testClearExtraneousSavers(self):
        export_dir = os.path.join(test.get_temp_dir(),
                                  "test_clear_extraneous_savers")
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Create a variable and a Saver.
        with ops.Graph().as_default() as graph:
            with session.Session(target="",
                                 config=config_pb2.ConfigProto(
                                     device_count={"CPU": 2})) as sess:
                self._init_and_validate_variable(sess, "v", 42)

                # Add two Savers, which should be removed in
                # add_meta_graph_and_variables() in favor of the locally added one.
                saver1 = tf_saver.Saver()
                graph.add_to_collection(ops.GraphKeys.SAVERS, saver1)
                saver2 = tf_saver.Saver()
                graph.add_to_collection(ops.GraphKeys.SAVERS, saver2)

                # Confirm there are two SaverDefs.
                savers = graph.get_collection(ops.GraphKeys.SAVERS)
                self.assertEqual(2, len(savers))

                # Confirm there are two Save and two Restore ops.
                save_op_names = set([
                    x.name for x in graph.get_operations()
                    if x.type == "SaveV2"
                ])
                self.assertSetEqual(set(["save/SaveV2", "save_1/SaveV2"]),
                                    save_op_names)

                restore_op_names = set([
                    x.name for x in graph.get_operations()
                    if x.type == "RestoreV2"
                ])
                self.assertSetEqual(
                    set(["save/RestoreV2", "save_1/RestoreV2"]),
                    restore_op_names)

                # The SavedModel builder adds its own Saver' for a total of three.
                builder.add_meta_graph_and_variables(sess,
                                                     [tag_constants.TRAINING],
                                                     clear_devices=True)

        # Save the SavedModel to disk.
        builder.save()

        # Restore the graph.
        with ops.Graph().as_default() as graph:
            with self.test_session(graph=graph) as sess:
                loader.load(sess, [tag_constants.TRAINING], export_dir)
                self.assertEqual(
                    42,
                    ops.get_collection(
                        ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

                # Confirm that the reloaded graph has only one SaverDef.
                savers = ops.get_collection(ops.GraphKeys.SAVERS)
                self.assertEqual(1, len(savers))

                # The reloaded graph should have exactly one Save and one Restore op.
                save_op_names = set([
                    x.name for x in graph.get_operations()
                    if x.type == "SaveV2"
                ])
                self.assertSetEqual(set(["save_2/SaveV2"]), save_op_names)
                restore_op_names = set([
                    x.name for x in graph.get_operations()
                    if x.type == "RestoreV2"
                ])
                self.assertSetEqual(set(["save_2/RestoreV2"]),
                                    restore_op_names)
示例#29
0
class FunctionTest(test.TestCase):

  def testBasic(self):
    matmul = function.defun(math_ops.matmul)
    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
    sq = matmul(t, t, transpose_a=True)
    sq2 = matmul(sq, t, transpose_a=True)
    self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
    self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])

  def testBasicGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    @function.defun
    def sq(a):
      return matmul(a, a)

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
    out = sq(t)
    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())

  def testNestedInputsGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    pair = collections.namedtuple('pair', ['a', 'b'])

    @function.defun
    def a_times_b(inputs):
      return matmul(inputs.a['a'], inputs.b['b'])

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    out = a_times_b(pair({'a': t}, {'b': t}))
    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())

  def testGraphModeWithGradients(self):
    v = resource_variable_ops.ResourceVariable(1.0, name='v')

    @function.defun
    def step():
      def inner():
        return v * v

      return backprop.implicit_grad(inner)()[0][0]

    self.assertAllEqual(step(), 2.0)

  def testGraphGradientVariable(self):
    with ops.Graph().as_default(), self.test_session():
      v = resource_variable_ops.ResourceVariable(1.0)

      @function.defun
      def f():
        return 2.0 * v

      node = f()
      grads, = gradients_impl.gradients(node, v)
      v.initializer.run()
      self.assertAllEqual(grads.eval(), 2.0)
      self.assertEqual(grads.shape, v.shape)

  def testGraphEagerIsolation(self):

    @function.defun
    def f():
      v = resource_variable_ops.ResourceVariable(1.0)
      return v.read_value()

    self.assertAllEqual(f(), 1.0)

    with ops.Graph().as_default():
      self.assertEqual(f().shape, ())

  def testBasicDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    def sq(a):
      return matmul(a, a)

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    sq_op = function.make_defun_op(sq, t)

    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
    out = sq_op(t)
    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())

  def testNestedInputsDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    pair = collections.namedtuple('pair', ['a', 'b'])

    def a_times_b(inputs):
      return matmul(inputs.a['a'], inputs.b['b'])

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    inputs = pair({'a': t}, {'b': t})
    sq_op = function.make_defun_op(a_times_b, inputs)

    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
    out = sq_op(inputs)
    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())

  def testNestedOutputDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    def sq(a):
      return (matmul(a, a), {'b': constant_op.constant(1.0)})

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    sq_op = function.make_defun_op(sq, t)

    self.assertEqual(sq_op.output_shapes,
                     (tensor_shape.TensorShape([2, 2]),
                      {'b': tensor_shape.TensorShape([])}))
    self.assertEqual(sq_op.output_dtypes,
                     (dtypes.float32, {'b': dtypes.float32}))
    (a, b) = sq_op(t)
    self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
    self.assertAllEqual(b['b'].numpy(), 1.0)

  def testDefunOpGraphModeWithGradients(self):
    v = resource_variable_ops.ResourceVariable(1.0, name='v')

    def step():
      def inner():
        return v * v

      return backprop.implicit_grad(inner)()[0][0]

    step_op = function.make_defun_op(step)

    self.assertEqual(step_op.output_dtypes, dtypes.float32)
    self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([]))
    self.assertAllEqual(step_op(), 2.0)

  def testDefunOpGraphModeNoneOutput(self):
    def fn(unused_a, unused_b):
      return None

    x = constant_op.constant(1)
    fn_op = function.make_defun_op(fn, x, x)

    self.assertEqual(fn_op.output_dtypes, None)
    self.assertEqual(fn_op.output_shapes, None)
    self.assertAllEqual(fn_op(x, x), None)

  def testDefunCapturedInt32(self):
    x = constant_op.constant(1, dtype=dtypes.int32)

    @function.defun
    def add_int32s():
      return x + x

    self.assertEqual(2, int(add_int32s()))

  def testDefunReadVariable(self):
    v = resource_variable_ops.ResourceVariable(1.0)

    @function.defun
    def f():
      return v.read_value()

    self.assertEqual(1.0, float(f()))

  def testDefunAssignAddVariable(self):
    v = resource_variable_ops.ResourceVariable(1.0)
    x = constant_op.constant(2.0)

    @function.defun
    def test_assign_add():
      v.assign_add(x)
      return v.read_value()

    self.assertEqual(3.0, float(test_assign_add()))

  def testDefunShapeInferenceWithCapturedResourceVariable(self):
    v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])

    def f():
      x = constant_op.constant([[1, 2], [3, 4]])
      out = math_ops.matmul(v, x)
      self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))

    compiled = function.defun(f)
    compiled()

  def testVariableInLoopInFunction(self):

    @function.defun
    def test_function():

      def loop_test(_):
        return False

      def loop_body(_):
        return variable_scope.get_variable('a', shape=())

      return control_flow_ops.while_loop(loop_test, loop_body, [0.0])

    self.assertEqual(test_function().shape, [])

  def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
    with context.graph_mode():
      v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])

      def f():
        x = constant_op.constant([[1, 2], [3, 4]])
        out = math_ops.matmul(v, x)
        self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))

      compiled = function.defun(f)
      compiled()

  def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
    with context.graph_mode():
      v = variables.Variable([[1, 2], [3, 4]])

      def f():
        x = constant_op.constant([[1, 2], [3, 4]])
        out = math_ops.matmul(v, x)
        self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))

      # Check that shape inference works while creating the defun
      compiled = function.defun(f)
      compiled()

  def testDefunDifferentiable(self):
    v = resource_variable_ops.ResourceVariable(1.0)

    @function.defun
    def f():
      return v * v

    self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)

  def testDefunCanBeDifferentiatedTwice(self):
    v = resource_variable_ops.ResourceVariable(1.0)

    @function.defun
    def f():
      return v * v

    self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
    # Ensure that v is watched again.
    self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)

  def testGraphModeCaptureVariable(self):
    with context.graph_mode(), self.test_session() as sess:

      class HasAVar(object):

        def __init__(self):
          self.v = resource_variable_ops.ResourceVariable(1.0)

        def call(self):
          return self.v * 2

      o = HasAVar()
      variables.global_variables_initializer().run()
      call = function.defun(o.call)
      op = call()
      self.assertAllEqual(sess.run(op), 2.0)

  def testGraphModeManyFunctions(self):
    with context.graph_mode(), self.test_session():

      @function.defun
      def f(x):
        return x * x

      @function.defun
      def g(x):
        return f(x) + 1

      self.assertAllEqual(g(constant_op.constant(2.0)).eval(), 5.0)

  def testDict(self):

    @function.defun
    def f(x):
      return {'name': x + 1}

    self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0)

  def testTensorConversionWithDefun(self):

    @function.defun
    def f(x):
      return math_ops.add(x, constant_op.constant(3))

    self.assertAllEqual(5, f(constant_op.constant(2)))

  def testTensorConversionCall(self):

    @function.defun
    def f(x):
      return math_ops.add(x, constant_op.constant(3))

    @function.defun
    def g(x):
      return f(f(x))

    self.assertAllEqual(8, g(constant_op.constant(2)))

  def testDefunCallBackprop(self):

    @function.defun
    def f(x):
      return math_ops.add(x, x)

    @function.defun
    def g(x):
      return backprop.gradients_function(f, [0])(x)[0]

    self.assertAllEqual(2, g(constant_op.constant(2.)))

  def testGraphModeEagerGradError(self):
    with context.graph_mode():
      def f():
        x = variable_scope.get_variable(
            'v', initializer=constant_op.constant(1.0))
        return x * constant_op.constant(2.0)

      with self.assertRaisesRegexp(ValueError,
                                   'No trainable variables were accessed'):
        backprop.implicit_val_and_grad(f)()

  def testDefunCallBackpropUsingSameObjectForMultipleArguments(self):

    @function.defun
    def g(x):
      return backprop.gradients_function(math_ops.multiply, [0, 1])(x, x)

    def np_g(x):
      return [d.numpy() for d in g(x)]

    x = constant_op.constant(1.)
    self.assertAllEqual([1., 1.], np_g(x))
    self.assertAllEqual([1., 1.], np_g(1.))

  def testCallShape(self):

    @function.defun
    def f(x):
      return x + 1

    @function.defun
    def g(x):
      x = f(x)
      self.assertEqual(x.shape.as_list(), [])
      return None

    g(constant_op.constant(1.0))

  def testNestedDefunWithNoOutputAndTapedInput(self):
    three = resource_variable_ops.ResourceVariable(3.0, name='v')

    @function.defun
    def f(x):
      # This function intentionally takes a taped variable as input,
      # but does not return any values
      math_ops.add(x, three)

    @function.defun
    def g(x):
      tape.watch_variable(x)
      y = math_ops.add(x, three)
      f(y)

    g(three)

  def testGradientTensorConversionWithDefun(self):
    three = resource_variable_ops.ResourceVariable(3.0, name='v')

    @function.defun
    def f(x):
      return math_ops.add(x, three)

    def g(x):
      tape.watch_variable(three)
      return f(x)

    g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0]
    self.assertAllEqual(g, 1.0)

  def testGradient(self):
    matmul = function.defun(math_ops.matmul)

    def sq(x):
      return matmul(x, x, transpose_a=True)

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
    grad_t, = backprop.gradients_function(sq, [0])(t)
    self.assertAllEqual(grad_t, [[6, 6], [14, 14]])

  def testGradientInFunction(self):

    @function.defun
    def f(x):
      return backprop.gradients_function(lambda y: y * y, [0])(x)[0]

    self.assertAllEqual(f(constant_op.constant(1.0)), 2.0)

  def testGatherResourceWithDefun(self):
    with ops.device('cpu:0'):
      v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])

    def sum_gather():
      return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))

    defined = function.defun(sum_gather)
    self.assertAllEqual(sum_gather(), defined())

  def testGradientOfGatherWithDefun(self):
    v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])

    def sum_gather():
      return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))

    grad_fn = backprop.implicit_grad(sum_gather)
    gradient = grad_fn()
    defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather))
    defun_gradient = defun_grad_fn()
    self.assertEqual(len(gradient), len(defun_gradient))

    gradient = gradient[0][0]
    defun_gradient = defun_gradient[0][0]
    self.assertAllEqual(gradient.values, defun_gradient.values)
    self.assertAllEqual(gradient.indices, defun_gradient.indices)
    self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape)

  def testReturningIndexedSlicesWithDefun(self):

    def validate(indexed_slice):
      def f():
        return indexed_slice

      output = function.defun(f)()
      self.assertTrue(isinstance(output, ops.IndexedSlices))
      self.assertAllEqual(indexed_slice.values, output.values)
      self.assertAllEqual(indexed_slice.indices, output.indices)
      self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)

      self.assertEqual(
          function.make_defun_op(f).output_shapes, indexed_slice.values.shape)

    arg = ops.IndexedSlices(
        values=constant_op.constant([1, 2]),
        indices=constant_op.constant([0, 1]),
        dense_shape=constant_op.constant([2]))
    validate(arg)

    arg = ops.IndexedSlices(
        values=constant_op.constant([1, 2]),
        indices=constant_op.constant([0, 1]),
        dense_shape=None)
    validate(arg)

  def testIndexedSliceAsArgumentWithDefun(self):

    @function.defun
    def f(indexed_slice):
      return indexed_slice

    def validate(arg):
      output = f(arg)
      self.assertTrue(isinstance(output, ops.IndexedSlices))
      self.assertAllEqual(arg.values, output.values)
      self.assertAllEqual(arg.indices, output.indices)
      self.assertAllEqual(arg.dense_shape, output.dense_shape)

    indexed_slice = ops.IndexedSlices(
        values=constant_op.constant([1]),
        indices=constant_op.constant([0]),
        dense_shape=constant_op.constant([1]))
    validate(indexed_slice)

    # Test that `f` works even when `dense_shape` is None.
    indexed_slice = ops.IndexedSlices(
        values=constant_op.constant([1]),
        indices=constant_op.constant([0]),
        dense_shape=None)
    validate(indexed_slice)

  def testFunctionOnDevice(self):
    if not context.context().num_gpus():
      self.skipTest('No GPUs found')

    x = constant_op.constant([1.]).gpu()
    f = function.defun(math_ops.add)
    y = f(x, x).cpu()
    self.assertAllEqual(y, [2.])

  @test_util.run_in_graph_and_eager_modes
  def testFunctionWithResourcesOnDifferentDevices(self):
    if not context.context().num_gpus():
      self.skipTest('No GPUs found.')

    with ops.device('/cpu:0'):
      v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])

    with ops.device('/gpu:0'):
      v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])

    def sum_gather():
      cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
      gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
      return cpu_result, gpu_result

    defined = function.defun(sum_gather)
    if not context.executing_eagerly():
      self.evaluate(variables.global_variables_initializer())
    expected = self.evaluate(sum_gather())
    self.assertAllEqual(expected, self.evaluate(defined()))

  @test_util.run_in_graph_and_eager_modes
  def testOpInFunctionWithConflictingResourceInputs(self):
    if not context.context().num_gpus():
      self.skipTest('No GPUs found.')

    with ops.device('/cpu:0'):
      v_cpu = resource_variable_ops.ResourceVariable(
          [0.0, 1.0, 2.0], name='cpu')
      v_also_cpu = resource_variable_ops.ResourceVariable(
          [0.0, 1.0, 2.0], name='also_cpu')

    with ops.device('/gpu:0'):
      v_gpu = resource_variable_ops.ResourceVariable(
          [0.0, 1.0, 2.0], name='gpu')

    @function.defun
    def resource_apply_adam():
      training_ops.resource_apply_adam(
          v_cpu.handle,
          v_gpu.handle,
          v_also_cpu.handle,
          1.0,  # beta1_power
          1.0,  # beta2_power
          1.0,  # learning_rate
          1.0,  # beta1
          1.0,  # beta2
          1.0,  # epsilon,
          [1.0, 1.0, 1.0],  # grad
          False)  # use_locking
      return None

    with self.assertRaisesRegexp(
        errors.InvalidArgumentError, 'Could not colocate node with its '
        'resource and reference inputs.*'):
      if not context.executing_eagerly():
        self.evaluate(variables.global_variables_initializer())
      self.evaluate(resource_apply_adam())

  def testFunctionHandlesInputsOnDifferentDevices(self):
    if not context.context().num_gpus():
      self.skipTest('No GPUs found')

    # The Reshape op requires the shape tensor to be placed in host memory.
    reshape = function.defun(array_ops.reshape)
    value = constant_op.constant([1., 2.]).gpu()
    shape = constant_op.constant([2, 1])
    reshaped = reshape(value, shape).cpu()
    self.assertAllEqual(reshaped, [[1], [2]])

  def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self):
    if not context.context().num_gpus():
      self.skipTest('No GPUs found')

    # The Reshape op requires the shape tensor to be placed in host memory.
    reshape = function.defun(array_ops.reshape)
    value = constant_op.constant([1., 2.])
    shape = constant_op.constant([2, 1]).gpu()
    reshape(value, shape)  # No error is raised

  def testDifferentiableFunctionNoneOutputs(self):

    @function.defun
    def my_function(x):
      return x, None

    def wrapper(x):
      return my_function(x)[0]

    g = backprop.gradients_function(wrapper, [0])(constant_op.constant(0.0))
    self.assertAllEqual(g[0], 1.)

    @function.defun
    def foo(a):
      return None, a * a

    x = constant_op.constant(5.0)
    with backprop.GradientTape() as tp:
      tp.watch(x)
      none, r = foo(x)
    g = tp.gradient(r, x)

    self.assertIs(none, None)
    self.assertAllEqual(r, 25.0)
    self.assertAllEqual(g, 2 * 5.0)

  def testNestedDifferentiableFunction(self):
    @function.defun
    def inner_fn(a, b):
      return a * math_ops.add(a, b)

    @function.defun
    def outer_fn(x):
      return inner_fn(x, 1.0)

    x = constant_op.constant(5.0)
    with backprop.GradientTape() as tp:
      tp.watch(x)
      result = outer_fn(x)
    grad = tp.gradient(result, x)

    self.assertAllEqual(grad, 2 * 5.0 + 1.0)

  def testNestedDifferentiableFunctionNoneOutputs(self):
    @function.defun
    def foo(a, b):
      return None, a * math_ops.add(a, b), None, 2*a

    @function.defun
    def bar(x):
      return foo(x, 1.0)

    x = constant_op.constant(5.0)
    with backprop.GradientTape(persistent=True) as tp:
      tp.watch(x)
      none1, r1, none2, r2 = bar(x)
    g1 = tp.gradient(r1, x)
    g2 = tp.gradient(r2, x)

    self.assertAllEqual(r1, 30.0)
    self.assertAllEqual(r2, 10.0)
    self.assertIs(none1, None)
    self.assertIs(none2, None)
    self.assertAllEqual(g1, 2 * 5.0 + 1.0)
    self.assertAllEqual(g2, 2.0)

  def testNoneOutput(self):

    @function.defun
    def my_function(_):
      return None

    self.assertAllEqual(my_function(1), None)

  def testNestedFunctions(self):
    # TensorFlow function (which is what would be used in TensorFlow graph
    # construction).
    @tf_function.Defun(dtypes.int32, dtypes.int32)
    def add(a, b):
      return math_ops.add(a, b)

    @function.defun
    def add_one(x):
      return add(x, 1)

    self.assertAllEqual(3, add_one(constant_op.constant(2)))

  def testVariableCaptureInNestedFunctions(self):
    v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)

    @function.defun
    def inner_read():
      return v.read_value()

    @function.defun
    def outer():
      return inner_read()

    self.assertEqual(1, int(outer()))

  def testReturnCapturedEagerTensor(self):
    t = constant_op.constant(1)

    @function.defun
    def read():
      return t

    self.assertEqual(1, int(read()))

  def testReturnCapturedGraphTensor(self):
    with context.graph_mode(), self.test_session():
      t = constant_op.constant(1)

      @function.defun
      def read():
        return t

      self.assertEqual(1, int(self.evaluate(read())))

  def testSequenceInputs(self):
    clip_by_global_norm = function.defun(clip_ops.clip_by_global_norm)
    t_list = [constant_op.constant(1.0), constant_op.constant(2.0)]
    clipped_list, global_norm = clip_by_global_norm(t_list,
                                                    constant_op.constant(.2))
    for t in clipped_list:
      self.assertTrue(isinstance(t, ops.Tensor))
    self.assertTrue(isinstance(global_norm, ops.Tensor))

  def testNestedSequenceInputs(self):

    def my_op(inputs):
      a, b, c = inputs
      e, f = b
      g, h = e
      return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c

    my_eager_op = function.defun(my_op)
    ret = my_eager_op([
        constant_op.constant(1), [(constant_op.constant(2),
                                   constant_op.constant(3)),
                                  constant_op.constant(4)],
        constant_op.constant(5)
    ])
    self.assertEqual(len(ret), 2)
    self.assertAllEqual(ret[0][0], 2)
    self.assertAllEqual(ret[0][1][0][0], 8)
    self.assertAllEqual(ret[0][1][0][1], 4)
    self.assertTrue(isinstance(ret[0][1][0], tuple))
    self.assertAllEqual(ret[0][1][1], 6)
    self.assertAllEqual(ret[0][2], 10)
    self.assertAllEqual(ret[1], 15)

  def testVariableNamesRespectNameScopesWithDefun(self):
    @function.defun
    def create_variable():
      with ops.name_scope('foo'):
        v = resource_variable_ops.ResourceVariable(0.0, name='bar')
      self.assertEqual(v.name, 'foo/bar:0')

    create_variable()

  def testVariableNamesRespectNameScopesWithDefunInGraph(self):
    with context.graph_mode():
      @function.defun
      def create_variable():
        with ops.name_scope('foo'):
          v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
        self.assertEqual(v.name, 'foo/bar:0')

      with ops.get_default_graph().as_default():
        create_variable()

  def testLayerInDefun(self):
    conv = convolutional.Conv2D(
        filters=1,
        kernel_size=2,
        kernel_initializer=init_ops.ones_initializer(),
        bias_initializer=init_ops.zeros_initializer())

    @function.defun
    def model(x):
      return conv(x)

    x = array_ops.ones([1, 2, 2, 1])
    y = model(x)
    self.assertAllEqual([[[[4.0]]]], y.numpy())

  @test_util.run_in_graph_and_eager_modes(
      config=config_pb2.ConfigProto(device_count={'CPU': 3}))
  def testDeviceAnnotationsRespected(self):
    @function.defun
    def multi_device_fn():
      with ops.device('/cpu:0'):
        s1 = iterator_ops.Iterator.from_structure(
            (dtypes.float32,)).string_handle()
      with ops.device('/cpu:1'):
        s2 = iterator_ops.Iterator.from_structure(
            (dtypes.float32,)).string_handle()
      with ops.device('/cpu:2'):
        s3 = iterator_ops.Iterator.from_structure(
            (dtypes.float32,)).string_handle()
      return s1, s2, s3

    outputs = multi_device_fn()
    self.assertTrue(compat.as_bytes('CPU:0') in self.evaluate(outputs[0]))
    self.assertTrue(compat.as_bytes('CPU:1') in self.evaluate(outputs[1]))
    self.assertTrue(compat.as_bytes('CPU:2') in self.evaluate(outputs[2]))

  def testVariablesAreTracked(self):
    v = resource_variable_ops.ResourceVariable(1.0)

    def foo(x):
      return v * x

    defined = function.defun(foo)

    x = constant_op.constant([1.0])
    self.assertAllEqual(defined.variables, [])
    _ = defined(x)
    self.assertAllEqual(defined.variables, [v])

    x = constant_op.constant([1.0, 2.0])
    _ = defined(x)  # ensure the variables list remains the same
    self.assertAllEqual(defined.variables, [v])

  def testTensorKeywordArguments(self):

    def foo(a, b):
      del a
      return b

    defined = function.defun(foo)
    a = constant_op.constant(2.0)
    b = constant_op.constant([1.0, 2.0])
    one = defined(a, b)
    self.assertEqual(len(defined._arguments_to_functions), 1)

    two = defined(a=a, b=b)
    self.assertEqual(len(defined._arguments_to_functions), 1)

    three = defined(b=b, a=a)
    self.assertEqual(len(defined._arguments_to_functions), 1)

    four = defined(a, b=b)
    self.assertEqual(len(defined._arguments_to_functions), 1)

    # The next call corresponds to a new input signature, hence
    # we expect another function to be defined.
    five = defined(b, a)
    self.assertEqual(len(defined._arguments_to_functions), 2)

    six = defined(a=b, b=a)
    self.assertEqual(len(defined._arguments_to_functions), 2)

    seven = defined(b=a, a=b)
    self.assertEqual(len(defined._arguments_to_functions), 2)

    self.assertAllEqual(one, [1.0, 2.0])
    self.assertAllEqual(two, [1.0, 2.0])
    self.assertAllEqual(three, [1.0, 2.0])
    self.assertAllEqual(four, [1.0, 2.0])
    self.assertAllEqual(five, 2.0)
    self.assertAllEqual(six, 2.0)
    self.assertAllEqual(seven, 2.0)

  def testGradientWithKeywordArguments(self):
    matmul = function.defun(math_ops.matmul)

    def sq(x):
      return matmul(a=x, b=x, transpose_a=True)

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
    grad_t, = backprop.gradients_function(sq, [0])(t)
    self.assertAllEqual(grad_t, [[6, 6], [14, 14]])

    with backprop.GradientTape(persistent=True) as gtape:
      gtape.watch(t)
      one = matmul(t, b=t, transpose_a=True)
      two = matmul(b=t, a=t, transpose_a=True)
      three = matmul(a=t, b=t, transpose_a=True)

    for output in [one, two, three]:
      self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]])

  def testGradientInFunctionWithKeywordArguments(self):

    @function.defun
    def f(x):
      return backprop.gradients_function(lambda y: y * y, [0])(x)[0]

    self.assertAllEqual(f(x=constant_op.constant(1.0)), 2.0)

  def testDecoratingInstanceMethod(self):

    class Foo(object):

      def one(self, tensor):
        return tensor

      @function.defun
      def two(self, tensor):
        return self.one(tensor)

    foo = Foo()
    t = constant_op.constant(1.0)
    out = foo.two(t)
    self.assertEqual(float(out), 1.0)

  def testPythonCallWithSideEffects(self):
    state = []

    @function.defun
    def side_effecting_function():
      state.append(0)

    side_effecting_function()
    self.assertAllEqual(state, [0])

    # The second invocation should call the graph function, which shouldn't
    # trigger the list append.
    side_effecting_function()
    self.assertAllEqual(state, [0])

    # Whereas calling the python function directly should create a side-effect.
    side_effecting_function.call_python_function()
    self.assertAllEqual(state, [0, 0])
 def _no_rewrite_session_config(self):
     rewriter_config = rewriter_config_pb2.RewriterConfig(
         disable_model_pruning=True)
     graph_options = config_pb2.GraphOptions(
         rewrite_options=rewriter_config)
     return config_pb2.ConfigProto(graph_options=graph_options)