def learn(self, n_steps):

        self.train_parameters['n_steps'] = n_steps

        if self.logging:
            self.logger = Logger(self.log_folder_details,
                                 self.train_parameters)

        cumulative_regret = 0

        for timestep in range(n_steps):

            x = self.env.sample()

            action, uncertainty, sampled_value = self.act(x.float())

            reward = self.env.hit(action)
            regret = self.env.regret(action)

            cumulative_regret += regret

            reward = torch.as_tensor([reward], dtype=torch.float)

            if self.logging:
                self.logger.add_scalar('Cumulative_Regret', cumulative_regret,
                                       timestep)
                self.logger.add_scalar('Mushrooms_Eaten', self.n_samples,
                                       timestep)
                if self.env.y_sample == 1:
                    self.logger.add_scalar('Uncertainty_Good', uncertainty,
                                           timestep)
                    self.logger.add_scalar('Sampled_Value_Good', sampled_value,
                                           timestep)
                else:
                    self.logger.add_scalar('Uncertainty_Bad', uncertainty,
                                           timestep)
                    self.logger.add_scalar('Sampled_Value_Bad', sampled_value,
                                           timestep)

            if action == 1:
                self.replay_buffer.add(x, reward)
                self.n_samples += 1

            if timestep % self.train_freq == 0 and self.n_samples > self.start_train_step:

                if self.verbose:
                    print('Timestep: {}, cumulative regret {}'.format(
                        str(timestep), str(cumulative_regret)))

                for update in range(self.updates_per_train):

                    samples = self.replay_buffer.sample(
                        np.min([self.n_samples, self.batch_size]))
                    self.train_step(samples)

        if self.logging:
            self.logger.save()
Пример #2
0
class UADQN:
    """
    # Required parameters
    env : Environment to use.
    network : Choice of neural network.

    # Environment parameter
    gamma : Discount factor

    # Replay buffer
    replay_start_size : The capacity of the replay buffer at which training can begin.
    replay_buffer_size : Maximum buffer capacity.

    # QR-DQN parameters
    n_quantiles: Number of quantiles to estimate
    kappa: Smoothing parameter for the Huber loss
    weight_scale: scale of prior neural network weights at initialization
    noise_scale: scale of aleatoric noise
    epistemic_factor: multiplier for epistemic uncertainty used for Thompson sampling
    aleatoric_factor: maulitplier for aleatoric uncertainty, used to adjust mean Q values
    update_target_frequency: Frequency at which target network is updated
    minibatch_size : Minibatch size.
    update_frequency : Number of environment steps taken between parameter update steps.
    learning_rate : Learning rate used for the Adam optimizer
    seed : The global seed to set.  None means randomly selected.
    adam_epsilon: Epsilon parameter for Adam optimizer
    biased_aleatoric: whether to use empirical std of quantiles as opposed to unbiased estimator

    # Logging and Saving
    logging : Whether to create logs when training
    log_folder_details : Additional information to put into the name of the log folder
    save_period : Periodicity with which the network weights are checkpointed
    notes : Notes to add to the log folder

    # Rendering
    render : Whether to render the environment during training. This slows down training.
    """

    def __init__(
        self,
        env,
        network,
        gamma=0.99,
        replay_start_size=50000,
        replay_buffer_size=1000000,
        n_quantiles=50,
        kappa=1,
        weight_scale=3,
        noise_scale=0.1,
        epistemic_factor=1,
        aleatoric_factor=1,
        update_target_frequency=10000,
        minibatch_size=32,
        update_frequency=1,
        learning_rate=1e-3,
        seed=None,
        adam_epsilon=1e-8,
        biased_aleatoric=False,
        logging=False,
        log_folder_details=None,
        save_period=250000,
        notes=None,
        render=False,
    ):

        # Agent parameters
        self.env = env
        self.gamma = gamma
        self.replay_start_size = replay_start_size
        self.replay_buffer_size = replay_buffer_size
        self.n_quantiles = n_quantiles
        self.kappa = kappa
        self.weight_scale = weight_scale
        self.noise_scale = noise_scale
        self.epistemic_factor = epistemic_factor,
        self.aleatoric_factor = aleatoric_factor,
        self.update_target_frequency = update_target_frequency
        self.minibatch_size = minibatch_size
        self.update_frequency = update_frequency
        self.learning_rate = learning_rate
        self.seed = random.randint(0, 1e6) if seed is None else seed
        self.adam_epsilon = adam_epsilon
        self.biased_aleatoric = biased_aleatoric
        self.logging = logging
        self.log_folder_details = log_folder_details
        self.save_period = save_period
        self.render = render
        self.notes = notes

        # Set global seed before creating network
        set_global_seed(self.seed, self.env)

        # Initialize agent
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.logger = None
        self.loss = quantile_huber_loss
        self.replay_buffer = ReplayBuffer(self.replay_buffer_size)

        # Initialize main Q learning network
        n_outputs = self.env.action_space.n*self.n_quantiles
        self.network = network(self.env.observation_space, n_outputs).to(self.device)
        self.target_network = network(self.env.observation_space, n_outputs).to(self.device)
        self.target_network.load_state_dict(self.network.state_dict())

        # Initialize anchored networks
        self.posterior1 = network(self.env.observation_space, n_outputs, weight_scale=weight_scale).to(self.device)
        self.posterior2 = network(self.env.observation_space, n_outputs, weight_scale=weight_scale).to(self.device)
        self.anchor1 = [p.data.clone() for p in list(self.posterior1.parameters())]
        self.anchor2 = [p.data.clone() for p in list(self.posterior2.parameters())]

        # Initialize optimizer
        params = list(self.network.parameters()) + list(self.posterior1.parameters()) + list(self.posterior2.parameters())
        self.optimizer = optim.Adam(params, lr=self.learning_rate, eps=self.adam_epsilon)

        # Figure out what the scale of the prior is from empirical std of network weights
        with torch.no_grad():
            std_list = []
            for i, p in enumerate(self.posterior1.parameters()):
                std_list.append(torch.std(p))
        self.prior_scale = torch.stack(std_list).mean().item()

        # Parameters to save to log file
        self.train_parameters = {
                    'Notes': notes,
                    'env': str(env),
                    'network': str(self.network),
                    'n_quantiles': n_quantiles,
                    'replay_start_size': replay_start_size,
                    'replay_buffer_size': replay_buffer_size,
                    'gamma': gamma,
                    'update_target_frequency': update_target_frequency,
                    'minibatch_size': minibatch_size,
                    'learning_rate': learning_rate,
                    'update_frequency': update_frequency,
                    'weight_scale': weight_scale,
                    'noise_scale': noise_scale,
                    'epistemic_factor': epistemic_factor,
                    'aleatoric_factor': aleatoric_factor,
                    'biased_aleatoric': biased_aleatoric,
                    'adam_epsilon': adam_epsilon,
                    'seed': self.seed
                    }

    def learn(self, timesteps, verbose=False):

        self.non_greedy_actions = 0
        self.timestep = 0
        self.this_episode_time = 0
        self.n_events = 0  # Number of times an important event is flagged in the info

        self.train_parameters['train_steps'] = timesteps
        pprint.pprint(self.train_parameters)

        if self.logging:
            self.logger = Logger(self.log_folder_details, self.train_parameters)

        # Initialize the state
        state = torch.as_tensor(self.env.reset())
        score = 0
        t1 = time.time()

        for timestep in range(timesteps):

            is_training_ready = timestep >= self.replay_start_size

            if self.render:
                self.env.render()

            # Select action
            action = self.act(state.to(self.device).float(), is_training_ready=is_training_ready)

            # Perform action in environment
            state_next, reward, done, info = self.env.step(action)

            if (info == "The agent fell!") and self.logging:  # For gridworld experiments
                self.n_events += 1
                self.logger.add_scalar('Agent falls', self.n_events, timestep)

            # Store transition in replay buffer
            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)
            state_next = torch.as_tensor(state_next)
            self.replay_buffer.add(state, action, reward, state_next, done)

            score += reward.item()
            self.this_episode_time += 1

            if done:

                if verbose:
                    print("Timestep : {}, score : {}, Time : {} s".format(timestep, score, round(time.time() - t1, 3)))

                if self.logging:
                    self.logger.add_scalar('Episode_score', score, timestep)

                # Reinitialize the state
                state = torch.as_tensor(self.env.reset())
                score = 0

                if self.logging:
                    non_greedy_fraction = self.non_greedy_actions/self.this_episode_time
                    self.logger.add_scalar('Non Greedy Fraction', non_greedy_fraction, timestep)

                self.non_greedy_actions = 0
                self.this_episode_time = 0
                t1 = time.time()

            else:
                state = state_next

            if is_training_ready:

                # Update main network
                if timestep % self.update_frequency == 0:

                    # Sample batch of transitions
                    transitions = self.replay_buffer.sample(self.minibatch_size, self.device)

                    # Train on selected batch
                    loss, anchor_loss = self.train_step(transitions)

                    if self.logging and timesteps < 50000:
                        self.logger.add_scalar('Loss', loss, timestep)
                        self.logger.add_scalar('Anchor Loss', anchor_loss, timestep)

                # Periodically update target Q network
                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(self.network.state_dict())

            if (timestep+1) % self.save_period == 0:
                self.save(timestep=timestep+1)

            self.timestep += 1

        if self.logging:
            self.logger.save()
            self.save()

        if self.render:
            self.env.close()

    def train_step(self, transitions):
        """
        Performs gradient descent step on a batch of transitions
        """

        states, actions, rewards, states_next, dones = transitions

        # Calculate target Q
        with torch.no_grad():
            target = self.target_network(states_next.float())
            target = target.view(self.minibatch_size, self.env.action_space.n, self.n_quantiles)

        # Calculate max of target Q values
        best_action_idx = torch.mean(target, dim=2).max(1, True)[1].unsqueeze(2)
        q_value_target = target.gather(1, best_action_idx.repeat(1, 1, self.n_quantiles))

        # Calculate TD target
        rewards = rewards.unsqueeze(2).repeat(1, 1, self.n_quantiles)
        dones = dones.unsqueeze(2).repeat(1, 1, self.n_quantiles)
        td_target = rewards + (1 - dones) * self.gamma * q_value_target

        # Calculate Q value of actions played
        outputs = self.network(states.float())
        outputs = outputs.view(self.minibatch_size, self.env.action_space.n, self.n_quantiles)
        actions = actions.unsqueeze(2).repeat(1, 1, self.n_quantiles)
        q_value = outputs.gather(1, actions)

        # TD loss for main network
        loss = self.loss(q_value.squeeze(), td_target.squeeze(), self.device, kappa=self.kappa)

        # Calculate predictions of posterior networks
        posterior1 = self.posterior1(states.float())
        posterior1 = posterior1.view(self.minibatch_size, self.env.action_space.n, self.n_quantiles)
        posterior1 = posterior1.gather(1, actions)

        posterior2 = self.posterior2(states.float())
        posterior2 = posterior2.view(self.minibatch_size, self.env.action_space.n, self.n_quantiles)
        posterior2 = posterior2.gather(1, actions)

        # Regression loss for the posterior networks
        loss_posterior1 = self.loss(posterior1.squeeze(), td_target.squeeze(), self.device, kappa=self.kappa)
        loss_posterior2 = self.loss(posterior2.squeeze(), td_target.squeeze(), self.device, kappa=self.kappa)
        loss += loss_posterior1 + loss_posterior2

        # Anchor loss for the posterior networks
        anchor_loss = self.calc_anchor_loss()        
        loss += anchor_loss

        # Update weights
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item(), anchor_loss.mean().item()

    def calc_anchor_loss(self):
        """
        Returns loss from anchoring
        """

        diff1 = []
        for i, p in enumerate(self.posterior1.parameters()):
            diff1.append(torch.sum((p - self.anchor1[i])**2))
        diff1 = torch.stack(diff1).sum()

        diff2 = []
        for i, p in enumerate(self.posterior2.parameters()):
            diff2.append(torch.sum((p-self.anchor2[i])**2))
        diff2 = torch.stack(diff2).sum()

        diff = diff1 + diff2

        num_data = np.min([self.timestep, self.replay_buffer_size])
        anchor_loss = self.noise_scale**2*diff/(self.prior_scale**2*num_data)

        return anchor_loss

    @torch.no_grad()
    def get_q(self, state):
        net = self.network(state).view(self.env.action_space.n, self.n_quantiles)
        action_means = torch.mean(net, dim=1)
        q = action_means
        return q

    @torch.no_grad()
    def act(self, state, is_training_ready=True):
        """
        Returns action to be performed using Thompson sampling
        with estimates provided by the two posterior networks
        """

        net = self.network(state).view(self.env.action_space.n, self.n_quantiles)

        posterior1 = self.posterior1(state).view(self.env.action_space.n, self.n_quantiles)
        posterior2 = self.posterior2(state).view(self.env.action_space.n, self.n_quantiles)

        mean_action_values = torch.mean(net, dim=1)

        # Calculate aleatoric uncertainty
        if self.biased_aleatoric:
            uncertainties_aleatoric = torch.std(net, dim=1)
        else:
            covariance = torch.mean((posterior1-torch.mean(posterior1))*(posterior2-torch.mean(posterior2)), dim=1)
            uncertainties_aleatoric = torch.sqrt(F.relu(covariance))

        # Aleatoric-adjusted Q values
        aleatoric_factor = torch.FloatTensor(self.aleatoric_factor).to(self.device)
        adjusted_action_values = mean_action_values - aleatoric_factor*uncertainties_aleatoric

        # Calculate epistemic uncertainty
        uncertainties_epistemic = torch.mean((posterior1-posterior2)**2, dim=1)/2 + 1e-8
        epistemic_factor = torch.FloatTensor(self.epistemic_factor).to(self.device)**2
        uncertainties_cov = epistemic_factor*torch.diagflat(uncertainties_epistemic)

        # Draw samples using Thompson sampling
        epistemic_distrib = torch.distributions.multivariate_normal.MultivariateNormal
        samples = epistemic_distrib(adjusted_action_values, covariance_matrix=uncertainties_cov).sample()
        action = samples.argmax().item()

        #print(mean_action_values, torch.sqrt(uncertainties_epistemic), torch.sqrt(uncertainties_aleatoric))
        
        if self.logging and self.this_episode_time == 0:
            self.logger.add_scalar('Epistemic Uncertainty 0', torch.sqrt(uncertainties_epistemic)[0], self.timestep)
            self.logger.add_scalar('Epistemic Uncertainty 1', torch.sqrt(uncertainties_epistemic)[1], self.timestep)
            self.logger.add_scalar('Aleatoric Uncertainty 0', uncertainties_aleatoric[0], self.timestep)
            self.logger.add_scalar('Aleatoric Uncertainty 1', uncertainties_aleatoric[1], self.timestep)
            self.logger.add_scalar('Q0', mean_action_values[0], self.timestep)
            self.logger.add_scalar('Q1', mean_action_values[1], self.timestep)

        if action != mean_action_values.argmax().item():
            self.non_greedy_actions += 1
        
        return action

    @torch.no_grad()
    def predict(self, state):
        """
        Returns action with the highest Q-value
        """
        net = self.network(state).view(self.env.action_space.n, self.n_quantiles)
        mean_action_values = torch.mean(net, dim=1)
        action = mean_action_values.argmax().item()

        return action

    def save(self, timestep=None):
        """
        Saves network weights
        """
        if timestep is not None:
            filename = 'network_' + str(timestep) + '.pth'
            filename_posterior1 = 'network_posterior1_' + str(timestep) + '.pth'
            filename_posterior2 = 'network_posterior2_' + str(timestep) + '.pth'
        else:
            filename = 'network.pth'
            filename_posterior1 = 'network_posterior1.pth'
            filename_posterior2 = 'network_posterior2.pth'

        save_path = self.logger.log_folder + '/' + filename
        save_path_posterior1 = self.logger.log_folder + '/' + filename_posterior1
        save_path_posterior2 = self.logger.log_folder + '/' + filename_posterior2

        torch.save(self.network.state_dict(), save_path)
        torch.save(self.posterior1.state_dict(), save_path_posterior1)
        torch.save(self.posterior2.state_dict(), save_path_posterior2)

    def load(self, path):
        """
        Loads network weights
        """
        self.network.load_state_dict(torch.load(path + 'network.pth', map_location='cpu'))
        self.posterior1.load_state_dict(torch.load(path + 'network_posterior1.pth', map_location='cpu'))
        self.posterior2.load_state_dict(torch.load(path + 'network_posterior2.pth', map_location='cpu'))
