def _ct_ht_shared_variable_finder(self, context, i): if self.lstm_cell_type == RNNUnitType.LSTMBlockCell: return None lstm_cell = context.cell_match[i] ct = lstm_cell.get_op("ct").output[0] ht = lstm_cell.get_op("ht").output[0] ct_concat = [c for c in self.g.find_output_consumers(ct) if is_tf_concat_op(c)] ht_concat = [c for c in self.g.find_output_consumers(ht) if is_tf_concat_op(c)] if len(ct_concat) != 1 or len(ht_concat) != 1 or ct_concat[0] != ht_concat[0]: logger.debug("failed to find ct-ht concat") return None ct_ht_shared_output = ct_concat[0].output[0] consumers = [] ct_identity_consumer = lstm_cell.get_op("ct_identity_consumer") ht_identity_consumer = lstm_cell.get_op("xh") ct_slice = [c for c in ct_identity_consumer.inputs if is_tf_slice_op(c)] ht_slice = [c for c in ht_identity_consumer.inputs if is_tf_slice_op(c)] if len(ct_slice) != 1 or len(ht_slice) != 1: logger.debug("failed to find slice op before identity consumers") return None consumers.extend([ct_slice[0], ht_slice[0]]) return self._find_state_variable_with_select( context, ct_ht_shared_output, consumers )
def infer_output_shapes_with_partial_inputs(op): # output shape of concat op: only the dim val of concatenated dim will be changed # so only partial(at least one) input shapes need to be known to infer output shape of concat op if utils.is_tf_concat_op(op): data_inputs = op.inputs[:-1] input_shapes = [get_tf_tensor_shape(inp) for inp in data_inputs] input_shapes = [shape for shape in input_shapes if shape is not None] if not input_shapes: logger.debug( "all input shapes of concat op %s are None, can't infer its output shape", op.name) return False new_shape = input_shapes[0] axis_op = op.inputs[-1] rank = len(new_shape) if not utils.is_tf_const_op(axis_op): op.outputs[0].set_shape([-1] * rank) return True axis = get_tf_const_value(axis_op) axis = axis if axis >= 0 else axis + rank new_shape[axis] = -1 if len(input_shapes) == len(data_inputs): # all input shapes are known concat_dim_vals = list(np.array(input_shapes)[:, axis]) # only when inputs' shape are known, then val of concat dim can be calculated if concat_dim_vals.count(-1) == 0: new_shape[axis] = sum(concat_dim_vals) op.outputs[0].set_shape(new_shape) logger.debug("set Concat op [%s] with new shape %s", op.outputs[0].name, new_shape) return True if op.type in ["Select", "SelectV2"]: new_shape = get_tf_tensor_shape(op.inputs[1]) if new_shape is None: new_shape = get_tf_tensor_shape(op.inputs[2]) if new_shape is not None: op.outputs[0].set_shape(new_shape) op.inputs[1].set_shape(new_shape) op.inputs[2].set_shape(new_shape) logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape) return True return False if op.type == "Pack": axis = op.get_attr("axis") input_shape = None for i in op.inputs: s = get_tf_tensor_shape(i) if s is not None: input_shape = s break if input_shape is None: return False if axis < 0: axis += len(input_shape) for i in op.inputs: if not get_tf_tensor_shape(i): i.set_shape(input_shape) logger.debug("set [%s] with new shape %s", i.name, input_shape) new_shape = input_shape[:axis] + [len(op.inputs)] + input_shape[axis:] op.outputs[0].set_shape(new_shape) logger.debug("set Pack op [%s] with new shape %s", op.outputs[0].name, new_shape) return True if op.type == "Pow": # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/pow new_shape = get_tf_tensor_shape(op.inputs[0]) if new_shape is None: new_shape = get_tf_tensor_shape(op.inputs[1]) if new_shape is not None: op.outputs[0].set_shape(new_shape) logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape) return True return False return None