def _outer_template():
   first_inner = template.make_template("i1", _inner_template)
   second_inner = template.make_template("i2", _inner_template)
   v1 = first_inner()
   v2 = second_inner()
   v3 = second_inner()
   return (first_inner, second_inner), (v1, v2, v3)
  def test_checkpointable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer())
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer())
      return v, v + 1., v2

    save_template = template.make_template("s1", _templated)
    save_root = checkpointable_utils.Checkpoint(my_template=save_template)
    v1_save, _, v2_save = save_template()
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_root = checkpointable_utils.Checkpoint(my_template=load_template)
    status = load_root.restore(save_path)
    var, var_plus_one, var2 = load_template()
    self.assertEqual(2, len(load_template._checkpoint_dependencies))
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
 def test_unique_name_raise_error_in_eager(self):
   with context.eager_mode():
     with self.assertRaisesRegexp(
         ValueError,
         "unique_name_ cannot be used when eager exeuction is enabled."):
       template.make_template(
           "_", variable_scoped_function, unique_name_="s1")
  def test_custom_getter(self):
    # Custom getter that maintains call count and forwards to true getter
    custom_getter_count = [0]

    def custom_getter(getter, name, *args, **kwargs):
      custom_getter_count[0] += 1
      return getter(name, *args, **kwargs)

    # Test that custom getter is called both when variables are created and
    # subsequently accessed
    tmpl1 = template.make_template(
        "s1", variable_scoped_function, custom_getter_=custom_getter)
    self.assertEqual(custom_getter_count[0], 0)
    tmpl1()
    self.assertEqual(custom_getter_count[0], 1)
    tmpl1()
    self.assertEqual(custom_getter_count[0], 2)

    # Test that custom getter is called when the variable scope is created
    # during construction
    custom_getter_count[0] = 0
    tmpl2 = template.make_template(
        "s2",
        variable_scoped_function,
        custom_getter_=custom_getter,
        create_scope_now_=True)
    self.assertEqual(custom_getter_count[0], 0)
    tmpl2()
    self.assertEqual(custom_getter_count[0], 1)
    tmpl2()
    self.assertEqual(custom_getter_count[0], 2)
Example #5
0
 def nested_template():
   nested1 = template.make_template("nested", variable_scoped_function)
   nested2 = template.make_template("nested", variable_scoped_function)
   v1 = nested1()
   v2 = nested2()
   self.assertNotEqual(v1, v2)
   return v2
Example #6
0
 def test_unique_name_raise_error(self):
   tmpl1 = template.make_template(
       "_", variable_scoped_function, unique_name_="s1")
   tmpl1()
   tmpl2 = template.make_template(
       "_", variable_scoped_function, unique_name_="s1")
   with self.assertRaises(ValueError):
     tmpl2()
 def test_same_unique_name_raise_error(self):
   tmpl1 = template.make_template(
       "_", variable_scoped_function, unique_name_="s1")
   tmpl1()
   tmpl2 = template.make_template(
       "_", variable_scoped_function, unique_name_="s1")
   with self.assertRaisesRegexp(
       ValueError, "Variable s1/dummy already exists, disallowed.*"):
     tmpl2()
  def test_nested_templates_with_defun(self):

    def variable_scoped_function_no_return_value(trainable=True):
      # defun cannot compile functions that return non-Tensor objects
      _ = variable_scope.get_variable(
          "dummy",
          shape=[1],
          trainable=trainable,
          initializer=init_ops.zeros_initializer())

    def nested_template():
      nested1 = template.make_template_internal(
          "nested",
          variable_scoped_function_no_return_value,
          create_graph_function_=True)
      nested2 = template.make_template_internal(
          "nested",
          variable_scoped_function_no_return_value,
          create_graph_function_=True)
      nested1()
      nested2()
      v1 = nested1.variables
      v2 = nested2.variables

      # nested1 and nested2 should not share variables
      self.assertNotEqual(v1, v2)

      # Variables created by nested1 should be isolated from variables
      # created by nested2.
      self.assertEqual(nested1.variables, v1)
      self.assertEqual(nested2.variables, v2)
      self.assertEqual(nested1.trainable_variables, v1)
      self.assertEqual(nested2.trainable_variables, v2)
      self.assertEqual(len(nested1.non_trainable_variables), 0)
      self.assertEqual(len(nested2.non_trainable_variables), 0)

    tmpl1 = template.make_template("s1", nested_template)
    tmpl2 = template.make_template("s1", nested_template)

    tmpl1()
    v1 = tmpl1.variables
    tmpl1()
    v2 = tmpl1.variables
    tmpl2()
    v3 = tmpl2.variables

    # The second invocation of tmpl1 should reuse the variables
    # created in the first invocation.
    self.assertSequenceEqual(v1, v2)

    # tmpl1 and tmpl2 should not share variables.
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/nested/dummy:0", v1[0].name)
    self.assertEqual("s1/nested_1/dummy:0", v1[1].name)
    self.assertEqual("s1_1/nested/dummy:0", v3[0].name)
    self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name)
  def test_template_with_name(self):
    tmpl1 = template.make_template("s1", variable_scoped_function)
    tmpl2 = template.make_template("s1", variable_scoped_function)

    v1 = tmpl1()
    v2 = tmpl1()
    v3 = tmpl2()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/dummy:0", v1.name)
    self.assertEqual("s1_1/dummy:0", v3.name)
Example #10
0
    def test_make_template(self):
        # Test both that we can call it with positional and keywords.
        tmpl1 = template.make_template("s1", internally_var_scoped_function, scope_name="test")
        tmpl2 = template.make_template("s1", internally_var_scoped_function, scope_name="test")

        v1 = tmpl1()
        v2 = tmpl1()
        v3 = tmpl2()
        self.assertEqual(v1, v2)
        self.assertNotEqual(v1, v3)
        self.assertEqual("s1/test/dummy:0", v1.name)
        self.assertEqual("s1_1/test/dummy:0", v3.name)
