def build_nested_blocks(self): # If the front end is milproto, we already have the well constructed cond/body block. # For this case, we set self.blocks directly. # We also check that _cond and _body are both dummy functions (return None). if self._existing_blocks is not None and self._existing_blocks.val is not None: assert self._true_fn.val([]) is None assert self._false_fn.val([]) is None self.blocks = self._existing_blocks.val return # Cond block true_block_name = self.name + "_true" with Block(name=true_block_name, outer_op=self) as true_block: true_func = self._true_fn.val true_ret_vars = true_func() if isinstance(true_ret_vars, tuple): true_ret_vars = list(true_ret_vars) if not isinstance(true_ret_vars, list): true_ret_vars = [true_ret_vars] true_block.set_outputs(true_ret_vars) self.blocks.append(true_block) false_block_name = self.name + "_false" with Block(name=false_block_name, outer_op=self) as false_block: false_func = self._false_fn.val false_ret_vars = false_func() if isinstance(false_ret_vars, tuple): false_ret_vars = list(false_ret_vars) if not isinstance(false_ret_vars, list): false_ret_vars = [false_ret_vars] false_block.set_outputs(false_ret_vars) self.blocks.append(false_block)
def _create_nested_blocks(context, op_spec): """ An utility function that creates nested blocks for control flow ops. """ if not op_spec.blocks: return [] blocks = [] for block_spec in op_spec.blocks: input_vars = [_create_var_from_spec(input) for input in block_spec.inputs] # add block input vars to the context for v in input_vars: context.register_var_with_name(v.name, v) # In pymil, the outer_op for a block can only be None if the block is a Functino. # As the result, we use a dummy outer_op here for block creation, and set it to # the legit op later on in _set_outer_op_for_nested_blocks dummy = mb.const(val=0.) with Block(block_inputs=input_vars, outer_op=dummy._op, name=Block._get_new_name()) as block: _load_block(context, block_spec) blocks.append(block) return blocks
def build_nested_blocks(self): # self.loop_vars is python tuple of Vars # Cond block block_name = self.name + "_block" with Block(block_inputs=self.loop_vars, outer_op=self, name=block_name) as block: # Body func body_func = self._body.val exit_vars = body_func(*block.inputs) # Cond func: cond_func = self._cond.val cond_var = cond_func(*block.inputs) cond_vars = cond_var if isinstance(cond_var, list) else [cond_var] # Concatenate the outputs block.set_outputs(cond_vars + list(exit_vars)) self.blocks.append(block) # Verify exit_vars has the same types as loop_vars for v_in, v_out in zip(self.loop_vars, exit_vars): if not is_subtype(v_out.sym_type, v_in.sym_type): msg = ("loop_vars '{}' changes in the body of " "while_loop '{}':\n {} -> {}") raise ValueError( msg.format(v_in.name, self.name, v_in.sym_type, v_out.sym_type))
def _build_block(self, block_inputs): # Cond block: block_name = self.name + '_cond_block' with Block(block_inputs=block_inputs, outer_op=self, name=block_name) as cond_block: cond_func = self._cond.val cond_var = cond_func(*cond_block.inputs) cond_vars = cond_var if isinstance(cond_var, list) else [cond_var] cond_block.set_outputs(cond_vars) # Body block block_name = self.name + '_body_block' with Block(block_inputs=block_inputs, outer_op=self, name=block_name) as body_block: body_func = self._body.val exit_vars = body_func(*body_block.inputs) exit_vars = list(exit_vars) if isinstance(exit_vars, (list, tuple)) \ else [exit_vars] body_block.set_outputs(exit_vars) return cond_block, body_block, exit_vars
def build_nested_blocks(self): # Cond block true_block_name = self.name + "_true" with Block(name=true_block_name, outer_op=self) as true_block: true_func = self._true_fn.val true_ret_vars = true_func() if isinstance(true_ret_vars, tuple): true_ret_vars = list(true_ret_vars) if not isinstance(true_ret_vars, list): true_ret_vars = [true_ret_vars] true_block.set_outputs(true_ret_vars) self.blocks.append(true_block) false_block_name = self.name + "_false" with Block(name=false_block_name, outer_op=self) as false_block: false_func = self._false_fn.val false_ret_vars = false_func() if isinstance(false_ret_vars, tuple): false_ret_vars = list(false_ret_vars) if not isinstance(false_ret_vars, list): false_ret_vars = [false_ret_vars] false_block.set_outputs(false_ret_vars) self.blocks.append(false_block)
def build_block(self, block_inputs): block_name = self.name + '_block' with Block(block_inputs=block_inputs, outer_op=self, name=block_name) as block: # Body func body_func = self._body.val exit_vars = body_func(*block.inputs) # Cond func: cond_func = self._cond.val cond_var = cond_func(*block.inputs) cond_vars = cond_var if isinstance(cond_var, list) else [cond_var] # Concatenate the outputs block.set_outputs(cond_vars + list(exit_vars)) return block, exit_vars
def _try_apply_transform( reduce_op: Operation, block: Block, gamma_var: Var, beta_var: Var, epsilon_var: Var, end_op: Operation, ops_to_remove: List[Operation], ) -> bool: """ Insert instance_norm / layer_norm and delete all ops. :param reduce_op: Start operation of the pattern. :param block: Block :param gamma_var: Gamma variable. :param beta_var: Beta variable. :param epsilon_var: Epsilon variable. :param end_op: End operation of the pattern. :param ops_to_remove: Operations to remove. """ if not _check_no_output_connection(block, ops_to_remove): return False axes = reduce_op.axes.val rank = len(reduce_op.x.shape) # check whether the pattern is instance_norm or layer_norm is_layernorm = False is_instancenorm = False is_require_rank4_transpose = False negative_axes = [a - rank if a >= 0 else a for a in axes] negative_axes.sort() if len(gamma_var.val.shape) == len(axes) and len( beta_var.val.shape) == len(axes): # axes for layer_norm must be [-1] or [-1, -2] or [-1, -2, -3] and so on if negative_axes == list(range(-len(negative_axes), 0)): is_layernorm = True if rank == 4 and (negative_axes == [-2, -1] or negative_axes == [-3, -2]): if (len(np.squeeze(gamma_var.val).shape) == 1 and len(np.squeeze(beta_var.val).shape) == 1): is_instancenorm = True if negative_axes == [-3, -2]: is_require_rank4_transpose = True if not (is_instancenorm or is_layernorm): return False # remove all the ops, and replace with a layer_norm or instance_norm op out_name = end_op.outputs[0].name if is_require_rank4_transpose: x = mb.transpose( x=reduce_op.x, perm=[0, 3, 1, 2], name=out_name + "_transpose_nhwc_nchw", before_op=end_op, ) if is_instancenorm: x = mb.instance_norm( x=x if is_require_rank4_transpose else reduce_op.x, gamma=np.squeeze(gamma_var.val), beta=np.squeeze(beta_var.val), epsilon=epsilon_var, name=out_name + "_instancenorm" if is_require_rank4_transpose else out_name, before_op=end_op, ) else: # is_layernorm x = mb.layer_norm( x=x if is_require_rank4_transpose else reduce_op.x, axes=axes, gamma=gamma_var, beta=beta_var, epsilon=epsilon_var, name=out_name + "_layernorm" if is_require_rank4_transpose else out_name, before_op=end_op, ) if is_require_rank4_transpose: x = mb.transpose( x=x, perm=[0, 2, 3, 1], name=out_name + "_transpose_nchw_nhwc", before_op=end_op, ) end_op.enclosing_block.replace_uses_of_var_after_op( anchor_op=end_op, old_var=end_op.outputs[0], new_var=x) # Remove all the ops at once block.remove_ops(ops_to_remove) return True