Пример #3
0
    def learn(self, timesteps, verbose=False):

        self.non_greedy_actions = 0
        self.timestep = 0
        self.this_episode_time = 0
        self.n_events = 0  # Number of times an important event is flagged in the info

        self.train_parameters['train_steps'] = timesteps
        pprint.pprint(self.train_parameters)

        if self.logging:
            self.logger = Logger(self.log_folder_details, self.train_parameters)

        # Initialize the state
        state = torch.as_tensor(self.env.reset())
        score = 0
        t1 = time.time()

        for timestep in range(timesteps):

            is_training_ready = timestep >= self.replay_start_size

            if self.render:
                self.env.render()

            # Select action
            action = self.act(state.to(self.device).float(), is_training_ready=is_training_ready)

            # Perform action in environment
            state_next, reward, done, info = self.env.step(action)

            if (info == "The agent fell!") and self.logging:  # For gridworld experiments
                self.n_events += 1
                self.logger.add_scalar('Agent falls', self.n_events, timestep)

            # Store transition in replay buffer
            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)
            state_next = torch.as_tensor(state_next)
            self.replay_buffer.add(state, action, reward, state_next, done)

            score += reward.item()
            self.this_episode_time += 1

            if done:

                if verbose:
                    print("Timestep : {}, score : {}, Time : {} s".format(timestep, score, round(time.time() - t1, 3)))

                if self.logging:
                    self.logger.add_scalar('Episode_score', score, timestep)

                # Reinitialize the state
                state = torch.as_tensor(self.env.reset())
                score = 0

                if self.logging:
                    non_greedy_fraction = self.non_greedy_actions/self.this_episode_time
                    self.logger.add_scalar('Non Greedy Fraction', non_greedy_fraction, timestep)

                self.non_greedy_actions = 0
                self.this_episode_time = 0
                t1 = time.time()

            else:
                state = state_next

            if is_training_ready:

                # Update main network
                if timestep % self.update_frequency == 0:

                    # Sample batch of transitions
                    transitions = self.replay_buffer.sample(self.minibatch_size, self.device)

                    # Train on selected batch
                    loss, anchor_loss = self.train_step(transitions)

                    if self.logging and timesteps < 50000:
                        self.logger.add_scalar('Loss', loss, timestep)
                        self.logger.add_scalar('Anchor Loss', anchor_loss, timestep)

                # Periodically update target Q network
                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(self.network.state_dict())

            if (timestep+1) % self.save_period == 0:
                self.save(timestep=timestep+1)

            self.timestep += 1

        if self.logging:
            self.logger.save()
            self.save()

        if self.render:
            self.env.close()
