def assertReferenceChecks( self, device_option, op, inputs, reference, input_device_options=None, threshold=1e-4, output_to_grad=None, grad_reference=None, atol=None, outputs_to_check=None, ): outs = super(SerializedTestCase, self).assertReferenceChecks( device_option, op, inputs, reference, input_device_options, threshold, output_to_grad, grad_reference, atol, outputs_to_check, ) grad_ops, _ = gradient_checker.getGradientForOp(op) self.assertSerializedOperatorChecks( inputs, outs, grad_ops, op, device_option, )
def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7): def parse_proto(x): proto = caffe2_pb2.OperatorDef() proto.ParseFromString(x) return proto source_dir = self.get_output_dir() # load serialized input and output loaded_inputs = numpy.load(os.path.join(source_dir, 'inputs.npz'), encoding='bytes')['inputs'] inputs_equal = True for (x, y) in zip(inputs, loaded_inputs): if not numpy.array_equal(x, y): inputs_equal = False loaded_outputs = numpy.load(os.path.join(source_dir, 'outputs.npz'), encoding='bytes')['outputs'] # load operator found_op = False for i in os.listdir(source_dir): op_file = os.path.join(source_dir, i) match = re.search('operator_(.+?)\.pb', i) if os.path.isfile(op_file) and match: with open(op_file, 'rb') as f: loaded_op = f.read() op_proto = parse_proto(loaded_op) device_type = int(match.group(1)) device_option = caffe2_pb2.DeviceOption( device_type=device_type) grad_ops, _ = gradient_checker.getGradientForOp(op_proto) found_op = True break # if inputs are not the same, run serialized input through serialized op if not inputs_equal: self.assertTrue(found_op) outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs) # assert outputs are equal for (x, y) in zip(outputs, loaded_outputs): numpy.testing.assert_allclose(x, y, atol=atol, rtol=rtol) # assert gradient op is equal for i in range(len(grad_ops)): with open(os.path.join(source_dir, 'gradient_{}.pb'.format(i)), 'rb') as f: loaded_grad = f.read() grad_proto = parse_proto(loaded_grad) self.assertTrue(grad_proto == grad_ops[i])
def _getGradientOrNone(op_proto): try: grad_ops, _ = gradient_checker.getGradientForOp(op_proto) return grad_ops except Exception: return []