def visit_exop(self, exop, tensor_input_decl, value_input_decl): source_tensor = tensor_input_decl.source_output_decl.tensor_decl.source_tensor current_exop = self.current_exop(exop, source_tensor) write_exop = ExOp(computation_decl=self.computation_decl, op=WriteOp(axes=current_exop.op.axes)) write_tensor_decl = write_exop.output_decls[0].tensor_decl write_tensor_decl.source_tensor = source_tensor write_exop.add_write_arg(write_exop.output_decls[0]) write_exop.add_input_decl(current_exop.output_decls[0]) write_exop.add_write_arg( write_exop.output_decls[0], tensor_input_decl.tensor_view_decl.tensor_description) write_exop.add_input_decl(value_input_decl.source_output_decl) self.exop_block.replace_exop(exop, write_exop) self.tensor_map[source_tensor] = write_exop
def end_pass(self, **kwargs): super(SSAConversion, self).end_pass(**kwargs) for source_tensor_decl, current_exop in iteritems(self.tensor_map): if current_exop.output_decls[0].tensor_decl is source_tensor_decl: continue if not source_tensor_decl.is_output: continue copy_exop = ExOp(computation_decl=self.computation_decl, create_value=False, op=WriteOp(axes=[])) copy_exop.add_write_arg(source_tensor_decl.exop.output_decls[0]) copy_exop.add_input_decl(current_exop.output_decls[0]) self.exop_block.add_exop(copy_exop)
def current_exop(self, exop, source_tensor): current_exop = self.tensor_map.get(source_tensor, None) if current_exop is None: current_exop = ExOp( computation_decl=self.computation_decl, create_value=False, op=ReadOp(axes=source_tensor.tensor_description_base.axes)) current_exop.add_ref_op(exop.op) current_exop.add_output_decl(source_tensor) source_tensor.exop = exop self.exop_block.add_exop(current_exop, exop.prev_exop) self.tensor_map[source_tensor] = current_exop return current_exop