Пример #4
0
class DQN:
    def __init__(
        self,
        env,
        network,
        replay_start_size=50000,
        replay_buffer_size=1000000,
        gamma=0.99,
        update_target_frequency=10000,
        minibatch_size=32,
        learning_rate=1e-3,
        update_frequency=1,
        initial_exploration_rate=1,
        final_exploration_rate=0.1,
        final_exploration_step=1000000,
        adam_epsilon=1e-8,
        logging=False,
        log_folder_details=None,
        seed=None,
        render=False,
        loss="huber",
        notes=None
    ):

        

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.replay_start_size = replay_start_size
        self.replay_buffer_size = replay_buffer_size
        self.gamma = gamma
        self.update_target_frequency = update_target_frequency
        self.minibatch_size = minibatch_size
        self.learning_rate = learning_rate
        self.update_frequency = update_frequency
        self.initial_exploration_rate = initial_exploration_rate
        self.epsilon = self.initial_exploration_rate
        self.final_exploration_rate = final_exploration_rate
        self.final_exploration_step = final_exploration_step
        self.adam_epsilon = adam_epsilon
        self.logging = logging
        self.render=render
        self.log_folder_details = log_folder_details
        if callable(loss):
            self.loss = loss
        else:
            try:
                self.loss = {'huber': F.smooth_l1_loss, 'mse': F.mse_loss}[loss]
            except KeyError:
                raise ValueError("loss must be 'huber', 'mse' or a callable")

        self.env = env
        self.replay_buffer = ReplayBuffer(self.replay_buffer_size)
        self.seed = random.randint(0, 1e6) if seed is None else seed
        self.logger = None

        set_global_seed(self.seed, self.env)

        self.network = network(self.env.observation_space, self.env.action_space.n).to(self.device)
        self.target_network = network(self.env.observation_space, self.env.action_space.n).to(self.device)
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = optim.Adam(self.network.parameters(), lr=self.learning_rate, eps=self.adam_epsilon)

        self.train_parameters = {'Notes':notes,
                'env':env.unwrapped.spec.id,
                'network':str(self.network),
                'replay_start_size':replay_start_size,
                'replay_buffer_size':replay_buffer_size,
                'gamma':gamma,
                'update_target_frequency':update_target_frequency,
                'minibatch_size':minibatch_size,
                'learning_rate':learning_rate,
                'update_frequency':update_frequency,
                'initial_exploration_rate':initial_exploration_rate,
                'final_exploration_rate':final_exploration_rate,
                'weight_scale':self.network.weight_scale,
                'final_exploration_step':final_exploration_step,
                'adam_epsilon':adam_epsilon,
                'loss':loss,
                'seed':self.seed}

    def learn(self, timesteps, verbose=False):

        self.train_parameters['train_steps'] = timesteps
        pprint.pprint(self.train_parameters)

        if self.logging:
            self.logger = Logger(self.log_folder_details,self.train_parameters)

        # On initialise l'état
        state = torch.as_tensor(self.env.reset())
        score = 0
        t1 = time.time()

        for timestep in range(timesteps):

            is_training_ready = timestep >= self.replay_start_size

            if self.render:
                self.env.render()

            # On prend une action
            action = self.act(state.to(self.device).float(), is_training_ready=is_training_ready)

            # Mise à jour d'epsilon
            self.update_epsilon(timestep)

            # On execute l'action dans l'environnement
            state_next, reward, done, _ = self.env.step(action)

            # On stock la transition dans le replay buffer
            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)
            state_next = torch.as_tensor(state_next)
            self.replay_buffer.add(state, action, reward, state_next, done)

            score += reward.item()

            if done:
                # On réinitialise l'état
                if verbose:
                    print("Timestep : {}, score : {}, Time : {} s".format(timestep, score, round(time.time() - t1, 3)))
                if self.logging:
                    self.logger.add_scalar('Episode_score', score, timestep)
                state = torch.as_tensor(self.env.reset())
                score = 0
                if self.logging:
                    self.logger.add_scalar('Q_at_start', self.get_max_q(state.to(self.device).float()), timestep)

                t1 = time.time()
            else:
                state = state_next

            if is_training_ready:

                # Update du réseau principal
                if timestep % self.update_frequency == 0:

                    # On sample un minibatche de transitions
                    transitions = self.replay_buffer.sample(self.minibatch_size, self.device)

                    # On s'entraine sur les transitions selectionnées
                    loss = self.train_step(transitions)

                    if self.logging and timesteps < 100000:
                        self.logger.add_scalar('Loss', loss, timestep)

                # Si c'est le moment, on update le target Q network on copiant les poids du network
                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(self.network.state_dict())

            if (timestep+1) % 250000 == 0:
                    self.save(timestep=timestep+1)

        if self.logging:
            self.logger.save()
            self.save()

        if self.render:
            self.env.close()

    def train_step(self, transitions):
        # huber = torch.nn.SmoothL1Loss()
        states, actions, rewards, states_next, dones = transitions

        # Calcul de la Q value via le target network (celle qui maximise)
        with torch.no_grad():
            q_value_target = self.target_network(states_next.float()).max(1, True)[0]

        # Calcul de la TD Target
        td_target = rewards + (1 - dones) * self.gamma * q_value_target

        # Calcul de la Q value en fonction de l'action jouée
        q_value = self.network(states.float()).gather(1, actions)

        loss = self.loss(q_value, td_target, reduction='mean')

        # Update des poids
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    @torch.no_grad()
    def get_max_q(self,state):
        return self.network(state).max().item()

    def act(self, state, is_training_ready=True):
        if is_training_ready and random.uniform(0, 1) >= self.epsilon:
            # Action qui maximise la Q function
            action = self.predict(state)
        else:
            # Action aléatoire ( sur l'interval [a, b[ )
            action = np.random.randint(0, self.env.action_space.n)
        return action

    def update_epsilon(self, timestep):
        eps = self.initial_exploration_rate - (self.initial_exploration_rate - self.final_exploration_rate) * (
            timestep / self.final_exploration_step
        )
        self.epsilon = max(eps, self.final_exploration_rate)

    @torch.no_grad()
    def predict(self, state):
        action = self.network(state).argmax().item()
        return action

    def save(self,timestep=None):
        if not self.logging:
            raise NotImplementedError('Cannot save without log folder.')

        if timestep is not None:
            filename = 'network_' + str(timestep) + '.pth'
        else:
            filename = 'network.pth'

        save_path = self.logger.log_folder + '/' + filename

        torch.save(self.network.state_dict(), save_path)

    def load(self,path):
        self.network.load_state_dict(torch.load(path,map_location='cpu'))
Пример #5
0
    def learn(self, timesteps, verbose=False):

        self.train_parameters['train_steps'] = timesteps
        pprint.pprint(self.train_parameters)

        if self.logging:
            self.logger = Logger(self.log_folder_details,self.train_parameters)

        # On initialise l'état
        state = torch.as_tensor(self.env.reset())
        score = 0
        t1 = time.time()

        for timestep in range(timesteps):

            is_training_ready = timestep >= self.replay_start_size

            if self.render:
                self.env.render()

            # On prend une action
            action = self.act(state.to(self.device).float(), is_training_ready=is_training_ready)

            # Mise à jour d'epsilon
            self.update_epsilon(timestep)

            # On execute l'action dans l'environnement
            state_next, reward, done, _ = self.env.step(action)

            # On stock la transition dans le replay buffer
            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)
            state_next = torch.as_tensor(state_next)
            self.replay_buffer.add(state, action, reward, state_next, done)

            score += reward.item()

            if done:
                # On réinitialise l'état
                if verbose:
                    print("Timestep : {}, score : {}, Time : {} s".format(timestep, score, round(time.time() - t1, 3)))
                if self.logging:
                    self.logger.add_scalar('Episode_score', score, timestep)
                state = torch.as_tensor(self.env.reset())
                score = 0
                if self.logging:
                    self.logger.add_scalar('Q_at_start', self.get_max_q(state.to(self.device).float()), timestep)

                t1 = time.time()
            else:
                state = state_next

            if is_training_ready:

                # Update du réseau principal
                if timestep % self.update_frequency == 0:

                    # On sample un minibatche de transitions
                    transitions = self.replay_buffer.sample(self.minibatch_size, self.device)

                    # On s'entraine sur les transitions selectionnées
                    loss = self.train_step(transitions)

                    if self.logging and timesteps < 100000:
                        self.logger.add_scalar('Loss', loss, timestep)

                # Si c'est le moment, on update le target Q network on copiant les poids du network
                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(self.network.state_dict())

            if (timestep+1) % 250000 == 0:
                    self.save(timestep=timestep+1)

        if self.logging:
            self.logger.save()
            self.save()

        if self.render:
            self.env.close()
