Esempio n. 1
0
def arg_ops_infer(node: Node):
    shape = node.in_port(0).data.get_shape()
    node_name = node.soft_get('name', node.id)
    assert shape is not None, "Input shape for the node {} is None".format(node_name)

    # there are two inputs in TensorFlow. The second input is the axis for ArgMax
    connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
    if len(connected_in_ports) == 2:
        axis = node.in_port(1).data.get_value()
        if axis is None:
            log.debug('The second argument to {} is None'.format(node.soft_get('name', node.id)))
            return
        node.axis = axis
        # remove the unnecessary input
        node.in_port(1).disconnect()

    num_top_axes = shape.size
    if num_top_axes < 3:
        num_top_axes = 3

    out_shape = np.ones(num_top_axes, dtype=np.int64)

    if node.has_valid('axis'):
        axis = get_canonical_axis_index(shape, node.axis)
        node.axis = axis
        out_shape = shape.copy()
        out_shape[axis] = node.top_k
        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
    else:
        out_shape[0] = shape[0]
        out_shape[2] = node.top_k
        if node.has_and_set('out_max_val'):
            out_shape[1] = 2

    node.out_port(0).data.set_shape(out_shape)
Esempio n. 2
0
    def regionyolo_infer(node: Node):
        input_shape = node.in_port(0).data.get_shape()
        axis = get_canonical_axis_index(input_shape, node.axis)
        end_axis = get_canonical_axis_index(input_shape, node.end_axis)
        node.axis = axis
        node.end_axis = end_axis
        if node.do_softmax:
            dims_to_flatten = input_shape[axis:end_axis + 1]
            if is_fully_defined(dims_to_flatten):
                flat_dim = np.ma.prod(dims_to_flatten)
            else:
                flat_dim = dynamic_dimension_value
            node.out_port(0).data.set_shape(
                [*input_shape[:axis], flat_dim, *input_shape[end_axis + 1:]])
        else:
            layout = node.graph.graph['layout']
            assert len(layout) == 4

            node.out_port(0).data.set_shape(
                shape_for_layout(layout,
                                 batch=input_shape[get_batch_dim(layout, 4)],
                                 features=(node.classes + node.coords + 1) *
                                 len(node.mask),
                                 height=input_shape[get_height_dim(layout, 4)],
                                 width=input_shape[get_width_dim(layout, 4)]))
    def test_lift_down_through_transpose_negative_axis(self):
        graph = build_graph(nodes3, [
            *connect('placeholder', 'reverse_channels_up'),
            *connect('transpose_order', '1:transpose'),
            *connect('reverse_channels_up', '0:transpose'),
            *connect('transpose', 'result')
        ])
        graph_ref = build_graph(nodes3, [
            *connect('placeholder', '0:transpose'),
            *connect('transpose_order', '1:transpose'),
            *connect('transpose', 'reverse_channels_up'),
            *connect('reverse_channels_up', '0:result')
        ])
        self.set_graph_attrs(graph, ['placeholder'])

        node = Node(graph, 'transpose')
        reverse_channels = Node(graph, 'reverse_channels_up')
        reverse_channels.axis = int64_array(-1)

        keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose(
            node, reverse_channels)

        self.assertTrue(keep_moving_down is True)
        self.check_graph_attrs(graph, ['placeholder'])
        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)

        reverse_channels = Node(graph, 'reverse_channels_down')
        self.assertTrue(reverse_channels.axis == 1)
        self.assertTrue(type(reverse_channels.axis) == np.ndarray)
Esempio n. 4
0
 def infer(node: Node):
     assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\
         'LogSoftmax node with id {} have more than one port connected'.format(node.id)
     if node.axis < 0:
         node.axis = len(node.in_port(0).data.get_shape()) + node.axis
     assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\
         'LogSoftmax node with id {} has wrong axis attribute'.format(node.id)
     copy_shape_infer(node)
     PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
