예제 #1
0
파일: linpos.py 프로젝트: veds12/genrl
    def __init__(self, bandit: DataBasedBandit, **kwargs):
        super(LinearPosteriorAgent, self).__init__(bandit, kwargs.get("device", "cpu"))

        self.init_pulls = kwargs.get("init_pulls", 3)
        self.lambda_prior = kwargs.get("lambda_prior", 0.25)
        self.a0 = kwargs.get("a0", 6.0)
        self.b0 = kwargs.get("b0", 6.0)
        self.mu = torch.zeros(
            size=(self.n_actions, self.context_dim + 1),
            device=self.device,
            dtype=torch.float,
        )
        self.cov = torch.stack(
            [
                (1.0 / self.lambda_prior)
                * torch.eye(self.context_dim + 1, device=self.device, dtype=torch.float)
                for _ in range(self.n_actions)
            ]
        )
        self.inv_cov = torch.stack(
            [
                self.lambda_prior
                * torch.eye(self.context_dim + 1, device=self.device, dtype=torch.float)
                for _ in range(self.n_actions)
            ]
        )
        self.a = self.a0 * torch.ones(
            self.n_actions, device=self.device, dtype=torch.float
        )
        self.b = self.b0 * torch.ones(
            self.n_actions, device=self.device, dtype=torch.float
        )
        self.db = TransitionDB(self.device)
        self.t = 0
        self.update_count = 0
예제 #2
0
 def __init__(self, bandit: DataBasedBandit, **kwargs):
     super(NeuralLinearPosteriorAgent, self).__init__(
         bandit, kwargs.get("device", "cpu")
     )
     self.init_pulls = kwargs.get("init_pulls", 3)
     self.lambda_prior = kwargs.get("lambda_prior", 0.25)
     self.a0 = kwargs.get("a0", 6.0)
     self.b0 = kwargs.get("b0", 6.0)
     hidden_dims = kwargs.get("hidden_dims", [50, 50])
     self.latent_dim = hidden_dims[-1]
     self.nn_update_ratio = kwargs.get("nn_update_ratio", 2)
     self.model = (
         NeuralBanditModel(
             context_dim=self.context_dim,
             hidden_dims=kwargs.get("hidden_dims", [50, 50]),
             n_actions=self.n_actions,
             init_lr=kwargs.get("init_lr", 0.1),
             max_grad_norm=kwargs.get("max_grad_norm", 0.5),
             lr_decay=kwargs.get("lr_decay", 0.5),
             lr_reset=kwargs.get("lr_reset", True),
             dropout_p=kwargs.get("dropout_p", None),
         )
         .to(torch.float)
         .to(self.device)
     )
     self.eval_with_dropout = kwargs.get("eval_with_dropout", False)
     self.mu = torch.zeros(
         size=(self.n_actions, self.latent_dim + 1),
         device=self.device,
         dtype=torch.float,
     )
     self.cov = torch.stack(
         [
             (1.0 / self.lambda_prior)
             * torch.eye(self.latent_dim + 1, device=self.device, dtype=torch.float)
             for _ in range(self.n_actions)
         ]
     )
     self.inv_cov = torch.stack(
         [
             self.lambda_prior
             * torch.eye(self.latent_dim + 1, device=self.device, dtype=torch.float)
             for _ in range(self.n_actions)
         ]
     )
     self.a = self.a0 * torch.ones(
         self.n_actions, device=self.device, dtype=torch.float
     )
     self.b = self.b0 * torch.ones(
         self.n_actions, device=self.device, dtype=torch.float
     )
     self.db = TransitionDB(self.device)
     self.latent_db = TransitionDB()
     self.t = 0
     self.update_count = 0
예제 #3
0
 def __init__(self, bandit: DataBasedBandit, **kwargs):
     super(NeuralGreedyAgent, self).__init__(bandit,
                                             kwargs.get("device", "cpu"))
     self.init_pulls = kwargs.get("init_pulls", 3)
     self.model = (NeuralBanditModel(
         context_dim=self.context_dim,
         hidden_dims=kwargs.get("hidden_dims", [50, 50]),
         n_actions=self.n_actions,
         init_lr=kwargs.get("init_lr", 0.1),
         max_grad_norm=kwargs.get("max_grad_norm", 0.5),
         lr_decay=kwargs.get("lr_decay", 0.5),
         lr_reset=kwargs.get("lr_reset", True),
         dropout_p=kwargs.get("dropout_p", None),
     ).to(torch.float).to(self.device))
     self.eval_with_dropout = kwargs.get("eval_with_dropout", False)
     self.epsilon = kwargs.get("epsilon", 0.0)
     self.db = TransitionDB(self.device)
     self.t = 0
     self.update_count = 0
