def output_size(self): if getattr(self.cells[-1], 'output_size', None) is not None: return self.cells[-1].output_size elif rnn_utils.is_multiple_state(self.cells[-1].state_size): return self.cells[-1].state_size[0] else: return self.cells[-1].state_size
def build(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] def get_batch_input_shape(batch_size, dim): shape = tf.TensorShape(dim).as_list() return tuple([batch_size] + shape) for cell in self.cells: if isinstance(cell, base_layer.Layer) and not cell.built: with backend.name_scope(cell.name): cell.build(input_shape) cell.built = True if getattr(cell, 'output_size', None) is not None: output_dim = cell.output_size elif rnn_utils.is_multiple_state(cell.state_size): output_dim = cell.state_size[0] else: output_dim = cell.state_size batch_size = tf.nest.flatten(input_shape)[0] if tf.nest.is_nested(output_dim): input_shape = tf.nest.map_structure( functools.partial(get_batch_input_shape, batch_size), output_dim) input_shape = tuple(input_shape) else: input_shape = tuple([batch_size] + tf.TensorShape(output_dim).as_list()) self.built = True
def compute_output_shape(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] # Check whether the input shape contains any nested shapes. It could be # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy # inputs. try: input_shape = tf.TensorShape(input_shape) except (ValueError, TypeError): # A nested tensor input input_shape = tf.nest.flatten(input_shape)[0] batch = input_shape[0] time_step = input_shape[1] if self.time_major: batch, time_step = time_step, batch if rnn_utils.is_multiple_state(self.cell.state_size): state_size = self.cell.state_size else: state_size = [self.cell.state_size] def _get_output_shape(flat_output_size): output_dim = tf.TensorShape(flat_output_size).as_list() if self.return_sequences: if self.time_major: output_shape = tf.TensorShape([time_step, batch] + output_dim) else: output_shape = tf.TensorShape([batch, time_step] + output_dim) else: output_shape = tf.TensorShape([batch] + output_dim) return output_shape if getattr(self.cell, 'output_size', None) is not None: # cell.output_size could be nested structure. output_shape = tf.nest.flatten( tf.nest.map_structure(_get_output_shape, self.cell.output_size)) output_shape = output_shape[0] if len( output_shape) == 1 else output_shape else: # Note that state_size[0] could be a tensor_shape or int. output_shape = _get_output_shape(state_size[0]) if self.return_state: def _get_state_shape(flat_state): state_shape = [batch] + tf.TensorShape(flat_state).as_list() return tf.TensorShape(state_shape) state_shape = tf.nest.map_structure(_get_state_shape, state_size) return generic_utils.to_list(output_shape) + tf.nest.flatten( state_shape) else: return output_shape
def build(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] # The input_shape here could be a nest structure. # do the tensor_shape to shapes here. The input could be single tensor, or a # nested structure of tensors. def get_input_spec(shape): """Convert input shape to InputSpec.""" if isinstance(shape, tf.TensorShape): input_spec_shape = shape.as_list() else: input_spec_shape = list(shape) batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) if not self.stateful: input_spec_shape[batch_index] = None input_spec_shape[time_step_index] = None return InputSpec(shape=tuple(input_spec_shape)) def get_step_input_shape(shape): if isinstance(shape, tf.TensorShape): shape = tuple(shape.as_list()) # remove the timestep from the input_shape return shape[1:] if self.time_major else (shape[0], ) + shape[2:] def get_state_spec(shape): state_spec_shape = tf.TensorShape(shape).as_list() # append batch dim state_spec_shape = [None] + state_spec_shape return InputSpec(shape=tuple(state_spec_shape)) # Check whether the input shape contains any nested shapes. It could be # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy # inputs. try: input_shape = tf.TensorShape(input_shape) except (ValueError, TypeError): # A nested tensor input pass if not tf.nest.is_nested(input_shape): # This indicates the there is only one input. if self.input_spec is not None: self.input_spec[0] = get_input_spec(input_shape) else: self.input_spec = [get_input_spec(input_shape)] step_input_shape = get_step_input_shape(input_shape) else: if self.input_spec is not None: self.input_spec[0] = tf.nest.map_structure( get_input_spec, input_shape) else: self.input_spec = generic_utils.to_list( tf.nest.map_structure(get_input_spec, input_shape)) step_input_shape = tf.nest.map_structure(get_step_input_shape, input_shape) # allow cell (if layer) to build before we set or validate state_spec. if isinstance(self.cell, base_layer.Layer) and not self.cell.built: with backend.name_scope(self.cell.name): self.cell.build(step_input_shape) self.cell.built = True # set or validate state_spec if rnn_utils.is_multiple_state(self.cell.state_size): state_size = list(self.cell.state_size) else: state_size = [self.cell.state_size] if self.state_spec is not None: # initial_state was passed in call, check compatibility self._validate_state_spec(state_size, self.state_spec) else: if tf.nest.is_nested(state_size): self.state_spec = tf.nest.map_structure( get_state_spec, state_size) else: self.state_spec = [ InputSpec(shape=[None] + tf.TensorShape(dim).as_list()) for dim in state_size ] # ensure the generated state_spec is correct. self._validate_state_spec(state_size, self.state_spec) if self.stateful: self.reset_states() self.built = True