def infer(node: Node): node_name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() input_value = node.in_port(0).data.get_value() target_shape = node.in_port(1).data.get_value() assert target_shape is not None, 'Output shape is not defined for node "{}"'.format( node_name) assert node.has_and_set( 'mode'), 'Broadcasting mode is not defined for node "{}"'.format( node_name) if node.mode == 'numpy': node.out_port(0).data.set_shape( uni_directional_shape_broadcasting(input_shape, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_shape( bi_directional_shape_broadcasting(input_shape, target_shape)) else: raise Error('The node "{}" has unsupported mode "{}"'.format( node_name, node.mode)) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') if input_value is not None and not node.has_and_set( 'stop_value_propagation'): if node.mode == 'numpy': node.out_port(0).data.set_value( uni_directional_broadcasting(input_value, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_value( bi_directional_broadcasting(input_value, target_shape))
def test_bi_directional_shape_broadcasting(self, input_shape, target_shape, expected_shape): result = bi_directional_shape_broadcasting(input_shape, target_shape) if expected_shape is None: self.assertIsNone(result) else: self.assertTrue(strict_compare_tensors(result, expected_shape))
def infer(node: Node): node_name = node.soft_get('name', node.id) assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 3, \ "Select operation must have 3 inputs: 'condition', 'then' and 'else' tensors for node {}".format(node_name) condition_value = node.in_port(0).data.get_value() resulting_tensors = [node.in_port(1).data.get_value(), node.in_port(2).data.get_value()] a_shape = node.in_port(1).data.get_shape() b_shape = node.in_port(2).data.get_shape() output_shape = bi_directional_shape_broadcasting(a_shape, b_shape) assert output_shape is not None, 'Input shapes for node {} are not broadcast-able'.format(node_name) node.out_port(0).data.set_shape(output_shape) if condition_value is not None: if resulting_tensors[0] is not None: resulting_tensors[0] = bi_directional_broadcasting(resulting_tensors[0], b_shape) if resulting_tensors[1] is not None: resulting_tensors[1] = bi_directional_broadcasting(resulting_tensors[1], a_shape) condition_value = bi_directional_broadcasting(condition_value, output_shape) output_value = np.ma.where(condition_value, resulting_tensors[0], resulting_tensors[1]) if condition_value.size != 1: if np.any(output_value == None): # If any element of output value is None that means that we use the value from the 'then' or the # 'else' tensor which is not defined, this means that we cannot perform value propagation. output_value = None else: output_value = output_value.astype(resulting_tensors[not np.bool(condition_value.item(0))].dtype) if output_value is not None: node.out_port(0).data.set_value(output_value)
def infer(node: Node): node_name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() input_value = node.in_port(0).data.get_value() target_shape = node.in_port(1).data.get_value() assert target_shape is not None, 'Output shape is not defined for node "{}"'.format( node_name) assert node.has_and_set( 'mode'), 'Broadcasting mode is not defined for node "{}"'.format( node_name) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') if input_value is not None and not node.has_and_set( 'stop_value_propagation'): if node.mode == 'numpy': node.out_port(0).data.set_value( uni_directional_broadcasting(input_value, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_value( bi_directional_broadcasting(input_value, target_shape)) elif node.mode == 'explicit': axes_mapping = node.in_port(2).data.get_value() assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \ 'is not supported. Node: `{}`'.format(node_name) PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis') axes_mapping = node.in_port(2).data.get_value() node.out_port(0).data.set_value( explicit_broadcasting(input_value, target_shape, axes_mapping)) else: raise Error('The node "{}" has unsupported mode "{}"'.format( node_name, node.mode)) else: if node.mode == 'numpy': node.out_port(0).data.set_shape( uni_directional_shape_broadcasting(input_shape, target_shape)) elif node.mode == 'bidirectional': node.out_port(0).data.set_shape( bi_directional_shape_broadcasting(input_shape, target_shape)) elif node.mode == 'explicit': axes_mapping = node.in_port(2).data.get_value() assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \ 'is not supported. Node: `{}`'.format(node_name) PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis') axes_mapping = node.in_port(2).data.get_value() new_shape, _ = explicit_shape_broadcasting( input_shape, target_shape, axes_mapping) node.out_port(0).data.set_shape(new_shape) else: raise Error('The node "{}" has unsupported mode "{}"'.format( node_name, node.mode))
def infer(node: Node): node_name = node.soft_get('name', node.id) connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] num_inputs = len(connected_in_ports) assert node.has_valid('equation'), "Einsum node {} must contain `equation` attribute".format(node_name) equation = node.equation # parse the equation and extract input and output subscripts input_subscripts, output_subscript = Einsum.parse_equation(node_name, equation) # check that each operand has the corresponding input subscript assert len(input_subscripts) == num_inputs, "The number of input operands of Einsum node {} " \ "must match the number of input subscripts " \ "in `equation`".format(node_name) # check compatibility of dimension sizes with the same label and generate a dictionary of shapes for labels label_to_shape = {} for input_ind in range(num_inputs): input_shape = node.in_port(input_ind).data.get_shape() input_subscript = input_subscripts[input_ind] labels = Einsum.extract_subscript_labels(node_name, input_subscript) num_dims = len(input_shape) num_labels = len(labels) num_broadcasted_dims = num_dims - num_labels + 1 dim_ind = 0 label_ind = 0 while label_ind < num_labels and dim_ind < num_dims: label = labels[label_ind] if label == "...": sub_shape = input_shape[dim_ind:dim_ind + num_broadcasted_dims] if label in label_to_shape.keys(): common_shape = bi_directional_shape_broadcasting(sub_shape, label_to_shape[label]) assert common_shape is not None, "The dimensions labeled of ellipsis must be broadcastable " \ "for Einsum node {}".format(node_name) label_to_shape[label] = common_shape else: label_to_shape[label] = sub_shape dim_ind += num_broadcasted_dims else: dim_size = input_shape[dim_ind] sub_shape = shape_array([dim_size]) assert label not in label_to_shape.keys() or np.array_equal(label_to_shape[label], sub_shape), \ "Sizes of dimensions with the same label of Einsum node {} " \ "must be compatible".format(node_name) label_to_shape[label] = sub_shape dim_ind += 1 label_ind += 1 # generate output shape based on the output subscript output_shape = shape_array([]) labels = Einsum.extract_subscript_labels(node_name, output_subscript) for label in labels: assert label in label_to_shape.keys(), "The label in the output subscript must appear" \ " in input subscripts in equation {} " \ "of Einsum node {}".format(equation, node_name) output_shape = np.ma.concatenate((output_shape, label_to_shape[label])) node.out_port(0).data.set_shape(output_shape)
def infer(node: Node): node_name = node.soft_get('name', node.id) assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 3, \ "Select operation must have 3 inputs: 'condition', 'then' and 'else' tensors for node {}".format(node_name) condition_value = node.in_port(0).data.get_value() condition_shape = node.in_port(0).data.get_shape() resulting_tensors = [ node.in_port(1).data.get_value(), node.in_port(2).data.get_value() ] a_shape = node.in_port(1).data.get_shape() b_shape = node.in_port(2).data.get_shape() broadcast_rule = node.soft_get('auto_broadcast', 'numpy') if broadcast_rule == 'numpy': msg = "In Select node '{}' condition and then/else shapes must be broadcastable. " \ "But instead got: cond_shape={}, then_shape={}, else_shape={}".format( node_name, condition_shape, a_shape, b_shape) output_shape = bi_directional_shape_broadcasting(a_shape, b_shape) assert output_shape is not None, msg # if Select was created from TF Where operations then 1D condition must have the same size # as 0-index dimension of output_shape. This condition is different from being numpy compatible # but by adding ones to the end we can achieve numpy compatibility, as in transformation SelectBroadcast.py if node.has_valid('format') and node['format'] == 'tf' and len( condition_shape) == 1: # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py#L4596-L4598 msg_tf = "In Select node '{}' if 'condition' is a 1D tensor then it's size " \ "must be matching with the first dimension of then/else branches. " \ "But instead got: cond_shape={}, then_shape={}, else_shape={}".format( node_name, condition_shape, a_shape, b_shape) assert condition_shape[0] == output_shape[0], msg_tf condition_shape = np.concatenate( (condition_shape, np.ones(len(output_shape) - 1))) output_shape = bi_directional_shape_broadcasting( output_shape, condition_shape) assert output_shape is not None, msg elif broadcast_rule == 'pdpd': # todo: add pdpd broadcasting rule # note that additionally to output_shape resulting_tensors must be broadcasted as well raise Error("PDPD broadcasting rule is not implemented yet") else: # broadcasting is not allowed assert compatible_shapes(a_shape, b_shape) and compatible_shapes(condition_shape, a_shape), \ 'In node \'{}\' for Select operation when broadcasting is off all inputs must be of the same shape. ' \ 'But instead got: cond_shape={}, then_shape={}, else_shape={}'.format( node_name, condition_shape, a_shape, b_shape) output_shape = shape_array([ i if i is not dynamic_dimension else j for i, j in zip(a_shape, b_shape) ]) node.out_port(0).data.set_shape(output_shape) if condition_value is not None: if is_fully_defined(condition_value) and np.all( condition_value == condition_value.item(0)): # in some graphs Select condition is always True[False] and # one of the branches is None (which is not selected) # if we use np.where for such cases then dtype of output_value will be object (non numeric type) # and subsequent numpy operation on such tensors will fail output_value = resulting_tensors[not np. bool(condition_value.item(0))] if output_value is None: return if broadcast_rule == 'numpy': output_value = bi_directional_broadcasting( output_value, output_shape) elif broadcast_rule == 'pdpd': # todo: add pdpd broadcasting rule raise Error( "PDPD broadcasting rule is not implemented yet") node.out_port(0).data.set_value(output_value) elif resulting_tensors[0] is not None and resulting_tensors[ 1] is not None: output_value = np.ma.where(condition_value, resulting_tensors[0], resulting_tensors[1]) node.out_port(0).data.set_value(output_value)