Exemplo n.º 1
0
 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
   """Returns a `Tensor`."""
   del weight_collections
   input_tensor = inputs.get(self)
   text_module = module.Module(
       self.module_spec, trainable=self.trainable and trainable)
   return self._get_dense_tensor_for_input_tensor(input_tensor, text_module)
Exemplo n.º 2
0
 def testGetExpectedImageSizeFromImageModuleInfo(self):
     with tf.Graph().as_default():
         spec = native_module.create_module_spec(image_module_fn_with_info)
         self.assertAllEqual(image_util.get_expected_image_size(spec),
                             [2, 4])
         m = module.Module(spec)
         self.assertAllEqual(image_util.get_expected_image_size(m), [2, 4])
Exemplo n.º 3
0
 def testModuleInNestedScope(self):
   with tf.Graph().as_default():
     with tf.compat.v1.variable_scope("foo"):
       m = module.Module(_ModuleSpec())
       result = m([1, 2])
     with tf.compat.v1.Session() as session:
       self.assertAllEqual(session.run(result), [2, 4])
Exemplo n.º 4
0
 def testModuleInterfaceGettersExplicitSignatureAndTags(self):
     """Tests that tags from Module(...) apply to module.get_*()."""
     m = module.Module(_ModuleSpec(), tags={"special"})
     self.assertItemsEqual(m.get_signature_names(), ["default", "extra"])
     self.assertItemsEqual(
         m.get_input_info_dict(signature="extra").keys(), ["x", "y"])
     self.assertItemsEqual(
         m.get_output_info_dict(signature="extra").keys(), ["z", "default"])
Exemplo n.º 5
0
 def _get_dense_tensor(self,
                       inputs,
                       weight_collections=None,
                       trainable=None):
     del weight_collections, trainable  # Unused.
     images = inputs.get(self)
     image_module = module.Module(self.module_spec)
     return self._get_dense_tensor_for_images(images, image_module)
Exemplo n.º 6
0
 def create_state(self, state_manager):
     """Imports the module along with all variables."""
     # Note: state_manager._trainable is not public but is the pattern used
     # to propagate the "trainable" state that used to be received via
     # self._get_dense_tensor.
     trainable = self.trainable and state_manager._trainable  # pylint: disable=protected-access
     m = module.Module(self.module_spec, trainable=trainable)
     state_manager.add_resource(self, _MODULE_RESOURCE_STRING, m)
Exemplo n.º 7
0
 def _get_dense_tensor(self,
                       inputs,
                       weight_collections=None,
                       trainable=None):
     """Returns a `Tensor` to represent this feature in the input_layer()."""
     del weight_collections, trainable  # Unused.
     m = module.Module(self.module_spec, trainable=False)
     images = inputs.get(self)
     return m({"images": images})
Exemplo n.º 8
0
 def _get_dense_tensor(self,
                       inputs,
                       weight_collections=None,
                       trainable=None):
     """Returns a `Tensor`."""
     del weight_collections
     text_batch = tf.reshape(inputs.get(self), shape=[-1])
     m = module.Module(self.module_spec,
                       trainable=self.trainable and trainable)
     return m(text_batch)
Exemplo n.º 9
0
  def _get_dense_tensor_for_inputs(self, text_batch, trainable):
    m = module.Module(self.module_spec, trainable=self.trainable and trainable)

    if self.default_value is not None:
      text_batch = tf.sparse.fill_empty_rows(text_batch, self.default_value)[0]
    embedded_tokens = m(text_batch.values)
    embedding_ids = tf.SparseTensor(
        indices=text_batch.indices,
        values=tf.range(tf.shape(text_batch.indices)[0], dtype=tf.int32),
        dense_shape=text_batch.dense_shape)

    return tf.nn.embedding_lookup_sparse(
        params=embedded_tokens,
        sp_ids=embedding_ids,
        sp_weights=None,
        combiner=self.combiner)
Exemplo n.º 10
0
 def testModuleDictInput(self):
   with tf.Graph().as_default():
     m = module.Module(_ModuleSpec())
     result = m({"x": [1, 2]})
     with tf.compat.v1.Session() as session:
       self.assertAllEqual(session.run(result), [2, 4])
Exemplo n.º 11
0
 def create_state(self, state_manager):
     """Imports the module along with all variables."""
     # Module is not trainable by default.
     m = module.Module(self.module_spec)
     state_manager.add_resource(self, _MODULE_RESOURCE_STRING, m)
Exemplo n.º 12
0
 def testModuleInterfaceGettersDefaultSignatureAndTags(self):
   with tf.Graph().as_default():
     m = module.Module(_ModuleSpec())
     self.assertItemsEqual(m.get_signature_names(), ["default"])
     self.assertItemsEqual(m.get_input_info_dict().keys(), ["x"])
     self.assertItemsEqual(m.get_output_info_dict().keys(), ["default"])
Exemplo n.º 13
0
 def testModuleSingleInput(self):
     m = module.Module(_ModuleSpec())
     result = m([1, 2])
     with tf.Session() as session:
         self.assertAllEqual(session.run(result), [2, 4])
Exemplo n.º 14
0
 def testModuleDictInput(self):
     m = module.Module(_ModuleSpec())
     result = m({"x": [1, 2]})
     with tf.Session() as session:
         self.assertAllEqual(session.run(result), [2, 4])
Exemplo n.º 15
0
 def _get_dense_tensor_for_input_tensor(self, input_tensor, trainable):
   text_batch = tf.reshape(input_tensor, shape=[-1])
   m = module.Module(self.module_spec, trainable=self.trainable and trainable)
   return m(text_batch)
Exemplo n.º 16
0
 def testGetExpectedImageSizeFromShape(self):
     spec = native_module.create_module_spec(image_module_fn)
     self.assertAllEqual(image_util.get_expected_image_size(spec), [2, 4])
     m = module.Module(spec)
     self.assertAllEqual(image_util.get_expected_image_size(m), [2, 4])
Exemplo n.º 17
0
 def testModuleDictOutput(self):
     m = module.Module(_ModuleSpec())
     result = m([1, 2], as_dict=True)
     self.assertTrue(isinstance(result, dict))
     self.assertAllEqual(list(result.keys()), ["default"])
Exemplo n.º 18
0
 def _get_dense_tensor_for_images(self, images):
   m = module.Module(self.module_spec, trainable=False)
   return m({"images": images})
Exemplo n.º 19
0
 def testModuleDictOutput(self):
   with tf.Graph().as_default():
     m = module.Module(_ModuleSpec())
     result = m([1, 2], as_dict=True)
     self.assertIsInstance(result, dict)
     self.assertAllEqual(list(result.keys()), ["default"])
Exemplo n.º 20
0
 def testGetNumImageChannels(self):
     with tf.Graph().as_default():
         spec = native_module.create_module_spec(image_module_fn)
         self.assertEqual(image_util.get_num_image_channels(spec), 3)
         m = module.Module(spec)
         self.assertEqual(image_util.get_num_image_channels(m), 3)
Exemplo n.º 21
0
 def _hub_module(self):
     if not hasattr(self, '_module'):
         self._module = module.Module(self.module_spec,
                                      trainable=self.trainable)
     return self._module