def replace_sub_graph(self, graph: Graph, match: Dict[str, Node]):
        node = match['op']
        name = node.name

        min_port_tuple = (node.in_port(1).get_source().node,
                          node.in_port(1).get_source().idx)
        max_port_tuple = (node.in_port(2).get_source().node,
                          node.in_port(2).get_source().idx)

        node.in_port(1).disconnect()
        node.in_port(2).disconnect()

        # make sure min < max
        min_less_max = Less(graph, {
            'name': name + '/if_min_less_max'
        }).create_node([min_port_tuple, max_port_tuple])
        minimum = Select(graph, {
            'name': name + '/minimum'
        }).create_node([min_less_max, min_port_tuple, max_port_tuple])
        maximum = Select(graph, {
            'name': name + '/maximum'
        }).create_node([min_less_max, max_port_tuple, min_port_tuple])

        # to create zero of limits data type, we multiply it by integer zero
        zero = create_op_node_with_second_input(graph,
                                                Mul,
                                                int64_array(0),
                                                {'name': name + '/zero'},
                                                input_node=minimum)

        # if 0 < min < max: min_adj = 0 and max_adj = max - min
        min_greater_zero = Greater(graph, {
            'name': name + '/if_minimum_greater_zero'
        }).create_node([minimum, zero])
        max_minus_min = Sub(graph, {
            'name': name + '/max_minus_min'
        }).create_node([maximum, minimum])
        minimum = Select(graph, {
            'name': name + '/first_adj_min'
        }).create_node([min_greater_zero, zero, minimum])
        maximum = Select(graph, {
            'name': name + '/first_adj_max'
        }).create_node([min_greater_zero, max_minus_min, maximum])

        # if min < max < 0: min_adj = min - max and max_adj = 0
        max_less_zero = Less(graph, {
            'name': name + '/if_max_less_zero'
        }).create_node([maximum, zero])
        min_minus_max = Sub(graph, {
            'name': name + '/min_minus_max'
        }).create_node([minimum, maximum])
        minimum = Select(graph, {
            'name': name + '/second_adj_min'
        }).create_node([max_less_zero, min_minus_max, minimum])
        maximum = Select(graph, {
            'name': name + '/second_adj_max'
        }).create_node([max_less_zero, zero, maximum])

        # scale = (max - min) / (2 ^ num_bits - 1),
        float_range = Sub(graph, {
            'name': name + '/float_range'
        }).create_node([maximum, minimum])
        quant_min_value, quant_max_value = int(
            node.narrow_range), 2**node.num_bits - 1
        int_range = Const(
            graph,
            dict(name=name + '/int_range',
                 value=quant_max_value - quant_min_value)).create_node()
        scale = Div(graph, {
            'name': name + '/scale'
        }).create_node([float_range, int_range])
        # min_adj = scale * round(min / scale)
        descaled_min = Div(graph, {
            'name': name + '/descaled_min'
        }).create_node([minimum, scale])
        rounded_descaled_min = Round(graph, {
            'name': name + '/rounded_descaled_min'
        }).create_node([descaled_min])
        min_adj = Mul(graph, {
            'name': name + '/min_adj'
        }).create_node([scale, rounded_descaled_min])
        # max_adj = max + min_adj - min.
        adjustment = Sub(graph, {
            'name': name + '/limits_adjustment'
        }).create_node([min_adj, minimum])
        max_adj = Add(graph, {
            'name': name + '/max_adj'
        }).create_node([maximum, adjustment])

        # FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow
        node.add_input_port(3, skip_if_exist=True)
        node.add_input_port(4, skip_if_exist=True)

        node.in_port(1).connect(min_adj.out_port(0))
        node.in_port(2).connect(max_adj.out_port(0))
        node.in_port(3).connect(min_adj.out_port(0))
        node.in_port(4).connect(max_adj.out_port(0))

        FakeQuantize.update_node_stat(node, {'levels': node['levels']})
