예제 #1
0
 def test_that_adding_gradient_scope_does_no_fancy_renaming(self):
     # because it cannot create collisions
     op = caffe2_pb2.OperatorDef()
     op.name = 'foo_grad'
     op.input.extend(['foo_grad', 'foo_grad_1'])
     shapes = {'foo_grad': [1]}
     blob_name_tracker = tb._get_blob_names([op])
     tb._add_gradient_scope(shapes, blob_name_tracker, [op])
     self.assertEqual(op.input[0], 'GRADIENTS/foo_grad')
     self.assertEqual(op.input[1], 'GRADIENTS/foo_grad_1')
     self.assertEqual(op.name, 'GRADIENTS/foo_grad')
     self.assertEqual(len(shapes), 1)
     self.assertEqual(shapes['GRADIENTS/foo_grad'], [1])
     self.assertEqual(len(blob_name_tracker), 2)
     self.assertEqual(blob_name_tracker['GRADIENTS/foo_grad'], 'foo_grad')
     self.assertEqual(blob_name_tracker['GRADIENTS/foo_grad_1'],
                      'foo_grad_1')
예제 #2
0
 def test_that_replacing_colons_gives_non_colliding_names(self):
     # .. and update shapes
     op = caffe2_pb2.OperatorDef()
     op.name = 'foo:0'
     op.input.extend(['foo:0', 'foo$0'])
     shapes = {'foo:0': [1]}
     blob_name_tracker = tb._get_blob_names([op])
     tb._replace_colons(shapes, blob_name_tracker, [op], '$')
     self.assertEqual(op.input[0], 'foo$0')
     self.assertEqual(op.input[1], 'foo$0_1')
     # Collision but blobs and op names are handled later by
     # _fill_missing_operator_names.
     self.assertEqual(op.name, 'foo$0')
     self.assertEqual(len(shapes), 1)
     self.assertEqual(shapes['foo$0'], [1])
     self.assertEqual(len(blob_name_tracker), 2)
     self.assertEqual(blob_name_tracker['foo$0'], 'foo:0')
     self.assertEqual(blob_name_tracker['foo$0_1'], 'foo$0')
예제 #3
0
 def test_that_auto_ssa_gives_non_colliding_names(self):
     op1 = caffe2_pb2.OperatorDef()
     op1.output.extend(['foo'])
     op2 = caffe2_pb2.OperatorDef()
     op2.input.extend(['foo'])
     op2.output.extend(['foo'])
     op2.output.extend(['foo_1'])
     shapes = {'foo': [1], 'foo_1': [2]}
     blob_name_tracker = tb._get_blob_names([op1, op2])
     tb._convert_to_ssa(shapes, blob_name_tracker, [op1, op2])
     self.assertEqual(op1.output[0], 'foo')
     self.assertEqual(op2.input[0], 'foo')
     self.assertEqual(op2.output[0], 'foo_1')
     # Unfortunate name but we do not parse original `_` for now.
     self.assertEqual(op2.output[1], 'foo_1_1')
     self.assertEqual(len(shapes), 3)
     self.assertEqual(shapes['foo'], [1])
     self.assertEqual(shapes['foo_1'], [1])
     self.assertEqual(shapes['foo_1_1'], [2])
     self.assertEqual(len(blob_name_tracker), 3)
     self.assertEqual(blob_name_tracker['foo'], 'foo')
     self.assertEqual(blob_name_tracker['foo_1'], 'foo')
     self.assertEqual(blob_name_tracker['foo_1_1'], 'foo_1')
