def forward(self, state: rlt.ServingFeatureData): state_with_presence, _, _ = state batch_size, state_dim = state_with_presence[0].size() state_first_step = self.state_preprocessor( state_with_presence[0], state_with_presence[1]).reshape(batch_size, -1) # shape: batch_size, seq_len step_probability = F.softmax(self.step_model(state_first_step), dim=1) # shape: batch_size, seq_len, num_action max_acc_reward = torch.cat( [ get_Q( self.model, state_first_step, self.all_permut[i + 1], ).unsqueeze(1) for i in range(self.seq_len) ], dim=1, ) # shape: batch_size, num_action max_acc_reward_weighted = torch.sum(max_acc_reward * step_probability.unsqueeze(2), dim=1) return max_acc_reward_weighted
def evaluate(self, eval_batch: rlt.MemoryNetworkInput): reward_net_prev_mode = self.reward_net.training self.reward_net.eval() loss = self.trainer.get_loss(eval_batch) detached_loss = loss.cpu().detach().item() # shape: batch_size, action_dim q_values_all_action_all_data = get_Q(self.trainer.seq2reward_network, eval_batch, self.trainer.all_permut).cpu() q_values = q_values_all_action_all_data.mean(0).tolist() action_distribution = torch.bincount( torch.argmax(q_values_all_action_all_data, dim=1), minlength=len(self.trainer.params.action_names), ) # normalize action_distribution = (action_distribution.float() / torch.sum(action_distribution)).tolist() # pyre-fixme[16]: `Seq2RewardEvaluator` has no attribute # `notify_observers`. self.notify_observers( mse_loss=loss, q_values=[q_values], action_distribution=[action_distribution], ) self.reward_net.train(reward_net_prev_mode) return (detached_loss, q_values, action_distribution)
def test_get_Q(self): NUM_ACTION = 2 MULTI_STEPS = 3 BATCH_SIZE = 2 STATE_DIM = 4 all_permut = gen_permutations(MULTI_STEPS, NUM_ACTION) seq2reward_network = FakeSeq2RewardNetwork() batch = rlt.MemoryNetworkInput( state=rlt.FeatureData( float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM) ), next_state=rlt.FeatureData( float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, STATE_DIM) ), action=rlt.FeatureData( float_features=torch.zeros(MULTI_STEPS, BATCH_SIZE, NUM_ACTION) ), reward=torch.zeros(1), time_diff=torch.zeros(1), step=torch.zeros(1), not_terminal=torch.zeros(1), ) q_values = get_Q(seq2reward_network, batch, all_permut) expected_q_values = torch.tensor([[11.0, 111.0], [11.0, 111.0]]) logger.info(f"q_values: {q_values}") assert torch.all(expected_q_values == q_values)
def get_loss(self, training_batch: rlt.MemoryNetworkInput): compress_model_output = self.compress_model_network( training_batch.state.float_features[0] ) target = get_Q(self.seq2reward_network, training_batch, self.all_permut) assert ( compress_model_output.size() == target.size() ), f"{compress_model_output.size()}!={target.size()}" mse = F.mse_loss(compress_model_output, target) return mse
def test_get_Q(self): NUM_ACTION = 2 MULTI_STEPS = 3 BATCH_SIZE = 2 STATE_DIM = 4 all_permut = gen_permutations(MULTI_STEPS, NUM_ACTION) seq2reward_network = FakeSeq2RewardNetwork() state = torch.zeros(BATCH_SIZE, STATE_DIM) q_values = get_Q(seq2reward_network, state, all_permut) expected_q_values = torch.tensor([[11.0, 111.0], [11.0, 111.0]]) logger.info(f"q_values: {q_values}") assert torch.all(expected_q_values == q_values)
def get_loss(self, training_batch: rlt.MemoryNetworkInput): # shape: batch_size, num_action compress_model_output = self.compress_model_network( training_batch.state.float_features[0]) target = get_Q(self.seq2reward_network, training_batch, self.all_permut) assert (compress_model_output.size() == target.size() ), f"{compress_model_output.size()}!={target.size()}" mse = F.mse_loss(compress_model_output, target) with torch.no_grad(): # pyre-fixme[16]: `Tuple` has no attribute `indices`. target_action = torch.max(target, dim=1).indices model_action = torch.max(compress_model_output, dim=1).indices accuracy = torch.mean((target_action == model_action).float()) return mse, accuracy
def train_and_eval_seq2reward_model(training_data, eval_data, learning_rate=0.01, num_epochs=5): SEQ_LEN, batch_size, NUM_ACTION = training_data[0].action.shape assert SEQ_LEN == 6 and NUM_ACTION == 2 seq2reward_network = Seq2RewardNetwork( state_dim=NUM_ACTION, action_dim=NUM_ACTION, num_hiddens=64, num_hidden_layers=2, ) trainer_param = Seq2RewardTrainerParameters( learning_rate=0.01, multi_steps=SEQ_LEN, action_names=["0", "1"], batch_size=batch_size, gamma=1.0, view_q_value=True, ) trainer = Seq2RewardTrainer(seq2reward_network=seq2reward_network, params=trainer_param) for _ in range(num_epochs): for batch in training_data: trainer.train(batch) total_eval_mse_loss = 0 for batch in eval_data: mse_loss, _ = trainer.get_loss(batch) total_eval_mse_loss += mse_loss.cpu().detach().item() eval_mse_loss = total_eval_mse_loss / len(eval_data) initial_state = torch.Tensor([[0, 0]]) q_values = torch.squeeze( get_Q( trainer.seq2reward_network, initial_state, trainer.all_permut, )) return eval_mse_loss, q_values
def get_loss(self, batch: rlt.MemoryNetworkInput): state_first_step = CompressModelTrainer.extract_state_first_step(batch) # shape: batch_size, num_action compress_model_output = self.compress_model_network(state_first_step) target = get_Q( self.seq2reward_network, state_first_step.float_features, self.all_permut, ) assert (compress_model_output.size() == target.size() ), f"{compress_model_output.size()}!={target.size()}" mse = F.mse_loss(compress_model_output, target) with torch.no_grad(): target_action = torch.max(target, dim=1).indices model_action = torch.max(compress_model_output, dim=1).indices accuracy = torch.mean((target_action == model_action).float()) return mse, accuracy
def forward(self, state: rlt.ServingFeatureData): """ This serving module only takes in current state. We need to simulate all multi-step length action seq's then predict accumulated reward on all those seq's. After that, we categorize all action seq's by their first actions. Then take the maximum reward as the predicted categorical reward for that category. Return: categorical reward for the first action """ state_with_presence, _, _ = state batch_size, state_dim = state_with_presence[0].size() state_first_step = self.state_preprocessor( state_with_presence[0], state_with_presence[1]).reshape(batch_size, -1) # shape: batch_size, num_action max_acc_reward = get_Q( self.model, state_first_step, self.all_permut, ) return max_acc_reward
def eval_seq2reward_model(eval_data, seq2reward_trainer): SEQ_LEN, batch_size, NUM_ACTION = next(iter(eval_data)).action.shape initial_state = torch.Tensor([[0, 0]]) initial_state_q_values = torch.squeeze( get_Q( seq2reward_trainer.seq2reward_network, initial_state, seq2reward_trainer.all_permut, ) ) total_mse_loss = 0 total_q_values = torch.zeros(NUM_ACTION) total_action_distribution = torch.zeros(NUM_ACTION) for idx, batch in enumerate(eval_data): ( mse_loss, _, q_values, action_distribution, ) = seq2reward_trainer.validation_step(batch, idx) total_mse_loss += mse_loss total_q_values += torch.tensor(q_values) total_action_distribution += torch.tensor(action_distribution) N_eval = len(eval_data) eval_mse_loss = total_mse_loss / N_eval eval_q_values = total_q_values / N_eval eval_action_distribution = total_action_distribution / N_eval return ( initial_state_q_values, eval_mse_loss, eval_q_values, eval_action_distribution, )
def evaluate(self, eval_batch: MemoryNetworkInput): prev_mode = self.compress_model_network.training self.compress_model_network.eval() mse, acc = self.trainer.get_loss(eval_batch) detached_loss = mse.cpu().detach().item() acc = acc.item() # shape: batch_size, action_dim q_values_all_action_all_data = get_Q( self.trainer.seq2reward_network, eval_batch, self.trainer.all_permut ).cpu() q_values = q_values_all_action_all_data.mean(0).tolist() action_distribution = torch.bincount( torch.argmax(q_values_all_action_all_data, dim=1), minlength=len(self.trainer.params.action_names), ) # normalize action_distribution = ( action_distribution.float() / torch.sum(action_distribution) ).tolist() self.compress_model_network.train(prev_mode) return (detached_loss, q_values, action_distribution, acc)
def evaluate(self, eval_batch: MemoryNetworkInput): prev_mode = self.compress_model_network.training self.compress_model_network.eval() mse, acc = self.trainer.get_loss(eval_batch) detached_loss = mse.cpu().detach().item() acc = acc.item() state_first_step = eval_batch.state.float_features[0] # shape: batch_size, action_dim q_values_all_action_all_data = get_Q( self.trainer.seq2reward_network, state_first_step, self.trainer.all_permut, ).cpu() q_values = q_values_all_action_all_data.mean(0).tolist() action_distribution = torch.bincount( torch.argmax(q_values_all_action_all_data, dim=1), minlength=len(self.trainer.params.action_names), ) # normalize action_distribution = (action_distribution.float() / torch.sum(action_distribution)).tolist() self.compress_model_network.train(prev_mode) # pyre-fixme[16]: `CompressModelEvaluator` has no attribute # `notify_observers`. self.notify_observers( mse_loss=detached_loss, q_values=[q_values], action_distribution=[action_distribution], accuracy=acc, ) return (detached_loss, q_values, action_distribution, acc)