예제 #2
0
    def replace_pattern(self, graph: Graph, match: dict):
        assert match['operator'].has('multiplication_transparent_ports')

        quantize = match['quantize']

        port = match['operator'].input_ports_with(match['quantized'])
        assert len(port) >= 1
        if len(port) > 1:
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more'
                ' than once'.format(match['quantized'].name))
            return

        assert len(port) == 1
        port = port[0]
        applicable = [
            pair for pair in match['operator'].multiplication_transparent_ports
            if pair[0] == port
        ]
        if len(applicable) == 0:
            return

        # Look at 3-rd and 4-th inputs of FakeQuantize -- they have constants that should be passed through.
        # Assume that the constant that should be passed through is a scalar.
        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)
        assert len(output_low.out_nodes()) == 1
        assert len(output_high.out_nodes()) == 1

        if not output_low.has_valid('value') and not output_high.has_valid(
                'value'):
            return

        output_low = output_low.value
        output_high = output_high.value

        operator = match['operator']

        weights = operator.in_node(1).value
        weights_rounded = np.round(weights)
        weights_consistent = np.all(np.isclose(weights, weights_rounded)) and \
                             set(np.unique(weights_rounded)).issubset({-1, 1})

        if weights_consistent and np.all(np.isclose(output_low, 0)) and np.all(
                np.isclose(output_high, 1)):
            reduction_indices = set(range(len(weights.shape))) - set(
                [operator.output_feature_channel])
            weights_reduced = np.add.reduce(weights,
                                            axis=tuple(reduction_indices))
            weights_reduced = weights_reduced.reshape(
                [len(weights_reduced), 1, 1])  # FIXME: works for NCHW only

            add_term = Const(graph, {'value': weights_reduced}).create_node()
            add = Add(graph, {}).create_node()
            add.in_port(1).connect(add_term.out_port(0))
            mul_term = Const(graph, {'value': np.array(0.5)}).create_node()
            mul = Mul(graph, {}).create_node()
            mul.in_port(1).connect(mul_term.out_port(0))
            add.out_port(0).connect(mul.in_port(0))

            operator.out_port(0).get_connection().set_source(mul.out_port(0))
            add.in_port(0).connect(operator.out_port(0))

            operator['pad_value'] = float(-1.0)
        elif weights_consistent and np.all(np.isclose(
                output_low, -1)) and np.all(np.isclose(output_high, +1)):
            pass
        else:
            log.debug(
                'ConvToBinaryConv: cannot apply transformation because input range is neither in [0, +1] nor '
                'in [-1, +1].')
            return

        operator['type'] = 'BinaryConvolution'
        operator['mode'] = 'xnor-popcount'
        operator['pad_value'] = operator.soft_get('pad_value', float(0))
        operator['input'] = operator.in_node(0).shape[1]
        # Weights are not bit-packed yet; there should be a separate transformation to do that

        assert output_low.size == 1
        assert output_high.size == 1

        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)

        # Make sure that low/high values are exactly 0/1
        output_low.value = np.zeros(output_low.shape)
        output_high.value = np.ones(output_high.shape)
예제 #3
0
    def replace_pattern(self, graph: Graph, match: [str, Node]):
        node = match['crop']
        assert node.has_valid('axis')
        node.axis = self.list_to_ndarray(node.axis)

        in_shape = node.in_port(0).data.get_shape()
        shape_rank = in_shape.size
        axis_mask = int64_array(
            [1 if i in node.axis else 0 for i in range(shape_rank)])
        begin_mask = axis_mask.copy()
        end_mask = axis_mask.copy()

        if len(node.in_nodes()) == 2 and node.has_valid('offset'):
            # Crop Type 1
            begin = Const(graph, {
                'value':
                self.mask_normalizer(shape_rank, node.axis, node.offset)
            }).create_node()
            shape = Shape(graph, {
                'name': node.name + '/shape_of_crop'
            }).create_node()
            end = Add(graph, {'name': node.name + '/end'}).create_node()
            node.in_port(1).get_connection().get_source().connect(
                shape.in_port(0))
            node.in_port(1).disconnect()
            shape.out_port(0).connect(end.in_port(0))
            begin.out_port(0).connect(end.in_port(1))
        elif node.has_valid('dim') and node.has_valid('offset'):
            # Crop Type 2
            node.dim = self.list_to_ndarray(node.dim)
            node.offset = self.list_to_ndarray(node.offset)
            assert node.dim.size == node.offset.size == node.axis.size

            begin = Const(graph, {
                'value':
                self.mask_normalizer(shape_rank, node.axis, node.offset)
            }).create_node()
            end_values = np.array(
                [node.offset[i] + node.dim[i] for i in range(len(node.dim))])
            end = Const(graph, {
                'value':
                self.mask_normalizer(shape_rank, node.axis, end_values)
            }).create_node()
        elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
            # Crop Type 3
            node.crop_begin = self.list_to_ndarray(node.crop_begin)
            node.crop_end = self.list_to_ndarray(node.crop_end)
            assert len(node.crop_begin) == len(node.crop_end) == len(node.axis)

            begin = Const(
                graph, {
                    'value':
                    self.mask_normalizer(shape_rank, node.axis,
                                         node.crop_begin)
                }).create_node()
            shape = Shape(graph, {
                'name': node.name + '/shape_of_crop'
            }).create_node()
            const = Const(
                graph, {
                    'value':
                    -1 *
                    self.mask_normalizer(shape_rank, node.axis, node.crop_end)
                }).create_node()
            end = Add(graph, {'name': node.name + '/end'}).create_node()

            node.in_port(0).get_connection().get_source().connect(
                shape.in_port(0))
            shape.out_port(0).connect(end.in_port(0))
            const.out_port(0).connect(end.in_port(1))

        else:
            raise Exception("Unknown type of Crop")

        source = node.in_port(0).get_connection().get_source()

        stride = Const(graph, {
            'value': np.ones(shape_rank, dtype=np.int64)
        }).create_node()
        ss = StridedSlice(
            graph, {
                'name': 'Crop_',
                'begin_mask': begin_mask,
                'end_mask': end_mask,
                'new_axis_mask': np.array([0]),
                'shrink_axis_mask': np.array([0]),
                'ellipsis_mask': np.array([0])
            }).create_node()

        source.connect(ss.in_port(0))
        begin.out_port(0).connect(ss.in_port(1))
        end.out_port(0).connect(ss.in_port(2))
        stride.out_port(0).connect(ss.in_port(3))

        node.in_port(0).disconnect()
        node.out_port(0).get_connection().set_source(ss.out_port(0))

        ss['force_precision_in_ports'] = {1: 'int64', 2: 'int64', 3: 'int64'}