Пример #6
0
class IDE():
    def __init__(self,
                 env,
                 network,
                 n_quantiles=20,
                 kappa=1,
                 lamda=0.1,
                 replay_start_size=50000,
                 replay_buffer_size=1000000,
                 gamma=0.99,
                 update_target_frequency=10000,
                 epsilon_12=0.00001,
                 minibatch_size=32,
                 learning_rate=1e-4,
                 update_frequency=1,
                 prior=0.01,
                 adam_epsilon=1e-8,
                 logging=False,
                 log_folder_details=None,
                 seed=None,
                 notes=None):

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.replay_start_size = replay_start_size
        self.replay_buffer_size = replay_buffer_size
        self.gamma = gamma
        self.epsilon_12 = epsilon_12
        self.lamda = lamda
        self.update_target_frequency = update_target_frequency
        self.minibatch_size = minibatch_size
        self.learning_rate = learning_rate
        self.update_frequency = update_frequency
        self.adam_epsilon = adam_epsilon
        self.logging = logging
        self.logger = None
        self.timestep = 0
        self.log_folder_details = log_folder_details

        self.env = env
        self.replay_buffer = ReplayBuffer(self.replay_buffer_size)
        self.seed = random.randint(0, 1e6) if seed is None else seed
        set_global_seed(self.seed, self.env)

        self.n_quantiles = n_quantiles

        self.network = network(self.env.observation_space,
                               self.env.action_space.n * self.n_quantiles,
                               self.env.action_space.n * self.n_quantiles).to(
                                   self.device)
        self.target_network = network(
            self.env.observation_space,
            self.env.action_space.n * self.n_quantiles,
            self.env.action_space.n * self.n_quantiles).to(self.device)
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = optim.Adam(self.network.parameters(),
                                    lr=self.learning_rate,
                                    eps=self.adam_epsilon)

        self.anchor1 = [
            p.data.clone() for p in list(self.network.output_1.parameters())
        ]
        self.anchor2 = [
            p.data.clone() for p in list(self.network.output_2.parameters())
        ]

        self.loss = quantile_huber_loss
        self.kappa = kappa
        self.prior = prior

        self.train_parameters = {
            'Notes': notes,
            'env': env.unwrapped.spec.id,
            'network': str(self.network),
            'replay_start_size': replay_start_size,
            'replay_buffer_size': replay_buffer_size,
            'gamma': gamma,
            'lambda': lamda,
            'epsilon_1and2': epsilon_12,
            'update_target_frequency': update_target_frequency,
            'minibatch_size': minibatch_size,
            'learning_rate': learning_rate,
            'update_frequency': update_frequency,
            'kappa': kappa,
            'weight_scale': self.network.weight_scale,
            'n_quantiles': n_quantiles,
            'prior': prior,
            'adam_epsilon': adam_epsilon,
            'seed': self.seed
        }

        self.n_greedy_actions = 0
        self.timestep = 0

    def learn(self, timesteps, verbose=False):

        self.train_parameters['train_steps'] = timesteps
        pprint.pprint(self.train_parameters)

        if self.logging:
            self.logger = Logger(self.log_folder_details,
                                 self.train_parameters)

        # On initialise l'état
        state = torch.as_tensor(self.env.reset())
        this_episode_time = 0
        score = 0
        t1 = time.time()

        for timestep in range(timesteps):

            self.timestep = timestep

            is_training_ready = timestep >= self.replay_start_size

            # On prend une action
            action = self.act(state.to(self.device).float(),
                              directed_exploration=True)

            # On execute l'action dans l'environnement
            state_next, reward, done, _ = self.env.step(action)

            # On stock la transition dans le replay buffer
            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)
            state_next = torch.as_tensor(state_next)
            self.replay_buffer.add(state, action, reward, state_next, done)

            score += reward.item()
            this_episode_time += 1

            if done:
                # On réinitialise l'état
                if verbose:
                    print("Timestep : {}, score : {}, Time : {} s".format(
                        timestep, score, round(time.time() - t1, 3)))
                if self.logging:
                    self.logger.add_scalar('Episode_score', score, timestep)
                    self.logger.add_scalar(
                        'Non_greedy_fraction',
                        1 - self.n_greedy_actions / this_episode_time,
                        timestep)
                state = torch.as_tensor(self.env.reset())
                score = 0
                if self.logging:
                    self.logger.add_scalar(
                        'Q_at_start',
                        self.get_max_q(state.to(self.device).float()),
                        timestep)
                t1 = time.time()
                self.n_greedy_actions = 0
                this_episode_time = 0
            else:
                state = state_next

            if is_training_ready:

                # Update du réseau principal
                if timestep % self.update_frequency == 0:

                    # On sample un minibatche de transitions
                    transitions = self.replay_buffer.sample(
                        self.minibatch_size, self.device)

                    # On s'entraine sur les transitions selectionnées
                    loss = self.train_step(transitions)
                    if self.logging and timesteps < 1000000:
                        self.logger.add_scalar('Loss', loss, timestep)

                # Si c'est le moment, on update le target Q network on copiant les poids du network
                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(
                        self.network.state_dict())

            if (timestep + 1) % 250000 == 0:
                self.save(timestep=timestep + 1)

        if self.logging:
            self.logger.save()
            self.save()

    def train_step(self, transitions):
        # huber = torch.nn.SmoothL1Loss()
        states, actions, rewards, states_next, dones = transitions

        # Calcul de la Q value via le target network (celle qui maximise)
        with torch.no_grad():

            target1, target2 = self.target_network(states_next.float())
            target1 = target1.view(self.minibatch_size,
                                   self.env.action_space.n, self.n_quantiles)
            target2 = target2.view(self.minibatch_size,
                                   self.env.action_space.n, self.n_quantiles)

            #Used to determine what action the current policy would have chosen in the next state
            target1_onpolicy, target2_onpolicy, = self.network(
                states_next.float())
            target1_onpolicy = target1_onpolicy.view(self.minibatch_size,
                                                     self.env.action_space.n,
                                                     self.n_quantiles)
            target2_onpolicy = target2_onpolicy.view(self.minibatch_size,
                                                     self.env.action_space.n,
                                                     self.n_quantiles)

        best_action_idx = torch.mean((target1 + target2) / 2,
                                     dim=2).max(1, True)[1].unsqueeze(2)
        target1_gathered = target1.gather(
            1, best_action_idx.repeat(1, 1, self.n_quantiles))
        target2_gathered = target2.gather(
            1, best_action_idx.repeat(1, 1, self.n_quantiles))

        #uncertainty_target = uncertainty_output.gather(1,best_action_idx.squeeze(2))
        q_value_target = 0.5*target1_gathered\
            + 0.5*target2_gathered

        # Calcul de la TD Target
        td_target = rewards.unsqueeze(2).repeat(1,1,self.n_quantiles) \
            + (1 - dones.unsqueeze(2).repeat(1,1,self.n_quantiles)) * self.gamma * q_value_target

        # Calcul de la Q value en fonction de l'action jouée
        out1, out2 = self.network(states.float())
        out1 = out1.view(self.minibatch_size, self.env.action_space.n,
                         self.n_quantiles)
        out2 = out2.view(self.minibatch_size, self.env.action_space.n,
                         self.n_quantiles)

        q_value1 = out1.gather(
            1,
            actions.unsqueeze(2).repeat(1, 1, self.n_quantiles))
        q_value2 = out2.gather(
            1,
            actions.unsqueeze(2).repeat(1, 1, self.n_quantiles))

        #Calculate quantile losses
        loss1 = self.loss(q_value1.squeeze(),
                          td_target.squeeze(),
                          self.device,
                          kappa=self.kappa)
        loss2 = self.loss(q_value2.squeeze(),
                          td_target.squeeze(),
                          self.device,
                          kappa=self.kappa)

        quantile_loss = loss1 + loss2

        diff1 = []
        for i, p in enumerate(self.network.output_1.parameters()):
            diff1.append(torch.sum((p - self.anchor1[i])**2))

        diff2 = []
        for i, p in enumerate(self.network.output_2.parameters()):
            diff2.append(torch.sum((p - self.anchor2[i])**2))

        diff1 = torch.stack(diff1).sum()
        diff2 = torch.stack(diff2).sum()

        anchor_loss = self.prior * (diff1 + diff2)

        loss = quantile_loss + anchor_loss
        #print(anchor_loss/loss)

        # Update des poids
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def act(self, state, directed_exploration=False):

        action = self.predict(state, directed_exploration=directed_exploration)

        return action

    @torch.no_grad()
    def predict(self, state, directed_exploration=False):
        if not directed_exploration:  #Choose greedily
            net1, net2 = self.network(state)
            net1 = net1.view(self.env.action_space.n, self.n_quantiles)
            net2 = net2.view(self.env.action_space.n, self.n_quantiles)
            action_means = torch.mean((net1 + net2) / 2, dim=1)
            action = action_means.argmax().item()
        else:
            net1, net2 = self.network(state)
            net1 = net1.view(self.env.action_space.n, self.n_quantiles)
            net2 = net2.view(self.env.action_space.n, self.n_quantiles)

            target_net1, target_net2 = self.target_network(state)
            target_net1 = target_net1.view(self.env.action_space.n,
                                           self.n_quantiles)
            target_net2 = target_net2.view(self.env.action_space.n,
                                           self.n_quantiles)

            uncertainties_epistemic = torch.mean(
                (target_net1 - target_net2)**2, dim=1) / 2

            #Calculate regret
            action_means = torch.mean((net1 + net2) / 2, dim=1)
            delta_regret = torch.max(
                action_means + self.lamda * torch.sqrt(uncertainties_epistemic)
            ) - (action_means -
                 torch.sqrt(self.lamda * uncertainties_epistemic))

            #Calculate and normalize aleatoric uncertainties (variance of quantile - local epistemic)
            aleatoric = torch.abs(
                (torch.var(net1, dim=1) + torch.var(net2, dim=1)) / 2 -
                uncertainties_epistemic)
            uncertainties_aleatoric = aleatoric / (self.epsilon_12 +
                                                   torch.mean(aleatoric))

            #Use learned uncertainty and aleatoric uncertainty to calculate regret to information ratio and select actions
            information = torch.log(1 + uncertainties_epistemic /
                                    uncertainties_aleatoric) + self.epsilon_12
            regret_info_ratio = delta_regret**2 / information
            action = regret_info_ratio.argmin().item()

            if action == action_means.argmax().item():
                self.n_greedy_actions += 1

            if self.logging:
                self.logger.add_scalar('Epistemic_Uncertainty',
                                       uncertainties_epistemic[action].sqrt(),
                                       self.timestep)
                self.logger.add_scalar('Aleatoric_Uncertainty',
                                       uncertainties_aleatoric[action].sqrt(),
                                       self.timestep)

        return action

    @torch.no_grad()
    def get_max_q(self, state):
        net1, net2 = self.network(state)
        net1 = net1.view(self.env.action_space.n, self.n_quantiles)
        net2 = net2.view(self.env.action_space.n, self.n_quantiles)
        action_means = torch.mean((net1 + net2) / 2, dim=1)
        max_q = action_means.max().item()
        return max_q

    def save(self, timestep=None):
        if not self.logging:
            raise NotImplementedError('Cannot save without log folder.')

        if timestep is not None:
            filename = 'network_' + str(timestep) + '.pth'
        else:
            filename = 'network.pth'

        save_path = self.logger.log_folder + '/' + filename

        torch.save(self.network.state_dict(), save_path)

    def load(self, path):
        self.network.load_state_dict(torch.load(path, map_location='cpu'))
        self.target_network.load_state_dict(
            torch.load(path, map_location='cpu'))
    def learn(self, timesteps, verbose=False):

        self.train_parameters['train_steps'] = timesteps
        pprint.pprint(self.train_parameters)

        if self.logging:
            self.logger = Logger(self.log_folder_details,
                                 self.train_parameters)

        # Initialize state
        state = torch.as_tensor(self.env.reset())
        score = 0
        t1 = time.time()

        for timestep in range(timesteps):

            is_training_ready = timestep >= self.replay_start_size

            # Pick action
            action = self.act(state.to(self.device).float(),
                              is_training_ready=is_training_ready)

            # Update epsilon
            self.update_epsilon(timestep)

            # Perform action in environment
            state_next, reward, done, _ = self.env.step(action)

            # Store transition in replay buffer
            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)
            state_next = torch.as_tensor(state_next)
            self.replay_buffer.add(state, action, reward, state_next, done)

            score += reward.item()

            if done:
                # Reinitialize the state
                if verbose:
                    print("Timestep : {}, score : {}, Time : {} s".format(
                        timestep, score, round(time.time() - t1, 3)))
                if self.logging:
                    self.logger.add_scalar('Episode_score', score, timestep)
                state = torch.as_tensor(self.env.reset())
                score = 0
                if self.logging:
                    self.logger.add_scalar(
                        'Q_at_start',
                        self.get_max_q(state.to(self.device).float()),
                        timestep)

                t1 = time.time()
            else:
                state = state_next

            if is_training_ready:

                # Update network
                if timestep % self.update_frequency == 0:

                    # Sample batch of transitions
                    transitions = self.replay_buffer.sample(
                        self.minibatch_size, self.device)

                    # Train on selected batch
                    loss = self.train_step(transitions)

                    if self.logging and timesteps < 100000:
                        self.logger.add_scalar('Loss', loss, timestep)

                # Update target network
                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(
                        self.network.state_dict())

            if (timestep + 1) % 250000 == 0:
                self.save(timestep=timestep + 1)

        if self.logging:
            self.logger.save()
            self.save()
