Ejemplo n.º 1
0
 def end_pass(self, **kwargs):
     super(SSAConversion, self).end_pass(**kwargs)
     for source_tensor_decl, current_exop in iteritems(self.tensor_map):
         if current_exop.output_decls[0].tensor_decl is source_tensor_decl:
             continue
         if not source_tensor_decl.is_output:
             continue
         copy_exop = ExOp(computation_decl=self.computation_decl,
                          create_value=False,
                          op=WriteOp(axes=[]))
         copy_exop.add_write_arg(source_tensor_decl.exop.output_decls[0])
         copy_exop.add_input_decl(current_exop.output_decls[0])
         self.exop_block.add_exop(copy_exop)
Ejemplo n.º 2
0
 def visit_exop(self, exop, tensor_input_decl, value_input_decl):
     source_tensor = tensor_input_decl.source_output_decl.tensor_decl.source_tensor
     current_exop = self.current_exop(exop, source_tensor)
     write_exop = ExOp(computation_decl=self.computation_decl,
                       op=WriteOp(axes=current_exop.op.axes))
     write_tensor_decl = write_exop.output_decls[0].tensor_decl
     write_tensor_decl.source_tensor = source_tensor
     write_exop.add_write_arg(write_exop.output_decls[0])
     write_exop.add_input_decl(current_exop.output_decls[0])
     write_exop.add_write_arg(
         write_exop.output_decls[0],
         tensor_input_decl.tensor_view_decl.tensor_description)
     write_exop.add_input_decl(value_input_decl.source_output_decl)
     self.exop_block.replace_exop(exop, write_exop)
     self.tensor_map[source_tensor] = write_exop