def test_tf_concat_infer(self): graph = build_graph(nodes_attributes, [ ('node_1', 'corr'), ('node_2', 'corr'), ('corr', 'node_3'), ('node_3', 'op_output') ], { 'node_3': {'shape': None}, 'node_1': {'shape': np.array([1, 3, 227, 227])}, 'node_2': {'shape': np.array([1, 3, 227, 227])}, 'corr': {'pad': 20, 'kernel_size': 1, 'max_displacement': 20, 'stride_1': 1, 'stride_2': 2, 'single_direction': 0, 'do_abs': False, 'correlation_type': 0} }) corr_node = Node(graph, 'corr') CorrelationOp.corr_infer(corr_node) exp_shape = np.array([1, 441, 227, 227]) res_shape = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])
def extract(cls, node): proto_layer = node.pb param = proto_layer.correlation_param corr_type = 'caffe.CorrelationParameter.MULTIPLY' if param.correlation_type == 1: corr_type = 'caffe.CorrelationParameter.SUBTRACT' update_attrs = { 'pad': param.pad, 'kernel_size': param.kernel_size, 'max_displacement': param.max_displacement, 'stride_1': param.stride_1, 'stride_2': param.stride_2, 'single_direction': param.single_direction, 'do_abs': int(param.do_abs), 'correlation_type': corr_type, } mapping_rule = merge_attrs(param, update_attrs) mapping_rule.update(layout_attrs()) # update the attributes of the node CorrelationOp.update_node_stat(node, mapping_rule) return cls.enabled