Example #1
0
    def testVariableName(self):
        """Test if custom name for the variable scope is propagated correctly."""
        g = ops.Graph()
        with g.as_default():
            a = variables.Variable(2.2, name='var_a')
            b = variables.Variable(3.3, name='var_b')
            d = variables.Variable(4.4, name='var_b')
            with g.name_scope('scope1'):
                with bfloat16.bfloat16_scope('bf16'):
                    a = math_ops.cast(a, dtypes.bfloat16)
                    b = math_ops.cast(b, dtypes.bfloat16)
                    c = math_ops.add(a, b, name='addition')
                with bfloat16.bfloat16_scope():
                    d = math_ops.cast(d, dtypes.bfloat16)
                    math_ops.add(c, d, name='addition')

        g_ops = g.get_operations()
        ops_name = []
        for op in g_ops:
            ops_name.append(str(op.name))

        self.assertIn('scope1/bf16/addition', ops_name)
        self.assertIn('scope1/bf16/Cast', ops_name)
        self.assertIn('scope1/addition', ops_name)
        self.assertIn('scope1/Cast', ops_name)
Example #2
0
 def tpu_subgraph_predict():
   if use_bfloat16:
     with bfloat16_scope():
       return tpu.rewrite(tpu_subgraph_predict_fn,
                          [preprocessed_inputs, true_image_shapes])
   else:
     return tpu.rewrite(tpu_subgraph_predict_fn,
                        [preprocessed_inputs, true_image_shapes])
 def testRequestedDType(self):
   """Test if requested dtype is honored in the getter.
   """
   with bfloat16.bfloat16_scope() as scope:
     v1 = variable_scope.get_variable("v1", [])
     self.assertEqual(v1.dtype.base_dtype, dtypes.float32)
     v2 = variable_scope.get_variable("v2", [], dtype=dtypes.bfloat16)
     self.assertEqual(v2.dtype.base_dtype, dtypes.bfloat16)
     self.assertEqual([dtypes.float32, dtypes.float32],
                      [v.dtype.base_dtype for v in scope.global_variables()])
Example #4
0
 def testRequestedDType(self):
   """Test if requested dtype is honored in the getter.
   """
   with bfloat16.bfloat16_scope() as scope:
     v1 = variable_scope.get_variable("v1", [])
     self.assertEqual(v1.dtype.base_dtype, dtypes.float32)
     v2 = variable_scope.get_variable("v2", [], dtype=dtypes.bfloat16)
     self.assertEqual(v2.dtype.base_dtype, dtypes.bfloat16)
     self.assertEqual([dtypes.float32, dtypes.float32],
                      [v.dtype.base_dtype for v in scope.global_variables()])
Example #5
0
File: ssd.py Project: Asharib90/OCR
    def predict_tpu_subgraph(preprocessed_inputs, true_image_shapes):
        """Wraps over the CPU version of `predict()`.

    This builds a same graph as the original `predict()`, manipulates
    result tensors' dimensions to be memory efficient on TPU, and
    returns them as list of tensors.

    Args:
      preprocessed_inputs: A 4D tensor of shape (batch, channels, height, width)
      true_image_shapes: True image shapes tensor.

    Returns:
      A Python list of tensors:
        box_encodings: 3D tensor of shape (code_size, batch_size, num_anchors)
        class_predictions_with_background: 3D tensor,
            shape (num_classes + 1, batch_size, num_anchors)
        anchors: 2D tensor of shape (4, num_anchors)
    """
        # Dimshuffle: (b, c, h, w) -> (b, h, w, c)
        preprocessed_inputs = tf.transpose(preprocessed_inputs,
                                           perm=[0, 2, 3, 1])
        if use_bfloat16:
            with bfloat16_scope():
                prediction_dict = detection_model.predict(
                    preprocessed_inputs, true_image_shapes)
        else:
            prediction_dict = detection_model.predict(preprocessed_inputs,
                                                      true_image_shapes)

        # Dimshuffle: (batch, anchors, depth) -> (depth, batch, anchors)
        return [
            tf.transpose(prediction_dict[BOX_ENCODINGS], perm=[2, 0, 1]),
            tf.transpose(prediction_dict[CLASS_PREDICTIONS_WITH_BACKGROUND],
                         perm=[2, 0, 1]),
            tf.transpose(prediction_dict[ANCHORS], perm=[1, 0]),
        ]
 def testScopeName(self):
   """Test if name for the variable scope is propagated correctly."""
   with bfloat16.bfloat16_scope() as bf:
     self.assertEqual(bf.name, "")
Example #7
0
 def testCustomScopeName(self):
     """Test if custom name for the variable scope is propagated correctly."""
     name = 'bfloat16'
     with bfloat16.bfloat16_scope('bfloat16') as bf:
         self.assertEqual(bf.name, name)
Example #8
0
 def testScopeName(self):
   """Test if name for the variable scope is propogated correctly.
   """
   with bfloat16.bfloat16_scope() as bf:
     self.assertEqual(bf.name, "")