Esempio n. 1
0
  def test_monkey_patch_default_variable_placement_strategy(
      self, num_tasks, op_names, before_want_ps, after_want_ps):
    """Checks that ps placement is based on var name."""

    var_ops = [tf.Variable(0., name=op_name).op for op_name in op_names]
    before_device_fn = tf.compat.v1.train.replica_device_setter(
        ps_tasks=num_tasks)
    self.assertEqual(before_want_ps, [before_device_fn(op) for op in var_ops])

    with monkey_patch_default_variable_placement_strategy():
      after_device_fn = tf.compat.v1.train.replica_device_setter(
          ps_tasks=num_tasks)
    self.assertEqual(after_want_ps, [after_device_fn(op) for op in var_ops])

    # Check that monkey-patch is only for the context.
    before_device_fn = tf.compat.v1.train.replica_device_setter(
        ps_tasks=num_tasks)
    self.assertEqual(before_want_ps, [before_device_fn(op) for op in var_ops])
Esempio n. 2
0
 def test_monkey_patch_default_variable_placement_strategy_no_ps(self):
     with monkey_patch_default_variable_placement_strategy():
         device_fn = tf.train.replica_device_setter(ps_tasks=0)
     self.assertIsNone(device_fn)