Ejemplo n.º 1
0
    def fuse_conv_and_bias_callback(self, op, label_map_op_list):
        """
        Callback function that handles fusion for Conv + bias  pattern
        """
        for (label_map, op) in label_map_op_list:
            map_roles = label_map[self.map_roles_label]
            conv_op = self.op_arg(map_roles, 0)
            bias = label_map[self.conv_bias_label]
            new_op_map = {}
            if isinstance(conv_op, ConvolutionOp):
                conv_new_op = ConvolutionOp(conv_op.conv_params, self.op_arg(conv_op, 0),
                                            self.op_arg(conv_op, 1), bias, axes=conv_op.axes)
                new_op_map[conv_op] = conv_new_op
                # Create ops that are downstream to convolution but still upstream of the add 'op'
                prev_op = self.op_arg(op, 0)
                upstream_ops = []
                while prev_op != conv_op:
                    upstream_ops.append(prev_op)
                    prev_op = self.op_arg(prev_op, 0)
                for old_op in reversed(upstream_ops):
                    new_arg = new_op_map[self.op_arg(old_op, 0)]
                    if isinstance(old_op, MapRolesOp):
                        new_op_map[old_op] = MapRolesOp(new_arg, old_op.axes_map)
                    elif isinstance(old_op, TensorSliceOp):
                        new_op_map[old_op] = TensorSliceOp(new_arg, old_op.slices, old_op.axes)
                    elif isinstance(old_op, ReorderAxes):
                        new_op_map[old_op] = ReorderAxes(new_arg, old_op.axes)

                self.replace_op(op, new_op_map[self.op_arg(op, 0)])
                self.replace_op(conv_op, conv_new_op)
Ejemplo n.º 2
0
 def fuse_conv_and_bias_callback(self, op, label_map_op_list):
     """
     Callback function that handles fusion for Conv + bias  pattern
     """
     for (label_map, op) in label_map_op_list:
         conv_op = label_map[self.conv_op_label]
         bias = label_map[self.conv_bias_label]
         conv_new_op = ConvolutionOp(conv_op.conv_params,
                                     conv_op.args[0],
                                     conv_op.args[1],
                                     bias,
                                     axes=conv_op.axes)
         self.op_replacement_dict[conv_op] = conv_new_op
         self.replace_op(op, conv_new_op)
Ejemplo n.º 3
0
 def fuse_conv_and_bias_callback(self, op, label_map_op_list):
     """
     Callback function that handles fusion for Conv + bias  pattern
     """
     for (label_map, op) in label_map_op_list:
         map_roles = label_map[self.map_roles_label]
         conv_op = map_roles.args[0]
         bias = label_map[self.conv_bias_label]
         if isinstance(map_roles.args[0], ConvolutionOp):
             conv_new_op = ConvolutionOp(conv_op.conv_params, conv_op.args[0],
                                         conv_op.args[1], bias, axes=conv_op.axes)
             self.op_replacement_dict[conv_op] = conv_new_op
             map_roles_op = MapRolesOp(conv_new_op, map_roles.axes_map)
             self.replace_op(op, map_roles_op)
Ejemplo n.º 4
0
 def fuse_conv_and_bias_callback(self, op, label_map_op_list):
     """
     Callback function that handles fusion for Conv + bias  pattern
     """
     for (label_map, op) in label_map_op_list:
         map_roles = label_map[self.map_roles_label]
         conv_op = self.op_arg(map_roles, 0)
         bias = label_map[self.conv_bias_label]
         if isinstance(conv_op, ConvolutionOp):
             conv_new_op = ConvolutionOp(conv_op.conv_params, self.op_arg(conv_op, 0),
                                         self.op_arg(conv_op, 1), bias, axes=conv_op.axes)
             map_roles_op = MapRolesOp(conv_new_op, map_roles.axes_map)
             # replace conv_op explicitly so we update the forwarded pointer correctly
             self.replace_op(op, map_roles_op)
             self.replace_op(conv_op, conv_new_op)
Ejemplo n.º 5
0
    def visit(self, op, inputs, filters, bias=None):
        """
        Convolution implementation requires contiguous layout.
        """

        replace = False
        # if not isinstance(inputs, ContiguousOp):
        #    inputs = ContiguousOp(inputs)
        #    replace = True

        if not isinstance(filters, ContiguousOp):
            filters = ContiguousOp(filters)
            replace = True

        if replace:
            self.replace_op(op, ConvolutionOp(op.conv_params, inputs, filters, bias, axes=op.axes))
Ejemplo n.º 6
0
 def op_from_args(self, op, args):
     return ConvolutionOp(op.conv_params, *args, axes=op.axes)