def __init__(self, session, action_dist, obs_vectorizer, make_cell): super(RNNCellAC, self).__init__(session, action_dist, obs_vectorizer) obs_seq_shape = (None, None) + obs_vectorizer.out_shape self.obs_ph = tf.placeholder(tf.float32, obs_seq_shape) self.mask_ph = tf.placeholder(tf.float32, (None, None)) with tf.variable_scope('cell_input'): cell_input = self.cell_input_sequences() with tf.variable_scope('cell'): cell = make_cell() with tf.variable_scope('states'): if isinstance(cell.state_size, tuple): self.create_state_fields(tf.float32, tuple(nest.flatten(cell.state_size))) else: self.create_state_fields(tf.float32, cell.state_size) init_state = mix_init_states(self.is_init_state_ph, self.init_state_vars, self.first_state_ph) with tf.variable_scope('base'): if isinstance(init_state, tuple): init_state = nest.pack_sequence_as(cell.state_size, init_state) self.base_out, states = tf.nn.dynamic_rnn( cell, cell_input, sequence_length=self.seq_lens_ph, initial_state=init_state) if isinstance(states, tuple): self.states_out = tuple(nest.flatten(states)) else: self.states_out = states with tf.variable_scope('actor'): self.actor_out = self.actor(self.base_out) with tf.variable_scope('critic'): self.critic_out = self.critic(self.base_out)
def _apply_rnn_encoder_output_layer(output_layer, time_major, hparams, mode, cell_outputs, cell_output_size): map_func = functools.partial( _forward_output_layers, output_layer=output_layer, time_major=time_major, hparams=hparams, mode=mode) cell_outputs_flat = nest.flatten(cell_outputs) cell_output_size_flat = nest.flatten(cell_output_size) o = [map_func(inputs=x, input_size=xs) for x, xs in zip(cell_outputs_flat, cell_output_size_flat)] outputs_flat, output_size_flat = zip(*o) outputs = nest.pack_sequence_as(cell_outputs, outputs_flat) output_size = nest.pack_sequence_as(cell_outputs, output_size_flat) return outputs, output_size
def debatch_timestep(self, ts): """Debatches a single timestep. Returns bs length of timesteps.""" traj_spec = self._traj_spec def f(arr): if arr is None: return arr l = np.split(arr, len(arr)) # remove the leading dimension l = list(map(functools.partial(np.squeeze, axis=0), l)) return l # split along the batch dimension d = nest.map_structure_up_to(traj_spec, f, ts) # determine the batch size lens = [ len(v) for v in filter(lambda k: k is not None, nest.flatten_up_to(traj_spec, d)) ] bs = lens[0] assert all(x == bs for x in lens) # Flatten and replicate by packing the sequence bs times. d = nest.flatten_up_to(traj_spec, d) l = [] for i in range(bs): l.append( nest.pack_sequence_as( traj_spec, list(map(lambda k: k if k is None else k[i], d)))) return l
def split_batch(template, tf_structure): split_flatten = zip(*[ tf.split(t, self.batch_size) for t in nest.flatten_up_to(template, tf_structure) ]) return [ nest.pack_sequence_as(template, flatten) for flatten in split_flatten ]
def combine_flat_list(_structure, _flat_list, axis=1): _combined = [] for i in range(len(_flat_list[0])): t = [] for v in _flat_list: t.append(v[i]) cc = tf.concat(t, axis) _combined.append(cc) return nest.pack_sequence_as(_structure, _combined)
def sgdstore_model(in_seqs): """ Apply an sgdstore model to the sequences. """ controller = tf.contrib.rnn.BasicRNNCell(64) layer = sgdstore.Stack(sgdstore.FC(4, 32), sgdstore.FC(32, 4)) sgdcell = sgdstore.Cell(layer, train_batch=4, query_batch=4) cell = tf.contrib.rnn.MultiRNNCell([controller, sgdcell]) init_state = (controller.zero_state(1, tf.float32), sgdcell.random_state(1, tf.float32)) init_vars = [tf.Variable(x) for x in nest.flatten(init_state)] repeated_vars = [ tf.tile(x, multiples=[BATCH_SIZE] + ([1] * (len(x.get_shape()) - 1))) for x in init_vars ] init_state = nest.pack_sequence_as(cell.state_size, repeated_vars) query_res = tf.nn.dynamic_rnn(cell, in_seqs, initial_state=init_state)[0] return tf.contrib.layers.fully_connected(query_res, 1, activation_fn=None)
def debatch_and_stack(self): """Remove the leading batch dimension and then stack on timestamp. Returns list of stacked timesteps for each batch.""" traj_spec = self._traj_spec def f(arr): if arr is None: return arr l = np.split(arr, len(arr)) # remove the leading dimension l = list(map(functools.partial(np.squeeze, axis=0), l)) return l l = [] for traj in self._trajs: # split along the batch dimension d = nest.map_structure_up_to(traj_spec, f, traj) # determine the batch size lens = [ len(v) for v in filter(lambda k: k is not None, nest.flatten_up_to(traj_spec, d)) ] bs = lens[0] assert all(x == bs for x in lens) # Flatten and replicate by packing the sequence bs times. d = nest.flatten_up_to(traj_spec, d) if not l: l = [[] for _ in range(bs)] for i in range(bs): l[i].append( nest.pack_sequence_as( traj_spec, list(map(lambda k: k if k is None else k[i], d)))) return list( map( functools.partial(Trajectory._stack, traj_spec=self._traj_spec), l))
def axial_reshape(x, ix): with tf.name_scope('axial_reshape', [x, ix]): # rectify ix ix = [([e] if np.isscalar(e) else e) for e in ix] # resolve input shape s = tf_shape(x) s = x.get_shape().as_list() ix_f = nest.flatten(ix) assert (len(s) - 1) == np.max(ix_f) # assert input correctness # transpose if necessary if not np.all(np.diff(ix_f) == 1): x = tf.transpose(x, ix_f) # reshape tm = nest.pack_sequence_as(ix, [s[i] for i in ix_f]) s_out = [reduce(merge_dim, e, 1) for e in tm] x = tf.reshape(x, s_out) return x
def while_fn(*args): current_iteration = args[0] persistent_values = args[1] transient_values = args[2] current_tensor_arrays = args[3] if time_major: input_values = inputs[current_iteration] else: input_values = inputs[:, current_iteration] new_persistent, new_transient = loop_fn(input_values, persistent_values, transient_values) flat_new_persistent = nest.flatten(new_persistent) flat_tensor_arrays = nest.flatten(current_tensor_arrays) flat_written_tensor_arrays = [ ta.write(current_iteration, a) for ta, a in zip(flat_tensor_arrays, flat_new_persistent) ] new_tensor_arrays = nest.pack_sequence_as(current_tensor_arrays, flat_written_tensor_arrays) return current_iteration + 1, new_persistent, new_transient, new_tensor_arrays
def make_structure(self, flat_input): return nest.pack_sequence_as(self.template_spec, flat_input)
def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, enc_cell, dec_cell, num_encoder_symbols, num_decoder_symbols, embedding_size, num_heads=1, output_projection=None, feed_previous=False, dtype=None, scope=None, initial_state_attention=False): """Embedding sequence-to-sequence model with attention. This model first embeds encoder_inputs by a newly created embedding (of shape [num_encoder_symbols x input_size]). Then it runs an RNN to encode embedded encoder_inputs into a state vector. It keeps the outputs of this RNN at every step to use for attention later. Next, it embeds decoder_inputs by another newly created embedding (of shape [num_decoder_symbols x input_size]). Then it runs attention decoder, initialized with the last encoder state, on embedded decoder_inputs and attending to encoder outputs. Warning: when output_projection is None, the size of the attention vectors and variables will be made proportional to num_decoder_symbols, can be large. Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols: Integer; number of symbols on the decoder side. embedding_size: Integer, the length of the embedding vector for each symbol. num_heads: Number of attention heads that read from attention_states. output_projection: None or a pair (W, B) of output projection weights and biases; W has shape [output_size x num_decoder_symbols] and B has shape [num_decoder_symbols]; if provided and feed_previous=True, each fed previous output will first be multiplied by W and added B. feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of decoder_inputs will be used (the "GO" symbol), and all other decoder inputs will be taken from previous outputs (as in embedding_rnn_decoder). If False, decoder_inputs are used as given (the standard decoder case). dtype: The dtype of the initial RNN state (default: tf.float32). scope: VariableScope for the created subgraph; defaults to "embedding_attention_seq2seq". initial_state_attention: If False (default), initial attentions are zero. If True, initialize the attentions from the initial state and attention states. Returns: A tuple of the form (outputs, state), where: outputs: A list of the same length as decoder_inputs of 2D Tensors with shape [batch_size x num_decoder_symbols] containing the generated outputs. state: The state of each decoder cell at the final time-step. It is a 2D Tensor of shape [batch_size x cell.state_size]. """ with tf.variable_scope(scope or "embedding_attention_seq2seq", dtype=dtype) as scope: dtype = scope.dtype # Encoder. encoder_cell = enc_cell encoder_cell = rnn.EmbeddingWrapper( encoder_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) encoder_outputs, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype) # First calculate a concatenation of encoder outputs to put attention on. top_states = [ tf.reshape(e, [-1, 1, encoder_cell.output_size]) for e in encoder_outputs ] attention_states = tf.concat(top_states, 1) # Decoder. output_size = None if output_projection is None: dec_cell = rnn.OutputProjectionWrapper(dec_cell, num_decoder_symbols) output_size = num_decoder_symbols if isinstance(feed_previous, bool): return embedding_attention_decoder( decoder_inputs, encoder_state, attention_states, dec_cell, num_decoder_symbols, embedding_size, num_heads=num_heads, output_size=output_size, output_projection=output_projection, feed_previous=feed_previous, initial_state_attention=initial_state_attention) # If feed_previous is a Tensor, we construct 2 graphs and use cond. def decoder(feed_previous_bool): reuse = None if feed_previous_bool else True with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): outputs, state = embedding_attention_decoder( decoder_inputs, encoder_state, attention_states, dec_cell, num_decoder_symbols, embedding_size, num_heads=num_heads, output_size=output_size, output_projection=output_projection, feed_previous=feed_previous_bool, update_embedding_for_previous=False, initial_state_attention=initial_state_attention) state_list = [state] if nest.is_sequence(state): state_list = nest.flatten(state) return outputs + state_list outputs_and_state = tf.cond(feed_previous, lambda: decoder(True), lambda: decoder(False)) outputs_len = len( decoder_inputs) # Outputs length same as decoder inputs. state_list = outputs_and_state[outputs_len:] state = state_list[0] if nest.is_sequence(encoder_state): state = nest.pack_sequence_as(structure=encoder_state, flat_sequence=state_list) return outputs_and_state[:outputs_len], state
def dynamic_deconv(encoder, embd, embd_T, max_time): batch_size, max_len, encoder_dim = tf.unstack(tf.shape(encoder)) emit_ta = nest.pack_sequence_as(int(embd.shape[0]), [tensor_array_ops.TensorArray(tf.float32, clear_after_read=False, size=0, dynamic_size=True, element_shape=tensor_shape.\ TensorShape([None,len(idx2word)]))]) emit_input = nest.pack_sequence_as(int(1), [tensor_array_ops.TensorArray(tf.int32, clear_after_read=False, size=0, dynamic_size=True, element_shape=tensor_shape.\ TensorShape([None]))]) emit_score = tensor_array_ops.TensorArray(tf.float32, clear_after_read=False, size=0, dynamic_size=True, element_shape=tensor_shape.\ TensorShape([None,None,1])) time = tf.constant(0, dtype=tf.int32) output_time = tf.constant(0, dtype=tf.int32) def initialize(batch_size, time, emit_input): for w in ["SOS"]: idx = tf.reshape(tf.constant(word2idx[w], dtype=tf.int32), [-1]) idx = tf.tile(idx, [batch_size]) emit_input = nest.map_structure(lambda ta, em: ta.write(time, em), emit_input, idx) time += 1 return emit_input, time emit_input, time = initialize(batch_size, time, emit_input) def body(output_time, time, emit_input, emit_ta, emit_score): # inputs_idx = tf.transpose(emit_input.gather([time-1]),[1,0]) inputs_idx = tf.transpose(emit_input.stack(), [1, 0]) print("input idx", inputs_idx) inputs_vec = tf.nn.embedding_lookup(embd, inputs_idx) output_vec, attn_s = deconv(inputs_vec, encoder, max_len, dim=256, reuse_flag=False) output_logits = wordclf(output_vec, 300, embd_T, reuse_flag=False) next_idx = tf.argmax(output_logits, axis=-1, output_type=tf.int32) emit_input = nest.map_structure(lambda ta, em: ta.write(time, em), emit_input, next_idx) time += 1 emit_ta = emit_ta.write(output_time, output_logits) emit_score = emit_score.write(output_time, attn_s) output_time += 1 return output_time, time, emit_input, emit_ta, emit_score def condition(t, *_): return t < max_time _, _, emit_input, emit_ta, emit_score = tf.while_loop( condition, body, loop_vars=[output_time, time, emit_input, emit_ta, emit_score], swap_memory=False) emit_input = tf.transpose(emit_input.stack(), [1, 0])[:, 1::] emit_ta = tf.transpose(emit_ta.stack(), [1, 0, 2]) emit_score = tf.transpose(emit_score.stack(), [1, 0, 2, 3]) emit_score = tf.reshape(emit_score, [batch_size, -1, max_len]) return emit_input, emit_ta, emit_score