def forward(self, input_dict, state, seq_lens): if "internal_state" in input_dict["obs"]: self.internal_states = input_dict["obs"]["internal_state"] pi_obs_inputs = input_dict["obs"][self.pi_obs_key] vf_obs_inputs = input_dict['obs'][self.vf_obs_key] if self.use_lstm: policy_out, self._value_out, *state_out = self.base_model([ add_time_dimension(pi_obs_inputs, seq_lens), add_time_dimension(vf_obs_inputs, seq_lens), seq_lens, *state]) policy_out = tf.reshape(policy_out, [-1, self.num_outputs]) else: policy_out, self._value_out = self.base_model([pi_obs_inputs, vf_obs_inputs]) state_out = state self.unmasked_policy_logits = policy_out if self.mask_invalid_actions: # set policy logits for invalid actions to zero self.valid_actions_masks = input_dict["obs"]["valid_actions_mask"] inf_mask = tf.maximum(tf.log(self.valid_actions_masks), tf.float32.min) self.masked_policy_logits = policy_out + inf_mask else: self.masked_policy_logits = policy_out return self.masked_policy_logits, state_out
def forward(self, input_dict, state, seq_lens): """Adds time dimension to batch before sending inputs to forward_rnn()""" # first we add the time dimension for each object input_dict["obs_vision"] = add_time_dimension(input_dict["obs"][0], seq_lens) input_dict["obs_messages"] = add_time_dimension(input_dict["obs"][1], seq_lens) output, new_state = self.forward_rnn(input_dict, state, seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state
def test_add_time_dimension(self): """Test add_time_dimension gives sequential data along the time dimension""" B, T, F = np.random.choice( np.asarray(list(range(8, 32)), dtype=np.int32), # use int32 for seq_lens size=3, replace=False, ) inputs_numpy = np.repeat(np.arange(B * T)[:, np.newaxis], repeats=F, axis=-1).astype(np.int32) check(inputs_numpy.shape, (B * T, F)) time_shift_diff_batch_major = np.ones(shape=(B, T - 1, F), dtype=np.int32) time_shift_diff_time_major = np.ones(shape=(T - 1, B, F), dtype=np.int32) if tf is not None: # Test tensorflow batch-major padded_inputs = tf.constant(inputs_numpy) batch_major_outputs = add_time_dimension(padded_inputs, max_seq_len=T, framework="tf", time_major=False) check(batch_major_outputs.shape.as_list(), [B, T, F]) time_shift_diff = batch_major_outputs[:, 1:] - batch_major_outputs[:, : -1] check(time_shift_diff, time_shift_diff_batch_major) if torch is not None: # Test torch batch-major padded_inputs = torch.from_numpy(inputs_numpy) batch_major_outputs = add_time_dimension(padded_inputs, max_seq_len=T, framework="torch", time_major=False) check(batch_major_outputs.shape, (B, T, F)) time_shift_diff = batch_major_outputs[:, 1:] - batch_major_outputs[:, : -1] check(time_shift_diff, time_shift_diff_batch_major) # Test torch time-major padded_inputs = torch.from_numpy(inputs_numpy) time_major_outputs = add_time_dimension(padded_inputs, max_seq_len=T, framework="torch", time_major=True) check(time_major_outputs.shape, (T, B, F)) time_shift_diff = time_major_outputs[1:] - time_major_outputs[:-1] check(time_shift_diff, time_shift_diff_time_major)
def forward(self, input_dict, state, seq_lens): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" if self.use_prev_action: output, new_state = self.forward_rnn( add_time_dimension(input_dict["obs"], seq_lens), state, seq_lens, add_time_dimension(input_dict["prev_action"], seq_lens)) else: output, new_state = self.forward_rnn( add_time_dimension(input_dict["obs"], seq_lens), state, seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state
def forward(self, input_dict, state, seq_lens): """Adds time dimension to batch before sending inputs to forward_rnn()""" # first we add the time dimension for each object if isinstance(input_dict["obs"], dict): padded_obs = add_time_dimension(input_dict["obs"]["obs"], seq_lens) else: padded_obs = add_time_dimension(input_dict["obs"], seq_lens) if self.use_prev_action: padded_action = add_time_dimension(input_dict["prev_actions"], seq_lens) padded_obs = tf.concat([padded_obs, padded_action], axis=-1) output, new_state = self.forward_rnn(padded_obs, state, seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" assert seq_lens is not None padded_inputs = input_dict["obs_flat"] max_seq_len = tf.shape(padded_inputs)[0] // tf.shape(seq_lens)[0] output, new_state = self.forward_rnn( add_time_dimension( padded_inputs, max_seq_len=max_seq_len, framework="tf", ), state, seq_lens, ) output = tf.reshape(output, [-1, self.num_outputs]) action_mask = input_dict["obs"]["action_mask"] inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min) output = output + inf_mask return output, new_state
def forward(self, input_dict, state, seq_lens): device = 'cuda' if torch.cuda.is_available() else 'cpu' x = input_dict["obs"]["conv_features"] x = self.shared_layers(x) if type(input_dict["prev_rewards"]) != torch.Tensor: input_dict["prev_rewards"] = torch.tensor( input_dict["prev_rewards"], device=device) last_reward = torch.reshape(input_dict["prev_rewards"], [-1, 1]).float() if type(input_dict["prev_actions"]) != torch.Tensor: prev_actions = np.array(input_dict["prev_actions"], dtype=np.int) else: prev_actions = np.array(input_dict["prev_actions"].cpu().numpy(), dtype=np.int) prev_actions = np.expand_dims(prev_actions, 0) one_hot_prev_actions = torch.cat( [nn.functional.one_hot(torch.tensor(a), 6) for a in prev_actions], axis=-1) x = torch.cat((x, input_dict["obs"]["fc_features"], last_reward, one_hot_prev_actions.float().to(device)), dim=1) output, new_state = self.forward_rnn( add_time_dimension(x.float(), seq_lens, framework="torch"), state, seq_lens) return torch.reshape(output, [-1, self.num_outputs]), new_state
def forward(self, input_dict, state, seq_lens): obs_inputs = input_dict["obs"][self._obs_key] self.valid_actions_masks = input_dict["obs"]["valid_actions_mask"] # self.valid_actions_masks = tf.Print(input_dict["obs"]["valid_actions_mask"], [input_dict["obs"]["valid_actions_mask"]], message="valid_act_mask: ") if self.use_lstm: obs_inputs_time_dist = add_time_dimension(obs_inputs, seq_lens) # obs_inputs_time_dist_check = tf.debugging.check_numerics( # obs_inputs_time_dist, "nan found in obs_inputs_time_dist", name=None # ) # seq_lens = tf.debugging.check_numerics( # seq_lens, "nan found in seq_lens", name=None # ) # state_checks = [] # for i in range(len(state)): # state_checks.append(tf.debugging.check_numerics( # state[i], f"nan found in state[{i}]", name=None # )) # with tf.control_dependencies([obs_inputs_time_dist_check, *state_checks]): base_model_out, *state_out = self._base_model( [obs_inputs_time_dist, seq_lens, *state]) # base_model_out = tf.Print(base_model_out, state_out, # message="state_out: ") return tf.reshape(base_model_out, [-1, *self._base_model_out_shape]), state_out else: base_model_out = self._base_model([obs_inputs]) state_out = state return base_model_out, state_out
def forward( self, input_dict: Dict[str, torch.Tensor], state: List[torch.Tensor], seq_lens: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: # first apply the cnn x = input_dict['obs'].float().permute(0, 3, 1, 2) / 255.0 x = self.cnn(x) # add time x_flat = x.view(x.shape[0], -1) # pylint: disable=too-many-function-args,missing-kwoa x = add_time_dimension(x_flat, seq_lens, "torch") # apply lstm # pylint: disable=no-member x, state_out = self.lstm( x, (torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0))) # pylint: disable=no-member x = torch.reshape(x, [-1, self.lstm_cell_size]) return self._forward_helper(x), [ torch.squeeze(state_out[0], 0), torch.squeeze(state_out[1], 0) ]
def forward(self, input_dict, state, seq_lens): if isinstance(seq_lens, np.ndarray): seq_lens = torch.Tensor(seq_lens).int() output, new_state = self.forward_rnn( add_time_dimension(input_dict["obs"].float(), seq_lens, framework="torch"), input_dict["prev_actions"], state, seq_lens) return torch.reshape(output, [-1, self.num_outputs]), new_state
def _build_layers_v2(self, input_dict, num_outputs, options): # Hard deprecate this class. All Models should use the ModelV2 # API from here on. deprecation_warning("Model->LSTM", "RecurrentNetwork", error=False) cell_size = options.get("lstm_cell_size") if options.get("lstm_use_prev_action_reward"): action_dim = int( np.product( input_dict["prev_actions"].get_shape().as_list()[1:])) features = tf.concat( [ input_dict["obs"], tf.reshape( tf.cast(input_dict["prev_actions"], tf.float32), [-1, action_dim]), tf.reshape(input_dict["prev_rewards"], [-1, 1]), ], axis=1) else: features = input_dict["obs"] last_layer = add_time_dimension(features, self.seq_lens) # Setup the LSTM cell lstm = tf1.nn.rnn_cell.LSTMCell(cell_size, state_is_tuple=True) self.state_init = [ np.zeros(lstm.state_size.c, np.float32), np.zeros(lstm.state_size.h, np.float32) ] # Setup LSTM inputs if self.state_in: c_in, h_in = self.state_in else: c_in = tf1.placeholder( tf.float32, [None, lstm.state_size.c], name="c") h_in = tf1.placeholder( tf.float32, [None, lstm.state_size.h], name="h") self.state_in = [c_in, h_in] # Setup LSTM outputs state_in = tf1.nn.rnn_cell.LSTMStateTuple(c_in, h_in) lstm_out, lstm_state = tf1.nn.dynamic_rnn( lstm, last_layer, initial_state=state_in, sequence_length=self.seq_lens, time_major=False, dtype=tf.float32) self.state_out = list(lstm_state) # Compute outputs last_layer = tf.reshape(lstm_out, [-1, cell_size]) logits = linear(last_layer, num_outputs, "action", normc_initializer(0.01)) return logits, last_layer
def forward(self, input_dict, state, seq_lens): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" output, new_state = self.forward_rnn( add_time_dimension(input_dict["obs_flat"], seq_lens, framework="tf"), state, seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state
def forward(self, input_dict, state, seq_lens): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" if isinstance(seq_lens, np.ndarray): seq_lens = torch.Tensor(seq_lens).int() output, new_state = self.forward_rnn( add_time_dimension( input_dict["obs_flat"].float(), seq_lens, framework="torch"), state, seq_lens) return torch.reshape(output, [-1, self.num_outputs]), new_state
def forward(self, input_dict, state, seq_lens): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" assert seq_lens is not None padded_inputs = input_dict["obs_flat"] max_seq_len = tf.shape(padded_inputs)[0] // tf.shape(seq_lens)[0] output, new_state = self.forward_rnn( add_time_dimension( padded_inputs, max_seq_len=max_seq_len, framework="tf"), state, seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state
def forward(self, input_dict, state, seq_lens): x = input_dict["obs"]["conv_features"] x = self.shared_conv_layers(x) x = torch.cat((x, input_dict["obs"]["fc_features"]), dim=1) x = self.shared_fc_layers(x) output, new_state = self.forward_rnn( add_time_dimension(x.float(), seq_lens, framework="torch"), state, seq_lens ) return torch.reshape(output, [-1, self.num_outputs]), new_state
def forward(self, input_dict, state, seq_lens): """ Evaluate the model. Adds time dimension to batch before sending inputs to forward_rnn() :param input_dict: The input tensors. :param state: The model state. :param seq_lens: LSTM sequence lengths. :return: The policy logits and state. """ trunk = self.encoder_model(input_dict["obs"]["curr_obs"]) new_dict = {"curr_obs": add_time_dimension(trunk, seq_lens)} output, new_state = self.forward_rnn(new_dict, state, seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state
def call( self, input_dict: SampleBatch ) -> (TensorType, List[TensorType], Dict[str, TensorType]): assert input_dict.get(SampleBatch.SEQ_LENS) is not None # Push obs through underlying (wrapped) model first. wrapped_out, _, _ = self.wrapped_keras_model(input_dict) # Concat. prev-action/reward if required. prev_a_r = [] if self.lstm_use_prev_action: prev_a = input_dict[SampleBatch.PREV_ACTIONS] if isinstance(self.action_space, (Discrete, MultiDiscrete)): prev_a = one_hot(prev_a, self.action_space) prev_a_r.append( tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim]) ) if self.lstm_use_prev_reward: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1] ) ) if prev_a_r: wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) max_seq_len = ( tf.shape(wrapped_out)[0] // tf.shape(input_dict[SampleBatch.SEQ_LENS])[0] ) wrapped_out_plus_time_dim = add_time_dimension( wrapped_out, max_seq_len=max_seq_len, framework="tf" ) model_out, value_out, h, c = self._rnn_model( [ wrapped_out_plus_time_dim, input_dict[SampleBatch.SEQ_LENS], input_dict["state_in_0"], input_dict["state_in_1"], ] ) model_out_no_time_dim = tf.reshape( model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0) ) return ( model_out_no_time_dim, [h, c], {SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])}, )
def forward(self, input_dict, state, seq_lens): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" flat_inputs = input_dict["obs_flat"].float() if isinstance(seq_lens, np.ndarray): seq_lens = torch.Tensor(seq_lens).int() max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0] self.time_major = self.model_config.get("_time_major", False) inputs = add_time_dimension( flat_inputs, max_seq_len=max_seq_len, framework="torch", time_major=self.time_major, ) output, new_state = self.forward_rnn(inputs, state, seq_lens) output = torch.reshape(output, [-1, self.num_outputs]) return output, new_state
def forward(self, input_dict, state, seq_lens): flat_inputs = input_dict["obs"]["state"].float() if isinstance(seq_lens, np.ndarray): seq_lens = torch.Tensor(seq_lens).int() max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0] self.time_major = self.model_config.get("_time_major", False) inputs = add_time_dimension( flat_inputs, max_seq_len=max_seq_len, framework="torch", time_major=self.time_major, ) action_logits, new_state = self.forward_rnn(inputs, state, seq_lens) action_logits = torch.reshape(action_logits, [-1, self.num_outputs]) action_mask = input_dict["obs"]["action_mask"] inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX) return action_logits + inf_mask, new_state
def forward(self, input_dict, state, seq_lens): """ First evaluate non-LSTM parts of model. Then add a time dimension to the batch before sending inputs to forward_rnn(), which evaluates the LSTM parts of the model. :param input_dict: The input tensors. :param state: The model state. :param seq_lens: LSTM sequence lengths. :return: The agent's own action logits and the new model state. """ # Evaluate non-lstm layers actor_critic_fc_output, moa_fc_output = self.moa_encoder_model( input_dict["obs"]["curr_obs"]) rnn_input_dict = { "ac_trunk": actor_critic_fc_output, "prev_moa_trunk": state[5], "other_agent_actions": input_dict["obs"]["other_agent_actions"], "visible_agents": input_dict["obs"]["visible_agents"], "prev_actions": input_dict["prev_actions"], } # Add time dimension to rnn inputs for k, v in rnn_input_dict.items(): rnn_input_dict[k] = add_time_dimension(v, seq_lens) output, new_state = self.forward_rnn(rnn_input_dict, state, seq_lens) action_logits = tf.reshape(output, [-1, self.num_outputs]) counterfactuals = tf.reshape( self._counterfactuals, [ -1, self._counterfactuals.shape[-2], self._counterfactuals.shape[-1] ], ) new_state.extend([action_logits, moa_fc_output]) self.compute_influence_reward(input_dict, state[4], counterfactuals) return action_logits, new_state
def _build_layers_v2(self, input_dict, num_outputs, options): # Previously, a new class object was created during # deserialization and this `capture_index` # variable would be refreshed between class instantiations. # This behavior is no longer the case, so we manually refresh # the variable. RNNSpyModel.capture_index = 0 def spy(sequences, state_in, state_out, seq_lens): if len(sequences) == 1: return 0 # don't capture inference inputs # TF runs this function in an isolated context, so we have to use # redis to communicate back to our suite ray.experimental.internal_kv._internal_kv_put( "rnn_spy_in_{}".format(RNNSpyModel.capture_index), pickle.dumps({ "sequences": sequences, "state_in": state_in, "state_out": state_out, "seq_lens": seq_lens }), overwrite=True) RNNSpyModel.capture_index += 1 return 0 features = input_dict["obs"] cell_size = 3 last_layer = add_time_dimension(features, self.seq_lens) # Setup the LSTM cell lstm = tf.nn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) self.state_init = [ np.zeros(lstm.state_size.c, np.float32), np.zeros(lstm.state_size.h, np.float32) ] # Setup LSTM inputs if self.state_in: c_in, h_in = self.state_in else: c_in = tf.placeholder(tf.float32, [None, lstm.state_size.c], name="c") h_in = tf.placeholder(tf.float32, [None, lstm.state_size.h], name="h") self.state_in = [c_in, h_in] # Setup LSTM outputs state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in) lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm, last_layer, initial_state=state_in, sequence_length=self.seq_lens, time_major=False, dtype=tf.float32) self.state_out = list(lstm_state) spy_fn = tf.py_func(spy, [ last_layer, self.state_in, self.state_out, self.seq_lens, ], tf.int64, stateful=True) # Compute outputs with tf.control_dependencies([spy_fn]): last_layer = tf.reshape(lstm_out, [-1, cell_size]) logits = linear(last_layer, num_outputs, "action", normc_initializer(0.01)) return logits, last_layer
def forward(self, input_dict, state, seq_lens): """ Adds time dimension to batch and does forward inference """ prev_actions = tf.cast(input_dict["prev_actions"][:, 0], dtype=tf.int32) prev_rewards = input_dict["prev_rewards"] lstm_state = state[:2] if self.use_receiver_bias: receiver_bias_state = state[2:4] obs_dict = input_dict["obs"] inputs = add_time_dimension(obs_dict["obs"], seq_lens) if self.use_comm: extra_inputs = obs_dict["message"] else: extra_inputs = tf.zeros_like(obs_dict["message"]) outputs = self.rnn_model( [inputs, extra_inputs, prev_actions, prev_rewards, seq_lens] + lstm_state) if self.use_receiver_bias: extra_inputs = tf.zeros_like(obs_dict["message"]) self.no_message_outputs = self.rnn_model( [inputs, extra_inputs, prev_actions, prev_rewards, seq_lens] + receiver_bias_state) if self.use_cpc: ( model_out, self._value_out, h, c, self._cpc_ins, self._cpc_preds, *self._unscaled_message_p, ) = outputs else: model_out, self._value_out, h, c, *self._unscaled_message_p = outputs next_states = [h, c] if self.use_receiver_bias: next_states.extend(self.no_message_outputs[2:4]) if self.use_inference_policy: if self.pm_type == "moving_avg": action_logits = model_out[..., :-self.message_size] unscaled_message_logits = model_out[..., -self.message_size:] avg_message_logits = tf.log( self._avg_message_p) - tf.log(1 - self._avg_message_p) scaled_message_logits = unscaled_message_logits - avg_message_logits model_out = tf.keras.layers.Concatenate()( [action_logits, scaled_message_logits]) elif self.pm_type == "hyper_nn": action_logits = model_out[..., :-self.message_size] unscaled_message_logits = model_out[..., -self.message_size:] scaled_message_logits = unscaled_message_logits - self._pm_logits model_out = tf.keras.layers.Concatenate()( [action_logits, scaled_message_logits]) else: raise NotImplementedError("Wrong type for inference_policy") self._model_out = tf.reshape(model_out, [-1, self.num_outputs]) return self._model_out, next_states