def create_and_check_model(
        self,
        config,
        states,
        actions,
        rewards,
        returns_to_go,
        timesteps,
        attention_mask,
    ):
        model = DecisionTransformerModel(config=config)
        model.to(torch_device)
        model.eval()
        result = model(states, actions, rewards, returns_to_go, timesteps,
                       attention_mask)

        self.parent.assertEqual(result.state_preds.shape, states.shape)
        self.parent.assertEqual(result.action_preds.shape, actions.shape)
        self.parent.assertEqual(result.return_preds.shape, returns_to_go.shape)
        self.parent.assertEqual(
            result.last_hidden_state.shape,
            (self.batch_size, self.seq_length * 3, self.hidden_size)
        )  # seq length *3 as there are 3 modelities: states, returns and actions
 def test_model_from_pretrained(self):
     for model_name in DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:
                                                                          1]:
         model = DecisionTransformerModel.from_pretrained(model_name)
         self.assertIsNotNone(model)
    def test_autoregressive_prediction(self):
        """
        An integration test that performs autoregressive prediction of state, action and return
        from a sequence of state, actions and returns. Test is performed over two timesteps.

        """

        NUM_STEPS = 2  # number of steps of autoregressive prediction we will perform
        TARGET_RETURN = 10  # defined by the RL environment, may be normalized
        model = DecisionTransformerModel.from_pretrained(
            "edbeeching/decision-transformer-gym-hopper-expert")
        model = model.to(torch_device)
        config = model.config
        torch.manual_seed(0)
        state = torch.randn(1, 1, config.state_dim).to(
            device=torch_device, dtype=torch.float32)  # env.reset()

        expected_outputs = torch.tensor(
            [[0.242793, -0.28693074, 0.8742613],
             [0.67815274, -0.08101085, -0.12952147]],
            device=torch_device)

        returns_to_go = torch.tensor(TARGET_RETURN,
                                     device=torch_device,
                                     dtype=torch.float32).reshape(1, 1, 1)
        states = state
        actions = torch.zeros(1,
                              0,
                              config.act_dim,
                              device=torch_device,
                              dtype=torch.float32)
        rewards = torch.zeros(1, 0, device=torch_device, dtype=torch.float32)
        timesteps = torch.tensor(0, device=torch_device,
                                 dtype=torch.long).reshape(1, 1)

        for step in range(NUM_STEPS):
            actions = torch.cat([
                actions,
                torch.zeros(1, 1, config.act_dim, device=torch_device)
            ],
                                dim=1)
            rewards = torch.cat(
                [rewards, torch.zeros(1, 1, device=torch_device)], dim=1)

            attention_mask = torch.ones(1, states.shape[1]).to(
                dtype=torch.long, device=states.device)

            with torch.no_grad():
                _, action_pred, _ = model(
                    states=states,
                    actions=actions,
                    rewards=rewards,
                    returns_to_go=returns_to_go,
                    timesteps=timesteps,
                    attention_mask=attention_mask,
                    return_dict=False,
                )

            self.assertEqual(action_pred.shape, actions.shape)
            self.assertTrue(
                torch.allclose(action_pred[0, -1],
                               expected_outputs[step],
                               atol=1e-4))
            state, reward, _, _ = (  # env.step(action)
                torch.randn(1, 1, config.state_dim).to(device=torch_device,
                                                       dtype=torch.float32),
                1.0,
                False,
                {},
            )

            actions[-1] = action_pred[0, -1]
            states = torch.cat([states, state], dim=1)
            pred_return = returns_to_go[0, -1] - reward
            returns_to_go = torch.cat(
                [returns_to_go, pred_return.reshape(1, 1, 1)], dim=1)
            timesteps = torch.cat([
                timesteps,
                torch.ones(
                    (1, 1), device=torch_device, dtype=torch.long) * (step + 1)
            ],
                                  dim=1)
    0.05444621,
    0.21297139,
    0.14530419,
    0.6124444,
    0.85174465,
    1.4515252,
    0.6751696,
    1.536239,
    1.6160746,
    5.6072536,
])
state_mean = torch.from_numpy(state_mean).to(device=device)
state_std = torch.from_numpy(state_std).to(device=device)

# Create the decision transformer model
model = DecisionTransformerModel.from_pretrained(
    "edbeeching/decision-transformer-gym-hopper-medium")
model = model.to(device)
model.eval()

for ep in range(10):
    episode_return, episode_length = 0, 0
    state = env.reset()
    target_return = torch.tensor(TARGET_RETURN,
                                 device=device,
                                 dtype=torch.float32).reshape(1, 1)
    states = torch.from_numpy(state).reshape(1,
                                             state_dim).to(device=device,
                                                           dtype=torch.float32)
    actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)