Example #1
0
    def build_nested_blocks(self):
        # If the front end is milproto, we already have the well constructed cond/body block.
        # For this case, we set self.blocks directly.
        # We also check that _cond and _body are both dummy functions (return None).
        if self._existing_blocks is not None and self._existing_blocks.val is not None:
            assert self._true_fn.val([]) is None
            assert self._false_fn.val([]) is None
            self.blocks = self._existing_blocks.val
            return

        # Cond block
        true_block_name = self.name + "_true"
        with Block(name=true_block_name, outer_op=self) as true_block:
            true_func = self._true_fn.val
            true_ret_vars = true_func()
            if isinstance(true_ret_vars, tuple):
                true_ret_vars = list(true_ret_vars)
            if not isinstance(true_ret_vars, list):
                true_ret_vars = [true_ret_vars]
            true_block.set_outputs(true_ret_vars)
            self.blocks.append(true_block)

        false_block_name = self.name + "_false"
        with Block(name=false_block_name, outer_op=self) as false_block:
            false_func = self._false_fn.val
            false_ret_vars = false_func()
            if isinstance(false_ret_vars, tuple):
                false_ret_vars = list(false_ret_vars)
            if not isinstance(false_ret_vars, list):
                false_ret_vars = [false_ret_vars]
            false_block.set_outputs(false_ret_vars)
            self.blocks.append(false_block)
Example #2
0
def _create_nested_blocks(context, op_spec):
    """
    An utility function that creates nested blocks for control flow ops.
    """
    if not op_spec.blocks:
        return []

    blocks = []

    for block_spec in op_spec.blocks:
        input_vars = [_create_var_from_spec(input) for input in block_spec.inputs]

        # add block input vars to the context
        for v in input_vars:
            context.register_var_with_name(v.name, v)

        # In pymil, the outer_op for a block can only be None if the block is a Functino.
        # As the result, we use a dummy outer_op here for block creation, and set it to
        # the legit op later on in _set_outer_op_for_nested_blocks
        dummy = mb.const(val=0.)
        with Block(block_inputs=input_vars, outer_op=dummy._op,
                   name=Block._get_new_name()) as block:
            _load_block(context, block_spec)

        blocks.append(block)

    return blocks
Example #3
0
    def build_nested_blocks(self):
        # self.loop_vars is python tuple of Vars
        # Cond block
        block_name = self.name + "_block"
        with Block(block_inputs=self.loop_vars, outer_op=self,
                   name=block_name) as block:
            # Body func
            body_func = self._body.val
            exit_vars = body_func(*block.inputs)

            # Cond func:
            cond_func = self._cond.val
            cond_var = cond_func(*block.inputs)
            cond_vars = cond_var if isinstance(cond_var, list) else [cond_var]

            # Concatenate the outputs
            block.set_outputs(cond_vars + list(exit_vars))
            self.blocks.append(block)

        # Verify exit_vars has the same types as loop_vars
        for v_in, v_out in zip(self.loop_vars, exit_vars):
            if not is_subtype(v_out.sym_type, v_in.sym_type):
                msg = ("loop_vars '{}' changes in the body of "
                       "while_loop '{}':\n {} -> {}")
                raise ValueError(
                    msg.format(v_in.name, self.name, v_in.sym_type,
                               v_out.sym_type))
Example #4
0
    def _build_block(self, block_inputs):
        # Cond block:
        block_name = self.name + '_cond_block'
        with Block(block_inputs=block_inputs, outer_op=self,
                   name=block_name) as cond_block:

            cond_func = self._cond.val
            cond_var = cond_func(*cond_block.inputs)
            cond_vars = cond_var if isinstance(cond_var, list) else [cond_var]
            cond_block.set_outputs(cond_vars)

        # Body block
        block_name = self.name + '_body_block'
        with Block(block_inputs=block_inputs, outer_op=self,
                   name=block_name) as body_block:
            body_func = self._body.val
            exit_vars = body_func(*body_block.inputs)
            exit_vars = list(exit_vars) if isinstance(exit_vars, (list, tuple)) \
                else [exit_vars]
            body_block.set_outputs(exit_vars)

        return cond_block, body_block, exit_vars
