Exemplo n.º 1
0
    def test_instance_norm_model_helper_helper(self, N, C, H, W, order, epsilon, seed, is_test):
        np.random.seed(seed)
        model = model_helper.ModelHelper(name="test_model")
        brew.instance_norm(
            model,
            'input',
            'output',
            C,
            epsilon=epsilon,
            is_test=is_test)

        input_blob = np.random.rand(N, C, H, W).astype(np.float32)
        if order == 'NHWC':
            input_blob = np.transpose(input_blob, axes=(0, 2, 3, 1))

        self.ws.create_blob('input').feed(input_blob)

        self.ws.create_net(model.param_init_net).run()
        self.ws.create_net(model.net).run()

        if is_test:
            scale = self.ws.blobs['output_s'].fetch()
            assert scale is not None
            assert scale.shape == (C, )
            bias = self.ws.blobs['output_b'].fetch()
            assert bias is not None
            assert bias.shape == (C, )

        output_blob = self.ws.blobs['output'].fetch()
        if order == 'NHWC':
            output_blob = np.transpose(output_blob, axes=(0, 3, 1, 2))

        assert output_blob.shape == (N, C, H, W)
Exemplo n.º 2
0
    def test_instance_norm_model_helper(self, N, C, H, W, order, epsilon, seed,
                                        is_test):
        np.random.seed(seed)
        model = model_helper.ModelHelper(name="test_model")
        brew.instance_norm(model,
                           'input',
                           'output',
                           C,
                           epsilon=epsilon,
                           order=order,
                           is_test=is_test)

        input_blob = np.random.rand(N, C, H, W).astype(np.float32)
        if order == 'NHWC':
            input_blob = utils.NCHW2NHWC(input_blob)

        self.ws.create_blob('input').feed(input_blob)

        self.ws.create_net(model.param_init_net).run()
        self.ws.create_net(model.net).run()

        if is_test:
            scale = self.ws.blobs['output_s'].fetch()
            assert scale is not None
            assert scale.shape == (C, )
            bias = self.ws.blobs['output_b'].fetch()
            assert bias is not None
            assert bias.shape == (C, )

        output_blob = self.ws.blobs['output'].fetch()
        if order == 'NHWC':
            output_blob = utils.NHWC2NCHW(output_blob)

        assert output_blob.shape == (N, C, H, W)
Exemplo n.º 3
0
 def InstanceNorm(self, *args, **kwargs):
     return brew.instance_norm(self, *args, order=self.order, **kwargs)
Exemplo n.º 4
0
 def InstanceNorm(self, *args, **kwargs):
     return brew.instance_norm(self, *args, order=self.order, **kwargs)