def test_shape_utils(self): self.assertEqual(utils.merge_shapes(None, None), None) self.assertEqual(utils.merge_shapes([], None), []) self.assertEqual(utils.merge_shapes(None, [1, 2, 3]), [1, 2, 3]) self.assertEqual(utils.merge_shapes([1, 3], [None, 3]), [1, 3]) self.assertEqual(utils.merge_shapes([1, None, 3], (-1, 2, "unk")), [1, 2, 3]) self.assertTrue(utils.are_shapes_compatible(None, [])) self.assertTrue(utils.are_shapes_compatible([1, None, 3], (-1, 2, "unk"))) self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (2, 3))) self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (4, 5, 6))) self.assertTrue(utils.are_shapes_equal(None, None)) self.assertFalse(utils.are_shapes_equal(None, [])) self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
def _get_output_shape_dtype(self, cond_context): output_shapes = [] output_dtypes = [] for i, _ in enumerate(cond_context.true_branch_context.output): true_output = cond_context.true_branch_context.output[i] false_output = cond_context.false_branch_context.output[i] true_shape = self.g.get_shape(true_output) true_dtype = self.g.get_dtype(true_output) false_shape = self.g.get_shape(false_output) false_dtype = self.g.get_dtype(false_output) if not utils.are_shapes_compatible(true_shape, false_shape): raise RuntimeError( "the shape of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_shape, false_shape ) ) if true_dtype != false_dtype: raise RuntimeError( "the dtype of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_dtype, false_dtype ) ) # in tf, the shape of different branched can be different, # for example output shape of branch A can be [-1] while branch B can be [1]. # Under this case, we should set output shape to be [-1] output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape))) output_dtypes.append(true_dtype) return output_shapes, output_dtypes
def _get_output_shape_dtype(self, cond_context): output_shapes = [] output_dtypes = [] for i, _ in enumerate(cond_context.true_branch_context.output): true_output = cond_context.true_branch_context.output[i] false_output = cond_context.false_branch_context.output[i] true_shape = self.g.get_shape(true_output) true_dtype = self.g.get_dtype(true_output) false_shape = self.g.get_shape(false_output) false_dtype = self.g.get_dtype(false_output) if not utils.are_shapes_compatible(true_shape, false_shape): raise RuntimeError( "the shape of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_shape, false_shape)) if true_dtype != false_dtype: raise RuntimeError( "the dtype of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_dtype, false_dtype)) output_shapes.append(utils.merge_shapes(true_shape, false_shape)) output_dtypes.append(true_dtype) return output_shapes, output_dtypes