Пример #1
0
    def testPartitioners(self):
        if tf.executing_eagerly():
            self.skipTest("Eager does not support partitioned variables.")

        partitioners = {
            "w": tf.variable_axis_size_partitioner(10),
            "b": tf.variable_axis_size_partitioner(8),
        }

        module = snt.nets.ConvNet2DTranspose(
            output_channels=self.output_channels,
            output_shapes=self.output_shapes,
            kernel_shapes=self.kernel_shapes,
            strides=self.strides,
            paddings=self.paddings,
            partitioners=partitioners)

        input_shape = [10, 100, 100, 3]
        input_to_net = tf.placeholder(tf.float32, shape=input_shape)

        _ = module(input_to_net)

        for layer in module._layers:
            self.assertEqual(type(layer.w), variables.PartitionedVariable)
            self.assertEqual(type(layer.b), variables.PartitionedVariable)
Пример #2
0
  def testCheckPartitioners(self):
    partitioners = {"key_a": tf.variable_axis_size_partitioner(10),
                    "key_c": tf.variable_axis_size_partitioner(10)}
    keys = ["key_a", "key_b"]
    self.assertRaisesRegexp(KeyError,
                            "Invalid partitioner keys.*",
                            snt.check_partitioners,
                            partitioners=partitioners,
                            keys=keys)

    del partitioners["key_c"]
    partitioners["key_b"] = "not a function"
    self.assertRaisesRegexp(TypeError,
                            "Partitioner for.*",
                            snt.check_partitioners,
                            partitioners=partitioners,
                            keys=keys)

    partitioners["key_b"] = {"key_c": "not a function"}
    self.assertRaisesRegexp(TypeError,
                            "Partitioner for.*",
                            snt.check_partitioners,
                            partitioners=partitioners,
                            keys=keys)

    partitioners["key_b"] = {
        "key_c": tf.variable_axis_size_partitioner(10),
        "key_d": tf.variable_axis_size_partitioner(10),
    }
    snt.check_partitioners(partitioners=partitioners, keys=keys)
Пример #3
0
    def testCheckPartitioners(self):
        partitioners = {
            "key_a": tf.variable_axis_size_partitioner(10),
            "key_c": tf.variable_axis_size_partitioner(10)
        }
        keys = ["key_a", "key_b"]
        self.assertRaisesRegexp(KeyError,
                                "Invalid partitioner keys.*",
                                snt.check_partitioners,
                                partitioners=partitioners,
                                keys=keys)

        del partitioners["key_c"]
        partitioners["key_b"] = "not a function"
        self.assertRaisesRegexp(TypeError,
                                "Partitioner for.*",
                                snt.check_partitioners,
                                partitioners=partitioners,
                                keys=keys)

        partitioners["key_b"] = {"key_c": "not a function"}
        self.assertRaisesRegexp(TypeError,
                                "Partitioner for.*",
                                snt.check_partitioners,
                                partitioners=partitioners,
                                keys=keys)

        partitioners["key_b"] = {
            "key_c": tf.variable_axis_size_partitioner(10),
            "key_d": tf.variable_axis_size_partitioner(10)
        }
        snt.check_partitioners(partitioners=partitioners, keys=keys)
Пример #4
0
  def testPartitioners(self, use_peepholes, use_batch_norm_h, use_batch_norm_x,
                       use_batch_norm_c):
    batch_size = 2
    hidden_size = 4

    keys = _get_possible_initializer_keys(
        use_peepholes, use_batch_norm_h, use_batch_norm_x, use_batch_norm_c)
    partitioners = {
        key: tf.variable_axis_size_partitioner(10) for key in keys
    }

    # Test we can successfully create the LSTM with partitioners.
    lstm, wrapped_lstm = _construct_lstm(hidden_size=hidden_size,
                                         use_peepholes=use_peepholes,
                                         use_batch_norm_h=use_batch_norm_h,
                                         use_batch_norm_x=use_batch_norm_x,
                                         use_batch_norm_c=use_batch_norm_c,
                                         partitioners=partitioners)

    # Test we can build the LSTM
    inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    prev_cell = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    prev_hidden = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    wrapped_lstm(inputs, (prev_hidden, prev_cell))

    # Test that the variables are partitioned.
    var_names = _get_lstm_variable_names(lstm)
    for var_name in var_names:
      self.assertEqual(type(getattr(lstm, "_" + var_name)),
                       variables.PartitionedVariable)
Пример #5
0
  def testPartitioners(self, lstm_class, dim, use_bias):
    keys = snt.Conv2DLSTM.get_possible_initializer_keys(use_bias)
    partitioners = {
        key: tf.variable_axis_size_partitioner(10) for key in keys
    }

    batch_size = 2
    input_shape = (8,) * dim
    input_channels = 3
    output_channels = 5

    input_shape = (batch_size,) + input_shape + (input_channels,)
    output_shape = input_shape[:-1] + (output_channels,)

    inputs = tf.placeholder(tf.float32, shape=input_shape)
    prev_hidden = tf.placeholder(tf.float32, shape=output_shape)
    prev_cell = tf.placeholder(tf.float32, shape=output_shape)

    # Test we can successfully create the LSTM with partitioners.
    lstm = lstm_class(
        input_shape=input_shape[1:],
        output_channels=output_channels,
        kernel_shape=1,
        use_bias=use_bias,
        partitioners=partitioners)
    lstm(inputs, (prev_hidden, prev_cell))

    # Test that the variables are partitioned.
    for convolution in lstm.convolutions.values():
      for key in keys:
        self.assertEqual(type(getattr(convolution, key)),
                         variables.PartitionedVariable)
    def testControlDepsNone(self):
        with self.test_session() as session:
            c = tf.constant(1.0)
            with tf.control_dependencies([c]):
                # d get the control dependency.
                d = tf.constant(2.0)
                # Partitioned variables do not.
                var_x = tf.get_variable(
                    "x",
                    shape=[2],
                    initializer=tf.ones_initializer(),
                    partitioner=tf.variable_axis_size_partitioner(4))

                ops_before_read = session.graph.get_operations()
                var_x.as_tensor()  # Caches the ops for subsequent reads.
                reading_ops = [
                    op for op in session.graph.get_operations()
                    if op not in ops_before_read
                ]

            self.assertEqual([c.op], d.op.control_inputs)
            # Tests that no control dependencies are added to reading a partitioned
            # variable which is similar to reading a variable.
            for op in reading_ops:
                self.assertEqual([], op.control_inputs)
    def testConcat(self):
        with self.test_session() as session:
            var_x = tf.get_variable(
                "x",
                initializer=tf.constant([1., 2.]),
                partitioner=tf.variable_axis_size_partitioner(4))

            c = tf.constant(1.0)
            with tf.control_dependencies([c]):
                ops_before_concat = session.graph.get_operations()
                value = var_x._concat()  # pylint: disable=protected-access
                concat_ops = [
                    op for op in session.graph.get_operations()
                    if op not in ops_before_concat
                ]

            concat_control_inputs = [
                ci for op in concat_ops for ci in op.control_inputs
            ]
            self.assertTrue(
                c.op in concat_control_inputs,
                "var_x._concat() should get control dependencies from its scope."
            )
            tf.global_variables_initializer().run()
            self.assertAllClose(value.eval(), var_x.as_tensor().eval())
