def infer_output_shapes_with_partial_inputs(op): # output shape of concat op: only the dim val of concatenated dim will be changed # so only partial(at least one) input shapes need to be known to infer output shape of concat op if utils.is_tf_concat_op(op): data_inputs = op.inputs[:-1] input_shapes = [get_tf_tensor_shape(inp) for inp in data_inputs] input_shapes = [shape for shape in input_shapes if shape is not None] if not input_shapes: logger.debug( "all input shapes of concat op %s are None, can't infer its output shape", op.name) return False new_shape = input_shapes[0] axis_op = op.inputs[-1] rank = len(new_shape) if not utils.is_tf_const_op(axis_op): op.outputs[0].set_shape([-1] * rank) return True axis = get_tf_const_value(axis_op) axis = axis if axis >= 0 else axis + rank new_shape[axis] = -1 if len(input_shapes) == len(data_inputs): # all input shapes are known concat_dim_vals = list(np.array(input_shapes)[:, axis]) # only when inputs' shape are known, then val of concat dim can be calculated if concat_dim_vals.count(-1) == 0: new_shape[axis] = sum(concat_dim_vals) op.outputs[0].set_shape(new_shape) logger.debug("set Concat op [%s] with new shape %s", op.outputs[0].name, new_shape) return True if op.type in ["Select", "SelectV2"]: new_shape = get_tf_tensor_shape(op.inputs[1]) if new_shape is None: new_shape = get_tf_tensor_shape(op.inputs[2]) if new_shape is not None: op.outputs[0].set_shape(new_shape) op.inputs[1].set_shape(new_shape) op.inputs[2].set_shape(new_shape) logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape) return True return False if op.type == "Pack": axis = op.get_attr("axis") input_shape = None for i in op.inputs: s = get_tf_tensor_shape(i) if s is not None: input_shape = s break if input_shape is None: return False if axis < 0: axis += len(input_shape) for i in op.inputs: if not get_tf_tensor_shape(i): i.set_shape(input_shape) logger.debug("set [%s] with new shape %s", i.name, input_shape) new_shape = input_shape[:axis] + [len(op.inputs)] + input_shape[axis:] op.outputs[0].set_shape(new_shape) logger.debug("set Pack op [%s] with new shape %s", op.outputs[0].name, new_shape) return True if op.type == "Pow": # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/pow new_shape = get_tf_tensor_shape(op.inputs[0]) if new_shape is None: new_shape = get_tf_tensor_shape(op.inputs[1]) if new_shape is not None: op.outputs[0].set_shape(new_shape) logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape) return True return False return None
def infer_shape_for_op_legacy(op): # invoke tf shape inference first infer_shape_for_op(op) has_unknown_input_shape = any( get_tf_tensor_shape(inp) is None for inp in op.inputs) has_unknown_output_shape = any( get_tf_tensor_shape(out) is None for out in op.outputs) # an input shape may be inferred from op output or other input shapes # try to infer it first if has_unknown_input_shape: if infer_input_shapes(op): return True if not has_unknown_output_shape: return False # for those ops, we don't expect all input shapes available to infer output shapes. ret = infer_output_shapes_with_partial_inputs(op) if ret is not None: return ret # for ops, we need all input shapes ready to infer output shapes. are_all_input_shape_ready = True no_shape = [] for i in op.inputs: if get_tf_tensor_shape(i) is None: are_all_input_shape_ready = False no_shape.append(i.name) if not are_all_input_shape_ready: logger.debug( "op %s has inputs don't have shape specified, they are: %s", op.name, no_shape) return False if op.type in direct_ops: return set_shape_from_input(op.inputs[0], op.outputs[0]) if op.type in broadcast_ops: return set_shape_from_inputs_broadcast(op.inputs, op.outputs[0]) if op.type == "RandomUniform": shape_op = op.inputs[0].op if not shape_op or shape_op.type != "Shape": return False return set_shape_from_input(shape_op.inputs[0], op.outputs[0]) if op.type == "Gather": # uses the follwing link to know how to infer shape of output # https://www.tensorflow.org/api_docs/python/tf/gather shape_params = get_tf_tensor_shape(op.inputs[0]) shape_indices = get_tf_tensor_shape(op.inputs[1]) # gather can only have 2 inputs # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather.html if len(op.inputs) == 3: axis_op = op.inputs[2].op if not utils.is_tf_const_op(axis_op): return False axis = get_tf_const_value(axis_op) else: axis = 0 shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:] op.outputs[0].set_shape(shape) return True if op.type in ["All", "Any", "Max", "Min"]: axis_op = op.inputs[1].op if not utils.is_tf_const_op(axis_op): return False axis = get_tf_const_value(axis_op) if not isinstance(axis, list): axis = [axis] keep_dims = op.get_attr("keep_dims") shape = get_tf_tensor_shape(op.inputs[0]) for i, _ in enumerate(axis): if axis[i] < 0: axis[i] += len(shape) new_shape = [] for i, _ in enumerate(shape): if i in axis: if keep_dims: new_shape.append(1) else: new_shape.append(shape[i]) op.outputs[0].set_shape(new_shape) logger.debug("set %s op [%s] with new shape %s", op.type, op.outputs[0].name, new_shape) return True if op.type == "ExpandDims": # https://www.tensorflow.org/api_docs/python/tf/expand_dims input_shape = get_tf_tensor_shape(op.inputs[0]) dim_op = op.inputs[1].op if input_shape is None or not utils.is_tf_const_op(dim_op): return False dim = get_tf_const_value(dim_op) if dim < 0: dim = dim + len(input_shape) + 1 new_shape = input_shape[:dim] + [1] + input_shape[dim:] op.outputs[0].set_shape(new_shape) logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape) return True if op.type == "Unpack": input_shape = get_tf_tensor_shape(op.inputs[0]) if input_shape is None: return False axis = op.get_attr("axis") axis = axis if axis >= 0 else axis + len(input_shape) # the link below says that the rank of output is "rank(input) -1", # from this statement "num" must equal to input_shape[axis], and if not tf will throw a runtime error # https://www.tensorflow.org/api_docs/python/tf/unstack new_shape = input_shape[:axis] + input_shape[axis + 1:] for output in op.outputs: output.set_shape(new_shape) logger.debug("set %s op [%s] with new shape %s", op.type, output.name, new_shape) return True if op.type in ["Minimum", "Maximum"]: # ops that are elementwise and support broadcasting input_shapes = [get_tf_tensor_shape(op) for op in op.inputs] new_shape = broadcast_shape_inference(*input_shapes) op.outputs[0].set_shape(new_shape) return True return False