Пример #8
0
class BOOTSTRAPPED:
    """
    # Required parameters
    env : Environment to use.
    network : Choice of neural network.

    # Environment parameter
    gamma : Discount factor

    # Replay buffer
    replay_start_size : The capacity of the replay buffer at which training can begin.
    replay_buffer_size : Maximum buffer capacity.

    # Bootstrapped DQN parameters
    n_heads: Number of bootstrap heads
    update_target_frequency: Frequency at which target network is updated
    minibatch_size : Minibatch size.
    update_frequency : Number of environment steps taken between parameter update steps.
    learning_rate : Learning rate used for the Adam optimizer
    loss : Type of loss function to use. Can be "huber" or "mse"
    seed : The global seed to set.  None means randomly selected.
    adam_epsilon: Epsilon parameter for Adam optimizer

    # Exploration
    initial_exploration_rate : Inital exploration rate.
    final_exploration_rate : Final exploration rate.
    final_exploration_step : Timestep at which the final exploration rate is reached.

    # Logging and Saving
    logging : Whether to create logs when training
    log_folder_details : Additional information to put into the name of the log folder
    save_period : Periodicity with which the network weights are checkpointed
    notes : Notes to add to the log folder

    # Rendering
    render : Whether to render the environment during training. This slows down training.
    """

    def __init__(
        self,
        env,
        network,
        gamma=0.99,
        replay_start_size=50000,
        replay_buffer_size=1000000,
        n_heads=10,
        update_target_frequency=10000,
        minibatch_size=32,
        update_frequency=1,
        learning_rate=1e-3,
        loss="huber",
        seed=None,
        adam_epsilon=1e-8,
        initial_exploration_rate=1,
        final_exploration_rate=0.1,
        final_exploration_step=1000000,
        logging=False,
        log_folder_details=None,
        save_period=250000,
        notes=None,
        render=False,
    ):

        # Agent parameters
        self.env = env
        self.gamma = gamma
        self.replay_start_size = replay_start_size
        self.replay_buffer_size = replay_buffer_size
        self.n_heads = n_heads
        self.update_target_frequency = update_target_frequency
        self.minibatch_size = minibatch_size
        self.update_frequency = update_frequency
        self.learning_rate = learning_rate
        if callable(loss):
            self.loss = loss
        else:
            try:
                self.loss = {'huber': F.smooth_l1_loss, 'mse': F.mse_loss}[loss]
            except KeyError:
                raise ValueError("loss must be 'huber', 'mse' or a callable")
        self.seed = random.randint(0, 1e6) if seed is None else seed
        self.adam_epsilon = adam_epsilon
        self.initial_exploration_rate = initial_exploration_rate
        self.final_exploration_rate = final_exploration_rate
        self.final_exploration_step = final_exploration_step
        self.logging = logging
        self.log_folder_details = log_folder_details
        self.save_period = save_period
        self.render = render
        self.notes = notes

        # Set global seed before creating network
        set_global_seed(self.seed, self.env)

        # Initialize agent
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.replay_buffer = ReplayBuffer(self.replay_buffer_size)
        self.logger = None
        self.epsilon = self.initial_exploration_rate
        self.network = network(self.env.observation_space, self.env.action_space.n, self.n_heads).to(self.device)
        self.target_network = network(self.env.observation_space, self.env.action_space.n, self.n_heads).to(self.device)
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = optim.Adam(self.network.parameters(), lr=self.learning_rate, eps=self.adam_epsilon)
        self.current_head = None
        self.timestep = 0

        # Parameters to save to log file
        self.train_parameters = {
                    'Notes': notes,
                    'env': str(env),
                    'network': str(self.network),
                    'replay_start_size': replay_start_size,
                    'replay_buffer_size': replay_buffer_size,
                    'gamma': gamma,
                    'n_heads': n_heads,
                    'update_target_frequency': update_target_frequency,
                    'minibatch_size': minibatch_size,
                    'learning_rate': learning_rate,
                    'update_frequency': update_frequency,
                    'initial_exploration_rate': initial_exploration_rate,
                    'final_exploration_rate': final_exploration_rate,
                    'weight_scale': self.network.weight_scale,
                    'final_exploration_step': final_exploration_step,
                    'adam_epsilon': adam_epsilon,
                    'loss': loss,
                    'seed': self.seed
                    }

    def learn(self, timesteps, verbose=False):

        self.current_head = np.random.randint(self.n_heads)

        self.train_parameters['train_steps'] = timesteps
        pprint.pprint(self.train_parameters)

        if self.logging:
            self.logger = Logger(self.log_folder_details, self.train_parameters)

        # Initialize the state
        state = torch.as_tensor(self.env.reset())
        score = 0
        t1 = time.time()

        for timestep in range(timesteps):

            self.timestep = timestep

            is_training_ready = timestep >= self.replay_start_size

            if self.render:
                self.env.render()

            # Select action
            action = self.act(state.to(self.device).float(), is_training_ready=is_training_ready)

            # Update epsilon
            self.update_epsilon(timestep)

            # Perform action in environment
            state_next, reward, done, _ = self.env.step(action)

            # Store transition in replay buffer
            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)
            state_next = torch.as_tensor(state_next)
            self.replay_buffer.add(state, action, reward, state_next, done)

            score += reward.item()

            if done:
                self.current_head = np.random.randint(self.n_heads)
                if self.logging:
                    self.logger.add_scalar('Acting_Head', self.current_head, self.timestep)

                # Reinitialize the state
                if verbose:
                    print("Timestep : {}, score : {}, Time : {} s".format(timestep, score, round(time.time() - t1, 3)))
                if self.logging:
                    self.logger.add_scalar('Episode_score', score, timestep)
                state = torch.as_tensor(self.env.reset())
                score = 0
                if self.logging:
                    self.logger.add_scalar('Q_at_start', self.get_max_q(state.to(self.device).float()), timestep)

                t1 = time.time()
            else:
                state = state_next

            if is_training_ready:

                # Update main network
                if timestep % self.update_frequency == 0:

                    # Sample batch of transitions
                    transitions = self.replay_buffer.sample(self.minibatch_size, self.device)

                    # Train on selected batch
                    loss = self.train_step(transitions)

                # Periodically update target Q network
                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(self.network.state_dict())

            if (timestep+1) % self.save_period == 0:
                self.save(timestep=timestep+1)

        if self.logging:
            self.logger.save()
            self.save()

        if self.render:
            self.env.close()

    def train_step(self, transitions):
        """
        Performs gradient descent step on a batch of transitions
        """

        states, actions, rewards, states_next, dones = transitions

        # Calculate target Q
        with torch.no_grad():

            # Target shape : batch x n_heads x n_actions
            targets = self.target_network(states_next.float())
            q_value_target = targets.max(2, True)[0]

        # Calculate TD target
        td_target = rewards.unsqueeze(1) + (1 - dones.unsqueeze(1)) * self.gamma * q_value_target

        # Calculate Q value of actions played
        output = self.network(states.float())
        q_value = output.gather(2, actions.unsqueeze(1).repeat(1,self.n_heads,1))

        loss = self.loss(q_value, td_target, reduction='mean')

        # Update weights
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    @torch.no_grad()
    def get_max_q(self,state):
        """
        Returns largest Q value at the state
        """
        return self.network(state).max().item()

    def act(self, state, is_training_ready=True):
        """
        Returns action to be performed with an epsilon-greedy policy
        """

        if is_training_ready and random.uniform(0, 1) >= self.epsilon:
            # Action that maximizes Q
            action = self.predict(state)
        else:
            # Random action
            action = np.random.randint(0, self.env.action_space.n)

        return action

    def update_epsilon(self, timestep):
        """
        Update the exploration parameter
        """

        eps = self.initial_exploration_rate - (self.initial_exploration_rate - self.final_exploration_rate) * (
            timestep / self.final_exploration_step
        )
        self.epsilon = max(eps, self.final_exploration_rate)

    @torch.no_grad()
    def predict(self, state, train=True):
        """
        Returns action with the highest Q-value
        """

        if train:
            out = self.network(state).squeeze()
            action = out[self.current_head, :].argmax().item()
            """
            if self.logging:
                self.logger.add_scalar('Uncertainty', out.max(1,True)[0].std(), self.timestep)
            """
        else:
            out = self.network(state).squeeze()  # Shape B x n_heads x n_actions
            # The heads vote on the best action
            actions, count = torch.unique(out.argmax(1), return_counts=True)
            action = actions[count.argmax().item()].item()

        return action

    def save(self, timestep=None):
        """
        Saves network weights
        """

        if not self.logging:
            raise NotImplementedError('Cannot save without log folder.')

        if timestep is not None:
            filename = 'network_' + str(timestep) + '.pth'
        else:
            filename = 'network.pth'

        save_path = self.logger.log_folder + '/' + filename

        torch.save(self.network.state_dict(), save_path)

    def load(self, path):
        """
        Loads network weights
        """

        self.network.load_state_dict(torch.load(path, map_location='cpu'))
