Esempio n. 1
0
    def infer(node: Node):
        assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 3, "Select operation must have 3 inputs:" \
                                          " \'condition\', \'then\' and \'else\' tensors"

        condition_value = node.in_port(0).data.get_value()
        resulting_tensors = [
            node.in_port(1).data.get_value(),
            node.in_port(2).data.get_value()
        ]

        a_shape = node.in_port(1).data.get_shape()
        b_shape = node.in_port(2).data.get_shape()
        node.out_port(0).data.set_shape(broadcast_shape(a_shape, b_shape))
        # Case with unknown condition
        if condition_value is not None:
            if condition_value.size != 1:
                output_value = np.where(condition_value, resulting_tensors[0],
                                        resulting_tensors[1])
                if np.any(output_value == None):
                    # If any element of output value is None that means that we use the value from 'then' or 'else' tensor
                    # which is not defined, this means that we cannot perform value propagation.
                    output_value = None
            else:
                output_value = resulting_tensors[not np.
                                                 bool(condition_value.item(0))]

            if output_value is not None:
                node.out_port(0).data.set_value(np.array(output_value))
Esempio n. 2
0
    def infer(node: Node):
        assert len(node.in_nodes()) == 3, "Select operation must have 3 inputs by TensorFlow reference:" \
                                          " \'condition\', \'then\' and \'else\' tensors"
        condition_node = node.in_node(0)
        resulting_tensors = [node.in_node(1), node.in_node(2)]

        a_shape = node.in_port(1).data.get_shape()
        b_shape = node.in_port(2).data.get_shape()
        output_shape = broadcast_shape(a_shape, b_shape)

        # Case with unknown condition
        if not condition_node.has_valid('value'):
            # infer only shapes
            node.out_port(0).data.set_shape(output_shape)
            return

        assert condition_node.value.size == 1
        condition_value = condition_node.value.item(0)

        assert isinstance(condition_value, np.bool), \
            "TensorFlow \'Select\' operation has 3 inputs: \'condition\', \'then\' and \'else\' tensors. " \
            "Value of \'condition\' tensor must be boolen by TensorFlow reference"

        output_value = resulting_tensors[not condition_value].value
        for _, out_node in node.graph.out_edges(node.id):
            node.graph.node[out_node]['shape'] = np.array(output_shape)
            node.graph.node[out_node][
                'value'] = None if output_value is None else np.array(
                    output_value)