Beispiel #1
0
    def test_sp_transform_with_output_params_infer(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'st'), ('node_2', 'st'),
                               ('st', 'node_3'), ('node_3', 'op_output')], {
                                   'node_3': {
                                       'shape': None
                                   },
                                   'node_1': {
                                       'shape': np.array([1, 3, 227, 227])
                                   },
                                   'node_2': {
                                       'shape': np.array([1, 3, 227, 227])
                                   },
                                   'st': {
                                       'output_H': 200,
                                       'output_W': 15
                                   }
                               })

        st_node = Node(graph, 'st')
        SpatialTransformOp.sp_infer(st_node)
        exp_shape = np.array([1, 3, 200, 15])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
Beispiel #2
0
    def test_sp_transform_concat_infer(self):
        graph = build_graph(
            nodes_attributes, [('node_1', 'st'), ('node_2', 'st'),
                               ('st', 'node_3')], {
                                   'node_3': {
                                       'is_output': True,
                                       'shape': None
                                   },
                                   'node_1': {
                                       'shape': np.array([1, 3, 227, 227])
                                   },
                                   'node_2': {
                                       'shape': np.array([1, 3, 227, 227])
                                   },
                                   'st': {}
                               })

        st_node = Node(graph, 'st')
        SpatialTransformOp.sp_infer(st_node)
        exp_shape = np.array([1, 3, 227, 227])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
    def extract(cls, node):
        proto_layer = node.pb
        param = proto_layer.st_param

        update_attrs = {
            'transform_type': param.transform_type,
            'sampler_type': param.sampler_type,
            'output_H': param.output_H,
            'output_W': param.output_W,
            'to_compute_dU': int(param.to_compute_dU),
            'theta_1_1': param.theta_1_1,
            'theta_1_2': param.theta_1_2,
            'theta_1_3': param.theta_1_3,
            'theta_2_1': param.theta_2_1,
            'theta_2_2': param.theta_2_2,
            'theta_2_3': param.theta_2_3
        }

        mapping_rule = merge_attrs(param, update_attrs)

        # update the attributes of the node
        SpatialTransformOp.update_node_stat(node, mapping_rule)
        return cls.enabled