def infer(node: Node): node_name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() input_value = node.in_port(0).data.get_value() target_shape = node.in_port(1).data.get_value() assert target_shape is not None, 'Output shape is not defined for node "{}"'.format( node_name) assert node.has_and_set( 'mode'), 'Broadcasting mode is not defined for node "{}"'.format( node_name) if node.mode == 'numpy': node.out_port(0).data.set_shape( uni_directional_shape_broadcasting(input_shape, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_shape( bi_directional_shape_broadcasting(input_shape, target_shape)) else: raise Error('The node "{}" has unsupported mode "{}"'.format( node_name, node.mode)) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') if input_value is not None and not node.has_and_set( 'stop_value_propagation'): if node.mode == 'numpy': node.out_port(0).data.set_value( uni_directional_broadcasting(input_value, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_value( bi_directional_broadcasting(input_value, target_shape))
def test_uni_directional_shape_broadcasting(self, input_shape, target_shape, expected_shape): result = uni_directional_shape_broadcasting(input_shape, target_shape) if expected_shape is None: self.assertIsNone(result) else: self.assertTrue(strict_compare_tensors(result, expected_shape))
def infer(node: Node): node_name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() input_value = node.in_port(0).data.get_value() target_shape = node.in_port(1).data.get_value() assert target_shape is not None, 'Output shape is not defined for node "{}"'.format( node_name) assert node.has_and_set( 'mode'), 'Broadcasting mode is not defined for node "{}"'.format( node_name) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') if input_value is not None and not node.has_and_set( 'stop_value_propagation'): if node.mode == 'numpy': node.out_port(0).data.set_value( uni_directional_broadcasting(input_value, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_value( bi_directional_broadcasting(input_value, target_shape)) elif node.mode == 'explicit': axes_mapping = node.in_port(2).data.get_value() assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \ 'is not supported. Node: `{}`'.format(node_name) PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis') axes_mapping = node.in_port(2).data.get_value() node.out_port(0).data.set_value( explicit_broadcasting(input_value, target_shape, axes_mapping)) else: raise Error('The node "{}" has unsupported mode "{}"'.format( node_name, node.mode)) else: if node.mode == 'numpy': node.out_port(0).data.set_shape( uni_directional_shape_broadcasting(input_shape, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_shape( bi_directional_shape_broadcasting(input_shape, target_shape)) elif node.mode == 'explicit': axes_mapping = node.in_port(2).data.get_value() assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \ 'is not supported. Node: `{}`'.format(node_name) PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis') axes_mapping = node.in_port(2).data.get_value() new_shape, _ = explicit_shape_broadcasting( input_shape, target_shape, axes_mapping) node.out_port(0).data.set_shape(new_shape) else: raise Error('The node "{}" has unsupported mode "{}"'.format( node_name, node.mode))
def test_uni_directional_broadcasting(self, input_shape, target_shape, expected_shape): self.assertTrue( np.array_equal( uni_directional_shape_broadcasting(input_shape, target_shape), expected_shape)) input_value = np.array(np.random.rand(*input_shape)) if expected_shape is not None: expected_value = np.broadcast_to(input_value, int64_array(target_shape)) self.assertTrue( np.array_equal( uni_directional_broadcasting(input_value, int64_array(target_shape)), expected_value)) else: with self.assertRaisesRegex( Exception, '.*cannot be uni-directionally broadcasted.*'): uni_directional_broadcasting(input_value, int64_array(target_shape))