예제 #4
0
    def replace_op(self, graph: Graph, node: Node):
        # split input to (i_part, f_part, c_part, o_part, ct_1)
        node_name = node.soft_get('name', node.id)
        split_node = create_op_with_const_inputs(
            graph, Split, {1: np.int64(1)}, {
                'name': node_name + '/split_lstm_input',
                'num_splits': 5
            })
        node.in_port(0).get_connection().set_destination(split_node.in_port(0))

        # i_t = Sigmoid(i_part + w_ic*ct_1)
        i_scale_attrs = {
            'name': node_name + '/i_scaleshift',
            'bias_term': False
        }
        i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
        input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
        split_node.out_port(4).connect(i_scale.in_port(0))

        sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node()
        split_node.out_port(0).connect(sum_i_c.in_port(0))
        i_scale.out_port(0).connect(sum_i_c.in_port(1))

        i_sigmoid = Sigmoid(graph, {
            'name': node_name + '/i_sigmoid'
        }).create_node()
        sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

        # f_t = Sigmoid(f_part + w_fc*ct_1)
        f_scale_attrs = {
            'name': node_name + '/f_scaleshift',
            'bias_term': False
        }
        f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
        input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
        split_node.out_port(4).connect(f_scale.in_port(0))

        sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node()
        split_node.out_port(1).connect(sum_f_c.in_port(0))
        f_scale.out_port(0).connect(sum_f_c.in_port(1))

        f_sigmoid = Sigmoid(graph, {
            'name': node_name + '/f_sigmoid'
        }).create_node()
        sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

        # c_t = f_t*ct_1 + i_t * tanh(c_part)
        c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node()
        split_node.out_port(2).connect(c_tanh.in_port(0))

        prod_i_c_tanh = Mul(graph, {
            'name': node_name + '/prod_i_c_tanh_'
        }).create_node()
        i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
        c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

        prod_f_ct_1 = Mul(graph, {
            'name': node_name + '/prod_f_ct_1_'
        }).create_node()
        f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
        split_node.out_port(4).connect(prod_f_ct_1.in_port(1))

        sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node()
        prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
        prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

        #  o_t = Sigmoid(o_part + w_oc*c_t)
        o_scale_attrs = {
            'name': node_name + '/o_scaleshift',
            'bias_term': False
        }
        o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
        input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
        sum_f_i.out_port(0).connect(o_scale.in_port(0))

        sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node()
        split_node.out_port(3).connect(sum_o_c.in_port(0))
        o_scale.out_port(0).connect(sum_o_c.in_port(1))

        o_sigmoid = Sigmoid(graph, {
            'name': node_name + '/o_sigmoid'
        }).create_node()
        sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

        # m_t = o_t * Tanh(c_t)
        c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node()
        sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

        prod_o_c_t_tanh = Mul(graph, {
            'name': node_name + '/prod_o_c_t_tanh_'
        }).create_node()
        o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
        c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

        # add concat to create 1 output
        concat = Concat(graph, {
            'name': node_name + '/concat_c_m'
        }).create_node()
        concat.add_sequence_of_ports('in', range(2))
        sum_f_i.out_port(0).connect(concat.in_port(0))
        prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

        return [concat.id]