Example #11
0
    def test_unique_name_and_reuse(self):
        tmpl1 = template.make_template("_", var_scoped_function, unique_name_="s1")
        v1 = tmpl1()
        v2 = tmpl1()

        tf.get_variable_scope().reuse_variables()
        tmpl2 = template.make_template("_", var_scoped_function, unique_name_="s1")
        v3 = tmpl2()

        self.assertEqual(v1, v2)
        self.assertEqual(v1, v3)
        self.assertEqual("s1/dummy:0", v1.name)
  def test_template_with_internal_reuse(self):
    tmpl1 = template.make_template("s1", internally_variable_scoped_function)
    tmpl2 = template.make_template("s1", internally_variable_scoped_function)

    v1 = tmpl1("test")
    v2 = tmpl1("test")
    v3 = tmpl2("test")
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/test/dummy:0", v1.name)
    self.assertEqual("s1_1/test/dummy:0", v3.name)

    with self.assertRaises(ValueError):
      tmpl1("not_test")
  def test_unique_name_and_reuse(self):
    tmpl1 = template.make_template(
        "_", variable_scoped_function, unique_name_="s1")
    v1 = tmpl1()
    v2 = tmpl1()

    variable_scope.get_variable_scope().reuse_variables()
    tmpl2 = template.make_template(
        "_", variable_scoped_function, unique_name_="s1")
    v3 = tmpl2()

    self.assertEqual(v1, v2)
    self.assertEqual(v1, v3)
    self.assertEqual("s1/dummy:0", v1.name)
  def test_template_with_internal_reuse(self):
    tmpl1 = template.make_template("s1", internally_variable_scoped_function)
    tmpl2 = template.make_template("s1", internally_variable_scoped_function)

    v1 = tmpl1("test")
    v2 = tmpl1("test")
    v3 = tmpl2("test")
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/test/dummy:0", v1.name)
    self.assertEqual("s1_1/test/dummy:0", v3.name)

    with self.assertRaises(ValueError):
      tmpl1("not_test")
  def test_make_template(self):
    # Test both that we can call it with positional and keywords.
    tmpl1 = template.make_template(
        "s1", internally_variable_scoped_function, scope_name="test")
    tmpl2 = template.make_template(
        "s1", internally_variable_scoped_function, scope_name="test")

    v1 = tmpl1()
    v2 = tmpl1()
    v3 = tmpl2()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/test/dummy:0", v1.name)
    self.assertEqual("s1_1/test/dummy:0", v3.name)
    def test_trackable_save_restore(self):
        def _templated():
            v = variable_scope.get_variable(
                "v",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            v2 = variable_scope.get_variable(
                "v2",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            manual = _ManualScope()
            return v, v + 1., v2, manual, manual()

        save_template = template.make_template("s1", _templated)
        v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
        six.assertCountEqual(self, [
            id(obj) for obj in
            [v1_save, v2_save, manual_scope, manual_scope_v, save_template]
        ], [id(obj) for obj in trackable_utils.list_objects(save_template)])
        self.assertDictEqual({"in_manual_scope": manual_scope_v},
                             manual_scope._trackable_children())
        optimizer = adam.AdamOptimizer(0.0)
        save_root = trackable_utils.Checkpoint(my_template=save_template,
                                               optimizer=optimizer)
        optimizer.minimize(v1_save.read_value)
        self.evaluate([v.initializer for v in save_template.variables])
        self.evaluate([v.initializer for v in optimizer.variables()])
        self.evaluate(v1_save.assign([12.]))
        self.evaluate(v2_save.assign([14.]))
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        save_path = save_root.save(checkpoint_prefix)

        load_template = template.make_template("s2", _templated)
        load_optimizer = adam.AdamOptimizer(0.0)
        load_root = trackable_utils.Checkpoint(my_template=load_template,
                                               optimizer=load_optimizer)
        status = load_root.restore(save_path)
        var, var_plus_one, var2, _, _ = load_template()
        load_optimizer.minimize(var.read_value)
        self.assertEqual(3, len(load_template._trackable_children()))
        self.assertEqual(set(["v", "v2", "ManualScope"]),
                         load_template._trackable_children().keys())
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([12.], self.evaluate(var))
        self.assertAllEqual([13.], self.evaluate(var_plus_one))
        self.assertAllEqual([14.], self.evaluate(var2))
  def test_checkpointable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      manual = _ManualScope()
      return v, v + 1., v2, manual, manual()

    save_template = template.make_template("s1", _templated)
    v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
    six.assertCountEqual(
        self,
        [v1_save, v2_save, manual_scope, manual_scope_v, save_template],
        checkpointable_utils.list_objects(save_template))
    manual_dep, = manual_scope._checkpoint_dependencies
    self.assertEqual("in_manual_scope", manual_dep.name)
    self.assertIs(manual_scope_v, manual_dep.ref)
    optimizer = adam.AdamOptimizer(0.0)
    save_root = checkpointable_utils.Checkpoint(
        my_template=save_template, optimizer=optimizer)
    optimizer.minimize(v1_save.read_value)
    self.evaluate([v.initializer for v in save_template.variables])
    self.evaluate([v.initializer for v in optimizer.variables()])
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_optimizer = adam.AdamOptimizer(0.0)
    load_root = checkpointable_utils.Checkpoint(
        my_template=load_template, optimizer=load_optimizer)
    status = load_root.restore(save_path)
    var, var_plus_one, var2, _, _ = load_template()
    load_optimizer.minimize(var.read_value)
    self.assertEqual(3, len(load_template._checkpoint_dependencies))
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    self.assertEqual("ManualScope",
                     load_template._checkpoint_dependencies[2].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
Example #18
0
    def test_nested_templates(self):
        def nested_template():
            nested1 = template.make_template("nested",
                                             variable_scoped_function)
            nested2 = template.make_template("nested",
                                             variable_scoped_function)
            v1 = nested1()
            v2 = nested2()

            # nested1 and nested2 should not share variables
            self.assertNotEqual(v1, v2)

            # Variables created by nested1 should be isolated from variables
            # created by nested2.
            self.assertEqual(nested1.variables, [v1])
            self.assertEqual(nested2.variables, [v2])
            self.assertEqual(nested1.trainable_variables, [v1])
            self.assertEqual(nested2.trainable_variables, [v2])
            self.assertEqual(len(nested1.non_trainable_variables), 0)
            self.assertEqual(len(nested2.non_trainable_variables), 0)
            return v1, v2

        tmpl1 = template.make_template("s1", nested_template)
        tmpl2 = template.make_template("s1", nested_template)

        v1, v2 = tmpl1()
        v3, v4 = tmpl1()
        v5, v6 = tmpl2()

        # The second invocation of tmpl1 should reuse the variables
        # created in the first invocation.
        self.assertEqual([v1, v2], [v3, v4])
        self.assertEqual(tmpl1.variables, [v1, v2])
        self.assertEqual(tmpl1.trainable_variables, [v1, v2])
        self.assertEqual(len(tmpl1.non_trainable_variables), 0)

        # tmpl1 and tmpl2 should not share variables.
        self.assertNotEqual([v1, v2], [v5, v6])
        self.assertSequenceEqual(tmpl2.variables, [v5, v6])
        self.assertSequenceEqual(tmpl2.trainable_variables, [v5, v6])
        self.assertEqual(len(tmpl2.non_trainable_variables), 0)
        self.assertEqual("s1/nested/dummy:0", v1.name)
        self.assertEqual("s1/nested_1/dummy:0", v2.name)
        self.assertEqual("s1_1/nested/dummy:0", v5.name)
        self.assertEqual("s1_1/nested_1/dummy:0", v6.name)

        self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
        self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
        self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
Example #19
0
  def test_template_in_scope(self):
    tmpl1 = template.make_template("s1", variable_scoped_function)
    tmpl2 = template.make_template("s1", variable_scoped_function)

    with variable_scope.variable_scope("scope"):
      v1 = tmpl1()
      v3 = tmpl2()

    # The template contract requires the following to ignore scope2.
    with variable_scope.variable_scope("scope2"):
      v2 = tmpl1()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("scope/s1/dummy:0", v1.name)
    self.assertEqual("scope/s1_1/dummy:0", v3.name)
Example #20
0
  def test_nested_templates(self):

    def nested_template():
      nested1 = template.make_template("nested", variable_scoped_function)
      nested2 = template.make_template("nested", variable_scoped_function)
      v1 = nested1()
      v2 = nested2()

      # nested1 and nested2 should not share variables
      self.assertNotEqual(v1, v2)

      # Variables created by nested1 should be isolated from variables
      # created by nested2.
      self.assertEqual(nested1.variables, [v1])
      self.assertEqual(nested2.variables, [v2])
      self.assertEqual(nested1.trainable_variables, [v1])
      self.assertEqual(nested2.trainable_variables, [v2])
      self.assertEqual(len(nested1.non_trainable_variables), 0)
      self.assertEqual(len(nested2.non_trainable_variables), 0)
      return v1, v2

    tmpl1 = template.make_template("s1", nested_template)
    tmpl2 = template.make_template("s1", nested_template)

    v1, v2 = tmpl1()
    v3, v4 = tmpl1()
    v5, v6 = tmpl2()

    # The second invocation of tmpl1 should reuse the variables
    # created in the first invocation.
    self.assertEqual([v1, v2], [v3, v4])
    self.assertEqual(tmpl1.variables, [v1, v2])
    self.assertEqual(tmpl1.trainable_variables, [v1, v2])
    self.assertEqual(len(tmpl1.non_trainable_variables), 0)

    # tmpl1 and tmpl2 should not share variables.
    self.assertNotEqual([v1, v2], [v5, v6])
    self.assertSequenceEqual(tmpl2.variables, [v5, v6])
    self.assertSequenceEqual(tmpl2.trainable_variables, [v5, v6])
    self.assertEqual(len(tmpl2.non_trainable_variables), 0)
    self.assertEqual("s1/nested/dummy:0", v1.name)
    self.assertEqual("s1/nested_1/dummy:0", v2.name)
    self.assertEqual("s1_1/nested/dummy:0", v5.name)
    self.assertEqual("s1_1/nested_1/dummy:0", v6.name)

    self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
    self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
    self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
Example #21
0
    def create_template_fn(
        self,
        name: str,
    ) -> Callable[[tf.Tensor], Tuple[tf.Tensor, tf.Tensor]]:
        """
        Creates simple shallow network. Note that this function will return a
        tensorflow template.
        Args:
            name: a scope name of the network
        Returns:
            a template function
        """
        def _shift_and_log_scale_fn(x: tf.Tensor):
            shape = K.int_shape(x)
            num_channels = shape[3]

            with tf.variable_scope("BlockNN"):
                h = x
                h = self.activation_fn(ops.conv2d("l_1", h, self.width))
                h = self.activation_fn(
                    ops.conv2d("l_2", h, self.width, filter_size=[1, 1]))
                # create shift and log_scale with zero initialization
                shift_log_scale = ops.conv2d_zeros("l_last", h,
                                                   2 * num_channels)
                shift = shift_log_scale[:, :, :, 0::2]
                log_scale = shift_log_scale[:, :, :, 1::2]
                log_scale = tf.clip_by_value(log_scale, -15.0, 15.0)
                return shift, log_scale

        return template_ops.make_template(name, _shift_and_log_scale_fn)
Example #22
0
    def test_merge_call(self):
        with ops.Graph().as_default():
            # The test is testing a v1 only function.
            if not test.is_gpu_available():
                self.skipTest("No GPU available")

            def fn():
                var1 = variable_scope.get_variable(
                    "var1",
                    shape=[],
                    initializer=init_ops.constant_initializer(21.))
                ds_context.get_replica_context().merge_call(lambda _: ())
                var2 = variable_scope.get_variable(
                    "var2",
                    shape=[],
                    initializer=init_ops.constant_initializer(2.))
                return var1 * var2

            temp = template.make_template("my_template", fn)

            strategy = mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"])
            out = strategy.experimental_local_results(strategy.run(temp))

            self.evaluate(variables.global_variables_initializer())
            self.assertAllEqual([42., 42.], self.evaluate(out))
Example #23
0
  def test_trainable_variables(self):
    # Make sure trainable_variables are created.
    with variable_scope.variable_scope("foo2"):
      # Create two templates with the same name, ensure scopes are made unique.
      ta = template.make_template("bar", variable_scoped_function, True)
      tb = template.make_template("bar", variable_scoped_function, True)

    # Initially there are not variables created.
    self.assertEqual([], list(ta.trainable_variables))
    self.assertEqual([], list(tb.trainable_variables))
    # After calling there are variables created.
    ta()
    tb()
    # Ensure we can get the scopes before either template is actually called.
    self.assertEqual(1, len(ta.trainable_variables))
    self.assertEqual(1, len(tb.trainable_variables))
Example #24
0
    def test_end_to_end(self):
        """This test shows a very simple line model with test_loss.

    The template is used to share parameters between a training and test model.
    """
        # y = 2x + 1
        training_input, training_output = ([1.0, 2.0, 3.0, 4.0], [2.8, 5.1, 7.2, 8.7])
        test_input, test_output = ([5.0, 6.0, 7.0, 8.0], [11, 13, 15, 17])

        tf.set_random_seed(1234)

        def test_line(x):
            m = tf.get_variable("w", shape=[], initializer=tf.truncated_normal_initializer())
            b = tf.get_variable("b", shape=[], initializer=tf.truncated_normal_initializer())
            return x * m + b

        line_template = template.make_template("line", test_line)

        train_prediction = line_template(training_input)
        test_prediction = line_template(test_input)

        train_loss = tf.reduce_mean(tf.square(train_prediction - training_output))
        test_loss = tf.reduce_mean(tf.square(test_prediction - test_output))

        optimizer = tf.train.GradientDescentOptimizer(0.1)
        train_op = optimizer.minimize(train_loss)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            initial_test_loss = sess.run(test_loss)
            sess.run(train_op)
            final_test_loss = sess.run(test_loss)

        # Parameters are tied, so the loss should have gone down when we trained it.
        self.assertLess(final_test_loss, initial_test_loss)
  def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None):
    with ops.Graph().as_default() as g:
      conv = template.make_template('shared_weights_conv', self._ConvLayer)
      conv()
      conv()
      if quant_delay is None:
        rewrite_fn()
      else:
        rewrite_fn(quant_delay=quant_delay)

    conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
    weights_quants = [
        op for op in g.get_operations()
        if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars'
    ]
    # Check that the shared weights variable is not quantized multiple times
    self.assertTrue(len(weights_quants) == 1)
    weights_quant_tensor = weights_quants[0].outputs[0]
    if quant_delay:
      delayed_weights_quants = [
          op for op in g.get_operations()
          if 'weights_quant' in op.name and op.type == 'Merge'
      ]
      self.assertTrue(len(delayed_weights_quants) == 1)
      weights_quant_tensor = delayed_weights_quants[0].outputs[0]
    # Check that the Conv2D operations get the quantized weights
    self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops))
