def test_convert_weight(self):
        handlers = [weights_converter.WeightHandler() for i in range(5)]
        not_processed_return = weights_converter.WeightHandlerReturn(
            processed=False, matched_source=False, matched_target=False)
        processed_return = weights_converter.WeightHandlerReturn(
            processed=True, matched_source=True, matched_target=False)

        for i in range(0, len(handlers) - 1):
            handlers[i].then(handlers[i + 1])
            handlers[i]._process_weights = lambda a, b: not_processed_return

        handlers[-1]._process_weights = lambda a, b: processed_return

        result = handlers[0].convert_weight(None, None)
        self.assertEqual(result, processed_return)
    def test(self):
        in_channels = 3
        kernel_size = 3
        out_channels = 5
        inp = np.ones((1, in_channels, 15, 15), dtype=np.float32)

        model_tensorflow = tf.keras.layers.Conv2D(out_channels,
                                                  kernel_size,
                                                  use_bias=False,
                                                  data_format='channels_first')
        model_tensorflow(inp)

        model_pytorch = torch.nn.Conv2d(in_channels,
                                        out_channels,
                                        kernel_size,
                                        bias=False)
        model_pytorch.weight.uniform_()

        converter = weights_converter.PytorchToTensorflowHandlers.ConvolutionHandler(
        )
        res = converter(('', model_pytorch.state_dict()['weight']),
                        model_tensorflow.weights[0])

        np.testing.assert_allclose(model_tensorflow(inp).numpy(),
                                   model_pytorch(torch.tensor(inp)).numpy(),
                                   rtol=1e-5)
        expected_res = weights_converter.WeightHandlerReturn(
            processed=True, matched_source=True, matched_target=True)
        self.assertEqual(res, expected_res)
    def test(self):
        inp = np.array([[1, 2, 3]], dtype=np.float32)
        model_tensorflow = tf.keras.layers.BatchNormalization(momentum=0.1,
                                                              epsilon=0.001)
        model_tensorflow.trainable = False
        model_tensorflow(inp)

        model_pytorch = torch.nn.BatchNorm1d(3, momentum=0.1, eps=0.001)
        model_pytorch.train(False)
        model_pytorch.weight.uniform_()

        converter = weights_converter.PytorchToTensorflowHandlers.SameShapeHandler(
        )
        res_weight = converter(('', model_pytorch.state_dict()['weight']),
                               model_tensorflow.weights[0])
        res_bias = converter(('', model_pytorch.state_dict()['bias']),
                             model_tensorflow.weights[1])

        np.testing.assert_allclose(model_tensorflow(inp).numpy(),
                                   model_pytorch(torch.tensor(inp)).numpy(),
                                   rtol=1e-5)
        expected_res = weights_converter.WeightHandlerReturn(
            processed=True, matched_source=True, matched_target=True)
        self.assertEqual(res_weight, expected_res)
        self.assertEqual(res_bias, expected_res)
    def test(self):
        model_pytorch = torch.nn.BatchNorm1d(3, momentum=0.1, eps=0.001)

        converter = weights_converter.PytorchToTensorflowHandlers.BatchNormalizationSkipExtraHandler(
        )
        res = converter(list(model_pytorch.state_dict().items())[-1], None)

        expected_res = weights_converter.WeightHandlerReturn(
            processed=True, matched_source=True, matched_target=False)
        self.assertEqual(res, expected_res)