Example #5
0
    def build_nested_blocks(self):
        # Cond block
        true_block_name = self.name + "_true"
        with Block(name=true_block_name, outer_op=self) as true_block:
            true_func = self._true_fn.val
            true_ret_vars = true_func()
            if isinstance(true_ret_vars, tuple):
                true_ret_vars = list(true_ret_vars)
            if not isinstance(true_ret_vars, list):
                true_ret_vars = [true_ret_vars]
            true_block.set_outputs(true_ret_vars)
            self.blocks.append(true_block)

        false_block_name = self.name + "_false"
        with Block(name=false_block_name, outer_op=self) as false_block:
            false_func = self._false_fn.val
            false_ret_vars = false_func()
            if isinstance(false_ret_vars, tuple):
                false_ret_vars = list(false_ret_vars)
            if not isinstance(false_ret_vars, list):
                false_ret_vars = [false_ret_vars]
            false_block.set_outputs(false_ret_vars)
            self.blocks.append(false_block)
Example #6
0
    def build_block(self, block_inputs):
        block_name = self.name + '_block'
        with Block(block_inputs=block_inputs, outer_op=self,
                   name=block_name) as block:
            # Body func
            body_func = self._body.val
            exit_vars = body_func(*block.inputs)

            # Cond func:
            cond_func = self._cond.val
            cond_var = cond_func(*block.inputs)
            cond_vars = cond_var if isinstance(cond_var, list) else [cond_var]

            # Concatenate the outputs
            block.set_outputs(cond_vars + list(exit_vars))
        return block, exit_vars
def _try_apply_transform(
    reduce_op: Operation,
    block: Block,
    gamma_var: Var,
    beta_var: Var,
    epsilon_var: Var,
    end_op: Operation,
    ops_to_remove: List[Operation],
) -> bool:
    """
    Insert instance_norm / layer_norm and delete all ops.

    :param reduce_op: Start operation of the pattern.
    :param block: Block
    :param gamma_var: Gamma variable.
    :param beta_var: Beta variable.
    :param epsilon_var: Epsilon variable.
    :param end_op: End operation of the pattern.
    :param ops_to_remove: Operations to remove.
    """
    if not _check_no_output_connection(block, ops_to_remove):
        return False

    axes = reduce_op.axes.val
    rank = len(reduce_op.x.shape)

    # check whether the pattern is instance_norm or layer_norm
    is_layernorm = False
    is_instancenorm = False
    is_require_rank4_transpose = False

    negative_axes = [a - rank if a >= 0 else a for a in axes]
    negative_axes.sort()

    if len(gamma_var.val.shape) == len(axes) and len(
            beta_var.val.shape) == len(axes):
        # axes for layer_norm must be [-1] or [-1, -2] or [-1, -2, -3] and so on
        if negative_axes == list(range(-len(negative_axes), 0)):
            is_layernorm = True

    if rank == 4 and (negative_axes == [-2, -1] or negative_axes == [-3, -2]):
        if (len(np.squeeze(gamma_var.val).shape) == 1
                and len(np.squeeze(beta_var.val).shape) == 1):
            is_instancenorm = True
        if negative_axes == [-3, -2]:
            is_require_rank4_transpose = True

    if not (is_instancenorm or is_layernorm):
        return False

    # remove all the ops, and replace with a layer_norm or instance_norm op
    out_name = end_op.outputs[0].name

    if is_require_rank4_transpose:
        x = mb.transpose(
            x=reduce_op.x,
            perm=[0, 3, 1, 2],
            name=out_name + "_transpose_nhwc_nchw",
            before_op=end_op,
        )
    if is_instancenorm:
        x = mb.instance_norm(
            x=x if is_require_rank4_transpose else reduce_op.x,
            gamma=np.squeeze(gamma_var.val),
            beta=np.squeeze(beta_var.val),
            epsilon=epsilon_var,
            name=out_name +
            "_instancenorm" if is_require_rank4_transpose else out_name,
            before_op=end_op,
        )
    else:  # is_layernorm
        x = mb.layer_norm(
            x=x if is_require_rank4_transpose else reduce_op.x,
            axes=axes,
            gamma=gamma_var,
            beta=beta_var,
            epsilon=epsilon_var,
            name=out_name +
            "_layernorm" if is_require_rank4_transpose else out_name,
            before_op=end_op,
        )
    if is_require_rank4_transpose:
        x = mb.transpose(
            x=x,
            perm=[0, 2, 3, 1],
            name=out_name + "_transpose_nchw_nhwc",
            before_op=end_op,
        )

    end_op.enclosing_block.replace_uses_of_var_after_op(
        anchor_op=end_op, old_var=end_op.outputs[0], new_var=x)
    # Remove all the ops at once
    block.remove_ops(ops_to_remove)
    return True