def _get_shape(tensor): tensor_shape = tf.shape(input=tensor) static_tensor_shape = contrib.get_static_value(tensor_shape) if static_tensor_shape is None: return tensor_shape else: return static_tensor_shape
def _used_weight(weights_list): for weight in weights_list: if weight is not None: return contrib.get_static_value(tf.convert_to_tensor(value=weight))
def _static_or_dynamic_batch_size(tensor, batch_axis): """Returns the static or dynamic batch size.""" batch_size = tf.shape(input=tensor)[batch_axis] static_batch_size = contrib.get_static_value(batch_size) return static_batch_size or batch_size