def forward(self, input_dict, state, seq_lens): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" output, new_state = self.forward_rnn( add_time_dimension(input_dict["obs_flat"], seq_lens), state, seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state
def _build(self, input_dict, num_outputs, options): if options.get("use_lstm"): cell_size = options.get("lstm_cell_size") self.state_init = ( np.zeros([cell_size], np.float32), ) self.state_in = ( tf.placeholder(tf.float32, [None, cell_size], name="state_in"), ) else: self.state_init = () self.state_in = () if self._imitation: obs_embed = ssbm_spaces.slippi_conv_list[0].embed(input_dict["obs"]) prev_actions = input_dict["prev_actions"] else: obs_embed = input_dict["obs"] prev_actions = tf.unstack(input_dict["prev_actions"], axis=-1) action_config = nest.flatten(ssbm_actions.repeated_simple_controller_config) prev_actions_embed = tf.concat([ conv.embed(action) for conv, action in zip(action_config, prev_actions)], -1) prev_rewards_embed = tf.expand_dims(input_dict["prev_rewards"], -1) inputs = tf.concat([obs_embed, prev_actions_embed, prev_rewards_embed], -1) trunk = snt.nets.MLP( output_sizes=options["fcnet_hiddens"], activation=getattr(tf.nn, options["fcnet_activation"]), activate_final=True) if not self._imitation: inputs = lstm.add_time_dimension(inputs, self.seq_lens) trunk_outputs = snt.BatchApply(trunk)(inputs) if options.get("use_lstm"): gru = snt.GRU(cell_size) core_outputs, state_out = tf.nn.dynamic_rnn( gru, trunk_outputs, initial_state=self.state_in[0], sequence_length=None if self._imitation else self.seq_lens, time_major=self._imitation) self.state_out = [state_out] else: core_outputs = trunk_outputs self.state_out = [] self._logit_head = snt.Linear(num_outputs, name="logit_head") logits = snt.BatchApply(self._logit_head)(core_outputs) self.values = tf.squeeze(snt.BatchApply(self._value_head)(core_outputs), -1) return logits, core_outputs
def _build_layers_v2(self, input_dict, num_outputs, options): options = options["custom_options"] rnn_units = options["rnn_units"] rnn_output_units = options["rnn_output_units"] mlp_hidden_units = options["mlp_hidden_units"] vf_share_layers = options["vf_share_layers"] use_linear_baseline = options["linear_baseline"] rnn_output_activation = options.get("rnn_output_activation", "tanh") mlp_activation = options.get("mlp_activation", "tanh") with tf.name_scope("rnn_inputs"): rnn_inputs = tf.concat([ input_dict["obs_buffer"], input_dict["logits_buffer"], tf.expand_dims(input_dict["reward_buffer"], axis=1) ], axis=1) with tf.name_scope("add_time_dimension"): rnn_inputs = add_time_dimension(rnn_inputs, self.seq_lens) mlp_inputs = input_dict["obs"] model_inputs = { "rnn": rnn_inputs, "seq_lens": self.seq_lens, "mlp": mlp_inputs } model = KerasTESPWithAdapPolicy( rnn_units, rnn_output_units, rnn_output_activation, mlp_hidden_units + [num_outputs], mlp_activation, vf_share_layers=vf_share_layers, custom_params=self.custom_params["policy"], use_linear_baseline=use_linear_baseline) self.policy_model = model if use_linear_baseline: output = model(model_inputs) else: output, self._value_function = model(model_inputs) # last_layer = model.mlp.layers[-1] # self.rnn_state_in = self.keras_model.initial_state # if not nest.is_sequence(self.rnn_state_in): # self.rnn_state_in = [self.rnn_state_in] self.rnn_state_out = nest.flatten(self.policy_model.state) self.rnn_state_out_init = [ np.zeros(state.shape.as_list(), dtype=state.dtype.as_numpy_dtype) for state in self.rnn_state_out ] return output, None
def call(self, inputs, seqlens=None, initial_state=None): features = inputs['obs'] if self._use_conv: features = self.conv_layer(features) features = add_time_dimension(features, seqlens) self.features = features latent, *rnn_state = self.rnn(features, initial_state=initial_state) self.latent = latent latent = tf.reshape(latent, [-1, latent.shape[-1]]) state_out = list(rnn_state) # latent = self.dense_layer(latent) logits = self.output_layer(latent) output = {'latent': latent, 'logits': logits, 'state_out': state_out} self.output_tensors = output return output
def _build_layers_v2(self, input_dict, num_outputs, options): def spy(sequences, state_in, state_out, seq_lens): if len(sequences) == 1: return 0 # don't capture inference inputs # TF runs this function in an isolated context, so we have to use # redis to communicate back to our suite ray.experimental.internal_kv._internal_kv_put( "rnn_spy_in_{}".format(RNNSpyModel.capture_index), pickle.dumps({ "sequences": sequences, "state_in": state_in, "state_out": state_out, "seq_lens": seq_lens }), overwrite=True) RNNSpyModel.capture_index += 1 return 0 features = input_dict["obs"] cell_size = 3 last_layer = add_time_dimension(features, self.seq_lens) # Setup the LSTM cell lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) self.state_init = [ np.zeros(lstm.state_size.c, np.float32), np.zeros(lstm.state_size.h, np.float32) ] # Setup LSTM inputs if self.state_in: c_in, h_in = self.state_in else: c_in = tf.placeholder(tf.float32, [None, lstm.state_size.c], name="c") h_in = tf.placeholder(tf.float32, [None, lstm.state_size.h], name="h") self.state_in = [c_in, h_in] # Setup LSTM outputs state_in = rnn.LSTMStateTuple(c_in, h_in) lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm, last_layer, initial_state=state_in, sequence_length=self.seq_lens, time_major=False, dtype=tf.float32) self.state_out = list(lstm_state) spy_fn = tf.py_func(spy, [ last_layer, self.state_in, self.state_out, self.seq_lens, ], tf.int64, stateful=True) # Compute outputs with tf.control_dependencies([spy_fn]): last_layer = tf.reshape(lstm_out, [-1, cell_size]) logits = linear(last_layer, num_outputs, "action", normc_initializer(0.01)) return logits, last_layer
def _build_layers_v2(self, input_dict, num_outputs, options): def spy(sequences, state_in, state_out, seq_lens): if len(sequences) == 1: return 0 # don't capture inference inputs # TF runs this function in an isolated context, so we have to use # redis to communicate back to our suite ray.experimental.internal_kv._internal_kv_put( "rnn_spy_in_{}".format(RNNSpyModel.capture_index), pickle.dumps({ "sequences": sequences, "state_in": state_in, "state_out": state_out, "seq_lens": seq_lens }), overwrite=True) RNNSpyModel.capture_index += 1 return 0 features = input_dict["obs"] cell_size = 3 last_layer = add_time_dimension(features, self.seq_lens) # Setup the LSTM cell lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) self.state_init = [ np.zeros(lstm.state_size.c, np.float32), np.zeros(lstm.state_size.h, np.float32) ] # Setup LSTM inputs if self.state_in: c_in, h_in = self.state_in else: c_in = tf.placeholder( tf.float32, [None, lstm.state_size.c], name="c") h_in = tf.placeholder( tf.float32, [None, lstm.state_size.h], name="h") self.state_in = [c_in, h_in] # Setup LSTM outputs state_in = rnn.LSTMStateTuple(c_in, h_in) lstm_out, lstm_state = tf.nn.dynamic_rnn( lstm, last_layer, initial_state=state_in, sequence_length=self.seq_lens, time_major=False, dtype=tf.float32) self.state_out = list(lstm_state) spy_fn = tf.py_func( spy, [ last_layer, self.state_in, self.state_out, self.seq_lens, ], tf.int64, stateful=True) # Compute outputs with tf.control_dependencies([spy_fn]): last_layer = tf.reshape(lstm_out, [-1, cell_size]) logits = linear(last_layer, num_outputs, "action", normc_initializer(0.01)) return logits, last_layer