Esempio n. 5
0
    def lift_up_through_transpose(node: Node, reverse_channels: Node):
        if node.in_port(1).disconnected() or node.in_port(0).disconnected():
            return False
        order = node.in_port(1).data.get_value()
        reverse_axis = reverse_channels.axis

        data_rank = len(list(node.in_port(0).data.get_shape()))

        if reverse_axis < 0:
            reverse_axis = data_rank + reverse_axis
        assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels)

        if order is None:
            return False
        new_axis = order[reverse_axis]
        reverse_channels.axis = int64_array(new_axis)
        return ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
Esempio n. 6
0
    def infer(node: Node):
        name = node.soft_get('name', node.id)

        connected_in_ports = {
            idx: port
            for idx, port in node.in_ports().items()
            if not port.disconnected()
        }
        assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \
            "AttributedGather should have 2 connected input port, but it doesn't for node: `{}`. Ports: {}" \
            "".format(name, connected_in_ports)

        axis = node.soft_get('axis', None)
        assert axis is not None

        data_shape = node.in_port(0).data.get_shape()
        assert data_shape is not None
        indices_shape = node.in_port(1).data.get_shape()
        assert indices_shape is not None

        # Convert negative axis
        axis = get_canonical_axis_index(data_shape, axis)
        node.axis = axis

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])

        data_value = node.in_port(0).data.get_value()
        indices_value = node.in_port(1).data.get_value()
        if data_value is not None and indices_value is not None:
            node.out_port(0).data.set_value(
                mo_array(np.take(data_value, indices_value, axis),
                         dtype=data_value.dtype))
            return

        shape = np.concatenate((data_shape[:axis], indices_shape))
        if axis < len(data_shape) - 1:
            shape = np.concatenate((shape, data_shape[axis + 1:]))

        node.out_port(0).data.set_shape(int64_array(shape))
Esempio n. 7
0
    def _two_inputs_infer(node: Node):
        N = len(node.in_nodes())
        node_name = node.soft_get('name', node.id)

        shapes = [node.in_port(i).data.get_shape() for i in range(N)]
        if any(s is None for s in shapes):
            raise Error('Not all input shapes were defined for {} node'.format(node_name))

        if not node.has_valid('axis'):
            raise Error('axis attribute is missing for {} node. should be set in crop extractor'.format(node_name))

        if not node.has_valid('offset'):
            raise Error('offset attribute is missing for {} node. should be set in crop extractor'.format(node_name))

        input_shape = shapes[0].copy()
        start_axis = get_canonical_axis_index(input_shape, node.axis)
        node.axis = start_axis

        reference_shape = shapes[1].copy()
        if node.has_valid('axes'):
            # The axes parameter  contain shape indexes for second input and show which shape indexes we need to use for
            # dim attribute.
            input_dim = node.axes
            node.in_port(1).disconnect()
        else:
            input_dim = list(range(0, input_shape.size))

        # set new shape to current shape
        new_shape = input_shape.copy()
        ir_axis = []
        ir_offset = []
        dim = []

        for i in input_dim:
            if i < start_axis:
                new_shape[i] = input_shape[i]
                continue

            crop_offset = 0
            if len(node.offset) == 1:
                crop_offset = node.offset[0]
            elif len(node.offset) > 1:
                crop_offset = node.offset[i - start_axis]

            if input_shape[i] - crop_offset < reference_shape[i]:
                raise Error('The crop for dimension is out of bounds in node {}'.format(node_name))

            dim.append(reference_shape[i])
            ir_axis.append(i)
            ir_offset.append(crop_offset)
            new_shape[i] = reference_shape[i]

        node.axis = ir_axis
        node.offset = ir_offset
        node['dim'] = dim
        node.out_port(0).data.set_shape(new_shape)

        if node.in_node(0).has_valid('value') and \
                not getattr(node.graph.graph['cmd_params'], 'enable_ssd_gluoncv', False):
            out_value = np.copy(node.in_node(0).value)

            slice_indexes = []
            for s in out_value.shape:
                slice_indexes.append(slice(0, s))

            for axis in input_dim:
                slice_indexes[axis] = slice(0, new_shape[axis])
                out_value = out_value[tuple(slice_indexes)]
            node.out_port(0).data.set_value(out_value)

        PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])