예제 #5
0
    def replace_op(self, graph: Graph, node: Node):
        input_out_port = node.in_port(0).get_source()

        memory_pair_input = unique_id('id')
        memory_pair_output = unique_id('id')

        # Input -> FullyConnected
        fc_layer_after_input_attrs = {
            'name': 'input_fullyconnected',
            'out-size': node.gifo_x_weights_shape[0],
            'transpose_weights': True,
            'bias_term': True,
        }

        fc_layer_after_input = FullyConnected(
            graph, fc_layer_after_input_attrs).create_node()
        fc_layer_after_input.in_port(0).connect(input_out_port)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1,
                       'weights', node.gifo_x_weights)
        input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2,
                       'biases', node.gifo_biases)

        init_value_prev_lstm_output = create_zero_value_with_batch_from_input(
            input_out_port, node.gifo_r_weights_shape[1])
        prev_lstm_output = ReadValue(graph, {
            'name': 'prev_memory_output',
            'variable_id': memory_pair_input
        }).create_node()
        prev_lstm_output.in_port(0).connect(
            init_value_prev_lstm_output.out_port(0))

        # *Memory(output) -> FullyConnected
        fc_layer_from_prev_state_attrs = {
            'name': 'prev_memory_output_fullyconnected',
            'out-size': node.gifo_r_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False,
        }

        fc_layer_from_prev_state = FullyConnected(
            graph, fc_layer_from_prev_state_attrs).create_node()
        fc_layer_from_prev_state.in_port(0).connect(
            prev_lstm_output.out_port(0))
        input_as_const(fc_layer_from_prev_state,
                       fc_layer_from_prev_state_attrs, 1, 'weights',
                       node.gifo_r_weights)

        # Memory -> FullyConnected  \
        #                           *Eltwise(sum)
        # Input -> FullyConnected   /
        join_input_prev_state_sum = Add(graph, {
            'name': 'join_input_eltwise'
        }).create_node()
        join_input_prev_state_sum.in_port(0).connect(
            fc_layer_from_prev_state.out_port(0))
        join_input_prev_state_sum.in_port(1).connect(
            fc_layer_after_input.out_port(0))

        # *Eltwise(sum) -> Split
        # it is split into 4 nodes: Act, Eltw*3
        # the following order is mandatory
        #       ___Tanh
        #      /
        # Split ---(2)Eltwise(sum)
        #     |\
        #     | \__(3)Eltwise(sum)
        #     |____(4)Eltwise(sum)
        split_joined_input_axis = Const(graph, {
            'value': np.int64(1)
        }).create_node()
        split_joined_input = Split(graph, {
            'name': 'join_input_split',
            'num_splits': 4,
            'out_ports_count': 4
        }).create_node()
        split_joined_input.in_port(0).connect(
            join_input_prev_state_sum.out_port(0))
        split_joined_input.in_port(1).connect(
            split_joined_input_axis.out_port(0))

        # prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
        #                                 'id': memory_pair_output,
        #                                 'index': 1,
        #                                 'size': 2,
        #                                 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
        #                                 }).create_node()
        init_value_prev_lstm_state = create_zero_value_with_batch_from_input(
            split_joined_input.out_port(0), node.input_gate_weights.shape[0])
        prev_lstm_state = ReadValue(graph, {
            'name': 'prev_memory_state',
            'variable_id': memory_pair_output
        }).create_node()
        prev_lstm_state.in_port(0).connect(
            init_value_prev_lstm_state.out_port(0))

        # *Memory(state) -> *ScaleShift(input)
        state_input_scaleshift_attrs = {
            'name': 'input_scaleshift',
            'bias_term': False
        }
        state_input_scaleshift = ScaleShiftOp(
            graph, state_input_scaleshift_attrs).create_node()
        state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1,
                       'weights', node.input_gate_weights)

        # *Memory(state) -> *ScaleShift(forget)
        state_forget_scaleshift_attrs = {
            'name': 'forget_scaleshift',
            'bias_term': False
        }
        state_forget_scaleshift = ScaleShiftOp(
            graph, state_forget_scaleshift_attrs).create_node()
        state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
        input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs,
                       1, 'weights', node.forget_gate_weights)

        # Split                                 \
        #                                       (2)Eltwise(sum)
        # Memory(state) -> *ScaleShift(input)  /
        join_prev_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_input_eltwise'
            }).create_node()
        join_prev_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(1))
        join_prev_lstm_input_joined_input_sum.in_port(1).connect(
            state_input_scaleshift.out_port(0))
        # Split                                 \
        #                                       (3)Eltwise(sum)
        # Memory(state) -> *ScaleShift(forget)  /
        join_prev_lstm_input_joined_forget_sum = Add(
            graph, {
                'name': 'join_prev_lstm_input_joined_forget_sum',
            }).create_node()
        join_prev_lstm_input_joined_forget_sum.in_port(0).connect(
            split_joined_input.out_port(2))
        join_prev_lstm_input_joined_forget_sum.in_port(1).connect(
            state_forget_scaleshift.out_port(0))

        # Split -> Tanh
        remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
        remember_tahn.in_port(0).connect(split_joined_input.out_port(0))

        # Split -> (2)Eltwise(sum) -> *Sigmoid
        remember_sigmoid = Sigmoid(graph, {
            'name': 'remember_sigmoid'
        }).create_node()
        remember_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_input_sum.out_port(0))

        # Split -> (3)Eltwise(sum) -> **Sigmoid
        forget_sigmoid = Sigmoid(graph, {
            'name': 'forget_sigmoid'
        }).create_node()
        forget_sigmoid.in_port(0).connect(
            join_prev_lstm_input_joined_forget_sum.out_port(0))

        # *Memory(state)                        \
        #                                       (6)Eltwise(mul)
        # Split -> (3)Eltwise(sum) -> **Sigmoid /
        join_forget_prev_state_mul = Mul(graph, {
            'name': 'join_forget_prev_state_mul'
        }).create_node()
        join_forget_prev_state_mul.in_port(0).connect(
            forget_sigmoid.out_port(0))
        join_forget_prev_state_mul.in_port(1).connect(
            prev_lstm_state.out_port(0))

        # Split -> Tahn                         \
        #                                       (5)Eltwise(mul)
        # Split -> (2)Eltwise(sum) -> *Sigmoid   /
        join_remember_candidates_mul = Mul(
            graph, {
                'name': 'join_remember_candidates_mul'
            }).create_node()
        join_remember_candidates_mul.in_port(0).connect(
            remember_tahn.out_port(0))
        join_remember_candidates_mul.in_port(1).connect(
            remember_sigmoid.out_port(0))

        # (5)Eltwise(mul)  \
        #               (7)Eltwise(sum)
        # (6)Eltwise(mul)   /
        join_forget_remember_sum = Add(graph, {
            'name': 'join_forget_remember_sum'
        }).create_node()
        join_forget_remember_sum.in_port(0).connect(
            join_forget_prev_state_mul.out_port(0))
        join_forget_remember_sum.in_port(1).connect(
            join_remember_candidates_mul.out_port(0))

        # (7)Eltwise(sum) -> Clamp
        join_forget_clamp = create_op_with_const_inputs(
            graph, Clamp, {
                1: np.array(-node.clip_value, dtype=np.float32),
                2: np.array(node.clip_value, dtype=np.float32)
            }, {'name': 'join_forget_clamp'}, join_forget_remember_sum)
        #
        # Clamp -> (2)Memory(state)
        next_lstm_state = Assign(graph, {
            'name': 'next_lstm_state',
            'variable_id': memory_pair_output
        }).create_node()
        next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))

        res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
        res_node.in_port(0).connect(next_lstm_state.out_port(0))

        # Clamp -> (2)Tahn
        state_filtered_tahn = Tanh(graph, {
            'name': 'state_filtered_tahn'
        }).create_node()
        state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))

        # Clamp -> (2)ScaleShift
        clamp_scaleshift_attrs = {
            'name': 'clamp_scaleshift',
            'bias_term': False
        }
        clamp_scaleshift = ScaleShiftOp(graph,
                                        clamp_scaleshift_attrs).create_node()
        clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
        input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights',
                       node.output_gate_weights)

        # Split                 \
        #                       (4)Eltwise(sum)
        # Clamp -> (2)ScaleShift /
        join_next_lstm_input_joined_input_sum = Add(
            graph, {
                'name': 'join_next_lstm_input_joined_input_sum',
            }).create_node()
        join_next_lstm_input_joined_input_sum.in_port(0).connect(
            split_joined_input.out_port(3))
        join_next_lstm_input_joined_input_sum.in_port(1).connect(
            clamp_scaleshift.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid
        output_sigmoid = Sigmoid(graph, {
            'name': 'output_sigmoid'
        }).create_node()
        output_sigmoid.in_port(0).connect(
            join_next_lstm_input_joined_input_sum.out_port(0))

        # (4)Eltwise(sum) -> (3)Sigmoid         \
        #                                       (5)Eltwise(mul)
        # Clamp -> (2)Tahn                      /
        joined_output_mul = Mul(graph, {
            'name': 'joined_output_mul'
        }).create_node()
        joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
        joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))

        # (5)Eltwise(mul) -> (3)FullyConnected
        fc_output_attrs = {
            'name': 'FullyConnected',
            'out-size': node.projection_weights_shape[0],
            'transpose_weights': True,
            'bias_term': False
        }
        fc_output = FullyConnected(graph, fc_output_attrs).create_node()
        fc_output.in_port(0).connect(joined_output_mul.out_port(0))
        input_as_const(fc_output, fc_output_attrs, 1, 'weights',
                       node.projection_weights)

        #                   / (2)Memory(output)
        # (3)FullyConnected
        #                   \ Output (any next node) (edge created automatically after replacement)
        next_lstm_output = Assign(graph, {
            'name': 'next_lstm_output',
            'variable_id': memory_pair_input
        }).create_node()
        next_lstm_output.in_port(0).connect(fc_output.out_port(0))

        res_node_lstm_output = Result(graph, {
            'name': 'next_lstm_output_out'
        }).create_node()
        res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))

        return [fc_output.id]