Пример #8
0
    def testPartitioners(self, use_peepholes, use_batch_norm_h,
                         use_batch_norm_x, use_batch_norm_c):
        batch_size = 2
        hidden_size = 4

        keys = snt.LSTM.get_possible_initializer_keys(use_peepholes,
                                                      use_batch_norm_h,
                                                      use_batch_norm_x,
                                                      use_batch_norm_c)
        partitioners = {
            key: tf.variable_axis_size_partitioner(10)
            for key in keys
        }

        # Test we can successfully create the LSTM with partitioners.
        lstm = snt.LSTM(hidden_size,
                        use_peepholes=use_peepholes,
                        use_batch_norm_h=use_batch_norm_h,
                        use_batch_norm_x=use_batch_norm_x,
                        use_batch_norm_c=use_batch_norm_c,
                        partitioners=partitioners)

        # Test we can build the LSTM
        inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
        prev_cell = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
        prev_hidden = tf.placeholder(tf.float32,
                                     shape=[batch_size, hidden_size])
        lstm(inputs, (prev_hidden, prev_cell))

        # Test that the variables are partitioned.
        var_names = _get_lstm_variable_names(lstm)
        for var_name in var_names:
            self.assertEqual(type(getattr(lstm, "_" + var_name)),
                             variables.PartitionedVariable)
Пример #9
0
    def testPartitioners(self, lstm_class, dim, use_bias):
        keys = snt.Conv2DLSTM.get_possible_initializer_keys(use_bias)
        partitioners = {
            key: tf.variable_axis_size_partitioner(10)
            for key in keys
        }

        batch_size = 2
        input_shape = (8, ) * dim
        input_channels = 3
        output_channels = 5

        input_shape = (batch_size, ) + input_shape + (input_channels, )
        output_shape = input_shape[:-1] + (output_channels, )

        inputs = tf.placeholder(tf.float32, shape=input_shape)
        prev_hidden = tf.placeholder(tf.float32, shape=output_shape)
        prev_cell = tf.placeholder(tf.float32, shape=output_shape)

        # Test we can successfully create the LSTM with partitioners.
        lstm = lstm_class(input_shape=input_shape[1:],
                          output_channels=output_channels,
                          kernel_shape=1,
                          use_bias=use_bias,
                          partitioners=partitioners)
        lstm(inputs, (prev_hidden, prev_cell))

        # Test that the variables are partitioned.
        for convolution in lstm.convolutions.values():
            for key in keys:
                self.assertEqual(type(getattr(convolution, key)),
                                 variables.PartitionedVariable)
  def benchmark_create_1000_partitions_with_100_parameter_servers(self):
    workers, _ = create_local_cluster(num_workers=1, num_ps=100)
    worker_sessions = [tf.Session(w.target) for w in workers]
    worker = worker_sessions[0]
    partition_sizes = (1, 512, 1024*32, 1024*128)

    partitioned = []

    for partition_size in partition_sizes:
      # max_shard_bytes is 4, shape is 1000*partition_size float32s which should
      # partition into 1000 shards, each containing partition_size float32s.
      print("Building partitioned variable with %d floats per partition"
            % partition_size)
      with tf.device(tf.train.replica_device_setter(ps_tasks=100)):
        partitioned_ix = tf.get_variable(
            "partitioned_%d" % partition_size,
            shape=[1000 * partition_size],
            dtype=tf.float32,
            # Each partition to have exactly N float32s
            partitioner=tf.variable_axis_size_partitioner(
                max_shard_bytes=4 * partition_size))
        # Concatenates along axis 0
        partitioned.append(tf.convert_to_tensor(partitioned_ix))

    tf.global_variables_initializer().run(session=worker)

    for ix, partition_size in enumerate(partition_sizes):
      print("Running benchmark having partitions with %d floats"
            % partition_size)
      self.run_op_benchmark(
          worker,
          partitioned[ix],
          name=("read_concat_1000_partitions_from_"
                "100_parameter_servers_partsize_%d_floats" % partition_size))
Пример #11
0
 def _create_conv(self, partitioned, name):
   hidden = tf.ones(shape=(1, 16, 16, 3))
   if partitioned:
     partitioners = {"w": tf.variable_axis_size_partitioner(4)}
   else:
     partitioners = None
   conv = snt.Conv2D(output_channels=3, kernel_shape=3, stride=1,
                     partitioners=partitioners, name=name)
   conv(hidden)
   return conv
Пример #12
0
 def _create_conv(self, partitioned, name):
   hidden = tf.ones(shape=(1, 16, 16, 3))
   if partitioned:
     partitioners = {"w": tf.variable_axis_size_partitioner(4)}
   else:
     partitioners = None
   conv = snt.Conv2D(output_channels=3, kernel_shape=3, stride=1,
                     partitioners=partitioners, name=name)
   conv(hidden)
   return conv
Пример #13
0
  def testPartitioners(self):
    partitioners = {
        "w": tf.variable_axis_size_partitioner(10),
        "b": tf.variable_axis_size_partitioner(8),
    }

    module = snt.nets.ConvNet2D(output_channels=self.output_channels,
                                kernel_shapes=self.kernel_shapes,
                                strides=self.strides,
                                paddings=self.paddings,
                                partitioners=partitioners)

    input_shape = [10, 100, 100, 3]
    input_to_net = tf.placeholder(tf.float32, shape=input_shape)

    _ = module(input_to_net)

    for layer in module._layers:
      self.assertEqual(type(layer.w), variables.PartitionedVariable)
      self.assertEqual(type(layer.b), variables.PartitionedVariable)
