Ejemplo n.º 1
0
 def extend(op: Node):
     einsum_name = op.soft_get('name', op.id)
     if isinstance(op['equation'], list):
         op['equation'] = ','.join(op['equation'])
     elif not isinstance(op['equation'], str):
         assert False, "Equation of Einsum node {} has incorrect format.".format(
             einsum_name)
Ejemplo n.º 2
0
    def extend(op: Node):
        for attr in StridedSlice.get_mask_names():
            # We can not use op.has_and_set(attr) here as a condition, because it will return False if begin/end is
            # 1D tensor and begin_mask/end_mask is equal to 0
            if op.has(attr) and op[attr] != '':
                Extender.attr_to_list(op, attr)
            else:
                assert attr not in ['begin_mask', 'end_mask'],\
                    '{} is not defined for the node {}'.format(attr, op.soft_get('name', op.id))
                op[attr] = int64_array([0])

        op.begin_mask = int64_array([1 - i for i in op.begin_mask])
        op.end_mask = int64_array([1 - i for i in op.end_mask])
Ejemplo n.º 3
0
def replace_with_hsigmoid(graph: Graph, first_node: Node, last_node: Node):
    # determine the input port of first and last nodes which gets the 'input' node output
    add_input_port_idx = int(
        first_node.in_port(0).get_connection().get_source().node.soft_get('op')
        == 'Const')
    last_node_name = last_node.soft_get('name', last_node.id)

    hsigmoid = HSigmoid(graph, {}).create_node()
    hsigmoid.in_port(0).connect(
        first_node.in_port(add_input_port_idx).get_source())
    last_node.out_port(0).get_connection().set_source(hsigmoid.out_port(0))

    rename_nodes([(last_node, last_node_name + '/TBR'),
                  (hsigmoid, last_node_name)])
Ejemplo n.º 4
0
 def attr_to_list(node: Node, attribute: str):
     if not node.has_valid(attribute):
         log.warning('Attribute {} missed in node {} with type {}!'.format(
             attribute, node.soft_get('name'), node.soft_get('type')))
     elif not isinstance(node[attribute], list):
         node[attribute] = [node[attribute]]