예제 #6
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['op']

        if 1 not in node.in_ports() or node.in_port(1).disconnected():

            if node.has_valid('factor') and not node.has_valid('width') and not node.has_valid('height'):
                factor = Const(graph, {'value': np.array(node.factor)}).create_node()

                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                mul = Mul(graph, {'name': node.name + '/factor_mul_'}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(mul.in_port(0))
                factor.out_port(0).connect(mul.in_port(1))

                node.add_input_port(1, skip_if_exist=True)
                assert node.in_port(1).disconnected()
                mul.out_port(0).connect(node.in_port(1))

            else:
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(0).get_connection().get_source()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))

                pads_value = node.pads_begin + node.pads_end
                pads_const = Const(graph, {'value': np.array(pads_value)}).create_node()
                add = Add(graph, {'name': node.name + '/pad_add'}).create_node()
                ss.out_port(0).connect(add.in_port(0))
                add.in_port(1).connect(pads_const.out_port(0))

                if node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') == 1:
                    shrink_factor = node.shrink_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'name': node.name + '/pre_shrink_sub_const',
                                          'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/pre_shrink_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    sub.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'value': np.array(1 / shrink_factor),
                                          'name': node.name + 'shrink_factor_div_const'}).create_node()
                    div = Mul(graph, {'name': node.name + 'shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    div.in_port(1).connect(const.out_port(0))

                    const = Const(graph, {'name': node.name + '/shrink_factor_add_one_const', 'value': np.array(1)
                                          }).create_node()
                    add = Add(graph, {'name': node.name + '/shrink_factor_add_one'}).create_node()
                    div.out_port(0).connect(add.in_port(0))
                    const.out_port(0).connect(add.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    add.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') == 1 and node.soft_get('zoom_factor') != 1:
                    zoom_factor = node.zoom_factor
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    node['debug_message'] = 'Interpolate layer replacer may be wrong, please, try to update it in the' \
                                            ' file (extensions/front/InterpolateNormalizer.py at the line {}).' \
                                            ''.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100)

                    # Reshape methods can be different in some cases
                    # Commented out section represents reshape that used in deeplab-caffe
                    # Uncomment the following lines, if your model was trained with deeplab-caffe
                    # or have the same reshape method
                    # const = Const(graph, {'value': np.array(-1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_sub_const'}).create_node()
                    # sub = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sub'}).create_node()
                    # add.out_port(0).connect(sub.in_port(0))
                    # const.out_port(0).connect(sub.in_port(1))
                    #
                    # const = Const(graph, {'value': np.array(zoom_factor - 1),
                    #                       'name': node.name + 'zoom_factor_deeplab-caffe_mul_const'}).create_node()
                    # mul = Mul(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_mul'}).create_node()
                    # sub.out_port(0).connect(mul.in_port(0))
                    # const.out_port(0).connect(mul.in_port(1))
                    #
                    # sum = Add(graph, {'name': node.name + 'zoom_factor_deeplab-caffe_sum'}).create_node()
                    # add.out_port(0).connect(sum.in_port(0))
                    # mul.out_port(0).connect(sum.in_port(1))
                    #
                    # node.add_input_port(1, skip_if_exist=True)
                    # assert node.in_port(1).disconnected()
                    # sum.out_port(0).connect(node.in_port(1))

                    # Comment out the following lines if you use the reshape method from previous section
                    const = Const(graph, {'value': np.array(zoom_factor),
                                          'name': node.name + '/zoom_factor_mul_const'}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()

                    add.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    mul.out_port(0).connect(node.in_port(1))

                elif node.soft_get('width') != 0 and node.soft_get('height') != 0:
                    const = Const(graph, {'value': np.array([node.height, node.width])}).create_node()
                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    const.out_port(0).connect(node.in_port(1))

                elif node.soft_get('shrink_factor') != 1 and node.soft_get('zoom_factor') != 1:
                    shrink_factor = node.shrink_factor
                    zoom_factor = node.zoom_factor
                    if shrink_factor < 1:
                        log.error('Shrink factor should be positive in node {}'.format(node.id))
                        return None
                    if zoom_factor < 1:
                        log.error('Zoom factor should be positive in node {}'.format(node.id))
                        return None

                    const = Const(graph, {'value': np.array(-1)}).create_node()
                    sub = Add(graph, {'name': node.name + '/shrink_zoom_factor_sub'}).create_node()
                    add.out_port(0).connect(sub.in_port(0))
                    const.out_port(0).connect(sub.in_port(1))

                    const = Const(graph, {'value': np.array(1 / (shrink_factor + 1))}).create_node()
                    div = Mul(graph, {'name': node.name + '/shrink_factor_div'}).create_node()
                    sub.out_port(0).connect(div.in_port(0))
                    const.out_port(0).connect(div.in_port(1))

                    const = Const(graph, {'value': np.array(-1),
                                          'name': node.name + 'shrink_zoom_factor_sum_const'}).create_node()
                    sum = Add(graph, {'name': node.name + '/shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    const.out_port(0).connect(sum.in_port(1))

                    const = Const(graph, {'value': np.array(zoom_factor - 1)}).create_node()
                    mul = Mul(graph, {'name': node.name + '/zoom_factor_mul'}).create_node()
                    sum.out_port(0).connect(mul.in_port(0))
                    const.out_port(0).connect(mul.in_port(1))

                    sum = Add(graph, {'name': node.name + '/final_shrink_zoom_factor_sum'}).create_node()
                    div.out_port(0).connect(sum.in_port(0))
                    mul.out_port(0).connect(sum.in_port(1))

                    node.add_input_port(1, skip_if_exist=True)
                    assert node.in_port(1).disconnected()
                    sum.out_port(0).connect(node.in_port(1))
        else:
            if node.soft_get('fw') == 'caffe':
                shape = Shape(graph, {'name': node.name + '/shape'}).create_node()

                begin = Const(graph, {'value': np.array([2])}).create_node()
                end = Const(graph, {'value': np.array([4])}).create_node()
                stride = Const(graph, {'value': np.array([1])}).create_node()
                ss = StridedSlice(graph, {'name': node.name + '/ss_0_port', 'begin_mask': np.array([1]),
                                          'end_mask': np.array([0]), 'new_axis_mask': np.array([0]),
                                          'shrink_axis_mask': np.array([0]),
                                          'ellipsis_mask': np.array([0])}).create_node()

                source = node.in_port(1).get_connection().get_source()
                node.in_port(1).disconnect()
                source.connect(shape.in_port(0))
                shape.out_port(0).connect(ss.in_port(0))
                begin.out_port(0).connect(ss.in_port(1))
                end.out_port(0).connect(ss.in_port(2))
                stride.out_port(0).connect(ss.in_port(3))
                ss.out_port(0).connect(node.in_port(1))
예제 #7
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node)
        initial_spatial_dims_node = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'across_channels': 1,
                'normalize_variance': 1,
                'eps': group_norm_node.eps
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_shape_op_node_float = Cast(
            graph, {
                'name':
                initial_shape_op_node.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_shape_op_node.out_port(0).connect(
            initial_shape_op_node_float.in_port(0))

        initial_batch_dim_node = node_to_get_batch_value(
            initial_shape_op_node_float)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node_float)
        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        initial_spatial_dims_node = Cast(
            graph, {
                'name':
                initial_spatial_dims_node_int.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_spatial_dims_node_int.out_port(0).connect(
            initial_spatial_dims_node.in_port(0))

        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node_float = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])
        new_shape_node = Cast(graph, {
            'name': new_shape_node_float.name + '/to_int64',
            'dst_type': np.int64
        }).create_node()
        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'normalize_variance': 1,
                'eps': group_norm_node.eps,
                'eps_mode': 'inside_sqrt'
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # MVN axes
        _, rank = get_shape_and_rank_nodes_by_port(
            mvn_node.in_port(0).get_connection().get_source(),
            return_as_a_scalar=True)
        rng = create_op_with_const_inputs(graph, Range, {
            0: int64_array(1),
            2: int64_array(1)
        }, {
            'name': group_norm_node.name + '/Range',
            'output_type': np.int64
        })
        mvn_node.in_port(1).connect(rng.out_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))
예제 #9
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        # check if we have dropout
        input_port = node.in_port(0)
        if node.has_and_set('use_dropout'):
            split_dropout = AttributedVariadicSplit(
                graph, {
                    'name': node_name + '/split_dropout',
                    'size_splits': int64_array([-1, 1, 1, 1]),
                    'axis': int64_array(1)
                }).create_node()
            input_port.get_connection().set_destination(
                split_dropout.in_port(0))
            input_port = split_dropout.out_port(0)
            i_drop_scale = split_dropout.out_port(1)
            f_drop_scale = split_dropout.out_port(2)
            o_drop_scale = split_dropout.out_port(3)

        # split input to (i_part, f_part, c_part, o_part, ct_1)
        split_node = create_op_with_const_inputs(
            graph, Split, {1: np.int64(1)}, {
                'name': node_name + '/split_lstm_input',
                'num_splits': 5
            })
        input_port.get_connection().set_destination(split_node.in_port(0))

        i_part = split_node.out_port(0)
        f_part = split_node.out_port(1)
        c_part = split_node.out_port(2)
        o_part = split_node.out_port(3)
        ct_1 = split_node.out_port(4)

        # i_t = Sigmoid(i_part + w_ic*ct_1)
        i_scale_attrs = {
            'name': node_name + '/i_scaleshift',
            'bias_term': False
        }
        i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
        input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
        ct_1.connect(i_scale.in_port(0))

        sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node()
        i_part.connect(sum_i_c.in_port(0))
        i_scale.out_port(0).connect(sum_i_c.in_port(1))

        i_sigmoid = Sigmoid(graph, {
            'name': node_name + '/i_sigmoid'
        }).create_node()
        sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

        if node['use_dropout']:
            mul_dropout_i = Mul(graph, {
                'name':
                split_node.soft_get('name', split_node.id) + '/mul_i'
            }).create_node()
            mul_dropout_i.in_port(0).connect(i_sigmoid.out_port(0))
            mul_dropout_i.in_port(1).connect(i_drop_scale)
            i_sigmoid = mul_dropout_i

        # f_t = Sigmoid(f_part + w_fc*ct_1)
        f_scale_attrs = {
            'name': node_name + '/f_scaleshift',
            'bias_term': False
        }
        f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
        input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
        ct_1.connect(f_scale.in_port(0))

        sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node()
        f_part.connect(sum_f_c.in_port(0))
        f_scale.out_port(0).connect(sum_f_c.in_port(1))

        f_sigmoid = Sigmoid(graph, {
            'name': node_name + '/f_sigmoid'
        }).create_node()
        sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

        if node['use_dropout']:
            mul_dropout_f = Mul(graph, {
                'name':
                split_node.soft_get('name', split_node.id) + '/mul_f'
            }).create_node()
            mul_dropout_f.in_port(0).connect(f_sigmoid.out_port(0))
            mul_dropout_f.in_port(1).connect(f_drop_scale)
            f_sigmoid = mul_dropout_f

        # c_t = f_t*ct_1 + i_t * tanh(c_part)
        c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node()
        c_part.connect(c_tanh.in_port(0))

        prod_i_c_tanh = Mul(graph, {
            'name': node_name + '/prod_i_c_tanh_'
        }).create_node()
        i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
        c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

        prod_f_ct_1 = Mul(graph, {
            'name': node_name + '/prod_f_ct_1_'
        }).create_node()
        f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
        ct_1.connect(prod_f_ct_1.in_port(1))

        sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node()
        prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
        prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

        #  o_t = Sigmoid(o_part + w_oc*c_t)
        o_scale_attrs = {
            'name': node_name + '/o_scaleshift',
            'bias_term': False
        }
        o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
        input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
        sum_f_i.out_port(0).connect(o_scale.in_port(0))

        sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node()
        o_part.connect(sum_o_c.in_port(0))
        o_scale.out_port(0).connect(sum_o_c.in_port(1))

        o_sigmoid = Sigmoid(graph, {
            'name': node_name + '/o_sigmoid'
        }).create_node()
        sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

        if node['use_dropout']:
            mul_dropout_o = Mul(graph, {
                'name':
                split_node.soft_get('name', split_node.id) + '/mul_o'
            }).create_node()
            mul_dropout_o.in_port(0).connect(o_sigmoid.out_port(0))
            mul_dropout_o.in_port(1).connect(o_drop_scale)
            o_sigmoid = mul_dropout_o

        # m_t = o_t * Tanh(c_t)
        c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node()
        sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

        prod_o_c_t_tanh = Mul(graph, {
            'name': node_name + '/prod_o_c_t_tanh_'
        }).create_node()
        o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
        c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

        # add concat to create 1 output
        concat = Concat(graph, {
            'name': node_name + '/concat_c_m'
        }).create_node()
        concat.add_sequence_of_ports('in', range(2))
        sum_f_i.out_port(0).connect(concat.in_port(0))
        prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

        return [concat.id]
예제 #10
0
def _fused_batch_norm_decomposition(graph: Graph,
                                    tinput: Port,
                                    toutput: Port,
                                    gamma: Port,
                                    beta: Port,
                                    mean: np.ndarray,
                                    variance: np.ndarray,
                                    can_be_fused=True):
    """
    This is common function for TF, Caffe and MXNet
    It creates Mul->Add->Mul->Add sub graph
    """
    shape = tinput.data.get_shape()
    batch_norm_name = tinput.get_connection().get_destination().node.name

    # Create first Mul & Add operations
    mul1_node = Mul(
        graph, dict(name=batch_norm_name + "/mean",
                    can_be_fused=can_be_fused)).create_node()
    add1_node = Add(
        graph,
        dict(name=batch_norm_name + "/variance",
             can_be_fused=can_be_fused)).create_node()

    const_mul1_node = Const(graph, dict(name="data_mul_",
                                        value=np.array(mean))).create_node()
    const_add1_node = Const(graph,
                            dict(name="data_add_",
                                 value=np.array(variance))).create_node()

    # Broadcast const from scalar
    # We can broadcast only when const.value is scalar
    if gamma.data.get_shape()[0] != gamma.data.get_value().shape[0]:
        value = gamma.data.get_value()
        value.resize(gamma.data.get_shape()).fill(value[0])
        gamma.data.set_value(value)

    # Create second Mul & Add
    mul2_node = Mul(
        graph, dict(name=batch_norm_name + "/gamma",
                    can_be_fused=can_be_fused)).create_node()
    add2_node = Add(
        graph, dict(name=batch_norm_name + "/beta",
                    can_be_fused=can_be_fused)).create_node()

    # Connect edges Mul1->Add1->Mul2->Add2
    tinput.get_connection().set_destination(mul1_node.in_port(0))
    mul1_node.in_port(1).get_connection().set_source(
        const_mul1_node.out_port(0))

    add1_node.in_port(0).get_connection().set_source(mul1_node.out_port(0))
    add1_node.in_port(1).get_connection().set_source(
        const_add1_node.out_port(0))

    mul2_node.in_port(0).get_connection().set_source(add1_node.out_port(0))
    gamma.get_connection().set_destination(mul2_node.in_port(1))

    add2_node.in_port(0).get_connection().set_source(mul2_node.out_port(0))
    beta.get_connection().set_destination(add2_node.in_port(1))

    toutput.get_connection().set_source(add2_node.out_port(0))
예제 #11
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        log.debug('GemmToFullyConnected is triggered')
        gemm = match['gemm']
        # TODO nGraph remove BEGIN
        if not graph.graph['cmd_params'].generate_experimental_IR_V10:
            A = gemm.in_node(0)
            B = gemm.in_node(1)
            B_consumers = graph.out_edges(B.node)
            C = gemm.in_node(2)

            if not (B.value is not None and C.value is not None
                    and A.shape is not None and not gemm.transpose_a and
                    (len(B_consumers) == 1 or not gemm.transpose_b)):
                log.warning('Cannot convert Gemm to FullyConnected')
                return

            if gemm.transpose_b:
                # B.value = B.value.transpose()
                # B.shape = np.array(B.value.shape, dtype=np.int64)
                gemm.transpose_b = 0
            else:
                B.value = B.value.transpose()
                B.shape = np.array(B.value.shape, dtype=np.int64)

            gemm['out-size'] = gemm.out_port(0).data.get_shape()[-1]
            gemm['type'] = 'FullyConnected'
            gemm['channel_dims'] = len(match['output'].shape) - 1
            gemm['bias_addable'] = True
            gemm['input_channel_dim'] = 1  # MatMul weights in IO
            gemm['output_channel_dim'] = 0
            gemm['layout'] = 'NCHW'

            gemm.in_port(1).bin = 'weights'

        else:

            B = gemm.in_node(1)
            assert B.value is not None

            if gemm.transpose_b:
                B.value = B.value.transpose()
                B.shape = np.array(B.value.shape, dtype=np.int64)

        bias_node = Add(graph, {'name': 'MatMulBias_'}).create_node()
        gemm.out_port(0).get_connection().set_source(bias_node.out_port(0))
        gemm.in_port(2).get_connection().set_destination(bias_node.in_port(1))
        gemm.out_port(0).connect(bias_node.in_port(0))
        if graph.graph['cmd_params'].generate_experimental_IR_V10:
            gemm.type = 'MatMul'

        if gemm.has_valid('alpha'):
            if not math.isclose(gemm.alpha, 1):
                mul_node = Mul(graph, {'name': 'MatMulAlpha_'}).create_node()
                const = Const(graph, {
                    'value': np.array(gemm.alpha)
                }).create_node()
                bias_node.in_port(0).get_connection().set_destination(
                    mul_node.in_port(0))
                bias_node.in_port(0).connect(mul_node.out_port(0))
                mul_node.in_port(1).connect(const.out_port(0))
            del gemm['alpha']

        if gemm.has_valid('beta'):
            if not math.isclose(gemm.beta, 1):
                mul_node = Mul(graph, {'name': 'MatMulBeta_'}).create_node()
                const = Const(graph, {
                    'value': np.array(gemm.beta)
                }).create_node()
                bias_node.in_port(1).get_connection().set_destination(
                    mul_node.in_port(0))
                bias_node.in_port(1).connect(mul_node.out_port(0))
                mul_node.in_port(1).connect(const.out_port(0))
            del gemm['beta']

        if not graph.graph['cmd_params'].generate_experimental_IR_V10:
            assign_dims_to_weights(gemm.in_node(1), None, 1, 0, 2)