예제 #4
0
 def __init__(self, bandit: DataBasedBandit, **kwargs):
     super(BootstrapNeuralAgent, self).__init__(bandit,
                                                kwargs.get("device", "cpu"))
     self.init_pulls = kwargs.get("init_pulls", 3)
     self.n = kwargs.get("n", 10)
     self.add_prob = kwargs.get("add_prob", 0.95)
     self.models = [
         NeuralBanditModel(
             context_dim=self.context_dim,
             hidden_dims=kwargs.get("hidden_dims", [50, 50]),
             n_actions=self.n_actions,
             init_lr=kwargs.get("init_lr", 0.1),
             max_grad_norm=kwargs.get("max_grad_norm", 0.5),
             lr_decay=kwargs.get("lr_decay", 0.5),
             lr_reset=kwargs.get("lr_reset", True),
             dropout_p=kwargs.get("dropout_p", None),
         ).to(torch.float).to(self.device) for _ in range(self.n)
     ]
     self.eval_with_dropout = kwargs.get("eval_with_dropout", False)
     self.dbs = [TransitionDB(self.device) for _ in range(self.n)]
     self.t = 0
     self.update_count = 0
예제 #5
0
class NeuralLinearPosteriorAgent(DCBAgent):
    """Deep contextual bandit agent using bayesian regression on for posterior inference

    A neural network is used to transform context vector to a latent represntation on
    which bayesian regression is performed.

    Args:
        bandit (DataBasedBandit): The bandit to solve
        init_pulls (int, optional): Number of times to select each action initially.
            Defaults to 3.
        hidden_dims (List[int], optional): Dimensions of hidden layers of network.
            Defaults to [50, 50].
        init_lr (float, optional): Initial learning rate. Defaults to 0.1.
        lr_decay (float, optional): Decay rate for learning rate. Defaults to 0.5.
        lr_reset (bool, optional): Whether to reset learning rate ever train interval.
            Defaults to True.
        max_grad_norm (float, optional): Maximum norm of gradients for gradient clipping.
            Defaults to 0.5.
        dropout_p (Optional[float], optional): Probability for dropout. Defaults to None
            which implies dropout is not to be used.
        eval_with_dropout (bool, optional): Whether or not to use dropout at inference.
            Defaults to False.
        nn_update_ratio (int, optional): . Defaults to 2.
        lambda_prior (float, optional): Guassian prior for linear model. Defaults to 0.25.
        a0 (float, optional): Inverse gamma prior for noise. Defaults to 3.0.
        b0 (float, optional): Inverse gamma prior for noise. Defaults to 3.0.
        device (str): Device to use for tensor operations.
            "cpu" for cpu or "cuda" for cuda. Defaults to "cpu".
    """

    def __init__(self, bandit: DataBasedBandit, **kwargs):
        super(NeuralLinearPosteriorAgent, self).__init__(
            bandit, kwargs.get("device", "cpu")
        )
        self.init_pulls = kwargs.get("init_pulls", 3)
        self.lambda_prior = kwargs.get("lambda_prior", 0.25)
        self.a0 = kwargs.get("a0", 6.0)
        self.b0 = kwargs.get("b0", 6.0)
        hidden_dims = kwargs.get("hidden_dims", [50, 50])
        self.latent_dim = hidden_dims[-1]
        self.nn_update_ratio = kwargs.get("nn_update_ratio", 2)
        self.model = (
            NeuralBanditModel(
                context_dim=self.context_dim,
                hidden_dims=kwargs.get("hidden_dims", [50, 50]),
                n_actions=self.n_actions,
                init_lr=kwargs.get("init_lr", 0.1),
                max_grad_norm=kwargs.get("max_grad_norm", 0.5),
                lr_decay=kwargs.get("lr_decay", 0.5),
                lr_reset=kwargs.get("lr_reset", True),
                dropout_p=kwargs.get("dropout_p", None),
            )
            .to(torch.float)
            .to(self.device)
        )
        self.eval_with_dropout = kwargs.get("eval_with_dropout", False)
        self.mu = torch.zeros(
            size=(self.n_actions, self.latent_dim + 1),
            device=self.device,
            dtype=torch.float,
        )
        self.cov = torch.stack(
            [
                (1.0 / self.lambda_prior)
                * torch.eye(self.latent_dim + 1, device=self.device, dtype=torch.float)
                for _ in range(self.n_actions)
            ]
        )
        self.inv_cov = torch.stack(
            [
                self.lambda_prior
                * torch.eye(self.latent_dim + 1, device=self.device, dtype=torch.float)
                for _ in range(self.n_actions)
            ]
        )
        self.a = self.a0 * torch.ones(
            self.n_actions, device=self.device, dtype=torch.float
        )
        self.b = self.b0 * torch.ones(
            self.n_actions, device=self.device, dtype=torch.float
        )
        self.db = TransitionDB(self.device)
        self.latent_db = TransitionDB()
        self.t = 0
        self.update_count = 0

    def select_action(self, context: torch.Tensor) -> int:
        """Select an action based on given context.

        Selects an action by computing a forward pass through network to output
        a representation of the context on which bayesian linear regression is
        performed to select an action.

        Args:
            context (torch.Tensor): The context vector to select action for.

        Returns:
            int: The action to take.
        """
        self.model.use_dropout = self.eval_with_dropout
        self.t += 1
        if self.t < self.n_actions * self.init_pulls:
            return torch.tensor(
                self.t % self.n_actions, device=self.device, dtype=torch.int
            )
        var = torch.tensor(
            [self.b[i] * invgamma.rvs(self.a[i]) for i in range(self.n_actions)],
            device=self.device,
            dtype=torch.float,
        )
        try:
            beta = (
                torch.tensor(
                    np.stack(
                        [
                            np.random.multivariate_normal(
                                self.mu[i], var[i] * self.cov[i]
                            )
                            for i in range(self.n_actions)
                        ]
                    )
                )
                .to(self.device)
                .to(torch.float)
            )
        except np.linalg.LinAlgError as e:  # noqa F841
            beta = (
                (
                    torch.stack(
                        [
                            torch.distributions.MultivariateNormal(
                                torch.zeros(self.context_dim + 1),
                                torch.eye(self.context_dim + 1),
                            ).sample()
                            for i in range(self.n_actions)
                        ]
                    )
                )
                .to(self.device)
                .to(torch.float)
            )
        results = self.model(context)
        latent_context = results["x"]
        values = torch.mv(beta, torch.cat([latent_context.squeeze(0), torch.ones(1)]))
        action = torch.argmax(values).to(torch.int)
        return action

    def update_db(self, context: torch.Tensor, action: int, reward: int):
        """Updates transition database with given transition

        Updates latent context and predicted rewards seperately.

        Args:
            context (torch.Tensor): Context recieved
            action (int): Action taken
            reward (int): Reward recieved
        """
        self.db.add(context, action, reward)
        results = self.model(context)
        self.latent_db.add(results["x"].detach(), action, reward)

    def update_params(self, action: int, batch_size: int = 512, train_epochs: int = 20):
        """Update parameters of the agent.

        Trains neural network and updates bayesian regression parameters.

        Args:
            action (int): Action to update the parameters for.
            batch_size (int, optional): Size of batch to update parameters with.
                Defaults to 512
            train_epochs (int, optional): Epochs to train neural network for.
                Defaults to 20
        """
        self.update_count += 1

        if self.update_count % self.nn_update_ratio == 0:
            self.model.train_model(self.db, train_epochs, batch_size)

        z, y = self.latent_db.get_data_for_action(action, batch_size)
        z = torch.cat([z, torch.ones(z.shape[0], 1)], dim=1)
        inv_cov = torch.mm(z.T, z) + self.lambda_prior * torch.eye(self.latent_dim + 1)
        cov = torch.inverse(inv_cov)
        mu = torch.mm(cov, torch.mm(z.T, y))
        a = self.a0 + self.t / 2
        b = self.b0 + (torch.mm(y.T, y) - torch.mm(mu.T, torch.mm(inv_cov, mu))) / 2
        self.mu[action] = mu.squeeze(1)
        self.cov[action] = cov
        self.inv_cov[action] = inv_cov
        self.a[action] = a
        self.b[action] = b