Example #26
0
def make_dense_function_G(units, name, activation=tf.nn.sigmoid):
    '''
    This is the function constructing one hidden layer dense NN.
    Network archtecture is D-actvation-units-activation-D
    '''
    def _fn(x):
        with ops.name_scope(name, "trial_function"):
            layer = layers.Dense(units,
                                 activation=activation,
                                 name=name,
                                 _scope=name)

            layer2 = layers.Dense(
                units=1,
                #activation=activation,
                name=name,
                _scope=name + '2')
            '''
            layer3 = layers.Dense(
                    units=1,
                    activation=activation,
                    name=name,
                    _scope=name+'3')
            '''
            return layer2.apply(layer.apply(x))
            #return layer3.apply(layer2.apply(layer.apply(x)))

    return template_ops.make_template("trial_function", _fn)
Example #27
0
    def test_trainable_variables(self):
        # Make sure trainable_variables are created.
        with variable_scope.variable_scope("foo2"):
            # Create two templates with the same name, ensure scopes are made unique.
            ta = template.make_template("bar", variable_scoped_function, True)
            tb = template.make_template("bar", variable_scoped_function, True)

        # Initially there are not variables created.
        self.assertEqual([], ta.trainable_variables)
        self.assertEqual([], tb.trainable_variables)
        # After calling there are variables created.
        ta()
        tb()
        # Ensure we can get the scopes before either template is actually called.
        self.assertEqual(1, len(ta.trainable_variables))
        self.assertEqual(1, len(tb.trainable_variables))