Пример #14
0
    def testPartitioners(self):
        partitioners = {
            "w": tf.variable_axis_size_partitioner(10),
            "b": tf.variable_axis_size_partitioner(8),
        }

        module = snt.nets.ConvNet2D(output_channels=self.output_channels,
                                    kernel_shapes=self.kernel_shapes,
                                    strides=self.strides,
                                    paddings=self.paddings,
                                    partitioners=partitioners)

        input_shape = [10, 100, 100, 3]
        input_to_net = tf.placeholder(tf.float32, shape=input_shape)

        _ = module(input_to_net)

        for layer in module._layers:
            self.assertEqual(type(layer.w), variables.PartitionedVariable)
            self.assertEqual(type(layer.b), variables.PartitionedVariable)
Пример #15
0
  def _testVariableAxisSizePartitioner(self, name, axis, max_shard_bytes,
                                       expected_axis_shards,
                                       expected_partitions,
                                       max_shards=None):
    partitioner = tf.variable_axis_size_partitioner(
        axis=axis, max_shard_bytes=max_shard_bytes, max_shards=max_shards)

    with tf.variable_scope("root", partitioner=partitioner):
      v0_list, v0_part = get_partitioned_variable_list(
          name, dtype=tf.float32, shape=(4, 8, 16, 32))
      self.assertEqual(len(v0_list), expected_axis_shards)
      self.assertAllEqual(v0_part, expected_partitions)
Пример #16
0
  def _testVariableAxisSizePartitioner(self, name, axis, max_shard_bytes,
                                       expected_axis_shards,
                                       expected_partitions,
                                       max_shards=None):
    partitioner = tf.variable_axis_size_partitioner(
        axis=axis, max_shard_bytes=max_shard_bytes, max_shards=max_shards)

    with tf.variable_scope("root", partitioner=partitioner):
      v0 = tf.get_variable(name, dtype=tf.float32, shape=(4, 8, 16, 32))
      v0_list = v0._get_variable_list()
      v0_part = v0._get_partitions()
      self.assertEqual(len(v0_list), expected_axis_shards)
      self.assertAllEqual(v0_part, expected_partitions)
Пример #17
0
    def module_fn():
        embeddings_shape = [len(keys_vocab), embedding_size]
        embedding_weights = tf.get_variable(
            name=_EMBEDDINGS_VAR_NAME,
            shape=embeddings_shape,
            dtype=tf.float32,
            initializer=tf.zeros_initializer(),
            trainable=is_trainable,
            partitioner=tf.variable_axis_size_partitioner(
                max_shard_bytes=shard_bytes  # 1Gb by default
            ))
        lookup_table = tf.contrib.lookup.index_table_from_tensor(
            mapping=keys_vocab, default_value=0)

        default_keys = tf.placeholder(dtype=tf.string, shape=[None])
        default_ids = lookup_table.lookup(default_keys)
        default_embeddings = tf.nn.embedding_lookup(
            params=embedding_weights,
            ids=default_ids,
            partition_strategy='div',
            max_norm=max_norm,
        )
        hub.add_signature('default', default_keys, default_embeddings)

        context_keys = tf.sparse_placeholder(dtype=tf.string,
                                             shape=(None, None))
        context_ids = lookup_table.lookup(context_keys)
        context_embeddings = tf.nn.safe_embedding_lookup_sparse(
            embedding_weights=embedding_weights,
            sparse_ids=context_ids,
            combiner=combiner,
            default_id=0,
            partition_strategy='div',
            max_norm=max_norm,
        )
        hub.add_signature('context', context_keys, context_embeddings)

        sequence_keys = tf.sparse_placeholder(dtype=tf.string,
                                              shape=(None, None, None))
        sequence_ids = lookup_table.lookup(sequence_keys)
        sequence_embeddings = tf.nn.safe_embedding_lookup_sparse(
            embedding_weights=embedding_weights,
            sparse_ids=sequence_ids,
            combiner=combiner,
            default_id=0,
            partition_strategy='div',
            max_norm=max_norm,
        )
        hub.add_signature('sequence', sequence_keys, sequence_embeddings)
Пример #18
0
  def testPartitioners(self):
    if tf.executing_eagerly():
      self.skipTest("Eager does not support partitioned variables.")

    partitioners = {
        "w": tf.variable_axis_size_partitioner(10),
        "b": tf.variable_axis_size_partitioner(8),
    }

    module = snt.nets.ConvNet2DTranspose(output_channels=self.output_channels,
                                         output_shapes=self.output_shapes,
                                         kernel_shapes=self.kernel_shapes,
                                         strides=self.strides,
                                         paddings=self.paddings,
                                         partitioners=partitioners)

    input_shape = [10, 100, 100, 3]
    input_to_net = tf.placeholder(tf.float32, shape=input_shape)

    _ = module(input_to_net)

    for layer in module._layers:
      self.assertEqual(type(layer.w), variables.PartitionedVariable)
      self.assertEqual(type(layer.b), variables.PartitionedVariable)
Пример #19
0
  def testVariableMapItems(self):
    hidden = tf.ones(shape=(1, 16, 16, 3))
    partitioner = tf.variable_axis_size_partitioner(4)
    conv = snt.Conv2D(output_channels=3,
                      kernel_shape=3,
                      stride=1,
                      partitioners={"w": partitioner})
    conv(hidden)
    variable_map = snt.get_normalized_variable_map(conv)
    items = snt.variable_map_items(variable_map)

    items_str = sorted((key, var.op.name) for key, var in items)
    self.assertEqual(items_str, [(u"b", u"conv_2d/b"),
                                 ("w", u"conv_2d/w/part_0"),
                                 ("w", u"conv_2d/w/part_1"),
                                 ("w", u"conv_2d/w/part_2")])
Пример #20
0
    def testVariableMapItems(self):
        hidden = tf.ones(shape=(1, 16, 16, 3))
        partitioner = tf.variable_axis_size_partitioner(4)
        conv = snt.Conv2D(output_channels=3,
                          kernel_shape=3,
                          stride=1,
                          partitioners={"w": partitioner})
        conv(hidden)
        variable_map = snt.get_normalized_variable_map(conv)
        items = snt.variable_map_items(variable_map)

        items_str = sorted((key, var.op.name) for key, var in items)
        self.assertEqual(items_str, [(u"b", u"conv_2d/b"),
                                     ("w", u"conv_2d/w/part_0"),
                                     ("w", u"conv_2d/w/part_1"),
                                     ("w", u"conv_2d/w/part_2")])