예제 #6
0
파일: linpos.py 프로젝트: veds12/genrl
class LinearPosteriorAgent(DCBAgent):
    """Deep contextual bandit agent using bayesian regression for posterior inference.

    Args:
        bandit (DataBasedBandit): The bandit to solve
        init_pulls (int, optional): Number of times to select each action initially.
            Defaults to 3.
        lambda_prior (float, optional): Guassian prior for linear model. Defaults to 0.25.
        a0 (float, optional): Inverse gamma prior for noise. Defaults to 6.0.
        b0 (float, optional): Inverse gamma prior for noise. Defaults to 6.0.
        device (str): Device to use for tensor operations.
            "cpu" for cpu or "cuda" for cuda. Defaults to "cpu".
    """

    def __init__(self, bandit: DataBasedBandit, **kwargs):
        super(LinearPosteriorAgent, self).__init__(bandit, kwargs.get("device", "cpu"))

        self.init_pulls = kwargs.get("init_pulls", 3)
        self.lambda_prior = kwargs.get("lambda_prior", 0.25)
        self.a0 = kwargs.get("a0", 6.0)
        self.b0 = kwargs.get("b0", 6.0)
        self.mu = torch.zeros(
            size=(self.n_actions, self.context_dim + 1),
            device=self.device,
            dtype=torch.float,
        )
        self.cov = torch.stack(
            [
                (1.0 / self.lambda_prior)
                * torch.eye(self.context_dim + 1, device=self.device, dtype=torch.float)
                for _ in range(self.n_actions)
            ]
        )
        self.inv_cov = torch.stack(
            [
                self.lambda_prior
                * torch.eye(self.context_dim + 1, device=self.device, dtype=torch.float)
                for _ in range(self.n_actions)
            ]
        )
        self.a = self.a0 * torch.ones(
            self.n_actions, device=self.device, dtype=torch.float
        )
        self.b = self.b0 * torch.ones(
            self.n_actions, device=self.device, dtype=torch.float
        )
        self.db = TransitionDB(self.device)
        self.t = 0
        self.update_count = 0

    def select_action(self, context: torch.Tensor) -> int:
        """Select an action based on given context.

        Selecting action with highest predicted reward computed through
        betas sampled from posterior.

        Args:
            context (torch.Tensor): The context vector to select action for.

        Returns:
            int: The action to take.
        """
        self.t += 1
        if self.t < self.n_actions * self.init_pulls:
            return torch.tensor(
                self.t % self.n_actions, device=self.device, dtype=torch.int
            )
        var = torch.tensor(
            [self.b[i] * invgamma.rvs(self.a[i]) for i in range(self.n_actions)],
            device=self.device,
            dtype=torch.float,
        )
        try:
            beta = (
                torch.tensor(
                    np.stack(
                        [
                            np.random.multivariate_normal(
                                self.mu[i], var[i] * self.cov[i]
                            )
                            for i in range(self.n_actions)
                        ]
                    )
                )
                .to(self.device)
                .to(torch.float)
            )
        except np.linalg.LinAlgError as e:  # noqa F841
            beta = (
                (
                    torch.stack(
                        [
                            torch.distributions.MultivariateNormal(
                                torch.zeros(self.context_dim + 1),
                                torch.eye(self.context_dim + 1),
                            ).sample()
                            for i in range(self.n_actions)
                        ]
                    )
                )
                .to(self.device)
                .to(torch.float)
            )
        values = torch.mv(beta, torch.cat([context.view(-1), torch.ones(1)]))
        action = torch.argmax(values).to(torch.int)
        return action

    def update_db(self, context: torch.Tensor, action: int, reward: int):
        """Updates transition database with given transition

        Args:
            context (torch.Tensor): Context recieved
            action (int): Action taken
            reward (int): Reward recieved
        """
        self.db.add(context, action, reward)

    def update_params(
        self, action: int, batch_size: int = 512, train_epochs: Optional[int] = None
    ):
        """Update parameters of the agent.

        Updated the posterior over beta though bayesian regression.

        Args:
            action (int): Action to update the parameters for.
            batch_size (int, optional): Size of batch to update parameters with.
                Defaults to 512
            train_epochs (Optional[int], optional): Epochs to train neural network for.
                Not applicable in this agent. Defaults to None
        """
        self.update_count += 1

        x, y = self.db.get_data_for_action(action, batch_size)
        x = torch.cat([x, torch.ones(x.shape[0], 1)], dim=1)
        inv_cov = torch.mm(x.T, x) + self.lambda_prior * torch.eye(self.context_dim + 1)
        cov = torch.pinverse(inv_cov)
        mu = torch.mm(cov, torch.mm(x.T, y))
        a = self.a0 + self.t / 2
        b = self.b0 + (torch.mm(y.T, y) - torch.mm(mu.T, torch.mm(inv_cov, mu))) / 2
        self.mu[action] = mu.squeeze(1)
        self.cov[action] = cov
        self.inv_cov[action] = inv_cov
        self.a[action] = a
        self.b[action] = b