Example #28
0
    def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None):
        with ops.Graph().as_default() as g:
            conv = template.make_template('shared_weights_conv',
                                          self._ConvLayer)
            conv()
            conv()
            if quant_delay is None:
                rewrite_fn()
            else:
                rewrite_fn(quant_delay=quant_delay)

        conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
        weights_quants = [
            op for op in g.get_operations() if 'weights_quant' in op.name
            and op.type == 'FakeQuantWithMinMaxVars'
        ]
        # Check that the shared weights variable is not quantized multiple times
        self.assertTrue(len(weights_quants) == 1)
        weights_quant_tensor = weights_quants[0].outputs[0]
        if quant_delay:
            delayed_weights_quants = [
                op for op in g.get_operations()
                if 'weights_quant' in op.name and op.type == 'Merge'
            ]
            self.assertTrue(len(delayed_weights_quants) == 1)
            weights_quant_tensor = delayed_weights_quants[0].outputs[0]
        # Check that the Conv2D operations get the quantized weights
        self.assertTrue(
            all(weights_quant_tensor in op.inputs for op in conv_ops))
def real_nvp_spectral_template(hidden_layers,
                               shift_only=False,
                               activation=nn_ops.relu,
                               name=None,
                               *args,
                               **kwargs):

    with ops.name_scope(name, "real_nvp_spectral_template"):

        def _fn(x, output_units):
            """Fully connected MLP parameterized via `real_nvp_template`."""
            for units in hidden_layers:
                x = spectral_dense(inputs=x,
                                   units=units,
                                   activation=activation,
                                   *args,
                                   **kwargs)
            x = spectral_dense(inputs=x,
                               units=(1 if shift_only else 2) * output_units,
                               activation=None,
                               *args,
                               **kwargs)
            if shift_only:
                return x, None
            shift, log_scale = array_ops.split(x, 2, axis=-1)
            return shift, tf.tanh(log_scale)

        return template_ops.make_template("real_nvp_spectral_template", _fn)
Example #30
0
def orthogonal_flow_template(hidden_layers,
                             activation=nn_ops.relu,
                             name=None,
                             *args,
                             **kwargs):
    with ops.name_scope(name, "orthogonal_flow_template"):

        def _fn(x, output_units):
            for units in hidden_layers:
                x = layers.dense(
                    inputs=x,
                    units=units,
                    activation=activation,
                    # kernel_initializer=tf.random_normal_initializer(0., .01,seed=14),
                    # bias_initializer=tf.random_normal_initializer(0., .01,seed=14),
                    *args,
                    **kwargs)
            x = layers.dense(inputs=x,
                             units=2 * output_units,
                             activation=None,
                             *args,
                             **kwargs)
            x, y = array_ops.split(x, 2, axis=-1)
            return x, y

        return template_ops.make_template("orthogonal_flow_template", _fn)