Пример #21
0
  def testPartitioners(self):
    # Partition embeddings such that there's one variable per vocabulary entry.
    partitioners = {"embeddings": tf.variable_axis_size_partitioner(
        4 * self._embed_dim)}
    embed_mod = snt.Embed(
        vocab_size=self._vocab_size,
        embed_dim=self._embed_dim,
        partitioners=partitioners)
    embeddings = embed_mod(tf.convert_to_tensor(self._ids))
    self.assertEqual(type(embed_mod.embeddings), variables.PartitionedVariable)
    self.assertEqual(len(embed_mod.embeddings), self._vocab_size)

    # Ensure that tf.nn.embedding_lookup() plays nicely with embedding
    # variables.
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(embeddings)
Пример #22
0
 def testPartitionedVariableMasking(self):
   partitioner = tf.variable_axis_size_partitioner(40)
   with self.cached_session() as session:
     with tf.variable_scope("", partitioner=partitioner):
       sparsity = tf.Variable(0.5, name="Sparsity")
       weights = tf.get_variable(
           "weights", initializer=tf.linspace(1.0, 100.0, 100))
       masked_weights = pruning.apply_mask(
           weights, scope=tf.get_variable_scope())
     p = pruning.Pruning(sparsity=sparsity)
     p._spec.threshold_decay = 0.0
     mask_update_op = p.mask_update_op()
     tf.global_variables_initializer().run()
     masked_weights_val = masked_weights.eval()
     session.run(mask_update_op)
     masked_weights_val = masked_weights.eval()
     self.assertAllEqual(np.count_nonzero(masked_weights_val), 50)
    def testConcat(self):
        with self.test_session() as session:
            var_x = tf.get_variable(
                "x", initializer=tf.constant([1.0, 2.0]), partitioner=tf.variable_axis_size_partitioner(4)
            )

            c = tf.constant(1.0)
            with tf.control_dependencies([c]):
                ops_before_concat = session.graph.get_operations()
                value = var_x._concat()  # pylint: disable=protected-access
                concat_ops = [op for op in session.graph.get_operations() if op not in ops_before_concat]

            concat_control_inputs = [ci for op in concat_ops for ci in op.control_inputs]
            self.assertTrue(
                c.op in concat_control_inputs, "var_x._concat() should get control dependencies from its scope."
            )
            tf.global_variables_initializer().run()
            self.assertAllClose(value.eval(), var_x.as_tensor().eval())
Пример #24
0
    def testPartitioners(self):
        # Partition embeddings such that there's one variable per vocabulary entry.
        partitioners = {
            "embeddings":
            tf.variable_axis_size_partitioner(4 * self._embed_dim)
        }
        embed_mod = snt.Embed(vocab_size=self._vocab_size,
                              embed_dim=self._embed_dim,
                              partitioners=partitioners)
        embeddings = embed_mod(tf.convert_to_tensor(self._ids))
        self.assertEqual(type(embed_mod.embeddings),
                         variables.PartitionedVariable)
        self.assertEqual(len(embed_mod.embeddings), self._vocab_size)

        # Ensure that tf.nn.embedding_lookup() plays nicely with embedding
        # variables.
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(embeddings)
Пример #25
0
  def testGetNormalizedVariableMapWithPartitionedVariable(self):
    hidden = tf.ones(shape=(1, 16, 16, 3))
    partitioner = tf.variable_axis_size_partitioner(4)
    conv = snt.Conv2D(output_channels=3,
                      kernel_shape=3,
                      stride=1,
                      partitioners={"w": partitioner})
    conv(hidden)
    variable_map = snt.get_normalized_variable_map(conv,
                                                   group_sliced_variables=True)
    self.assertEqual(len(variable_map), 2)
    self.assertEqual(variable_map["b"], conv.b)
    self.assertEqual(len(variable_map["w"]), 3)

    variable_map = snt.get_normalized_variable_map(conv,
                                                   group_sliced_variables=False)
    self.assertEqual(variable_map["b"], conv.b)
    self.assertEqual(
        set(variable_map), set(["b", "w/part_0", "w/part_1", "w/part_2"]))
Пример #26
0
    def testGetNormalizedVariableMapWithPartitionedVariable(self):
        hidden = tf.ones(shape=(1, 16, 16, 3))
        partitioner = tf.variable_axis_size_partitioner(4)
        conv = snt.Conv2D(output_channels=3,
                          kernel_shape=3,
                          stride=1,
                          partitioners={"w": partitioner})
        conv(hidden)
        variable_map = snt.get_normalized_variable_map(
            conv, group_sliced_variables=True)
        self.assertEqual(len(variable_map), 2)
        self.assertEqual(variable_map["b"], conv.b)
        self.assertEqual(len(variable_map["w"]), 3)

        variable_map = snt.get_normalized_variable_map(
            conv, group_sliced_variables=False)
        self.assertEqual(variable_map["b"], conv.b)
        self.assertEqual(set(variable_map),
                         set(["b", "w/part_0", "w/part_1", "w/part_2"]))