Пример #9
0
class DropoutAgent():
    def __init__(self,
                 env,
                 network,
                 dropout=0.1,
                 std_prior=0.01,
                 logging=True,
                 train_freq=10,
                 updates_per_train=100,
                 weight_decay=1e-5,
                 batch_size=32,
                 start_train_step=10,
                 log_folder_details=None,
                 learning_rate=1e-3,
                 verbose=False):

        self.env = env
        self.network = network(env.n_features, std_prior, dropout=dropout)
        self.logging = logging
        self.replay_buffer = ReplayBuffer()
        self.batch_size = batch_size
        self.log_folder_details = log_folder_details
        self.train_freq = train_freq
        self.start_train_step = start_train_step
        self.updates_per_train = updates_per_train
        self.verbose = verbose
        self.dropout = dropout
        self.weight_decay = weight_decay

        self.n_samples = 0
        self.optimizer = optim.Adam(self.network.parameters(),
                                    lr=learning_rate,
                                    eps=1e-8,
                                    weight_decay=self.weight_decay)

        self.train_parameters = {
            'dropout': dropout,
            'weight_decay': weight_decay,
            'std_prior': std_prior,
            'train_freq': train_freq,
            'updates_per_train': updates_per_train,
            'batch_size': batch_size,
            'start_train_step': start_train_step,
            'learning_rate': learning_rate
        }

    def learn(self, n_steps):

        self.train_parameters['n_steps'] = n_steps

        if self.logging:
            self.logger = Logger(self.log_folder_details,
                                 self.train_parameters)

        cumulative_regret = 0

        for timestep in range(n_steps):

            self.dropout = 1000 * self.dropout / (self.n_samples + 1000)

            x = self.env.sample()

            action = self.act(x.float())

            reward = self.env.hit(action)
            regret = self.env.regret(action)

            cumulative_regret += regret

            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)

            if action == 1:
                self.n_samples += 1
                self.replay_buffer.add(x, reward)

            if self.logging:
                self.logger.add_scalar('Cumulative_Regret', cumulative_regret,
                                       timestep)
                self.logger.add_scalar('Mushrooms_Eaten', self.n_samples,
                                       timestep)

            if timestep % self.train_freq == 0 and self.n_samples > self.start_train_step:

                if self.verbose:
                    print('Timestep: {}, cumulative regret {}'.format(
                        str(timestep), str(cumulative_regret)))

                for update in range(self.updates_per_train):

                    samples = self.replay_buffer.sample(
                        np.min([self.n_samples, self.batch_size]))
                    self.train_step(samples)

        if self.logging:
            self.logger.save()

    def train_step(self, samples):

        states, rewards = samples

        # Calcul de la TD Target
        target = rewards

        # Calcul de la Q value en fonction de l'action jouée

        q_value = self.network(states.float(), self.dropout)

        loss_function = torch.nn.MSELoss()
        loss = loss_function(q_value.squeeze(), target.squeeze())

        # Update des poids
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def act(self, x):

        action = self.predict(x)

        return action

    @torch.no_grad()
    def predict(self, x):

        estimated_value = self.network(x, self.dropout)

        if estimated_value > 0:
            action = 1
        else:
            action = 0

        return action
    def learn(self, timesteps, verbose=False):

        self.train_parameters['train_steps'] = timesteps
        pprint.pprint(self.train_parameters)

        if self.logging:
            self.logger = Logger(self.log_folder_details,self.train_parameters)

        state = torch.as_tensor(self.env.reset())
        this_episode_time = 0
        score = 0
        t1 = time.time()

        for timestep in range(timesteps):

            is_training_ready = timestep >= self.replay_start_size

            action = self.act(state.to(self.device).float(), directed_exploration=True)

            state_next, reward, done, _ = self.env.step(action)

            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)
            state_next = torch.as_tensor(state_next)
            self.replay_buffer.add(state, action, reward, state_next, done)

            score += reward.item()
            this_episode_time += 1

            if done:
                if verbose:
                    print("Timestep : {}, score : {}, Time : {} s".format(timestep, score, round(time.time() - t1, 3)))
                if self.logging:
                    self.logger.add_scalar('Episode_score', score, timestep)
                    self.logger.add_scalar('Non_greedy_fraction', 1-self.n_greedy_actions/this_episode_time, timestep)
                state = torch.as_tensor(self.env.reset())
                score = 0
                if self.logging:
                    self.logger.add_scalar('Q_at_start', self.get_max_q(state.to(self.device).float()), timestep)
                t1 = time.time()
                self.n_greedy_actions = 0
                this_episode_time = 0
            else:
                state = state_next
                

            if is_training_ready:

                if timestep % self.update_frequency == 0:

                    transitions = self.replay_buffer.sample(self.minibatch_size, self.device)

                    loss = self.train_step(transitions)
                    if self.logging and timesteps < 1000000:
                        self.logger.add_scalar('Loss', loss, timestep)

                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(self.network.state_dict())

            if (timestep+1) % 250000 == 0:
                self.save(timestep=timestep+1)

        if self.logging:
            self.logger.save()
            self.save()
class EGreedyAgent():
    def __init__(self,
                 env,
                 network,
                 epsilon=0.05,
                 n_quantiles=20,
                 mean_prior=0,
                 std_prior=0.01,
                 logging=True,
                 train_freq=10,
                 updates_per_train=100,
                 batch_size=32,
                 start_train_step=10,
                 log_folder_details=None,
                 learning_rate=1e-3,
                 verbose=False):

        self.env = env
        self.network = network(env.n_features, n_quantiles, mean_prior,
                               std_prior)
        self.logging = logging
        self.replay_buffer = ReplayBuffer()
        self.batch_size = batch_size
        self.log_folder_details = log_folder_details
        self.epsilon = epsilon
        self.optimizer = optim.Adam(self.network.parameters(),
                                    lr=learning_rate,
                                    eps=1e-8)
        self.n_quantiles = n_quantiles
        self.train_freq = train_freq
        self.start_train_step = start_train_step
        self.updates_per_train = updates_per_train
        self.verbose = verbose

        self.n_samples = 0

        self.train_parameters = {
            'epsilon': epsilon,
            'n_quantiles': n_quantiles,
            'mean_prior': mean_prior,
            'std_prior': std_prior,
            'train_freq': train_freq,
            'updates_per_train': updates_per_train,
            'batch_size': batch_size,
            'start_train_step': start_train_step,
            'learning_rate': learning_rate
        }

    def learn(self, n_steps):

        self.train_parameters['n_steps'] = n_steps

        if self.logging:
            self.logger = Logger(self.log_folder_details,
                                 self.train_parameters)

        cumulative_regret = 0

        for timestep in range(n_steps):

            x = self.env.sample()

            action = self.act(x.float())

            reward = self.env.hit(action)
            regret = self.env.regret(action)

            cumulative_regret += regret

            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)

            if action == 1:
                self.n_samples += 1
                self.replay_buffer.add(x, reward)

            if self.logging:
                self.logger.add_scalar('Cumulative_Regret', cumulative_regret,
                                       timestep)
                self.logger.add_scalar('Mushrooms_Eaten', self.n_samples,
                                       timestep)

            if timestep % self.train_freq == 0 and self.n_samples > self.start_train_step:

                if self.verbose:
                    print('Timestep: {}, cumulative regret {}'.format(
                        str(timestep), str(cumulative_regret)))

                for update in range(self.updates_per_train):

                    samples = self.replay_buffer.sample(
                        np.min([self.n_samples, self.batch_size]))
                    self.train_step(samples)

        if self.logging:
            self.logger.save()

    def train_step(self, samples):

        states, rewards = samples

        target = rewards.repeat(1, self.n_quantiles)

        q_value = self.network(states.float()).view(
            np.min([self.n_samples, self.batch_size]), self.n_quantiles)

        loss = quantile_huber_loss(q_value.squeeze(), target.squeeze())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def act(self, x):

        if np.random.uniform() >= self.epsilon:
            action = self.predict(x)
        else:
            action = np.random.randint(0, 2)
        return action

    @torch.no_grad()
    def predict(self, x):

        estimated_value = torch.mean(self.network(x))

        if estimated_value > 0:
            action = 1
        else:
            action = 0

        return action