Example #31
0
  def test_enforces_no_extra_trainable_variables_eager(self):
    tmpl = template.make_template("s",
                                  function_with_side_create,
                                  trainable=True)

    tmpl(name="1")
    with self.assertRaises(ValueError):
      tmpl(name="2")
    def test_trackable_save_restore_nested(self):
        def _inner_template():
            v = variable_scope.get_variable(
                "v", shape=[1], initializer=init_ops.zeros_initializer())
            return v

        def _outer_template():
            first_inner = template.make_template("i1", _inner_template)
            second_inner = template.make_template("i2", _inner_template)
            v1 = first_inner()
            v2 = second_inner()
            v3 = second_inner()
            return (first_inner, second_inner), (v1, v2, v3)

        with variable_scope.variable_scope("ignored"):
            save_template = template.make_template("s1", _outer_template)
            save_root = trackable_utils.Checkpoint(my_template=save_template)
            (inner_template_one, inner_template_two), _ = save_template()
        self.evaluate(inner_template_one.variables[0].assign([20.]))
        self.evaluate(inner_template_two.variables[0].assign([25.]))
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        save_path = save_root.save(checkpoint_prefix)

        load_template = template.make_template("s2", _outer_template)
        load_root = trackable_utils.Checkpoint(my_template=load_template)
        status = load_root.restore(save_path)
        (inner_template_one, inner_template_two), (v1, v2,
                                                   v3) = load_template()
        outer_template_dependencies = load_root.my_template._checkpoint_dependencies
        self.assertLen(outer_template_dependencies, 2)
        self.assertEqual("i1", outer_template_dependencies[0].name)
        self.assertIs(inner_template_one, outer_template_dependencies[0].ref)
        self.assertEqual("i2", outer_template_dependencies[1].name)
        self.assertIs(inner_template_two, outer_template_dependencies[1].ref)
        self.assertLen(inner_template_one._checkpoint_dependencies, 1)
        self.assertEqual("v",
                         inner_template_one._checkpoint_dependencies[0].name)
        self.assertLen(inner_template_two._checkpoint_dependencies, 1)
        self.assertEqual("v",
                         inner_template_two._checkpoint_dependencies[0].name)
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([20.], self.evaluate(v1))
        self.assertAllEqual([25.], self.evaluate(v2))
        self.assertAllEqual([25.], self.evaluate(v3))
Example #33
0
    def nested_template():
      nested1 = template.make_template("nested", variable_scoped_function)
      nested2 = template.make_template("nested", variable_scoped_function)
      v1 = nested1()
      v2 = nested2()

      # nested1 and nested2 should not share variables
      self.assertNotEqual(v1, v2)

      # Variables created by nested1 should be isolated from variables
      # created by nested2.
      self.assertEqual(nested1.variables, [v1])
      self.assertEqual(nested2.variables, [v2])
      self.assertEqual(nested1.trainable_variables, [v1])
      self.assertEqual(nested2.trainable_variables, [v2])
      self.assertEqual(len(nested1.non_trainable_variables), 0)
      self.assertEqual(len(nested2.non_trainable_variables), 0)
      return v1, v2
Example #34
0
def real_nvp_default_template(
    hidden_layers,
    shift_only=False,
    activation=nn_ops.relu,
    name=None,
    *args,
    **kwargs):
  """Build a scale-and-shift function using a multi-layer neural network.

  This will be wrapped in a make_template to ensure the variables are only
  created once. It takes the `d`-dimensional input x[0:d] and returns the `D-d`
  dimensional outputs `loc` ("mu") and `log_scale` ("alpha").

  Arguments:
    hidden_layers: Python `list`-like of non-negative integer, scalars
      indicating the number of units in each hidden layer. Default: `[512, 512].
    shift_only: Python `bool` indicating if only the `shift` term shall be
      computed (i.e. NICE bijector). Default: `False`.
    activation: Activation function (callable). Explicitly setting to `None`
      implies a linear activation.
    name: A name for ops managed by this function. Default:
      "real_nvp_default_template".
    *args: `tf.layers.dense` arguments.
    **kwargs: `tf.layers.dense` keyword arguments.

  Returns:
    shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]).
    log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]).

  Raises:
    NotImplementedError: if rightmost dimension of `inputs` is unknown prior to
      graph execution.
  """

  with ops.name_scope(name, "real_nvp_default_template"):
    def _fn(x, output_units):
      """Fully connected MLP parameterized via `real_nvp_template`."""
      for units in hidden_layers:
        x = layers.dense(
            inputs=x,
            units=units,
            activation=activation,
            *args,
            **kwargs)
      x = layers.dense(
          inputs=x,
          units=(1 if shift_only else 2) * output_units,
          activation=None,
          *args,
          **kwargs)
      if shift_only:
        return x, None
      shift, log_scale = array_ops.split(x, 2, axis=-1)
      return shift, log_scale
    return template_ops.make_template(
        "real_nvp_default_template", _fn)
Example #35
0
  def test_nested_templates(self):
    def nested_template():
      nested1 = template.make_template("nested", var_scoped_function)
      nested2 = template.make_template("nested", var_scoped_function)
      v1 = nested1()
      v2 = nested2()
      self.assertNotEqual(v1, v2)
      return v2

    tmpl1 = template.make_template("s1", nested_template)
    tmpl2 = template.make_template("s1", nested_template)

    v1 = tmpl1()
    v2 = tmpl1()
    v3 = tmpl2()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/nested_1/dummy:0", v1.name)
    self.assertEqual("s1_2/nested_1/dummy:0", v3.name)