Пример #27
0
  def testPartitioners(self):
    batch_size = 2
    hidden_size = 4

    # Test we can successfully create the GRU with partitioners.
    keys = snt.GRU.POSSIBLE_KEYS
    partitioners = {
        key: tf.variable_axis_size_partitioner(10) for key in keys
    }
    gru = snt.GRU(hidden_size, partitioners=partitioners)

    # Test we can build the GRU.
    inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    state = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    gru(inputs, state)

    # Test that the variables are partitioned.
    for key in keys:
      self.assertEqual(type(getattr(gru, "_" + key)),
                       variables.PartitionedVariable)
    def testControlDepsNone(self):
        with self.test_session() as session:
            c = tf.constant(1.0)
            with tf.control_dependencies([c]):
                # d get the control dependency.
                d = tf.constant(2.0)
                # Partitioned variables do not.
                var_x = tf.get_variable(
                    "x", shape=[2], initializer=tf.ones_initializer(), partitioner=tf.variable_axis_size_partitioner(4)
                )

                ops_before_read = session.graph.get_operations()
                var_x.as_tensor()  # Caches the ops for subsequent reads.
                reading_ops = [op for op in session.graph.get_operations() if op not in ops_before_read]

            self.assertEqual([c.op], d.op.control_inputs)
            # Tests that no control dependencies are added to reading a partitioned
            # variable which is similar to reading a variable.
            for op in reading_ops:
                self.assertEqual([], op.control_inputs)
  def testVariableAxisSizePartitioner(self):
    with self.test_session():
      # Create a partitioned variable of shape (4, 8, 16, 32) type float32
      # Bytes per slice along the given axes:

      # 8 * 16 * 32 * sizeof(float32) = 16384 / slice on axis 0
      # 4 * 16 * 32 * sizeof(float32) = 8192 / slice on axis 1
      # 4 * 8 * 32 * sizeof(float32) = 4096 / slice on axis 2
      # 4 * 8 * 16 * sizeof(float32) = 2048 / slice on axis 3

      # Now partition it in different ways...

      # No need to slice: bytes_per_slice * dim0 = 65536 < max_shard_bytes
      self._testVariableAxisSizePartitioner("v0", axis=0,
                                            max_shard_bytes=131072,
                                            expected_axis_shards=1,
                                            expected_partitions=(1, 1, 1, 1))

      # Slice exactly once: bytes_per_slice * dim1 = 65536 = max_shard_bytes
      self._testVariableAxisSizePartitioner("v1", axis=1,
                                            max_shard_bytes=65536,
                                            expected_axis_shards=1,
                                            expected_partitions=(1, 1, 1, 1))

      # Slice into 2 parts:
      # bytes_per_slice = 4096
      # slices_per_shard = 32768 / 4096 = 8
      # axis_shards = 16 / 8 = 2
      self._testVariableAxisSizePartitioner("v2", axis=2,
                                            max_shard_bytes=32768,
                                            expected_axis_shards=2,
                                            expected_partitions=(1, 1, 2, 1))

      # This partitioner makes sure we maximize the number of shards along
      # axis 3. Slice it into 32 parts:
      # bytes_per_slice = 2048
      # slices_per_shard = 2048 / 2048 = 1
      # axis_shards = 32 / 1 = 32
      self._testVariableAxisSizePartitioner("v3a", axis=3,
                                            max_shard_bytes=2048,
                                            expected_axis_shards=32,
                                            expected_partitions=(1, 1, 1, 32))

      # This partitioner makes sure we do not go past the bound of allowable
      # number of shards along axis 3.
      # Slice into 32 parts:
      # bytes_per_slice = 2048
      # slices_per_shard = max(1, 1024 / 2048) = 1
      # axis_shards = 32 / 1 = 32
      # Slice into max of 32 parts because: max_shard_bytes < bytes_per_slice
      self._testVariableAxisSizePartitioner("v3b", axis=3,
                                            max_shard_bytes=1024,
                                            expected_axis_shards=32,
                                            expected_partitions=(1, 1, 1, 32))

      # Specify max_shards so that it won't affect sharding.
      self._testVariableAxisSizePartitioner("v3c", axis=3,
                                            max_shard_bytes=1024,
                                            expected_axis_shards=32,
                                            expected_partitions=(1, 1, 1, 32),
                                            max_shards=33)

      # Specify max_shards so that it will affect sharding.
      self._testVariableAxisSizePartitioner("v3d", axis=3,
                                            max_shard_bytes=1024,
                                            expected_axis_shards=2,
                                            expected_partitions=(1, 1, 1, 2),
                                            max_shards=2)

      # Use the partitioner with strings
      partitioner_axis3_str = tf.variable_axis_size_partitioner(
          axis=3, max_shard_bytes=32768, bytes_per_string_element=8)

      with tf.variable_scope("root", partitioner=partitioner_axis3_str):
        v3str = tf.get_variable(
            "v3str",
            initializer=np.array([""] * 4 * 8 * 16 * 32).reshape(4, 8, 16, 32),
            dtype=tf.string,
            shape=(4, 8, 16, 32))
        v3str_list = v3str._get_variable_list()
        v3str_part = v3str._get_partitions()

        # Now the estimated bytes_per_slice = 4*8*16*bytes_per_string_element
        # which is equal to 4096.  Setting a max_shard_bytes of 32768
        # and we should get a split of 4.
        # Slice into 4 parts:
        # bytes_per_slice = 4096
        # slices_per_shard = 32768 / 4096 = 8
        # axis_shards = 32 / 8 = 4
        self.assertEqual(len(v3str_list), 4)
        self.assertAllEqual(v3str_part, (1, 1, 1, 4))
Пример #30
0
    def testVariableAxisSizePartitioner(self):
        with self.test_session():
            # Create a partitioned variable of shape (4, 8, 16, 32) type float32
            # Bytes per slice along the given axes:

            # 8 * 16 * 32 * sizeof(float32) = 16384 / slice on axis 0
            # 4 * 16 * 32 * sizeof(float32) = 8192 / slice on axis 1
            # 4 * 8 * 32 * sizeof(float32) = 4096 / slice on axis 2
            # 4 * 8 * 16 * sizeof(float32) = 2048 / slice on axis 3

            # Now partition it in different ways...

            partitioner_axis0 = tf.variable_axis_size_partitioner(
                axis=0, max_shard_bytes=32768, bytes_per_string_element=8)

            with tf.variable_scope("root", partitioner=partitioner_axis0):
                v0_list, v0_part = get_partitioned_variable_list(
                    "v0", dtype=tf.float32, shape=(4, 8, 16, 32))
                # No need to slice: size_per_slice = 16384 < 32768 = max_shard_bytes
                self.assertEqual(len(v0_list), 1)
                self.assertAllEqual(v0_part, (1, 1, 1, 1))

            partitioner_axis1 = tf.variable_axis_size_partitioner(
                axis=1, max_shard_bytes=8192, bytes_per_string_element=8)

            with tf.variable_scope("root", partitioner=partitioner_axis1):
                v1_list, v1_part = get_partitioned_variable_list(
                    "v1", dtype=tf.float32, shape=(4, 8, 16, 32))
                # Slice exactly once: size_per_slice = 8192 == 8192 = max_shard_bytes
                self.assertEqual(len(v1_list), 1)
                self.assertAllEqual(v1_part, (1, 1, 1, 1))

            partitioner_axis2 = tf.variable_axis_size_partitioner(
                axis=2, max_shard_bytes=2048, bytes_per_string_element=8)

            with tf.variable_scope("root", partitioner=partitioner_axis2):
                v2_list, v2_part = get_partitioned_variable_list(
                    "v2", dtype=tf.float32, shape=(4, 8, 16, 32))
                # Slice into 2 parts:
                #   size_per_slice = 4096 == 2 * 2048 = max_shard_bytes
                self.assertEqual(len(v2_list), 2)
                self.assertAllEqual(v2_part, (1, 1, 2, 1))

            # This partitioner makes sure we maximize the number of shards
            # along axis 3
            partitioner_axis3_a = tf.variable_axis_size_partitioner(
                axis=3, max_shard_bytes=64, bytes_per_string_element=8)

            with tf.variable_scope("root", partitioner=partitioner_axis3_a):
                v3a_list, v3a_part = get_partitioned_variable_list(
                    "v3a", dtype=tf.float32, shape=(4, 8, 16, 32))
                # Slice into 32 parts: 2048 == 64 * 32
                self.assertEqual(len(v3a_list), 32)
                self.assertAllEqual(v3a_part, (1, 1, 1, 32))

            # This partitioner makes sure we go past the bound of allowable
            # number of shards along axis 3
            partitioner_axis3_b = tf.variable_axis_size_partitioner(
                axis=3, max_shard_bytes=32, bytes_per_string_element=8)

            with tf.variable_scope("root", partitioner=partitioner_axis3_b):
                v3b_list, v3b_part = get_partitioned_variable_list(
                    "v3b", dtype=tf.float32, shape=(4, 8, 16, 32))
                # Slice into the maximum of 32 parts because: 2048 > 32 * 32
                self.assertEqual(len(v3b_list), 32)
                self.assertAllEqual(v3b_part, (1, 1, 1, 32))

            # Use the partitioner with strings
            partitioner_axis3_str = tf.variable_axis_size_partitioner(
                axis=3, max_shard_bytes=1024, bytes_per_string_element=8)

            with tf.variable_scope("root", partitioner=partitioner_axis3_str):
                v3str_list, v3str_part = get_partitioned_variable_list(
                    "v3str",
                    initializer=np.array([""] * 4 * 8 * 16 * 32).reshape(
                        4, 8, 16, 32),
                    dtype=tf.string,
                    shape=(4, 8, 16, 32))

                # Now the estimated size_per_slice = 4*8*16*bytes_per_string_element
                # which is equal to 4096.  Setting a max_shard_bytes of 1024
                # and we should get a split of 4.
                self.assertEqual(len(v3str_list), 4)
                self.assertAllEqual(v3str_part, (1, 1, 1, 4))