예제 #7
0
class NeuralNoiseSamplingAgent(DCBAgent):
    """Deep contextual bandit agent with noise sampling for neural network parameters.

    Args:
        bandit (DataBasedBandit): The bandit to solve
        init_pulls (int, optional): Number of times to select each action initially.
            Defaults to 3.
        hidden_dims (List[int], optional): Dimensions of hidden layers of network.
            Defaults to [50, 50].
        init_lr (float, optional): Initial learning rate. Defaults to 0.1.
        lr_decay (float, optional): Decay rate for learning rate. Defaults to 0.5.
        lr_reset (bool, optional): Whether to reset learning rate ever train interval.
            Defaults to True.
        max_grad_norm (float, optional): Maximum norm of gradients for gradient clipping.
            Defaults to 0.5.
        dropout_p (Optional[float], optional): Probability for dropout. Defaults to None
            which implies dropout is not to be used.
        eval_with_dropout (bool, optional): Whether or not to use dropout at inference.
            Defaults to False.
        noise_std_dev (float, optional): Standard deviation of sampled noise.
            Defaults to 0.05.
        eps (float, optional): Small constant for bounding KL divergece of noise.
            Defaults to 0.1.
        noise_update_batch_size (int, optional): Batch size for updating noise parameters.
            Defaults to 256.
        device (str): Device to use for tensor operations.
            "cpu" for cpu or "cuda" for cuda. Defaults to "cpu".
    """
    def __init__(self, bandit: DataBasedBandit, **kwargs):
        super(NeuralNoiseSamplingAgent,
              self).__init__(bandit, kwargs.get("device", "cpu"))
        self.init_pulls = kwargs.get("init_pulls", 3)
        self.model = (NeuralBanditModel(
            context_dim=self.context_dim,
            hidden_dims=kwargs.get("hidden_dims", [50, 50]),
            n_actions=self.n_actions,
            init_lr=kwargs.get("init_lr", 0.1),
            max_grad_norm=kwargs.get("max_grad_norm", 0.5),
            lr_decay=kwargs.get("lr_decay", 0.5),
            lr_reset=kwargs.get("lr_reset", True),
            dropout_p=kwargs.get("dropout_p", None),
        ).to(torch.float).to(self.device))
        self.eval_with_dropout = kwargs.get("eval_with_dropout", False)
        self.noise_std_dev = kwargs.get("noise_std_dev", 0.05)
        self.eps = kwargs.get("eps", 0.1)
        self.db = TransitionDB(self.device)
        self.noise_update_batch_size = kwargs.get("noise_update_batch_size",
                                                  256)
        self.t = 0
        self.update_count = 0

    def select_action(self, context: torch.Tensor) -> int:
        """Select an action based on given context.

        Selects an action by adding noise to neural network paramters and
        the computing forward with the context vector as input.

        Args:
            context (torch.Tensor): The context vector to select action for.

        Returns:
            int: The action to take
        """
        self.model.use_dropout = self.eval_with_dropout
        self.t += 1
        if self.t < self.n_actions * self.init_pulls:
            return torch.tensor(self.t % self.n_actions,
                                device=self.device,
                                dtype=torch.int)

        _, predicted_rewards = self._noisy_pred(context)
        action = torch.argmax(predicted_rewards).to(torch.int)
        return action

    def update_db(self, context: torch.Tensor, action: int, reward: int):
        """Updates transition database with given transition

        Args:
            context (torch.Tensor): Context recieved
            action (int): Action taken
            reward (int): Reward recieved
        """
        self.db.add(context, action, reward)

    def update_params(
        self,
        action: Optional[int] = None,
        batch_size: int = 512,
        train_epochs: int = 20,
    ):
        """Update parameters of the agent.

        Trains each neural network in the ensemble.

        Args:
            action (Optional[int], optional): Action to update the parameters for.
                Not applicable in this agent. Defaults to None.
            batch_size (int, optional): Size of batch to update parameters with.
                Defaults to 512
            train_epochs (int, optional): Epochs to train neural network for.
                Defaults to 20
        """
        self.update_count += 1
        self.model.train_model(self.db, train_epochs, batch_size)
        self._update_noise()

    def _noisy_pred(self, context: torch.Tensor) -> torch.Tensor:
        noise = []
        with torch.no_grad():
            for p in self.model.parameters():
                noise.append(torch.normal(0, self.noise_std_dev, size=p.shape))
                p += noise[-1]

        results = self.model(context)

        with torch.no_grad():
            for i, p in enumerate(self.model.parameters()):
                p -= noise[i]

        return results["x"], results["pred_rewards"]

    def _update_noise(self):
        x, _, _ = self.db.get_data(self.noise_update_batch_size)
        with torch.no_grad():
            results = self.model(x)
            y_pred_noisy, _ = self._noisy_pred(x)

        p = torch.distributions.Categorical(logits=results["x"])
        q = torch.distributions.Categorical(logits=y_pred_noisy)
        kl = torch.distributions.kl.kl_divergence(p, q).mean()

        delta = -np.log1p(-self.eps + self.eps / self.n_actions)
        if kl < delta:
            self.noise_std_dev *= 1.01
        else:
            self.noise_std_dev /= 1.01

        self.eps *= 0.99
