예제 #1
0
 def assertReferenceChecks(
     self,
     device_option,
     op,
     inputs,
     reference,
     input_device_options=None,
     threshold=1e-4,
     output_to_grad=None,
     grad_reference=None,
     atol=None,
     outputs_to_check=None,
 ):
     outs = super(SerializedTestCase, self).assertReferenceChecks(
         device_option,
         op,
         inputs,
         reference,
         input_device_options,
         threshold,
         output_to_grad,
         grad_reference,
         atol,
         outputs_to_check,
     )
     grad_ops, _ = gradient_checker.getGradientForOp(op)
     self.assertSerializedOperatorChecks(
         inputs,
         outs,
         grad_ops,
         op,
         device_option,
     )
예제 #2
0
    def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7):
        def parse_proto(x):
            proto = caffe2_pb2.OperatorDef()
            proto.ParseFromString(x)
            return proto

        source_dir = self.get_output_dir()

        # load serialized input and output
        loaded_inputs = numpy.load(os.path.join(source_dir, 'inputs.npz'),
                                   encoding='bytes')['inputs']
        inputs_equal = True
        for (x, y) in zip(inputs, loaded_inputs):
            if not numpy.array_equal(x, y):
                inputs_equal = False
        loaded_outputs = numpy.load(os.path.join(source_dir, 'outputs.npz'),
                                    encoding='bytes')['outputs']

        # load operator
        found_op = False
        for i in os.listdir(source_dir):
            op_file = os.path.join(source_dir, i)
            match = re.search('operator_(.+?)\.pb', i)
            if os.path.isfile(op_file) and match:
                with open(op_file, 'rb') as f:
                    loaded_op = f.read()
                op_proto = parse_proto(loaded_op)
                device_type = int(match.group(1))
                device_option = caffe2_pb2.DeviceOption(
                    device_type=device_type)
                grad_ops, _ = gradient_checker.getGradientForOp(op_proto)
                found_op = True
                break

        # if inputs are not the same, run serialized input through serialized op
        if not inputs_equal:
            self.assertTrue(found_op)
            outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs)

        # assert outputs are equal
        for (x, y) in zip(outputs, loaded_outputs):
            numpy.testing.assert_allclose(x, y, atol=atol, rtol=rtol)

        # assert gradient op is equal
        for i in range(len(grad_ops)):
            with open(os.path.join(source_dir, 'gradient_{}.pb'.format(i)),
                      'rb') as f:
                loaded_grad = f.read()
            grad_proto = parse_proto(loaded_grad)
            self.assertTrue(grad_proto == grad_ops[i])
예제 #3
0
def _getGradientOrNone(op_proto):
    try:
        grad_ops, _ = gradient_checker.getGradientForOp(op_proto)
        return grad_ops
    except Exception:
        return []