Пример #31
0
    def testVariableAxisSizePartitioner(self):
        with self.test_session():
            # Create a partitioned variable of shape (4, 8, 16, 32) type float32
            # Bytes per slice along the given axes:

            # 8 * 16 * 32 * sizeof(float32) = 16384 / slice on axis 0
            # 4 * 16 * 32 * sizeof(float32) = 8192 / slice on axis 1
            # 4 * 8 * 32 * sizeof(float32) = 4096 / slice on axis 2
            # 4 * 8 * 16 * sizeof(float32) = 2048 / slice on axis 3

            # Now partition it in different ways...

            # No need to slice: bytes_per_slice * dim0 = 65536 < max_shard_bytes
            self._testVariableAxisSizePartitioner("v0",
                                                  axis=0,
                                                  max_shard_bytes=131072,
                                                  expected_axis_shards=1,
                                                  expected_partitions=(1, 1, 1,
                                                                       1))

            # Slice exactly once: bytes_per_slice * dim1 = 65536 = max_shard_bytes
            self._testVariableAxisSizePartitioner("v1",
                                                  axis=1,
                                                  max_shard_bytes=65536,
                                                  expected_axis_shards=1,
                                                  expected_partitions=(1, 1, 1,
                                                                       1))

            # Slice into 2 parts:
            # bytes_per_slice = 4096
            # slices_per_shard = 32768 / 4096 = 8
            # axis_shards = 16 / 8 = 2
            self._testVariableAxisSizePartitioner("v2",
                                                  axis=2,
                                                  max_shard_bytes=32768,
                                                  expected_axis_shards=2,
                                                  expected_partitions=(1, 1, 2,
                                                                       1))

            # This partitioner makes sure we maximize the number of shards along
            # axis 3. Slice it into 32 parts:
            # bytes_per_slice = 2048
            # slices_per_shard = 2048 / 2048 = 1
            # axis_shards = 32 / 1 = 32
            self._testVariableAxisSizePartitioner("v3a",
                                                  axis=3,
                                                  max_shard_bytes=2048,
                                                  expected_axis_shards=32,
                                                  expected_partitions=(1, 1, 1,
                                                                       32))

            # This partitioner makes sure we do not go past the bound of allowable
            # number of shards along axis 3.
            # Slice into 32 parts:
            # bytes_per_slice = 2048
            # slices_per_shard = max(1, 1024 / 2048) = 1
            # axis_shards = 32 / 1 = 32
            # Slice into max of 32 parts because: max_shard_bytes < bytes_per_slice
            self._testVariableAxisSizePartitioner("v3b",
                                                  axis=3,
                                                  max_shard_bytes=1024,
                                                  expected_axis_shards=32,
                                                  expected_partitions=(1, 1, 1,
                                                                       32))

            # Specify max_shards so that it won't affect sharding.
            self._testVariableAxisSizePartitioner("v3c",
                                                  axis=3,
                                                  max_shard_bytes=1024,
                                                  expected_axis_shards=32,
                                                  expected_partitions=(1, 1, 1,
                                                                       32),
                                                  max_shards=33)

            # Specify max_shards so that it will affect sharding.
            self._testVariableAxisSizePartitioner("v3d",
                                                  axis=3,
                                                  max_shard_bytes=1024,
                                                  expected_axis_shards=2,
                                                  expected_partitions=(1, 1, 1,
                                                                       2),
                                                  max_shards=2)

            # Use the partitioner with strings
            partitioner_axis3_str = tf.variable_axis_size_partitioner(
                axis=3, max_shard_bytes=32768, bytes_per_string_element=8)

            with tf.variable_scope("root", partitioner=partitioner_axis3_str):
                v3str = tf.get_variable("v3str",
                                        initializer=np.array([""] * 4 * 8 *
                                                             16 * 32).reshape(
                                                                 4, 8, 16, 32),
                                        dtype=tf.string,
                                        shape=(4, 8, 16, 32))
                v3str_list = v3str._get_variable_list()
                v3str_part = v3str._get_partitions()

                # Now the estimated bytes_per_slice = 4*8*16*bytes_per_string_element
                # which is equal to 4096.  Setting a max_shard_bytes of 32768
                # and we should get a split of 4.
                # Slice into 4 parts:
                # bytes_per_slice = 4096
                # slices_per_shard = 32768 / 4096 = 8
                # axis_shards = 32 / 8 = 4
                self.assertEqual(len(v3str_list), 4)
                self.assertAllEqual(v3str_part, (1, 1, 1, 4))
