def test_convert_weight(self): handlers = [weights_converter.WeightHandler() for i in range(5)] not_processed_return = weights_converter.WeightHandlerReturn( processed=False, matched_source=False, matched_target=False) processed_return = weights_converter.WeightHandlerReturn( processed=True, matched_source=True, matched_target=False) for i in range(0, len(handlers) - 1): handlers[i].then(handlers[i + 1]) handlers[i]._process_weights = lambda a, b: not_processed_return handlers[-1]._process_weights = lambda a, b: processed_return result = handlers[0].convert_weight(None, None) self.assertEqual(result, processed_return)
def test(self): in_channels = 3 kernel_size = 3 out_channels = 5 inp = np.ones((1, in_channels, 15, 15), dtype=np.float32) model_tensorflow = tf.keras.layers.Conv2D(out_channels, kernel_size, use_bias=False, data_format='channels_first') model_tensorflow(inp) model_pytorch = torch.nn.Conv2d(in_channels, out_channels, kernel_size, bias=False) model_pytorch.weight.uniform_() converter = weights_converter.PytorchToTensorflowHandlers.ConvolutionHandler( ) res = converter(('', model_pytorch.state_dict()['weight']), model_tensorflow.weights[0]) np.testing.assert_allclose(model_tensorflow(inp).numpy(), model_pytorch(torch.tensor(inp)).numpy(), rtol=1e-5) expected_res = weights_converter.WeightHandlerReturn( processed=True, matched_source=True, matched_target=True) self.assertEqual(res, expected_res)
def test(self): inp = np.array([[1, 2, 3]], dtype=np.float32) model_tensorflow = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=0.001) model_tensorflow.trainable = False model_tensorflow(inp) model_pytorch = torch.nn.BatchNorm1d(3, momentum=0.1, eps=0.001) model_pytorch.train(False) model_pytorch.weight.uniform_() converter = weights_converter.PytorchToTensorflowHandlers.SameShapeHandler( ) res_weight = converter(('', model_pytorch.state_dict()['weight']), model_tensorflow.weights[0]) res_bias = converter(('', model_pytorch.state_dict()['bias']), model_tensorflow.weights[1]) np.testing.assert_allclose(model_tensorflow(inp).numpy(), model_pytorch(torch.tensor(inp)).numpy(), rtol=1e-5) expected_res = weights_converter.WeightHandlerReturn( processed=True, matched_source=True, matched_target=True) self.assertEqual(res_weight, expected_res) self.assertEqual(res_bias, expected_res)
def test(self): model_pytorch = torch.nn.BatchNorm1d(3, momentum=0.1, eps=0.001) converter = weights_converter.PytorchToTensorflowHandlers.BatchNormalizationSkipExtraHandler( ) res = converter(list(model_pytorch.state_dict().items())[-1], None) expected_res = weights_converter.WeightHandlerReturn( processed=True, matched_source=True, matched_target=False) self.assertEqual(res, expected_res)