def create_and_check_model( self, config, states, actions, rewards, returns_to_go, timesteps, attention_mask, ): model = DecisionTransformerModel(config=config) model.to(torch_device) model.eval() result = model(states, actions, rewards, returns_to_go, timesteps, attention_mask) self.parent.assertEqual(result.state_preds.shape, states.shape) self.parent.assertEqual(result.action_preds.shape, actions.shape) self.parent.assertEqual(result.return_preds.shape, returns_to_go.shape) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.seq_length * 3, self.hidden_size) ) # seq length *3 as there are 3 modelities: states, returns and actions
def test_model_from_pretrained(self): for model_name in DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[: 1]: model = DecisionTransformerModel.from_pretrained(model_name) self.assertIsNotNone(model)
def test_autoregressive_prediction(self): """ An integration test that performs autoregressive prediction of state, action and return from a sequence of state, actions and returns. Test is performed over two timesteps. """ NUM_STEPS = 2 # number of steps of autoregressive prediction we will perform TARGET_RETURN = 10 # defined by the RL environment, may be normalized model = DecisionTransformerModel.from_pretrained( "edbeeching/decision-transformer-gym-hopper-expert") model = model.to(torch_device) config = model.config torch.manual_seed(0) state = torch.randn(1, 1, config.state_dim).to( device=torch_device, dtype=torch.float32) # env.reset() expected_outputs = torch.tensor( [[0.242793, -0.28693074, 0.8742613], [0.67815274, -0.08101085, -0.12952147]], device=torch_device) returns_to_go = torch.tensor(TARGET_RETURN, device=torch_device, dtype=torch.float32).reshape(1, 1, 1) states = state actions = torch.zeros(1, 0, config.act_dim, device=torch_device, dtype=torch.float32) rewards = torch.zeros(1, 0, device=torch_device, dtype=torch.float32) timesteps = torch.tensor(0, device=torch_device, dtype=torch.long).reshape(1, 1) for step in range(NUM_STEPS): actions = torch.cat([ actions, torch.zeros(1, 1, config.act_dim, device=torch_device) ], dim=1) rewards = torch.cat( [rewards, torch.zeros(1, 1, device=torch_device)], dim=1) attention_mask = torch.ones(1, states.shape[1]).to( dtype=torch.long, device=states.device) with torch.no_grad(): _, action_pred, _ = model( states=states, actions=actions, rewards=rewards, returns_to_go=returns_to_go, timesteps=timesteps, attention_mask=attention_mask, return_dict=False, ) self.assertEqual(action_pred.shape, actions.shape) self.assertTrue( torch.allclose(action_pred[0, -1], expected_outputs[step], atol=1e-4)) state, reward, _, _ = ( # env.step(action) torch.randn(1, 1, config.state_dim).to(device=torch_device, dtype=torch.float32), 1.0, False, {}, ) actions[-1] = action_pred[0, -1] states = torch.cat([states, state], dim=1) pred_return = returns_to_go[0, -1] - reward returns_to_go = torch.cat( [returns_to_go, pred_return.reshape(1, 1, 1)], dim=1) timesteps = torch.cat([ timesteps, torch.ones( (1, 1), device=torch_device, dtype=torch.long) * (step + 1) ], dim=1)
0.05444621, 0.21297139, 0.14530419, 0.6124444, 0.85174465, 1.4515252, 0.6751696, 1.536239, 1.6160746, 5.6072536, ]) state_mean = torch.from_numpy(state_mean).to(device=device) state_std = torch.from_numpy(state_std).to(device=device) # Create the decision transformer model model = DecisionTransformerModel.from_pretrained( "edbeeching/decision-transformer-gym-hopper-medium") model = model.to(device) model.eval() for ep in range(10): episode_return, episode_length = 0, 0 state = env.reset() target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1) states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32) actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32) rewards = torch.zeros(0, device=device, dtype=torch.float32)