Пример #32
0
class _DistributedConvolutionalNeuralFields:
    """Implements Liu et al. (2015): Learning Depth from Single Monocular Images
    Using Deep Convolutional Neural Fields.  DOI: 10.1109/TPAMI.2015.2505283
    """

    def __init__(self):
        self.patch_size = (100, 100)  # Liu et al.: 224x224
        self.sp_size = (40, 40)  # Liu et al.: super pixels, not patches
        self.gamma = 1
        self.epsilon = 1e-7  # For numerical stability

    def pair_indices(self, images):
        max_rows, max_cols = self.num_superpixels(images)
        left = []
        right = []
        for row in range(1, max_rows - 1):
            for col in range(2 - (row & 1), max_cols - 1, 2):
                pixel = row * max_cols + col
                for addend in [-max_cols, max_cols, -1, 1]:
                    left.append(pixel)
                    right.append(pixel + addend)
        return left, right

    def num_superpixels(self, images):
        rows = math.ceil(int(images.shape[1]) / self.sp_size[0])
        cols = math.ceil(int(images.shape[2]) / self.sp_size[1])
        return rows, cols

    @tfhelper.variable_scope('extract_superpixel')
    def superpixels(self, images):
        # TODO: over-segmentation instead of simple extraction
        superpixels = tf.extract_image_patches(images=images,
                                               ksizes=[1, *self.sp_size, 1],
                                               strides=[1, *self.sp_size, 1],
                                               rates=[1, 1, 1, 1],
                                               padding='SAME')
        return tf.reshape(superpixels, (int(images.shape[0]),
                                        -1,
                                        self.sp_size[0] * self.sp_size[1],
                                        int(images.shape[-1])))

    @tfhelper.variable_scope('extract_patch')
    def patches(self, images):
        # TODO: use superpixels for patch calculation
        patches = tf.extract_image_patches(images=images,
                                           ksizes=[1, *self.patch_size, 1],
                                           strides=[1, *self.sp_size, 1],
                                           rates=[1, 1, 1, 1],
                                           padding='SAME')
        return tf.reshape(patches, (int(images.shape[0]), -1, *self.patch_size,
                                    int(images.shape[-1])))

    @tfhelper.make_template('unary_layers')
    def unary_part_patch(self, image_patch):
        with tf.device('/job:worker/task:1'):
            temp = tf.layers.conv2d(image_patch, 64, 11, activation=tf.nn.relu)
            temp = tf.layers.max_pooling2d(temp, 2, 2)
        with tf.device('/job:worker/task:2'):
            temp = tf.layers.conv2d(temp, 256, 5, activation=tf.nn.relu)
            temp = tf.layers.max_pooling2d(temp, 2, 2)
            temp = tf.layers.conv2d(temp, 256, 3, activation=tf.nn.relu)
        with tf.device('/job:worker/task:3'):
            temp = tf.layers.conv2d(temp, 256, 3, activation=tf.nn.relu)
            temp = tf.layers.conv2d(temp, 256, 3, activation=tf.nn.relu)
            temp = tf.layers.max_pooling2d(temp, 2, 2)

        # Fit result into dense layer's 1D
        temp = tf.reshape(temp, [int(image_patch.shape[0]), -1])

        # temp = tf.layers.dense(temp, 4096, activation=tf.nn.relu)
        with tf.device('/job:worker/task:2'):
            temp = tf.layers.dense(temp, 128, activation=tf.nn.relu)
            temp = tf.layers.dense(temp, 16, activation=tf.nn.sigmoid)
            temp = tf.layers.dense(temp, 1, activation=None)
        return temp

    @tfhelper.variable_scope('unary',
                         partitioner=tf.variable_axis_size_partitioner((64 << 20) -1))
    def unary_part(self, images):
        patches = self.patches(images)
        return tf.map_fn(self.unary_part_patch, patches)

    @tfhelper.make_template('pairwise_layers')
    def pairwise_dense(self, similarities):
        return tf.layers.dense(similarities, 1, activation=None)

    @tfhelper.variable_scope('histogram')
    def color_histogram(self, superpixel):
        values = tf.reduce_sum(superpixel * (16777216., 65536., 256.), axis=-1)
        histogram = tf.histogram_fixed_width(values, (0, 16777216.),
                                             256, tf.float32)
        return histogram

    @tfhelper.variable_scope('similarity')
    def similarity(self, features, pairs):
        left = tf.map_fn(lambda batch: tf.gather(batch, pairs[0]), features)
        right = tf.map_fn(lambda batch: tf.gather(batch, pairs[1]), features)
        return tf.exp(-self.gamma * tf.norm(left - right, axis=2))

    @tfhelper.variable_scope('pairwise',
                         partitioner=tf.variable_axis_size_partitioner((64 << 20) -1))
    def pairwise_part(self, images):
        superpixels = self.superpixels(images)
        pairs = self.pair_indices(images)

        histograms = tf.map_fn(lambda batch: tf.map_fn(self.color_histogram,
                                                    batch), superpixels)

        # color difference similarity
        cdiff_sim = self.similarity(tf.reduce_mean(superpixels, axis=-1),
                                    pairs)
        # color histogram similarity
        histdiff_sim = self.similarity(histograms, pairs)
        # TODO: texture disparity

        # gather similarities
        similarities = tf.stack([cdiff_sim, histdiff_sim], axis=-1)

        return tf.map_fn(self.pairwise_dense, similarities)

    @tfhelper.variable_scope('loss')
    def loss_part(self, target, z, r):
        superpixels = self.superpixels(target)
        y = tf.reduce_mean(superpixels, axis=2)
        pairs = self.pair_indices(target)

        # See Liu et al. (2015) p. 5, eq. (9) - (11)
        def get_A(batch_item, R):
            I = tf.eye(int(R.shape[0]))
            R = tf.scatter_nd_update(R, list(zip(*pairs)),
                                     tf.squeeze(batch_item))
            R = tf.scatter_nd_update(R, list(zip(*pairs[::-1])),
                                     tf.squeeze(batch_item))
            D = tf.diag(tf.reduce_sum(R, axis=1))
            return I + D - R

        # Get reused helpers for loss calculation
        with tf.name_scope('calc_A'):
            # Define R outside of get_A to avoid UPDATE_OP being placed inside
            # loop (see https://github.com/tensorflow/tensorflow/issues/6087)
            R = tf.Variable(tf.zeros((int(r.shape[1]), int(r.shape[1]))),
                            trainable=False)
            A = tf.map_fn(lambda y: get_A(y, R), r,
                          parallel_iterations=1)

        zT = tf.transpose(z, [0, 2, 1])
        yT = tf.transpose(y, [0, 2, 1])

        # energy = E(y, x)
        with tf.name_scope('energy'):
            energy = tf.squeeze(yT @ A @ y - 2 * zT @ y + zT @ z)

        # Integral Z(x) = exp( -E(y, x) ) dy
        with tf.name_scope('integral'):
            fac = math.pi ** (int(r.shape[1]) / 2)
            fac /= (tf.matrix_determinant(A) ** .5) + self.epsilon
            inverseA = tf.matrix_inverse(A) + self.epsilon
            exp = tf.squeeze(tf.exp(zT @ inverseA @ z - zT @ z))
            Z = fac * exp + self.epsilon

        # Neg log-likelihood
        with tf.name_scope('nll'):
            loss = -tf.log(tf.exp(-energy) / Z + self.epsilon)

        # Mean over batch
        loss = tf.reduce_mean(loss, name='mean_loss')
        tf.losses.add_loss(loss)

        return loss

    def __call__(self, images, depths, train=True):
        images = tf.image.resize_images(images, [240, 320])
        depths = tf.image.resize_images(depths, [240, 320])

        z = self.unary_part(images)
        r = self.pairwise_part(images)
        loss = self.loss_part(depths, z, r)

        with tf.name_scope('output'):
            rows, cols = self.num_superpixels(images)
            output = tf.image.resize_images(tf.reshape(z, [-1, rows, cols, 1]),
                                            (int(depths.shape[1]),
                                            int(depths.shape[2])))

        with tf.name_scope('summaries'):
            tf.summary.image('Output', output, max_outputs=1)
            tf.summary.image('Input', images, max_outputs=1)
            tf.summary.image('Target', depths, max_outputs=1)

        optimizer = tf.train.GradientDescentOptimizer(0.1)
        return optimizer.minimize(loss,
                                  tf.train.get_or_create_global_step())