class EQRDQN():
    def __init__(self,
                 env,
                 network,
                 n_quantiles=50,
                 kappa=1,
                 replay_start_size=50000,
                 replay_buffer_size=1000000,
                 gamma=0.99,
                 update_target_frequency=10000,
                 minibatch_size=32,
                 learning_rate=1e-4,
                 update_frequency=1,
                 prior=0.01,
                 adam_epsilon=1e-8,
                 logging=False,
                 log_folder_details=None,
                 seed=None,
                 notes=None):

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.replay_start_size = replay_start_size
        self.replay_buffer_size = replay_buffer_size
        self.gamma = gamma
        self.update_target_frequency = update_target_frequency
        self.minibatch_size = minibatch_size
        self.learning_rate = learning_rate
        self.update_frequency = update_frequency
        self.adam_epsilon = adam_epsilon
        self.logging = logging
        self.logger = []
        self.timestep = 0
        self.log_folder_details = log_folder_details

        self.env = env
        self.replay_buffer = ReplayBuffer(self.replay_buffer_size)
        self.seed = random.randint(0, 1e6) if seed is None else seed
        self.logger = None

        set_global_seed(self.seed, self.env)

        self.n_quantiles = n_quantiles

        self.network = network(self.env.observation_space,
                               self.env.action_space.n * self.n_quantiles,
                               self.env.action_space.n * self.n_quantiles).to(
                                   self.device)
        self.target_network = network(
            self.env.observation_space,
            self.env.action_space.n * self.n_quantiles,
            self.env.action_space.n * self.n_quantiles).to(self.device)
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = optim.Adam(self.network.parameters(),
                                    lr=self.learning_rate,
                                    eps=self.adam_epsilon)

        self.anchor1 = [
            p.data.clone() for p in list(self.network.output_1.parameters())
        ]
        self.anchor2 = [
            p.data.clone() for p in list(self.network.output_2.parameters())
        ]

        self.loss = quantile_huber_loss
        self.kappa = kappa
        self.prior = prior

        self.train_parameters = {
            'Notes': notes,
            'env': env.unwrapped.spec.id,
            'network': str(self.network),
            'replay_start_size': replay_start_size,
            'replay_buffer_size': replay_buffer_size,
            'gamma': gamma,
            'update_target_frequency': update_target_frequency,
            'minibatch_size': minibatch_size,
            'learning_rate': learning_rate,
            'update_frequency': update_frequency,
            'kappa': kappa,
            'n_quantiles': n_quantiles,
            'weight_scale': self.network.weight_scale,
            'prior': prior,
            'adam_epsilon': adam_epsilon,
            'seed': self.seed
        }

        self.n_greedy_actions = 0

    def learn(self, timesteps, verbose=False):

        self.train_parameters['train_steps'] = timesteps
        pprint.pprint(self.train_parameters)

        if self.logging:
            self.logger = Logger(self.log_folder_details,
                                 self.train_parameters)

        # Initialize the state
        state = torch.as_tensor(self.env.reset())
        this_episode_time = 0
        score = 0
        t1 = time.time()

        for timestep in range(timesteps):

            is_training_ready = timestep >= self.replay_start_size

            # Select action
            action = self.act(state.to(self.device).float(),
                              thompson_sampling=True)

            # Perform action in environments
            state_next, reward, done, _ = self.env.step(action)

            # Store transition in replay buffer
            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)
            done = torch.as_tensor([done], dtype=torch.float)
            state_next = torch.as_tensor(state_next)
            self.replay_buffer.add(state, action, reward, state_next, done)

            score += reward.item()
            this_episode_time += 1

            if done:

                if verbose:
                    print("Timestep : {}, score : {}, Time : {} s".format(
                        timestep, score, round(time.time() - t1, 3)))
                if self.logging:
                    self.logger.add_scalar('Episode_score', score, timestep)
                    self.logger.add_scalar(
                        'Non_greedy_fraction',
                        1 - self.n_greedy_actions / this_episode_time,
                        timestep)
                state = torch.as_tensor(self.env.reset())
                score = 0
                if self.logging:
                    self.logger.add_scalar(
                        'Q_at_start',
                        self.get_max_q(state.to(self.device).float()),
                        timestep)
                t1 = time.time()
                self.n_greedy_actions = 0
                this_episode_time = 0
            else:
                state = state_next

            if is_training_ready:

                # Update main network
                if timestep % self.update_frequency == 0:

                    # Sample a batch of transitions
                    transitions = self.replay_buffer.sample(
                        self.minibatch_size, self.device)

                    # Train on selected batch
                    loss = self.train_step(transitions)
                    if self.logging and timesteps < 1000000:
                        self.logger.add_scalar('Loss', loss, timestep)

                # Update target Q
                if timestep % self.update_target_frequency == 0:
                    self.target_network.load_state_dict(
                        self.network.state_dict())

            if (timestep + 1) % 250000 == 0:
                self.save(timestep=timestep + 1)

            self.timestep = timestep

        if self.logging:
            self.logger.save()
            self.save()

    def train_step(self, transitions):
        states, actions, rewards, states_next, dones = transitions

        with torch.no_grad():

            target1, target2 = self.target_network(states_next.float())
            target1 = target1.view(self.minibatch_size,
                                   self.env.action_space.n, self.n_quantiles)
            target2 = target2.view(self.minibatch_size,
                                   self.env.action_space.n, self.n_quantiles)

        best_action_idx = torch.mean((target1 + target2) / 2,
                                     dim=2).max(1, True)[1].unsqueeze(2)
        q_value_target = 0.5*target1.gather(1, best_action_idx.repeat(1,1,self.n_quantiles))\
            + 0.5*target2.gather(1, best_action_idx.repeat(1,1,self.n_quantiles))

        # Calculate TD target
        td_target = rewards.unsqueeze(2).repeat(1,1,self.n_quantiles) \
            + (1 - dones.unsqueeze(2).repeat(1,1,self.n_quantiles)) * self.gamma * q_value_target

        out1, out2 = self.network(states.float())
        out1 = out1.view(self.minibatch_size, self.env.action_space.n,
                         self.n_quantiles)
        out2 = out2.view(self.minibatch_size, self.env.action_space.n,
                         self.n_quantiles)

        q_value1 = out1.gather(
            1,
            actions.unsqueeze(2).repeat(1, 1, self.n_quantiles))
        q_value2 = out2.gather(
            1,
            actions.unsqueeze(2).repeat(1, 1, self.n_quantiles))

        loss1 = self.loss(q_value1.squeeze(),
                          td_target.squeeze(),
                          self.device,
                          kappa=self.kappa)
        loss2 = self.loss(q_value2.squeeze(),
                          td_target.squeeze(),
                          self.device,
                          kappa=self.kappa)

        quantile_loss = loss1 + loss2

        diff1 = []
        for i, p in enumerate(self.network.output_1.parameters()):
            diff1.append(torch.sum((p - self.anchor1[i])**2))

        diff2 = []
        for i, p in enumerate(self.network.output_2.parameters()):
            diff2.append(torch.sum((p - self.anchor2[i])**2))

        diff1 = torch.stack(diff1).sum()
        diff2 = torch.stack(diff2).sum()

        anchor_loss = self.prior * (diff1 + diff2)

        loss = quantile_loss + anchor_loss

        # Update weights
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def act(self, state, thompson_sampling=False):

        action = self.predict(state, thompson_sampling=thompson_sampling)

        return action

    @torch.no_grad()
    def predict(self, state, thompson_sampling=False):
        if not thompson_sampling:
            net1, net2 = self.network(state)
            net1 = net1.view(self.env.action_space.n, self.n_quantiles)
            net2 = net2.view(self.env.action_space.n, self.n_quantiles)
            action_means = torch.mean((net1 + net2) / 2, dim=1)
            action = action_means.argmax().item()
        else:
            net1, net2 = self.network(state)
            net1_target, net2_target = self.target_network(state)
            net1 = net1.view(self.env.action_space.n, self.n_quantiles)
            net2 = net2.view(self.env.action_space.n, self.n_quantiles)
            net1_target = net1_target.view(self.env.action_space.n,
                                           self.n_quantiles)
            net2_target = net2_target.view(self.env.action_space.n,
                                           self.n_quantiles)
            action_means = torch.mean((net1 + net2) / 2, dim=1)
            action_uncertainties = torch.mean(
                (net1_target - net2_target)**2, dim=1) / 2
            samples = torch.distributions.multivariate_normal.MultivariateNormal(
                action_means,
                covariance_matrix=torch.diagflat(
                    action_uncertainties)).sample()
            action = samples.argmax().item()
            if action == action_means.argmax().item():
                self.n_greedy_actions += 1

        return action

    @torch.no_grad()
    def get_max_q(self, state):
        net1, net2 = self.network(state)
        net1 = net1.view(self.env.action_space.n, self.n_quantiles)
        net2 = net2.view(self.env.action_space.n, self.n_quantiles)
        action_means = torch.mean((net1 + net2) / 2, dim=1)
        max_q = action_means.max().item()
        return max_q

    def save(self, timestep=None):
        if not self.logging:
            raise NotImplementedError('Cannot save without log folder.')

        if timestep is not None:
            filename = 'network_' + str(timestep) + '.pth'
        else:
            filename = 'network.pth'

        save_path = self.logger.log_folder + '/' + filename

        torch.save(self.network.state_dict(), save_path)

    def load(self, path):
        self.network.load_state_dict(torch.load(path, map_location='cpu'))
        self.target_network.load_state_dict(
            torch.load(path, map_location='cpu'))