Example #36
0
  def test_internal_variable_reuse(self):
    def nested():
      with tf.variable_scope("nested") as vs:
        v1 = tf.get_variable("x", initializer=tf.zeros_initializer, shape=[])
      with tf.variable_scope(vs, reuse=True):
        v2 = tf.get_variable("x")
      self.assertEqual(v1, v2)
      return v1

    tmpl1 = template.make_template("s1", nested)
    tmpl2 = template.make_template("s1", nested)

    v1 = tmpl1()
    v2 = tmpl1()
    v3 = tmpl2()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/nested/x:0", v1.name)
    self.assertEqual("s1_2/nested/x:0", v3.name)
  def test_checkpointable_save_restore_nested(self):

    def _inner_template():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer())
      return v

    def _outer_template():
      first_inner = template.make_template("i1", _inner_template)
      second_inner = template.make_template("i2", _inner_template)
      v1 = first_inner()
      v2 = second_inner()
      v3 = second_inner()
      return (first_inner, second_inner), (v1, v2, v3)

    with variable_scope.variable_scope("ignored"):
      save_template = template.make_template("s1", _outer_template)
      save_root = checkpointable_utils.Checkpoint(my_template=save_template)
      (inner_template_one, inner_template_two), _ = save_template()
    self.evaluate(inner_template_one.variables[0].assign([20.]))
    self.evaluate(inner_template_two.variables[0].assign([25.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _outer_template)
    load_root = checkpointable_utils.Checkpoint(my_template=load_template)
    status = load_root.restore(save_path)
    (inner_template_one, inner_template_two), (v1, v2, v3) = load_template()
    outer_template_dependencies = load_root.my_template._checkpoint_dependencies
    self.assertEqual(2, len(outer_template_dependencies))
    self.assertEqual("i1", outer_template_dependencies[0].name)
    self.assertIs(inner_template_one, outer_template_dependencies[0].ref)
    self.assertEqual("i2", outer_template_dependencies[1].name)
    self.assertIs(inner_template_two, outer_template_dependencies[1].ref)
    self.assertEqual(1, len(inner_template_one._checkpoint_dependencies))
    self.assertEqual("v", inner_template_one._checkpoint_dependencies[0].name)
    self.assertEqual(1, len(inner_template_two._checkpoint_dependencies))
    self.assertEqual("v", inner_template_two._checkpoint_dependencies[0].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([20.], self.evaluate(v1))
    self.assertAllEqual([25.], self.evaluate(v2))
    self.assertAllEqual([25.], self.evaluate(v3))
Example #38
0
    def test_immediate_scope_creation(self):
        # Create templates in scope a then call in scope b. make_template should
        # capture the scope the first time it is called, and make_immediate_template
        # should capture the scope at construction time.
        with tf.variable_scope("ctor_scope"):
            tmpl_immed = template.make_template("a", var_scoped_function, True)  # create scope here
            tmpl_defer = template.make_template("b", var_scoped_function, False)  # default: create scope at __call__
        with tf.variable_scope("call_scope"):
            inner_imm_var = tmpl_immed()
            inner_defer_var = tmpl_defer()
        outer_imm_var = tmpl_immed()
        outer_defer_var = tmpl_defer()

        self.assertNotEqual(inner_imm_var, inner_defer_var)
        self.assertEqual(outer_imm_var, inner_imm_var)
        self.assertEqual(outer_defer_var, inner_defer_var)

        self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name)
        self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name)
    def test_nested_templates(self):
        def nested_template():
            nested1 = template.make_template("nested", var_scoped_function)
            nested2 = template.make_template("nested", var_scoped_function)
            v1 = nested1()
            v2 = nested2()
            self.assertNotEqual(v1, v2)
            return v2

        tmpl1 = template.make_template("s1", nested_template)
        tmpl2 = template.make_template("s1", nested_template)

        v1 = tmpl1()
        v2 = tmpl1()
        v3 = tmpl2()
        self.assertEqual(v1, v2)
        self.assertNotEqual(v1, v3)
        self.assertEqual("s1/nested_1/dummy:0", v1.name)
        self.assertEqual("s1_2/nested_1/dummy:0", v3.name)
Example #40
0
  def test_internal_variable_reuse(self):
    def nested():
      with tf.variable_scope("nested") as vs:
        v1 = tf.get_variable("x", initializer=tf.zeros_initializer, shape=[])
      with tf.variable_scope(vs, reuse=True):
        v2 = tf.get_variable("x")
      self.assertEqual(v1, v2)
      return v1

    tmpl1 = template.make_template("s1", nested)
    tmpl2 = template.make_template("s1", nested)

    v1 = tmpl1()
    v2 = tmpl1()
    v3 = tmpl2()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/nested/x:0", v1.name)
    self.assertEqual("s1_1/nested/x:0", v3.name)
Example #41
0
  def test_global_variables(self):
    # Make sure global_variables are created.
    with variable_scope.variable_scope("foo"):
      # Create two templates with the same name, ensure scopes are made unique.
      ta = template.make_template("bar", variable_scoped_function, True)
      if context.executing_eagerly():
        tb = template.make_template("s", function_with_side_create,
                                    trainable=False)
      else:
        tb = template.make_template("s", function_with_create, trainable=False)

    # Initially there are not variables created.
    self.assertEqual([], list(ta.global_variables))
    self.assertEqual([], list(tb.global_variables))
    # After calling there are variables created.
    ta()
    tb()
    # Ensure we can get the scopes before either template is actually called.
    self.assertEqual(1, len(ta.global_variables))
    self.assertEqual(2, len(tb.global_variables))
def real_nvp_default_template(hidden_layers,
                              shift_only=False,
                              activation=nn_ops.relu,
                              name=None,
                              *args,
                              **kwargs):
    """Build a scale-and-shift function using a multi-layer neural network.

  This will be wrapped in a make_template to ensure the variables are only
  created once. It takes the `d`-dimensional input x[0:d] and returns the `D-d`
  dimensional outputs `loc` ("mu") and `log_scale` ("alpha").

  Arguments:
    hidden_layers: Python `list`-like of non-negative integer, scalars
      indicating the number of units in each hidden layer. Default: `[512, 512].
    shift_only: Python `bool` indicating if only the `shift` term shall be
      computed (i.e. NICE bijector). Default: `False`.
    activation: Activation function (callable). Explicitly setting to `None`
      implies a linear activation.
    name: A name for ops managed by this function. Default:
      "real_nvp_default_template".
    *args: `tf.layers.dense` arguments.
    **kwargs: `tf.layers.dense` keyword arguments.

  Returns:
    shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]).
    log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]).

  Raises:
    NotImplementedError: if rightmost dimension of `inputs` is unknown prior to
      graph execution.
  """

    with ops.name_scope(name, "real_nvp_default_template"):

        def _fn(x, output_units):
            """Fully connected MLP parameterized via `real_nvp_template`."""
            for units in hidden_layers:
                x = layers.dense(inputs=x,
                                 units=units,
                                 activation=activation,
                                 *args,
                                 **kwargs)
            x = layers.dense(inputs=x,
                             units=(1 if shift_only else 2) * output_units,
                             activation=None,
                             *args,
                             **kwargs)
            if shift_only:
                return x, None
            shift, log_scale = array_ops.split(x, 2, axis=-1)
            return shift, log_scale

        return template_ops.make_template("real_nvp_default_template", _fn)
    def test_checkpointable_save_restore(self):
        def _templated():
            v = variable_scope.get_variable(
                "v",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            v2 = variable_scope.get_variable(
                "v2",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            return v, v + 1., v2

        save_template = template.make_template("s1", _templated)
        v1_save, _, v2_save = save_template()
        optimizer = adam.AdamOptimizer(0.0)
        save_root = checkpointable_utils.Checkpoint(my_template=save_template,
                                                    optimizer=optimizer)
        optimizer.minimize(v1_save.read_value)
        self.evaluate([v.initializer for v in optimizer.variables()])
        self.evaluate(v1_save.assign([12.]))
        self.evaluate(v2_save.assign([14.]))
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        save_path = save_root.save(checkpoint_prefix)

        load_template = template.make_template("s2", _templated)
        load_optimizer = adam.AdamOptimizer(0.0)
        load_root = checkpointable_utils.Checkpoint(my_template=load_template,
                                                    optimizer=load_optimizer)
        status = load_root.restore(save_path)
        var, var_plus_one, var2 = load_template()
        load_optimizer.minimize(var.read_value)
        self.assertEqual(2, len(load_template._checkpoint_dependencies))
        self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
        self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([12.], self.evaluate(var))
        self.assertAllEqual([13.], self.evaluate(var_plus_one))
        self.assertAllEqual([14.], self.evaluate(var2))
Example #44
0
  def test_immediate_scope_creation(self):
    # Create templates in scope a then call in scope b. make_template should
    # capture the scope the first time it is called, and make_immediate_template
    # should capture the scope at construction time.
    with tf.variable_scope("ctor_scope"):
      tmpl_immed = template.make_template(
          "a", var_scoped_function, True)  # create scope here
      tmpl_defer = template.make_template(
          "b", var_scoped_function, False)  # default: create scope at __call__
    with tf.variable_scope("call_scope"):
      inner_imm_var = tmpl_immed()
      inner_defer_var = tmpl_defer()
    outer_imm_var = tmpl_immed()
    outer_defer_var = tmpl_defer()

    self.assertNotEqual(inner_imm_var, inner_defer_var)
    self.assertEqual(outer_imm_var, inner_imm_var)
    self.assertEqual(outer_defer_var, inner_defer_var)

    self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name)
    self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name)
Example #45
0
def masked_autoregressive_default_template2(hidden_layers,
                                            shift_only=False,
                                            activation=tf.nn.sigmoid,
                                            log_scale_min_clip=-5.,
                                            log_scale_max_clip=3.,
                                            log_scale_clip_gradient=False,
                                            name=None,
                                            *args,
                                            **kwargs):
    with ops.name_scope(name,
                        "masked_autoregressive_default_template",
                        values=[log_scale_min_clip, log_scale_max_clip]):

        def _fn(x):
            """MADE parameterized via `masked_autoregressive_default_template`."""
            # TODO(b/67594795): Better support of dynamic shape.
            input_depth = x.shape.with_rank_at_least(1)[-1].value
            if input_depth is None:
                raise NotImplementedError(
                    "Rightmost dimension must be known prior to graph execution."
                )
            input_shape = (np.int32(x.shape.as_list()) if
                           x.shape.is_fully_defined() else array_ops.shape(x))
            for i, units in enumerate(hidden_layers):
                x = masked_dense(inputs=x,
                                 units=units,
                                 num_blocks=input_depth,
                                 exclusive=True if i == 0 else False,
                                 activation=activation,
                                 *args,
                                 **kwargs)
            x = masked_dense(inputs=x,
                             units=(1 if shift_only else 2) * input_depth,
                             num_blocks=input_depth,
                             activation=None,
                             *args,
                             **kwargs)
            if shift_only:
                x = array_ops.reshape(x, shape=input_shape)
                return x
            x = array_ops.reshape(x,
                                  shape=array_ops.concat([input_shape, [2]],
                                                         axis=0))
            shift, log_scale = array_ops.unstack(x, num=2, axis=-1)
            which_clip = (math_ops.clip_by_value if log_scale_clip_gradient
                          else _clip_by_value_preserve_grad)
            log_scale = which_clip(log_scale, log_scale_min_clip,
                                   log_scale_max_clip)
            return shift, log_scale

        return template_ops.make_template(
            "masked_autoregressive_default_template", _fn)
Example #46
0
  def __init__(self,
               f,
               g,
               num_layers=1,
               f_side_input=None,
               g_side_input=None,
               use_efficient_backprop=True):

    if isinstance(f, list):
      assert len(f) == num_layers
    else:
      f = [f] * num_layers

    if isinstance(g, list):
      assert len(g) == num_layers
    else:
      g = [g] * num_layers

    scope_prefix = "revblock/revlayer_%d/"
    f_scope = scope_prefix + "f"
    g_scope = scope_prefix + "g"

    f = [
        template.make_template(f_scope % i, fn, create_scope_now_=True)
        for i, fn in enumerate(f)
    ]
    g = [
        template.make_template(g_scope % i, fn, create_scope_now_=True)
        for i, fn in enumerate(g)
    ]

    self.f = f
    self.g = g

    self.num_layers = num_layers
    self.f_side_input = f_side_input or []
    self.g_side_input = g_side_input or []

    self._use_efficient_backprop = use_efficient_backprop
Example #47
0
 def test_template_with_empty_name(self):
     tpl = template.make_template("", variable_scoped_function)
     with variable_scope.variable_scope("outer"):
         x = variable_scope.get_variable("x", [])
         v = tpl()
     self.assertEqual("outer/", tpl.variable_scope_name)
     self.assertEqual("outer//dummy:0", v.name)
     if context.executing_eagerly():
         # In eager mode `x` is not visible to the template since the template does
         # not rely on global collections.
         self.assertEqual([v], tpl.variables)
     else:
         self.assertEqual([x, v], tpl.variables)
Example #48
0
  def test_non_trainable_variables(self):
    # Make sure non_trainable_variables are created.
    with variable_scope.variable_scope("foo2"):
      ta = template.make_template("a", variable_scoped_function,
                                  trainable=True)
      tb = template.make_template("b", variable_scoped_function,
                                  trainable=False)
    # Initially there are not variables created.
    self.assertEqual([], list(ta.variables))
    self.assertEqual([], list(tb.variables))
    # After calling there are variables created.
    ta()
    tb()
    # Check the trainable and non_trainable variables.
    self.assertEqual(1, len(ta.trainable_variables))
    self.assertEqual([], list(ta.non_trainable_variables))

    self.assertEqual([], list(tb.trainable_variables))
    self.assertEqual(1, len(tb.non_trainable_variables))
    # Ensure variables returns all the variables.
    self.assertEqual(1, len(ta.variables))
    self.assertEqual(1, len(tb.variables))
    def __init__(self,
                 f,
                 g,
                 num_layers=1,
                 f_side_input=None,
                 g_side_input=None,
                 use_efficient_backprop=True):

        if isinstance(f, list):
            assert len(f) == num_layers
        else:
            f = [f] * num_layers

        if isinstance(g, list):
            assert len(g) == num_layers
        else:
            g = [g] * num_layers

        scope_prefix = "revblock/revlayer_%d/"
        f_scope = scope_prefix + "f"
        g_scope = scope_prefix + "g"

        f = [
            template.make_template(f_scope % i, fn, create_scope_now_=True)
            for i, fn in enumerate(f)
        ]
        g = [
            template.make_template(g_scope % i, fn, create_scope_now_=True)
            for i, fn in enumerate(g)
        ]

        self.f = f
        self.g = g

        self.num_layers = num_layers
        self.f_side_input = f_side_input or []
        self.g_side_input = g_side_input or []

        self._use_efficient_backprop = use_efficient_backprop
  def test_trackable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      return v, v + 1., v2

    save_template = template.make_template("s1", _templated)
    v1_save, _, v2_save = save_template()
    optimizer = adam.AdamOptimizer(0.0)
    save_root = util.Checkpoint(
        my_template=save_template, optimizer=optimizer)
    optimizer.minimize(v1_save.read_value)
    self.evaluate([v.initializer for v in optimizer.variables()])
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_optimizer = adam.AdamOptimizer(0.0)
    load_root = util.Checkpoint(
        my_template=load_template, optimizer=load_optimizer)
    status = load_root.restore(save_path)
    var, var_plus_one, var2 = load_template()
    load_optimizer.minimize(var.read_value)
    self.assertEqual(2, len(load_template._checkpoint_dependencies))
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
Example #51
0
  def test_scope_access(self):
    # Ensure that we can access the scope inside the template, because the name
    # of that scope may be different from the name we pass to make_template, due
    # to having been made unique by variable_scope.
    with variable_scope.variable_scope("foo"):
      # Create two templates with the same name, ensure scopes are made unique.
      ta = template.make_template("bar", variable_scoped_function, True)
      tb = template.make_template("bar", variable_scoped_function, True)

    # Ensure we can get the scopes before either template is actually called.
    self.assertEqual(ta.variable_scope.name, "foo/bar")
    self.assertEqual(tb.variable_scope.name, "foo/bar_1")

    with variable_scope.variable_scope("foo_2"):
      # Create a template which defers scope creation.
      tc = template.make_template("blah", variable_scoped_function, False)

    # Before we call the template, the scope property will be set to None.
    self.assertEqual(tc.variable_scope, None)
    tc()

    # Template is called at the top level, so there is no preceding "foo_2".
    self.assertEqual(tc.variable_scope.name, "blah")
Example #52
0
  def test_nested_eager_templates_raises_error(self):

    def nested_template():
      nested1 = template.make_template("nested", variable_scoped_function)
      nested2 = template.make_template("nested", variable_scoped_function)
      v1 = nested1()
      v2 = nested2()
      self.assertNotEqual(v1, v2)
      return v2

    with context.eager_mode():
      tmpl1 = template.make_template("s1", nested_template)
      with self.assertRaisesRegexp(
          ValueError, "Nested EagerTemaplates are not currently supported."):
        tmpl1()
Example #53
0
    def create_template_fn(
            self,
            name: str,
    ) -> Callable[[tf.Tensor], Tuple[tf.Tensor, tf.Tensor]]:
        """
        Creates simple shallow network. Note that this function will return a
        tensorflow template.
        Args:
            name: a scope name of the network
        Returns:
            a template function
        """
        def _shift_and_log_scale_fn(x: tf.Tensor, y_label: tf.Tensor = None):
            """NN is a shallow, 3 convolutions with 512 units: 3x3, 1x1, 3x3, the last one returns shift and logscale
            """
            shape = K.int_shape(x)
            num_channels = shape[3]

            with tf.variable_scope("BlockNN"):
                h = x
                # Concatenate conditioning labels with x.
                # Just in the shift and log scale fn should be fine...
                h = ops.conv2d("l_1", h, self.width)
                depth = K.int_shape(h)[-1]
                label_size = K.int_shape(y_label)[-1]
                dense_w = tf.get_variable(
                    "dense_w",
                    shape=(label_size, depth),
                    initializer=tf.contrib.layers.xavier_initializer())
                dense_b = tf.get_variable(
                    "dense_b",
                    shape=(depth, ),
                    initializer=tf.contrib.layers.xavier_initializer())

                conditioning_y = tf.nn.xw_plus_b(y_label, dense_w, dense_b)
                h = h + conditioning_y[:, None, None, :]
                h = self.activation_fn(h)  # 3x3 filter
                h = self.activation_fn(
                    ops.conv2d("l_2", h, self.width, filter_size=[1, 1]))
                # create shift and log_scale with zero initialization
                shift_log_scale = ops.conv2d_zeros("l_last", h,
                                                   2 * num_channels)
                shift = shift_log_scale[:, :, :, 0::2]
                log_scale = shift_log_scale[:, :, :, 1::2]
                log_scale = tf.clip_by_value(log_scale, -15.0, 15.0)
                return shift, log_scale

        return template_ops.make_template(name, _shift_and_log_scale_fn)
Example #54
0
def _shift_and_log_scale_fn_template(name):
    def _shift_and_log_scale_fn(x: tf.Tensor):
        shape = K.int_shape(x)
        num_channels = shape[3]
        # nn definition
        h = tf_layers.conv2d(x, num_outputs=num_channels, kernel_size=3)
        h = tf_layers.conv2d(h, num_outputs=num_channels // 2, kernel_size=3)
        # create shift and log_scale
        shift = tf_layers.conv2d(h, num_outputs=num_channels, kernel_size=3)
        log_scale = tf_layers.conv2d(h,
                                     num_outputs=num_channels,
                                     kernel_size=3,
                                     activation_fn=None)
        log_scale = tf.clip_by_value(log_scale, -15.0, 15.0)
        return shift, log_scale

    return template_ops.make_template(name, _shift_and_log_scale_fn)