def evaluate(self, tdp: TrainingBatch): """ Calculate state feature sensitivity due to actions: randomly permutating actions and see how much the prediction of next state feature deviates. """ self.trainer.mdnrnn.mdnrnn.eval() batch_size, seq_len, state_dim = tdp.training_input.next_state.size() state_feature_num = self.feature_extractor.state_feature_num feature_sensitivity = torch.zeros(state_feature_num) mdnrnn_input = tdp.training_input mdnrnn_output = self.trainer.mdnrnn(mdnrnn_input) predicted_next_state_means = mdnrnn_output.mus shuffled_mdnrnn_input = MemoryNetworkInput( state=tdp.training_input.state, # shuffle the actions action=FeatureVector(float_features=tdp.training_input.action. float_features[torch.randperm(batch_size)]), next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ) shuffled_mdnrnn_output = self.trainer.mdnrnn(shuffled_mdnrnn_input) shuffled_predicted_next_state_means = shuffled_mdnrnn_output.mus assert (predicted_next_state_means.size() == shuffled_predicted_next_state_means.size() == (batch_size, seq_len, self.trainer.params.num_gaussians, state_dim)) state_feature_boundaries = ( self.feature_extractor.sorted_state_feature_start_indices + [state_dim]) for i in range(state_feature_num): boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) abs_diff = torch.mean( torch.sum( torch.abs( shuffled_predicted_next_state_means[:, :, :, boundary_start: boundary_end] - predicted_next_state_means[:, :, :, boundary_start:boundary_end] ), dim=3, )) feature_sensitivity[i] = abs_diff.cpu().detach().item() self.trainer.mdnrnn.mdnrnn.train() logger.info("**** Debug tool feature sensitivity ****: {}".format( feature_sensitivity)) return {"feature_sensitivity": feature_sensitivity.numpy()}
def test_forward_pass(self): state_dim = 1 action_dim = 2 input = StateInput(state=FeatureVector( float_features=torch.tensor([[2.0]]))) bcq_drop_threshold = 0.20 q_network = FullyConnectedDQN(state_dim, action_dim, sizes=[2], activations=["relu"]) # Set weights of q-network to make it deterministic q_net_layer_0_w = torch.tensor([[1.2], [0.9]]) q_network.state_dict()["fc.layers.0.weight"].data.copy_( q_net_layer_0_w) q_net_layer_0_b = torch.tensor([0.0, 0.0]) q_network.state_dict()["fc.layers.0.bias"].data.copy_(q_net_layer_0_b) q_net_layer_1_w = torch.tensor([[0.5, -0.5], [1.0, 1.0]]) q_network.state_dict()["fc.layers.1.weight"].data.copy_( q_net_layer_1_w) q_net_layer_1_b = torch.tensor([0.0, 0.0]) q_network.state_dict()["fc.layers.1.bias"].data.copy_(q_net_layer_1_b) imitator_network = FullyConnectedNetwork( layers=[state_dim, 2, action_dim], activations=["relu", "linear"]) # Set weights of imitator network to make it deterministic im_net_layer_0_w = torch.tensor([[1.2], [0.9]]) imitator_network.state_dict()["layers.0.weight"].data.copy_( im_net_layer_0_w) im_net_layer_0_b = torch.tensor([0.0, 0.0]) imitator_network.state_dict()["layers.0.bias"].data.copy_( im_net_layer_0_b) im_net_layer_1_w = torch.tensor([[0.5, 1.5], [1.0, 2.0]]) imitator_network.state_dict()["layers.1.weight"].data.copy_( im_net_layer_1_w) im_net_layer_1_b = torch.tensor([0.0, 0.0]) imitator_network.state_dict()["layers.1.bias"].data.copy_( im_net_layer_1_b) imitator_probs = torch.nn.functional.softmax(imitator_network( input.state.float_features), dim=1) bcq_mask = imitator_probs < bcq_drop_threshold assert bcq_mask[0][0] == 1 assert bcq_mask[0][1] == 0 model = BatchConstrainedDQN( state_dim=state_dim, q_network=q_network, imitator_network=imitator_network, bcq_drop_threshold=bcq_drop_threshold, ) final_q_values = model(input) assert final_q_values.q_values[0][0] == -1e10 assert abs(final_q_values.q_values[0][1] - 4.2) < 0.0001
def preprocess(self, batch) -> TrainingBatch: state_features_dense, state_features_dense_presence = self.sparse_to_dense_processor( batch["state_features"]) next_state_features_dense, next_state_features_dense_presence = self.sparse_to_dense_processor( batch["next_state_features"]) mdp_ids = np.array(batch["mdp_id"]).reshape(-1, 1) sequence_numbers = torch.tensor(batch["sequence_number"], dtype=torch.int32).reshape(-1, 1) rewards = torch.tensor(batch["reward"], dtype=torch.float32).reshape(-1, 1) time_diffs = torch.tensor(batch["time_diff"], dtype=torch.int32).reshape(-1, 1) if "action_probability" in batch: propensities = torch.tensor(batch["action_probability"], dtype=torch.float32).reshape(-1, 1) else: propensities = torch.ones(rewards.shape, dtype=torch.float32) return TrainingBatch( training_input=BaseInput( state=FeatureVector(float_features=ValuePresence( value=state_features_dense, presence=state_features_dense_presence, )), next_state=FeatureVector(float_features=ValuePresence( value=next_state_features_dense, presence=next_state_features_dense_presence, )), reward=rewards, time_diff=time_diffs, ), extras=ExtraData( mdp_id=mdp_ids, sequence_number=sequence_numbers, action_probability=propensities, ), )
def input_prototype(self): return StateAction( state=FeatureVector(float_features=torch.randn([1, 4])), action=FeatureVector(float_features=torch.randn([1, 4])), )
def evaluate(self, tdp: TrainingBatch): """ Calculate feature importance: setting each state/action feature to the mean value and observe loss increase. """ self.trainer.mdnrnn.mdnrnn.eval() state_features = tdp.training_input.state.float_features action_features = tdp.training_input.action.float_features # type: ignore batch_size, seq_len, state_dim = state_features.size() # type: ignore action_dim = action_features.size()[2] # type: ignore action_feature_num = self.action_feature_num state_feature_num = self.state_feature_num feature_importance = torch.zeros(action_feature_num + state_feature_num) orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True) orig_loss = orig_losses["loss"].cpu().detach().item() del orig_losses action_feature_boundaries = self.sorted_action_feature_start_indices + [ action_dim ] state_feature_boundaries = self.sorted_state_feature_start_indices + [state_dim] for i in range(action_feature_num): action_features = tdp.training_input.action.float_features.reshape( # type: ignore (batch_size * seq_len, action_dim) ).data.clone() # if actions are discrete, an action's feature importance is the loss # increase due to setting all actions to this action if self.discrete_action: assert action_dim == action_feature_num action_vec = torch.zeros(action_dim) action_vec[i] = 1 action_features[:] = action_vec # type: ignore # if actions are continuous, an action's feature importance is the loss # increase due to masking this action feature to its mean value else: boundary_start, boundary_end = ( action_feature_boundaries[i], action_feature_boundaries[i + 1], ) action_features[ # type: ignore :, boundary_start:boundary_end ] = self.compute_median_feature_value( # type: ignore action_features[:, boundary_start:boundary_end] # type: ignore ) action_features = action_features.reshape( # type: ignore (batch_size, seq_len, action_dim) ) # type: ignore new_tdp = TrainingBatch( training_input=MemoryNetworkInput( # type: ignore state=tdp.training_input.state, action=FeatureVector( # type: ignore float_features=action_features ), # type: ignore next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss( new_tdp, state_dim=state_dim, batch_first=True ) feature_importance[i] = losses["loss"].cpu().detach().item() - orig_loss del losses for i in range(state_feature_num): state_features = tdp.training_input.state.float_features.reshape( # type: ignore (batch_size * seq_len, state_dim) ).data.clone() boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) state_features[ # type: ignore :, boundary_start:boundary_end ] = self.compute_median_feature_value( state_features[:, boundary_start:boundary_end] # type: ignore ) state_features = state_features.reshape( # type: ignore (batch_size, seq_len, state_dim) ) # type: ignore new_tdp = TrainingBatch( training_input=MemoryNetworkInput( # type: ignore state=FeatureVector(float_features=state_features), # type: ignore action=tdp.training_input.action, next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss( new_tdp, state_dim=state_dim, batch_first=True ) feature_importance[i + action_feature_num] = ( losses["loss"].cpu().detach().item() - orig_loss ) del losses self.trainer.mdnrnn.mdnrnn.train() logger.info( "**** Debug tool feature importance ****: {}".format(feature_importance) ) return {"feature_loss_increase": feature_importance.numpy()}
def evaluate(self, tdp: TrainingBatch): """ Calculate feature importance: setting each state/action feature to the mean value and observe loss increase. """ self.trainer.mdnrnn.mdnrnn.eval() state_features = tdp.training_input.state.float_features action_features = tdp.training_input.action.float_features batch_size, seq_len, state_dim = state_features.size() action_dim = action_features.size()[2] action_feature_num = self.feature_extractor.action_feature_num state_feature_num = self.feature_extractor.state_feature_num feature_importance = torch.zeros(action_feature_num + state_feature_num) orig_losses = self.trainer.get_loss(tdp, state_dim=state_dim, batch_first=True) orig_loss = orig_losses["loss"].cpu().detach().item() del orig_losses action_feature_boundaries = ( self.feature_extractor.sorted_action_feature_start_indices + [action_dim]) state_feature_boundaries = ( self.feature_extractor.sorted_state_feature_start_indices + [state_dim]) for i in range(action_feature_num): action_features = tdp.training_input.action.float_features.reshape( (batch_size * seq_len, action_dim)).data.clone() boundary_start, boundary_end = ( action_feature_boundaries[i], action_feature_boundaries[i + 1], ) action_features[:, boundary_start: boundary_end] = action_features[:, boundary_start: boundary_end].mean( dim=0) action_features = action_features.reshape( (batch_size, seq_len, action_dim)) new_tdp = TrainingBatch( training_input=MemoryNetworkInput( state=tdp.training_input.state, action=FeatureVector(float_features=action_features), next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss(new_tdp, state_dim=state_dim, batch_first=True) feature_importance[i] = losses["loss"].cpu().detach().item( ) - orig_loss del losses for i in range(state_feature_num): state_features = tdp.training_input.state.float_features.reshape( (batch_size * seq_len, state_dim)).data.clone() boundary_start, boundary_end = ( state_feature_boundaries[i], state_feature_boundaries[i + 1], ) state_features[:, boundary_start: boundary_end] = state_features[:, boundary_start: boundary_end].mean( dim=0) state_features = state_features.reshape( (batch_size, seq_len, state_dim)) new_tdp = TrainingBatch( training_input=MemoryNetworkInput( state=FeatureVector(float_features=state_features), action=tdp.training_input.action, next_state=tdp.training_input.next_state, reward=tdp.training_input.reward, not_terminal=tdp.training_input.not_terminal, ), extras=ExtraData(), ) losses = self.trainer.get_loss(new_tdp, state_dim=state_dim, batch_first=True) feature_importance[i + action_feature_num] = ( losses["loss"].cpu().detach().item() - orig_loss) del losses self.trainer.mdnrnn.mdnrnn.train() logger.info("**** Debug tool feature importance ****: {}".format( feature_importance)) return {"feature_loss_increase": feature_importance.numpy()}
def preprocess(self, batch) -> TrainingBatch: training_batch = super().preprocess(batch) actions, actions_presence = self.action_sparse_to_dense( batch["action"]) next_actions, next_actions_presence = self.action_sparse_to_dense( batch["next_action"]) max_action_size = max( len(pna) for pna in batch["possible_next_actions"]) pnas_mask = torch.Tensor([ ([1] * len(l) + [0] * (max_action_size - len(l))) for l in batch["possible_next_actions"] ]) flat_pnas: List[Dict[int, float]] = [] for pa in batch["possible_next_actions"]: flat_pnas.extend(pa) for _ in range(max_action_size - len(pa)): flat_pnas.append({}) not_terminal = torch.from_numpy( np.array([len(pna) > 0 for pna in batch["possible_next_actions"] ]).astype(np.float32)).reshape(-1, 1) pnas, pnas_presence = self.action_sparse_to_dense(flat_pnas) base_input = cast(BaseInput, training_batch.training_input) tiled_next_state = torch.repeat_interleave( base_input.next_state.float_features.value, max_action_size, dim=0) tiled_next_state_presence = torch.repeat_interleave( base_input.next_state.float_features.presence, max_action_size, dim=0) pas_mask = torch.Tensor([ ([1] * len(l) + [0] * (max_action_size - len(l))) for l in batch["possible_actions"] ]) flat_pas: List[Dict[int, float]] = [] for pa in batch["possible_actions"]: flat_pas.extend(pa) for _ in range(max_action_size - len(pa)): flat_pas.append({}) pas, pas_presence = self.action_sparse_to_dense(flat_pas) return training_batch._replace(training_input=ParametricDqnInput( state=base_input.state, reward=base_input.reward, time_diff=base_input.time_diff, action=FeatureVector(float_features=ValuePresence( value=actions, presence=actions_presence)), next_action=FeatureVector(float_features=ValuePresence( value=next_actions, presence=next_actions_presence)), not_terminal=not_terminal, next_state=base_input.next_state, tiled_next_state=FeatureVector(float_features=ValuePresence( value=tiled_next_state, presence=tiled_next_state_presence)), possible_actions=FeatureVector(float_features=ValuePresence( value=pas, presence=pas_presence)), possible_actions_mask=pas_mask, possible_next_actions=FeatureVector(float_features=ValuePresence( value=pnas, presence=pnas_presence)), possible_next_actions_mask=pnas_mask, ))