예제 #4
0
 def test_renaming_tensorflow_style(self):
     # Construct some dummy operators here
     # NOTE: '_w', '_bn', etc without the postfix '_' are only renamed when
     # they are at the very end of the name.
     # Test that '_w', '_w_' are renamed to '/weight', '/weight_', resp.
     op1 = caffe2_pb2.OperatorDef()
     op1.input.extend(['foo_w'])
     op1.output.extend(['foo_w_2'])
     # Test that '_bn', '_bn_' are renamed to '/batchnorm', '/batchnorm_',
     # respectively.
     op2 = caffe2_pb2.OperatorDef()
     op2.input.extend(['foo_bn'])
     op2.output.extend(['foo_bn_2'])
     # Test that '_b', '_b_', are renamed to '/bias', '/bias_', resp.
     op3 = caffe2_pb2.OperatorDef()
     op3.input.extend(['foo_b'])
     op3.output.extend(['foo_b_2'])
     # Test that '_s', '_s_', are renamed to '/scale', '/scale_', resp.
     op4 = caffe2_pb2.OperatorDef()
     op4.input.extend(['foo_s'])
     op4.output.extend(['foo_s_2'])
     # Test that '_sum', '_sum_', are renamed to '/sum', '/sum_', resp.
     op5 = caffe2_pb2.OperatorDef()
     op5.input.extend(['foo_sum'])
     op5.output.extend(['foo_sum_2'])
     # Test that '_branch', '_branch_', are renamed to '/branch', '/branch_',
     # respectively. Multiple inputs/outputs are also tested in this case.
     op6 = caffe2_pb2.OperatorDef()
     op6.input.extend(['foo_branch'])
     op6.input.extend(['test_branch_2'])
     op6.output.extend(['foo_branch_3'])
     op6.output.extend(['test_branch4'])
     shapes = {
         'foo_w': [1], 'foo_w_2': [2], 'foo_bn': [3], 'foo_bn_2': [4],
         'foo_b': [5], 'foo_b_2': [6], 'foo_s': [7], 'foo_s_2': [8],
         'foo_sum': [9], 'foo_sum_2': [10], 'foo_branch': [11],
         'test_branch_2': [12], 'foo_branch_3': [13], 'test_branch4': [14],
     }
     ops = [op1, op2, op3, op4, op5, op6]
     blob_name_tracker = tb._get_blob_names(ops)
     tb._rename_tensorflow_style(shapes, blob_name_tracker, ops)
     # Testing that keys in blob name tracker were renamed correctly
     self.assertEqual(blob_name_tracker['foo/weight'], 'foo_w')
     self.assertEqual(blob_name_tracker['foo/weight_2'], 'foo_w_2')
     self.assertEqual(blob_name_tracker['foo/batchnorm'], 'foo_bn')
     self.assertEqual(blob_name_tracker['foo/batchnorm_2'], 'foo_bn_2')
     self.assertEqual(blob_name_tracker['foo/bias'], 'foo_b')
     self.assertEqual(blob_name_tracker['foo/bias_2'], 'foo_b_2')
     self.assertEqual(blob_name_tracker['foo/scale'], 'foo_s')
     self.assertEqual(blob_name_tracker['foo/scale_2'], 'foo_s_2')
     self.assertEqual(blob_name_tracker['foo/sum'], 'foo_sum')
     self.assertEqual(blob_name_tracker['foo/sum_2'], 'foo_sum_2')
     self.assertEqual(blob_name_tracker['foo/branch'], 'foo_branch')
     self.assertEqual(blob_name_tracker['test/branch_2'], 'test_branch_2')
     self.assertEqual(blob_name_tracker['foo/branch_3'], 'foo_branch_3')
     self.assertEqual(blob_name_tracker['test/branch4'], 'test_branch4')
     # Testing that keys in shapes were renamed correctly
     self.assertEqual(shapes['foo/weight'], [1])
     self.assertEqual(shapes['foo/batchnorm_2'], [4])
     self.assertEqual(shapes['foo/sum'], [9])
     self.assertEqual(shapes['test/branch_2'], [12])
     # Testing that the ops were renamed correctly
     self.assertEqual(op1.input[0], 'foo/weight')
     self.assertEqual(op1.output[0], 'foo/weight_2')
     self.assertEqual(op2.input[0], 'foo/batchnorm')
     self.assertEqual(op2.output[0], 'foo/batchnorm_2')
     self.assertEqual(op3.input[0], 'foo/bias')
     self.assertEqual(op3.output[0], 'foo/bias_2')
     self.assertEqual(op4.input[0], 'foo/scale')
     self.assertEqual(op4.output[0], 'foo/scale_2')
     self.assertEqual(op5.input[0], 'foo/sum')
     self.assertEqual(op5.output[0], 'foo/sum_2')
     self.assertEqual(op6.input[0], 'foo/branch')
     self.assertEqual(op6.input[1], 'test/branch_2')
     self.assertEqual(op6.output[0], 'foo/branch_3')
     self.assertEqual(op6.output[1], 'test/branch4')