Пример #33
0
  def testVariableAxisSizePartitioner(self):
    with self.test_session():
      # Create a partitioned variable of shape (4, 8, 16, 32) type float32
      # Bytes per slice along the given axes:

      # 8 * 16 * 32 * sizeof(float32) = 16384 / slice on axis 0
      # 4 * 16 * 32 * sizeof(float32) = 8192 / slice on axis 1
      # 4 * 8 * 32 * sizeof(float32) = 4096 / slice on axis 2
      # 4 * 8 * 16 * sizeof(float32) = 2048 / slice on axis 3

      # Now partition it in different ways...

      partitioner_axis0 = tf.variable_axis_size_partitioner(
          axis=0, max_shard_bytes=32768, bytes_per_string_element=8)

      with tf.variable_scope("root", partitioner=partitioner_axis0):
        v0_list, v0_part = get_partitioned_variable_list(
            "v0", dtype=tf.float32, shape=(4, 8, 16, 32))
        # No need to slice: size_per_slice = 16384 < 32768 = max_shard_bytes
        self.assertEqual(len(v0_list), 1)
        self.assertAllEqual(v0_part, (1, 1, 1, 1))

      partitioner_axis1 = tf.variable_axis_size_partitioner(
          axis=1, max_shard_bytes=8192, bytes_per_string_element=8)

      with tf.variable_scope("root", partitioner=partitioner_axis1):
        v1_list, v1_part = get_partitioned_variable_list(
            "v1", dtype=tf.float32, shape=(4, 8, 16, 32))
        # Slice exactly once: size_per_slice = 8192 == 8192 = max_shard_bytes
        self.assertEqual(len(v1_list), 1)
        self.assertAllEqual(v1_part, (1, 1, 1, 1))

      partitioner_axis2 = tf.variable_axis_size_partitioner(
          axis=2, max_shard_bytes=2048, bytes_per_string_element=8)

      with tf.variable_scope("root", partitioner=partitioner_axis2):
        v2_list, v2_part = get_partitioned_variable_list(
            "v2", dtype=tf.float32, shape=(4, 8, 16, 32))
        # Slice into 2 parts:
        #   size_per_slice = 4096 == 2 * 2048 = max_shard_bytes
        self.assertEqual(len(v2_list), 2)
        self.assertAllEqual(v2_part, (1, 1, 2, 1))

      # This partitioner makes sure we maximize the number of shards
      # along axis 3
      partitioner_axis3_a = tf.variable_axis_size_partitioner(
          axis=3, max_shard_bytes=64, bytes_per_string_element=8)

      with tf.variable_scope("root", partitioner=partitioner_axis3_a):
        v3a_list, v3a_part = get_partitioned_variable_list(
            "v3a", dtype=tf.float32, shape=(4, 8, 16, 32))
        # Slice into 32 parts: 2048 == 64 * 32
        self.assertEqual(len(v3a_list), 32)
        self.assertAllEqual(v3a_part, (1, 1, 1, 32))

      # This partitioner makes sure we go past the bound of allowable
      # number of shards along axis 3
      partitioner_axis3_b = tf.variable_axis_size_partitioner(
          axis=3, max_shard_bytes=32, bytes_per_string_element=8)

      with tf.variable_scope("root", partitioner=partitioner_axis3_b):
        v3b_list, v3b_part = get_partitioned_variable_list(
            "v3b", dtype=tf.float32, shape=(4, 8, 16, 32))
        # Slice into the maximum of 32 parts because: 2048 > 32 * 32
        self.assertEqual(len(v3b_list), 32)
        self.assertAllEqual(v3b_part, (1, 1, 1, 32))

      # Use the partitioner with strings
      partitioner_axis3_str = tf.variable_axis_size_partitioner(
          axis=3, max_shard_bytes=1024, bytes_per_string_element=8)

      with tf.variable_scope("root", partitioner=partitioner_axis3_str):
        v3str_list, v3str_part = get_partitioned_variable_list(
            "v3str",
            initializer=np.array([""] * 4*8*16*32).reshape(4, 8, 16, 32),
            dtype=tf.string, shape=(4, 8, 16, 32))

        # Now the estimated size_per_slice = 4*8*16*bytes_per_string_element
        # which is equal to 4096.  Setting a max_shard_bytes of 1024
        # and we should get a split of 4.
        self.assertEqual(len(v3str_list), 4)
        self.assertAllEqual(v3str_part, (1, 1, 1, 4))