예제 #8
0
class VariationalAgent(DCBAgent):
    """Deep contextual bandit agent using variation inference.

    Args:
        bandit (DataBasedBandit): The bandit to solve
        init_pulls (int, optional): Number of times to select each action initially.
            Defaults to 3.
        hidden_dims (List[int], optional): Dimensions of hidden layers of network.
            Defaults to [50, 50].
        init_lr (float, optional): Initial learning rate. Defaults to 0.1.
        lr_decay (float, optional): Decay rate for learning rate. Defaults to 0.5.
        lr_reset (bool, optional): Whether to reset learning rate ever train interval.
            Defaults to True.
        max_grad_norm (float, optional): Maximum norm of gradients for gradient clipping.
            Defaults to 0.5.
        dropout_p (Optional[float], optional): Probability for dropout. Defaults to None
            which implies dropout is not to be used.
        eval_with_dropout (bool, optional): Whether or not to use dropout at inference.
            Defaults to False.
        noise_std (float, optional): Standard deviation of noise in bayesian neural network.
            Defaults to 0.1.
        device (str): Device to use for tensor operations.
            "cpu" for cpu or "cuda" for cuda. Defaults to "cpu".
    """

    def __init__(self, bandit: DataBasedBandit, **kwargs):
        super(VariationalAgent, self).__init__(bandit, kwargs.get("device", "cpu"))
        self.init_pulls = kwargs.get("init_pulls", 3)
        self.model = (
            BayesianNNBanditModel(
                context_dim=self.context_dim,
                hidden_dims=kwargs.get("hidden_dims", [50, 50]),
                n_actions=self.n_actions,
                init_lr=kwargs.get("init_lr", 0.1),
                max_grad_norm=kwargs.get("max_grad_norm", 0.5),
                lr_decay=kwargs.get("lr_decay", 0.5),
                lr_reset=kwargs.get("lr_reset", True),
                dropout_p=kwargs.get("dropout_p", None),
                noise_std=kwargs.get("noise_std", 0.1),
            )
            .to(torch.float)
            .to(self.device)
        )
        self.eval_with_dropout = kwargs.get("eval_with_dropout", False)
        self.db = TransitionDB(self.device)
        self.t = 0
        self.update_count = 0

    def select_action(self, context: torch.Tensor) -> int:
        """Select an action based on given context.

        Selects an action by computing a forward pass through
        the bayesian neural network.

        Args:
            context (torch.Tensor): The context vector to select action for.

        Returns:
            int: The action to take.
        """
        self.model.use_dropout = self.eval_with_dropout
        self.t += 1
        if self.t < self.n_actions * self.init_pulls:
            return torch.tensor(
                self.t % self.n_actions, device=self.device, dtype=torch.int
            )
        results = self.model(context)
        action = torch.argmax(results["pred_rewards"]).to(torch.int)
        return action

    def update_db(self, context: torch.Tensor, action: int, reward: int):
        """Updates transition database with given transition

        Args:
            context (torch.Tensor): Context recieved
            action (int): Action taken
            reward (int): Reward recieved
        """
        self.db.add(context, action, reward)

    def update_params(self, action: int, batch_size: int = 512, train_epochs: int = 20):
        """Update parameters of the agent.

        Trains each neural network in the ensemble.

        Args:
            action (Optional[int], optional): Action to update the parameters for.
                Not applicable in this agent. Defaults to None.
            batch_size (int, optional): Size of batch to update parameters with.
                Defaults to 512
            train_epochs (int, optional): Epochs to train neural network for.
                Defaults to 20
        """
        self.update_count += 1
        self.model.train_model(self.db, train_epochs, batch_size)