class BBBAgent():
    def __init__(self,
                 env,
                 network,
                 mean_prior=0,
                 std_prior=0.01,
                 noise_scale=0.01,
                 logging=True,
                 train_freq=10,
                 updates_per_train=100,
                 batch_size=32,
                 start_train_step=10,
                 log_folder_details=None,
                 learning_rate=1e-3,
                 bayesian_sample_size=20,
                 verbose=False):

        self.env = env
        self.network = BayesianNetwork(env.n_features, torch.device('cpu'),
                                       std_prior, noise_scale)
        self.logging = logging
        self.replay_buffer = ReplayBuffer()
        self.batch_size = batch_size
        self.log_folder_details = log_folder_details
        self.optimizer = optim.Adam(self.network.parameters(),
                                    lr=learning_rate,
                                    eps=1e-8)
        self.train_freq = train_freq
        self.start_train_step = start_train_step
        self.updates_per_train = updates_per_train
        self.bayesian_sample_size = bayesian_sample_size
        self.verbose = verbose

        self.n_samples = 0
        self.timestep = 0

        self.train_parameters = {
            'mean_prior': mean_prior,
            'std_prior': std_prior,
            'noise_scale': noise_scale,
            'train_freq': train_freq,
            'updates_per_train': updates_per_train,
            'batch_size': batch_size,
            'start_train_step': start_train_step,
            'learning_rate': learning_rate,
            'bayesian_sample_size': bayesian_sample_size
        }

    def learn(self, n_steps):

        self.train_parameters['n_steps'] = n_steps

        if self.logging:
            self.logger = Logger(self.log_folder_details,
                                 self.train_parameters)

        cumulative_regret = 0

        for timestep in range(n_steps):

            x = self.env.sample()

            action, sampled_value = self.act(x.float())

            reward = self.env.hit(action)
            regret = self.env.regret(action)

            cumulative_regret += regret

            action = torch.as_tensor([action], dtype=torch.long)
            reward = torch.as_tensor([reward], dtype=torch.float)

            if action == 1:
                self.n_samples += 1
                self.replay_buffer.add(x, reward)

            if self.logging:
                self.logger.add_scalar('Cumulative_Regret', cumulative_regret,
                                       timestep)
                self.logger.add_scalar('Mushrooms_Eaten', self.n_samples,
                                       timestep)
                if self.env.y_sample == 1:
                    self.logger.add_scalar('Sampled_Value_Good',
                                           sampled_value.item(), self.timestep)
                else:
                    self.logger.add_scalar('Sampled_Value_Bad',
                                           sampled_value.item(), self.timestep)

            if timestep % self.train_freq == 0 and self.n_samples > self.start_train_step:

                if self.verbose:
                    print('Timestep: {}, cumulative regret {}'.format(
                        str(timestep), str(cumulative_regret)))

                for update in range(self.updates_per_train):

                    samples = self.replay_buffer.sample(
                        np.min([self.n_samples, self.batch_size]))
                    self.train_step(samples)

            self.timestep += 1

        if self.logging:
            self.logger.save()

    def train_step(self, samples):

        states, rewards = samples

        loss, _, _, _ = self.network.sample_elbo(
            states.float(), rewards, self.n_samples,
            np.min([self.n_samples, self.batch_size]),
            self.bayesian_sample_size)

        if self.logging:
            self.logger.add_scalar('Loss', loss.detach().item(), self.timestep)

        # Update weights
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def act(self, x):

        action, sampled_value = self.predict(x)

        return action, sampled_value

    @torch.no_grad()
    def predict(self, x):

        sampled_value = self.network.forward(x)

        if sampled_value > 0:
            action = 1
        else:
            action = 0

        return action, sampled_value
class ThompsonAgent():
    def __init__(self,
                 env,
                 network,
                 n_quantiles=20,
                 mean_prior=0,
                 std_prior=0.01,
                 noise_scale=0.01,
                 logging=True,
                 train_freq=10,
                 updates_per_train=100,
                 batch_size=32,
                 start_train_step=10,
                 log_folder_details=None,
                 learning_rate=1e-3,
                 verbose=False):

        self.env = env
        self.network1 = network(env.n_features, n_quantiles, mean_prior,
                                std_prior)
        self.network2 = network(env.n_features, n_quantiles, mean_prior,
                                std_prior)
        self.optimizer = optim.Adam(list(self.network1.parameters()) +
                                    list(self.network2.parameters()),
                                    lr=learning_rate,
                                    eps=1e-8)

        self.logging = logging
        self.replay_buffer = ReplayBuffer()
        self.batch_size = batch_size
        self.log_folder_details = log_folder_details
        self.n_quantiles = n_quantiles
        self.train_freq = train_freq
        self.start_train_step = start_train_step
        self.updates_per_train = updates_per_train
        self.n_samples = 0
        self.noise_scale = noise_scale
        self.std_prior = std_prior
        self.verbose = verbose

        self.prior1 = [
            p.data.clone() for p in list(self.network1.features.parameters())
        ]
        self.prior2 = [
            p.data.clone() for p in list(self.network2.features.parameters())
        ]

        self.train_parameters = {
            'n_quantiles': n_quantiles,
            'mean_prior': mean_prior,
            'std_prior': std_prior,
            'train_freq': train_freq,
            'updates_per_train': updates_per_train,
            'batch_size': batch_size,
            'start_train_step': start_train_step,
            'learning_rate': learning_rate,
            'noise_scale': noise_scale
        }

    def learn(self, n_steps):

        self.train_parameters['n_steps'] = n_steps

        if self.logging:
            self.logger = Logger(self.log_folder_details,
                                 self.train_parameters)

        cumulative_regret = 0

        for timestep in range(n_steps):

            x = self.env.sample()

            action, uncertainty, sampled_value = self.act(x.float())

            reward = self.env.hit(action)
            regret = self.env.regret(action)

            cumulative_regret += regret

            reward = torch.as_tensor([reward], dtype=torch.float)

            if self.logging:
                self.logger.add_scalar('Cumulative_Regret', cumulative_regret,
                                       timestep)
                self.logger.add_scalar('Mushrooms_Eaten', self.n_samples,
                                       timestep)
                if self.env.y_sample == 1:
                    self.logger.add_scalar('Uncertainty_Good', uncertainty,
                                           timestep)
                    self.logger.add_scalar('Sampled_Value_Good', sampled_value,
                                           timestep)
                else:
                    self.logger.add_scalar('Uncertainty_Bad', uncertainty,
                                           timestep)
                    self.logger.add_scalar('Sampled_Value_Bad', sampled_value,
                                           timestep)

            if action == 1:
                self.replay_buffer.add(x, reward)
                self.n_samples += 1

            if timestep % self.train_freq == 0 and self.n_samples > self.start_train_step:

                if self.verbose:
                    print('Timestep: {}, cumulative regret {}'.format(
                        str(timestep), str(cumulative_regret)))

                for update in range(self.updates_per_train):

                    samples = self.replay_buffer.sample(
                        np.min([self.n_samples, self.batch_size]))
                    self.train_step(samples)

        if self.logging:
            self.logger.save()

    def train_step(self, samples):

        states, rewards = samples

        target = rewards.repeat(1, self.n_quantiles)

        q_value1 = self.network1(states.float()).view(
            np.min([self.n_samples, self.batch_size]), self.n_quantiles)
        q_value2 = self.network2(states.float()).view(
            np.min([self.n_samples, self.batch_size]), self.n_quantiles)

        loss1 = quantile_huber_loss(q_value1.squeeze(), target.squeeze())
        loss2 = quantile_huber_loss(q_value2.squeeze(), target.squeeze())

        reg = []
        for i, p in enumerate(self.network1.features.parameters()):
            diff = (p - self.prior1[i])
            reg.append(torch.sum(diff**2))
        loss_anchored1 = torch.sum(torch.stack(reg))

        reg = []
        for i, p in enumerate(self.network2.features.parameters()):
            diff = (p - self.prior2[i])
            reg.append(torch.sum(diff**2))
        loss_anchored2 = torch.sum(torch.stack(reg))

        loss = loss1 + loss2 + self.noise_scale * (
            loss_anchored1 + loss_anchored2) / (self.std_prior**2 *
                                                self.n_samples)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def act(self, x):

        action, uncertainty, sampled_value = self.predict(x)

        return action, uncertainty, sampled_value

    @torch.no_grad()
    def predict(self, x):

        net1 = self.network1(x)
        net2 = self.network2(x)

        action_mean = torch.mean((net1 + net2) / 2)
        action_uncertainty = torch.sqrt(torch.mean((net1 - net2)**2) / 2)
        sampled_value = torch.distributions.Normal(
            action_mean, action_uncertainty).sample()

        if sampled_value > 0:
            action = 1
        else:
            action = 0

        return action, action_uncertainty.item(), sampled_value