def arg_ops_infer(node: Node): shape = node.in_port(0).data.get_shape() node_name = node.soft_get('name', node.id) assert shape is not None, "Input shape for the node {} is None".format(node_name) # there are two inputs in TensorFlow. The second input is the axis for ArgMax connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] if len(connected_in_ports) == 2: axis = node.in_port(1).data.get_value() if axis is None: log.debug('The second argument to {} is None'.format(node.soft_get('name', node.id))) return node.axis = axis # remove the unnecessary input node.in_port(1).disconnect() num_top_axes = shape.size if num_top_axes < 3: num_top_axes = 3 out_shape = np.ones(num_top_axes, dtype=np.int64) if node.has_valid('axis'): axis = get_canonical_axis_index(shape, node.axis) node.axis = axis out_shape = shape.copy() out_shape[axis] = node.top_k PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) else: out_shape[0] = shape[0] out_shape[2] = node.top_k if node.has_and_set('out_max_val'): out_shape[1] = 2 node.out_port(0).data.set_shape(out_shape)
def regionyolo_infer(node: Node): input_shape = node.in_port(0).data.get_shape() axis = get_canonical_axis_index(input_shape, node.axis) end_axis = get_canonical_axis_index(input_shape, node.end_axis) node.axis = axis node.end_axis = end_axis if node.do_softmax: dims_to_flatten = input_shape[axis:end_axis + 1] if is_fully_defined(dims_to_flatten): flat_dim = np.ma.prod(dims_to_flatten) else: flat_dim = dynamic_dimension_value node.out_port(0).data.set_shape( [*input_shape[:axis], flat_dim, *input_shape[end_axis + 1:]]) else: layout = node.graph.graph['layout'] assert len(layout) == 4 node.out_port(0).data.set_shape( shape_for_layout(layout, batch=input_shape[get_batch_dim(layout, 4)], features=(node.classes + node.coords + 1) * len(node.mask), height=input_shape[get_height_dim(layout, 4)], width=input_shape[get_width_dim(layout, 4)]))
def test_lift_down_through_transpose_negative_axis(self): graph = build_graph(nodes3, [ *connect('placeholder', 'reverse_channels_up'), *connect('transpose_order', '1:transpose'), *connect('reverse_channels_up', '0:transpose'), *connect('transpose', 'result') ]) graph_ref = build_graph(nodes3, [ *connect('placeholder', '0:transpose'), *connect('transpose_order', '1:transpose'), *connect('transpose', 'reverse_channels_up'), *connect('reverse_channels_up', '0:result') ]) self.set_graph_attrs(graph, ['placeholder']) node = Node(graph, 'transpose') reverse_channels = Node(graph, 'reverse_channels_up') reverse_channels.axis = int64_array(-1) keep_moving_down = ReverseChannelsPropagationDown.pass_rc_through_transpose( node, reverse_channels) self.assertTrue(keep_moving_down is True) self.check_graph_attrs(graph, ['placeholder']) (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) reverse_channels = Node(graph, 'reverse_channels_down') self.assertTrue(reverse_channels.axis == 1) self.assertTrue(type(reverse_channels.axis) == np.ndarray)
def infer(node: Node): assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 1,\ 'LogSoftmax node with id {} have more than one port connected'.format(node.id) if node.axis < 0: node.axis = len(node.in_port(0).data.get_shape()) + node.axis assert 0 <= node.axis < len(node.in_port(0).data.get_shape()),\ 'LogSoftmax node with id {} has wrong axis attribute'.format(node.id) copy_shape_infer(node) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
def lift_up_through_transpose(node: Node, reverse_channels: Node): if node.in_port(1).disconnected() or node.in_port(0).disconnected(): return False order = node.in_port(1).data.get_value() reverse_axis = reverse_channels.axis data_rank = len(list(node.in_port(0).data.get_shape())) if reverse_axis < 0: reverse_axis = data_rank + reverse_axis assert 0 < reverse_axis < data_rank, "Incorrect ReverseChannels axis in node {}.".format(reverse_channels) if order is None: return False new_axis = order[reverse_axis] reverse_channels.axis = int64_array(new_axis) return ReverseChannelsPropagationUp.lift_up_through_zero_port_only(node, reverse_channels)
def infer(node: Node): name = node.soft_get('name', node.id) connected_in_ports = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } assert len(connected_in_ports) == 2 and 0 in connected_in_ports and 1 in connected_in_ports, \ "AttributedGather should have 2 connected input port, but it doesn't for node: `{}`. Ports: {}" \ "".format(name, connected_in_ports) axis = node.soft_get('axis', None) assert axis is not None data_shape = node.in_port(0).data.get_shape() assert data_shape is not None indices_shape = node.in_port(1).data.get_shape() assert indices_shape is not None # Convert negative axis axis = get_canonical_axis_index(data_shape, axis) node.axis = axis PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')]) data_value = node.in_port(0).data.get_value() indices_value = node.in_port(1).data.get_value() if data_value is not None and indices_value is not None: node.out_port(0).data.set_value( mo_array(np.take(data_value, indices_value, axis), dtype=data_value.dtype)) return shape = np.concatenate((data_shape[:axis], indices_shape)) if axis < len(data_shape) - 1: shape = np.concatenate((shape, data_shape[axis + 1:])) node.out_port(0).data.set_shape(int64_array(shape))
def _two_inputs_infer(node: Node): N = len(node.in_nodes()) node_name = node.soft_get('name', node.id) shapes = [node.in_port(i).data.get_shape() for i in range(N)] if any(s is None for s in shapes): raise Error('Not all input shapes were defined for {} node'.format(node_name)) if not node.has_valid('axis'): raise Error('axis attribute is missing for {} node. should be set in crop extractor'.format(node_name)) if not node.has_valid('offset'): raise Error('offset attribute is missing for {} node. should be set in crop extractor'.format(node_name)) input_shape = shapes[0].copy() start_axis = get_canonical_axis_index(input_shape, node.axis) node.axis = start_axis reference_shape = shapes[1].copy() if node.has_valid('axes'): # The axes parameter contain shape indexes for second input and show which shape indexes we need to use for # dim attribute. input_dim = node.axes node.in_port(1).disconnect() else: input_dim = list(range(0, input_shape.size)) # set new shape to current shape new_shape = input_shape.copy() ir_axis = [] ir_offset = [] dim = [] for i in input_dim: if i < start_axis: new_shape[i] = input_shape[i] continue crop_offset = 0 if len(node.offset) == 1: crop_offset = node.offset[0] elif len(node.offset) > 1: crop_offset = node.offset[i - start_axis] if input_shape[i] - crop_offset < reference_shape[i]: raise Error('The crop for dimension is out of bounds in node {}'.format(node_name)) dim.append(reference_shape[i]) ir_axis.append(i) ir_offset.append(crop_offset) new_shape[i] = reference_shape[i] node.axis = ir_axis node.offset = ir_offset node['dim'] = dim node.out_port(0).data.set_shape(new_shape) if node.in_node(0).has_valid('value') and \ not getattr(node.graph.graph['cmd_params'], 'enable_ssd_gluoncv', False): out_value = np.copy(node.in_node(0).value) slice_indexes = [] for s in out_value.shape: slice_indexes.append(slice(0, s)) for axis in input_dim: slice_indexes[axis] = slice(0, new_shape[axis]) out_value = out_value[tuple(slice_indexes)] node.out_port(0).data.set_value(out_value) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])