Example #1
0
class TestPrioritizedReplayBuffer(unittest.TestCase):
    def setUp(self):
        self.buffer = PrioritizedReplayBuffer(prioritized_memory=self.memory,
                                              buffer_size=4,
                                              seed=0,
                                              device="cpu")
        self.memory = self.buffer.memory

    def test_add(self):
        self.buffer.add(state=[1, 2, 3],
                        action=0,
                        reward=0,
                        next_state=[4, 5, 6],
                        done=False)
        priority_one = self.buffer._calculate_priority(1)
        np.testing.assert_array_equal(
            [priority_one, priority_one, 0, priority_one, 0, 0, 0],
            self.memory.tree)

    def test_sample(self):
        self.buffer.add(state=[1, 2, 3],
                        action=0,
                        reward=0,
                        next_state=[4, 5, 6],
                        done=False)
        self.buffer.add(state=[4, 5, 6],
                        action=0,
                        reward=0,
                        next_state=[7, 8, 9],
                        done=False)
        self.buffer.add(state=[7, 8, 9],
                        action=0,
                        reward=0,
                        next_state=[10, 11, 12],
                        done=False)

        sample = self.buffer.sample(2)
        print(sample)
class Agent():
    """Interacts with and learns from the environment."""

    def __init__(self, state_size, action_size, seed, QNetwork):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
            QNetwork: a class inheriting from torch.nn.Module that define the structure of the neural network
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)
        #when using a dropout the qnetwork_target should be put in eval mode
        self.qnetwork_target.eval()
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Replay memory
        self.memory = PrioritizedReplayBuffer(seed, device, action_size, BUFFER_SIZE, BATCH_SIZE, DEFAULT_PRIORITY, PRIORITY_FACTOR)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.u_step = 0
        self.t_step = 0
        self.up_step = 0
        # To control the importance sampling weight. As the network is converging, b should move toward 1
        self.b = torch.tensor(1., device=device, requires_grad=False)
        self.b_decay = torch.tensor(0.00015, device=device, requires_grad=False)
    
    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy.
        
        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy()).astype(np.int_)
        else:
            return random.choice(np.arange(self.action_size)).astype(np.int_)
    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)
        
        self.t_step = (self.t_step + 1) % TRANSFER_EVERY
        self.u_step = (self.u_step + 1) % UPDATE_EVERY
        self.up_step = (self.up_step + 1) % UPDATE_PRIORITY_EVERY
        
        # Learn from experiences
        if len(self.memory) > BATCH_SIZE and self.u_step == 0:
            # sample the experiences from the memory based on their priority
            experiences = self.memory.sample()
            self.learn(experiences)
        # Transfer the knowledge from the local network to the fixed on
        if len(self.memory) > BATCH_SIZE and self.t_step == 0:
            self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
        # Update the priorities in the memory to alter the sampling
        # Ideally, this should be done before the sampling is taking place
        # But, for sake of performance, it might be better to recalculate them less often
        if len(self.memory) > 1 and self.up_step == 0:
            for experiences in self.memory.get_all_experiences(512):
                with torch.no_grad():
                    self.qnetwork_local.eval()
                    current_estimate, from_env = self.get_target_estimate(experiences)
                    # update the priorities based on newly learned errors
                    self.memory.update(experiences[-1], (from_env - current_estimate).squeeze())
            
    def get_target_estimate(self, experiences):
            states, actions, rewards, next_states, dones, probabilities, selected = experiences
            with torch.no_grad():
                best_actions = self.qnetwork_local(next_states).detach().max(1)[1].unsqueeze(1)
                evaluations = self.qnetwork_target(next_states).gather(1,best_actions)
                from_env = rewards + GAMMA*evaluations*(1 - dones)
            return self.qnetwork_local(states).gather(1, actions), from_env
        
    def learn(self, experiences):
        self.qnetwork_local.train()
        current_estimate,from_env = self.get_target_estimate(experiences)
        probabilities = experiences[-2]
        errors = (from_env - current_estimate)
        # Since the experiences were retrieved based on a given probabilities, such experience will biase the network
        # Therefore, we introduce here an importance sampling weight
        loss = (errors * errors / (len(self.memory) * probabilities) * self.b).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.b = min(self.b + self.b_decay,1)

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        θ_target = θ_target + τ*(θ_local - θ_target)

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
Example #3
0
class D3QNetAgent(AgentWithConverter):
    def __init__(self,
                 observation_space,
                 action_space,
                 num_frames=4,
                 batch_size=32,
                 learning_rate=1e-5,
                 learning_rate_decay_steps=10000,
                 learning_rate_decay_rate=0.95,
                 discount_factor=0.95,
                 tau=1e-2,
                 lam=1,
                 per_size=50000,
                 per_alpha=0.6,
                 per_beta=0.4,
                 per_anneal_rate=1.5e6,
                 epsilon=0.99,
                 decay_epsilon=1024 * 32,
                 final_epsilon=0.0001):

        # initializes AgentWithConverter class to handle action conversiions
        AgentWithConverter.__init__(self,
                                    action_space,
                                    action_space_converter=IdToAct)

        self.obs_space = observation_space
        self.act_space = action_space
        self.num_frames = num_frames

        self.batch_size = batch_size
        self.lr = learning_rate
        self.lr_decay_steps = learning_rate_decay_steps
        self.lr_decay_rate = learning_rate_decay_rate
        self.gamma = discount_factor
        self.lam = lam
        self.tau = tau
        # epsilon is the degree of exploraation
        self.initial_epsilon = epsilon
        # Adaptive epsilon decay constants
        self.decay_epsilon = decay_epsilon
        self.final_epsilon = final_epsilon

        #PER data
        self.buff_size = per_size
        self.alpha = per_alpha
        self.beta = per_beta
        self.anneal = per_anneal_rate

        self.observation_size = self.obs_space.size_obs()
        self.action_size = self.act_space.size()

        self.d3qn = D3QNet(self.action_size, self.observation_size,
                           self.num_frames, self.lr, self.lr_decay_steps,
                           self.lr_decay_rate, self.batch_size, self.gamma,
                           self.tau)

        #State variables
        self.obs = None
        self.done = None
        self.epsilon = self.initial_epsilon

        self.state = []
        self.frames = []
        self.next_frames = []
        self.replay_buffer = PrioritizedReplayBuffer(self.buff_size,
                                                     self.alpha, self.beta,
                                                     self.anneal)

        return

    ## Helper Functions

    # Adds to current frame buffer, enforfes length to num_frames
    def update_curr_frame(self, obs):
        self.frames.append(obs.copy())
        if (len(self.frames) > self.num_frames):
            self.frames.pop(0)
        return

    # Adds next frame to next frame buffer, enforces length to num_frames
    def update_next_frame(self, next_obs):
        self.next_frames.append(next_obs.copy())
        if (len(self.next_frames) > self.num_frames):
            self.next_frames.pop(0)
        return

    # Adaptive epsilon decay determines next epsilon based off number of steps
    # completed and curren epsilon
    def set_next_epsilon(self, current_step):
        ada_div = self.decay_epsilon / 10.0
        step_off = current_step + ada_div
        ada_eps = self.initial_epsilon * -math.log10(
            (step_off + 1) / (self.decay_epsilon + ada_div))
        ada_eps_up_clip = min(self.initial_epsilon, ada_eps)
        ada_eps_low_clip = max(self.final_epsilon, ada_eps_up_clip)
        self.epsilon = ada_eps_low_clip
        return

    ## Agent Interface

    #Adapted from l2rpn-baselines from RTE-France
    # Vectorizes observations from grid2op environment for neural network uses
    def convert_obs(self, obs):
        li_vect = []
        for el in obs.attr_list_vect:
            v = obs._get_array_from_attr_name(el).astype(np.float32)
            v_fix = np.nan_to_num(v)
            v_norm = np.linalg.norm(v_fix)
            if v_norm > 1e6:
                v_res = (v_fix / v_norm) * 10.0
            else:
                v_res = v_fix
            li_vect.append(v_res)
        return np.concatenate(li_vect)

    # converts encoded action number to action used to interact with grid2op
    # environment
    def convert_act(self, encoded_act):
        return super().convert_act(encoded_act)

    # Required for agent evaluation
    # Returns random action or best action as estimated by Q network based on
    # exploration parameter (espilon)
    def my_act(self, state, reward, done):
        if (len(self.frames) < self.num_frames): return 0  #do nothing
        random_act = random.randint(0, self.action_size)
        self.update_curr_frame(state)
        qnet_act, _ = self.dqn.model_action(np.array(self.frames))
        if (np.random.rand(1) < self.epsilon):
            return random_act
        else:
            return qnet_act

    ## Training Loop
    def learn(self,
              env,
              num_epochs,
              num_steps,
              soft_update_freq=250,
              hard_update_freq=1000):

        #pre-training to fill buffer

        print("Starting Pretraining...\n")
        self.done = True
        # Plays random moves and saves the resulting (s, a, r, s', d) pair to
        # replay buffer. Resets environment when done and continues
        while (len(self.replay_buffer) < self.buff_size):
            if (self.done):
                # reset environment and state parameters
                new_env = env.reset()
                self.frames = []
                self.next_frames = []
                self.done = False
                self.obs = new_env
                self.state = self.convert_obs(self.obs)

            self.update_curr_frame(self.state)

            # action is random
            encoded_act = np.random.randint(0, self.action_size)
            act = self.convert_act(encoded_act)
            new_obs, reward, self.done, info = env.step(act)

            gplay_reward = info['rewards']['gameplay']
            adj_reward = self.lam * reward + (1 - self.lam) * gplay_reward

            new_state = self.convert_obs(new_obs)
            self.update_next_frame(new_state)

            # only add to buffer if num_frames states are seen
            if (len(self.frames) == self.num_frames
                    and len(self.next_frames) == self.num_frames):

                agg_state = np.array(self.frames)
                agg_next_state = np.array(self.next_frames)
                #(s,a,r,s',d) pair
                self.replay_buffer.add(agg_state, encoded_act, adj_reward,
                                       agg_next_state, self.done)

            self.obs = new_obs
            self.state = new_state

        epoch = 0  # number of complete runs through environment
        total_steps = 0  # total number of training steps across all epochs

        losses = []  # losses[i] is loss from dqn at total_step i
        avg_losses = [
        ]  # avg_losses[i] is avg loss during training during epcoh i
        net_reward = []  #net_reward[i] is total reward during epoch i
        alive = []  # alive[i] is number of steps survived for at epoch i

        print("Starting training...\n")

        # Trains a minimum of num_steps or num_epochs
        while (total_steps < num_steps or epoch < num_epochs):

            total_reward = 0
            curr_steps = 0
            total_loss = []

            # Reset state parameters
            self.frames = []
            self.next_frames = []
            self.done = False
            self.obs = env.reset()
            self.state = self.convert_obs(self.obs)

            # continues until failure
            while (not self.done):

                self.update_curr_frame(self.state)

                # Determine action
                if (len(self.frames) < self.num_frames):
                    enc_act = 0  # do nothing
                elif (np.random.rand(1) < self.epsilon):
                    enc_act = np.random.randint(0, self.action_size)
                else:
                    input = np.array(self.frames)
                    enc_act, _ = self.d3qn.model_action(input)

                # converts action and steps in environment
                act = self.convert_act(enc_act)
                new_obs, reward, self.done, info = env.step(act)

                gplay_reward = info['rewards']['gameplay']
                adj_reward = self.lam * reward + (1 - self.lam) * gplay_reward

                new_state = self.convert_obs(new_obs)
                # updates next_state frame
                self.update_next_frame(new_state)

                if (len(self.frames) == self.num_frames
                        and len(self.next_frames) == self.num_frames):

                    agg_state = np.array(self.frames)
                    agg_next_state = np.array(self.next_frames)
                    # Adds (s,a,r,s',d) tuple to replay buffer
                    self.replay_buffer.add(agg_state, encoded_act, adj_reward,
                                           agg_next_state, self.done)
                # finds the next epsilon
                self.set_next_epsilon(total_steps)

                # samples a batch_size number of experience samples from replay
                # buffer
                (s_batch, a_batch, r_batch, s_next_batch, d_batch, w_batch,
                 ind_batch) = (self.replay_buffer.sample(
                     self.batch_size, total_steps))

                # updates network estimates based on replay
                loss = self.d3qn.train_on_minibatch(s_batch, a_batch, r_batch,
                                                    s_next_batch, d_batch,
                                                    w_batch)

                priorities = self.d3qn.prio
                self.replay_buffer.update_priorities(ind_batch, priorities)

                # periodically hard updates the network
                if (total_steps % hard_update_freq):
                    self.d3qn.hard_update_target_network()

                # periodically soft update the network
                elif (total_steps % soft_update_freq):
                    self.d3qn.soft_update_target_network()

                # update state variables
                self.obs = new_obs
                self.state = new_state

                # increase steps, updates metrics
                curr_steps += 1
                total_steps += 1
                total_reward += reward
                losses.append(loss)
                total_loss.append(loss)

            # updates metrics throughout epoch
            alive.append(curr_steps)
            net_reward.append(total_reward)
            avg_losses.append(np.average(np.array(total_loss)))

            epoch += 1
            # sanity check to ensure it's working
            if (epoch % 100 == 0):
                print("Completed epoch {}".format(epoch))
                print("Total steps: {}".format(total_steps))

        return (epoch, total_steps, losses, avg_losses, net_reward, alive)
Example #4
0
class Agent():
    """An agent that interacts with and learns from its environment. As its baseline it
    uses Deep Q-Learning (DQN). The following can optionally be enabled in all possible
    combinations:
        * Double DQN (DDQN)
        * Prioritized Experience Replay
        * Dueling Network Architecture
    """
    def __init__(self,
                 state_size,
                 action_size,
                 buffer_size=int(1e5),
                 batch_size=64,
                 gamma=0.99,
                 tau=0.025,
                 learning_rate=5e-4,
                 update_every=4,
                 enable_double_dqn=False,
                 enable_prioritized_experience_replay=False,
                 enable_dueling_network=False,
                 alpha=0.6):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): the dimension of each state
            action_size (int): the dimension of each action
            
            buffer_size (int): the replay buffer size
            batch_size (int): the minibatch size 
            gamma (float): the reward discount factor
            tau (float): for soft updates of the target parameters
            learning_rate (float): the learning rate
            update_every (int): controls how regularly the network learns
            
            enable_double_dqn (bool): enables Double DQN (DDQN)
            enable_prioritized__experience_replay (bool): enables Prioritized Experience Replay
            enable_dueling_network (bool): enables a Dueling Network architecture
            
            alpha (float): the priority dampening effect in Prioritized Experience Replay
        """
        self.state_size = state_size
        self.action_size = action_size

        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.learning_rate = learning_rate
        self.update_every = update_every

        self.enable_double_dqn = enable_double_dqn
        self.enable_prioritized_experience_replay = enable_prioritized_experience_replay
        self.enable_dueling_network = enable_dueling_network

        self.alpha = alpha

        # Instantiate the local and target networks.
        Network = QStandardNetwork if not self.enable_dueling_network else QDuelingNetwork
        self.qnetwork_local = Network(state_size, action_size).to(device)
        self.qnetwork_target = Network(state_size, action_size).to(device)
        # Starting off with the same random weights in self.qnetwork_local and self.qnetwork_target.
        self._perform_soft_update(self.qnetwork_local,
                                  self.qnetwork_target,
                                  tau=1)

        # Instantiate the experience memory.
        if not self.enable_prioritized_experience_replay:
            self.memory = ReplayBuffer(self.buffer_size)
        else:
            self.memory = PrioritizedReplayBuffer(self.buffer_size,
                                                  alpha=self.alpha)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(),
                                    lr=self.learning_rate)

        # Start counting the steps.
        self.t_step = 0

        # Clear the performance timers.
        self.reset_timers()

    def act(self, state, epsilon):
        """Returns an action for the given state and epison-greedy value as per the current policy.
        
        Params
        ======
            state (array_like): the current state
            epsilon (float): the epsilon-greedy value
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)

        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if random.random() > epsilon:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def step(self, state, action, reward, next_state, done, beta=None):
        """Updates the policy based on the state, action, reward and next_state.
        Also takes into account that the episode might be done. 
        
        For Prioritized Experience Replay, also receives the beta value that should 
        be used for the de-biasing factor.
        
        Params
        ======
            state (array_like): the current state
            action (int): the action taken
            reward (float): the reward received for taking the action in the state
            next_state (array_like): the resulting state
            done (bool): indicates whether the episode is done or not
            beta (float): For Prioritized Experience Replay, the beta value that
                should be used next for the de-biasing factor
        """

        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > self.batch_size:
                experiences = self.memory.sample(self.batch_size)
                self._learn(experiences, beta)

    def save(self, path):
        torch.save(self.qnetwork_local.state_dict(), path)

    def reset_timers(self):
        """Resets performance timers.        
        """
        self.time_1 = 0
        self.time_2 = 0
        self.time_3 = 0
        self.time_4 = 0

    def _learn(self, experiences, beta):
        """Updates Q by learning from a batch of experiences.

        Params
        ======
            experiences (tuple): A batch of sampled experiences.
        """

        predictions, targets, nodes = self._calculate_predictions_and_targets(
            experiences)

        if not self.enable_prioritized_experience_replay:
            self._learn_from_experience_replay(predictions, targets)
        else:
            self._learn_from_prioritized_experience_replay(
                predictions, targets, nodes, beta)

    def _calculate_predictions_and_targets(self, experiences):
        """From a batch of sampled experiences, calculates the predictions and targets.
        Also returns the nodes of the samples.
        
        Params
        ======
            experiences (tuple): a batch of sampled experiences
        
        Returns
        =======
            A tuple of predictions, targets and nodes.
        """

        in_states, in_actions, in_rewards, in_next_states, in_dones, nodes = experiences

        states = torch.from_numpy(in_states).float().to(device)
        actions = torch.from_numpy(in_actions).long().to(device)
        rewards = torch.from_numpy(in_rewards).float().to(device)
        next_states = torch.from_numpy(in_next_states).float().to(device)
        dones = torch.from_numpy(in_dones).float().to(device)

        predictions = self.qnetwork_local(states)[
            torch.range(0, states.shape[0] - 1, dtype=torch.long),
            torch.squeeze(actions)].to(device)

        with torch.no_grad():

            if not self.enable_double_dqn:
                inputs_for_targets = self.qnetwork_target(next_states).to(
                    device)
                targets = (torch.squeeze(rewards) +
                           (1.0 - torch.squeeze(dones)) * self.gamma *
                           inputs_for_targets.max(1)[0]).to(device)
            else:
                temp_1 = self.qnetwork_local(next_states).to(device)
                temp_2 = temp_1.max(1)[1].to(device)
                temp_3 = self.qnetwork_target(next_states)[
                    torch.range(0, next_states.shape[0] - 1, dtype=torch.long),
                    temp_2].to(device)
                targets = (torch.squeeze(rewards) +
                           (1.0 - torch.squeeze(dones)) * self.gamma *
                           temp_3).to(device)

        return (predictions, targets, nodes)

    def _learn_from_experience_replay(self, predictions, targets):
        """Updates Q by learning from (non-prioritized) Experience Replay.
        
        Params
        ======
            predictions (array_like): batch-size predictions
            targets (array_like): batch-size targets
        """

        assert not self.enable_prioritized_experience_replay

        td_errors = targets - predictions
        torch.Tensor.clamp_(td_errors, min=-1, max=1)

        loss = torch.mean(torch.pow(td_errors, 2))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update the target network:
        self._perform_soft_update(self.qnetwork_local,
                                  self.qnetwork_target,
                                  tau=self.tau)

    def _learn_from_prioritized_experience_replay(self, predictions, targets,
                                                  nodes, beta):
        """Updates Q by learning from Prioritized Experience Replay.
        
        Params
        ======
            predictions (array_like): batch-size predictions
            targets (array_like): batch-size targets
            nodes (array_like): the nodes associated with the predictions and targets
            beta (float): The beta value that should be used next for the de-biasing factor        
        """

        assert self.enable_prioritized_experience_replay

        # Calculate the gradient weights:
        time_1_start = time.process_time()
        root = nodes[0].get_root()
        total_weight = root.get_max_of_weights_overall(
        )  # 'alpha' has already been applied.
        sampled_weights = np.array([n.own_weight for n in nodes
                                    ])  # 'alpha' has already been applied.
        scaled_weights = sampled_weights / total_weight  # P
        gradient_weights = np.power(self.buffer_size * scaled_weights, -beta)
        gradient_weights = gradient_weights / np.max(gradient_weights)
        gradient_weights = torch.from_numpy(gradient_weights).float().to(
            device)
        self.time_1 += time.process_time(
        ) - time_1_start  # Measure the performance.

        # Calculate the TD errors and loss; update the local network weights:
        time_2_start = time.process_time()
        td_errors = targets - predictions
        torch.Tensor.clamp_(td_errors, min=-1,
                            max=1)  # Clip the TD errors for greater stability.
        loss = torch.mean(torch.pow(td_errors, 2) * gradient_weights)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # An alternative but less elegant and slower approach involves applying the
        # gradient weights to the gradient components instead of the loss components
        # as was done above:
        # weighted_gradients = {}
        # for k in range(self.batch_size):
        #   loss_single = self.loss_function(predictions[k], targets[k])  # Where for equivalent self.loss_function is MSE.
        #   self.qnetwork_local.zero_grad()
        #   loss_single.backward(retain_graph=True)
        #   with torch.no_grad():
        #       for name, param in self.qnetwork_local.named_parameters():
        #       if name not in weighted_gradients:
        #           weighted_gradients[name] = param.grad * gradient_weights[k]
        #       else:
        #           weighted_gradients[name] += param.grad * gradient_weights[k]
        # with torch.no_grad():
        #    for name, param in self.qnetwork_local.named_parameters():
        #        param.data -= self.learning_rate * weighted_gradients[name]
        self.time_2 += time.process_time(
        ) - time_2_start  # Measure the performance.

        # Update the target network:
        time_3_start = time.process_time()
        self._perform_soft_update(self.qnetwork_local,
                                  self.qnetwork_target,
                                  tau=self.tau)
        self.time_3 += time.process_time(
        ) - time_3_start  # Measure the performance.

        # Update the node weights:
        time_4_start = time.process_time()
        with torch.no_grad():
            for node, td_error in zip(nodes, td_errors.cpu().numpy()):
                weight = self.memory.calculate_weight_from_raw_weight(td_error)
                node.update_weight(weight)
        self.time_4 += time.process_time(
        ) - time_4_start  # Measure the performance.

    def _perform_soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
class  PrioritizedDQNAgent(Agent):
    """Interacts with and learns from the environment."""

    def __init__(self, movie_dict=None, act_set=None, slot_set=None, params=None, seed=1, compute_weights=False):
        self.movie_dict = movie_dict
        self.act_set = act_set
        self.slot_set = slot_set
        self.act_cardinality = len(act_set.keys())
        self.slot_cardinality = len(slot_set.keys())
        self.seed = seed
        self.compute_weights = compute_weights

        self.feasible_actions = dialog_config.feasible_actions
        self.num_actions = len(self.feasible_actions)

        self.movie_dict = movie_dict
        self.act_set = act_set
        self.slot_set = slot_set
        self.act_cardinality = len(act_set.keys())
        self.slot_cardinality = len(slot_set.keys())

        self.feasible_actions = dialog_config.feasible_actions
        self.num_actions = len(self.feasible_actions)

        self.epsilon = params['epsilon']
        self.agent_run_mode = params['agent_run_mode']
        self.agent_act_level = params['agent_act_level']
        self.experience_replay_pool = []  # experience replay pool <s_t, a_t, r_t, s_t+1>

        # Replay memory
        self.memory = PrioritizedReplayBuffer(
            self.num_actions, BUFFER_SIZE, BATCH_SIZE, EXPERIENCES_PER_SAMPLING, seed, compute_weights)
        # Initialize time step (for updating every UPDATE_NN_EVERY steps)
        self.t_step_nn = 0
        # Initialize time step (for updating every UPDATE_MEM_PAR_EVERY steps)
        self.t_step_mem_par = 0
        # Initialize time step (for updating every UPDATE_MEM_EVERY steps)
        self.t_step_mem = 0
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        self.experience_replay_pool_size = params.get(
            'experience_replay_pool_size', 1000)
        self.hidden_size = params.get('dqn_hidden_size', 60)
        self.gamma = params.get('gamma', 0.9)
        self.predict_mode = params.get('predict_mode', False)
        self.warm_start = params.get('warm_start', 0)

        self.max_turn = params['max_turn'] + 4
        self.state_dimension = 2 * self.act_cardinality + \
            7 * self.slot_cardinality + 3 + self.max_turn

        self.qnetwork_local = QNetwork(
            self.state_dimension, self.num_actions, seed).to(self.device)
        self.qnetwork_target = QNetwork(
            self.state_dimension, self.num_actions, seed).to(self.device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=5e-4)

        self.cur_bellman_err = 0

        # Prediction Mode: load trained DQN model
        if params['trained_model_path'] != None:
            self.dqn.model = copy.deepcopy(
                self.load_trained_DQN(params['trained_model_path']))
            self.clone_dqn = copy.deepcopy(self.dqn)
            self.predict_mode = True
            self.warm_start = 2


    def initialize_episode(self):
        """ Initialize a new episode. This function is called every time a new episode is run. """

        self.current_slot_id = 0
        self.phase = 0
        self.request_set = ['moviename', 'starttime',
                            'city', 'date', 'theater', 'numberofpeople']

    def state_to_action(self, state):
        """ DQN: Input state, output action """

        self.representation = self.prepare_state_representation(state)
        self.action = self.run_policy(self.representation)
        act_slot_response = copy.deepcopy(self.feasible_actions[self.action])
        return {'act_slot_response': act_slot_response, 'act_slot_value_response': None}


    def prepare_state_representation(self, state):
        """ Create the representation for each state """

        user_action = state['user_action']
        current_slots = state['current_slots']
        kb_results_dict = state['kb_results_dict']
        agent_last = state['agent_action']

        ########################################################################
        #   Create one-hot of acts to represent the current user action
        ########################################################################
        user_act_rep = np.zeros((1, self.act_cardinality))
        user_act_rep[0, self.act_set[user_action['diaact']]] = 1.0

        ########################################################################
        #     Create bag of inform slots representation to represent the current user action
        ########################################################################
        user_inform_slots_rep = np.zeros((1, self.slot_cardinality))
        for slot in user_action['inform_slots'].keys():
            user_inform_slots_rep[0, self.slot_set[slot]] = 1.0

        ########################################################################
        #   Create bag of request slots representation to represent the current user action
        ########################################################################
        user_request_slots_rep = np.zeros((1, self.slot_cardinality))
        for slot in user_action['request_slots'].keys():
            user_request_slots_rep[0, self.slot_set[slot]] = 1.0

        ########################################################################
        #   Creat bag of filled_in slots based on the current_slots
        ########################################################################
        current_slots_rep = np.zeros((1, self.slot_cardinality))
        for slot in current_slots['inform_slots']:
            current_slots_rep[0, self.slot_set[slot]] = 1.0

        ########################################################################
        #   Encode last agent act
        ########################################################################
        agent_act_rep = np.zeros((1, self.act_cardinality))
        if agent_last:
            agent_act_rep[0, self.act_set[agent_last['diaact']]] = 1.0

        ########################################################################
        #   Encode last agent inform slots
        ########################################################################
        agent_inform_slots_rep = np.zeros((1, self.slot_cardinality))
        if agent_last:
            for slot in agent_last['inform_slots'].keys():
                agent_inform_slots_rep[0, self.slot_set[slot]] = 1.0

        ########################################################################
        #   Encode last agent request slots
        ########################################################################
        agent_request_slots_rep = np.zeros((1, self.slot_cardinality))
        if agent_last:
            for slot in agent_last['request_slots'].keys():
                agent_request_slots_rep[0, self.slot_set[slot]] = 1.0

        turn_rep = np.zeros((1, 1)) + state['turn'] / 10.

        ########################################################################
        #  One-hot representation of the turn count?
        ########################################################################
        turn_onehot_rep = np.zeros((1, self.max_turn))
        turn_onehot_rep[0, state['turn']] = 1.0

        ########################################################################
        #   Representation of KB results (scaled counts)
        ########################################################################
        kb_count_rep = np.zeros((1, self.slot_cardinality + 1)) + \
            kb_results_dict['matching_all_constraints'] / 100.
        for slot in kb_results_dict:
            if slot in self.slot_set:
                kb_count_rep[0, self.slot_set[slot]
                             ] = kb_results_dict[slot] / 100.

        ########################################################################
        #   Representation of KB results (binary)
        ########################################################################
        kb_binary_rep = np.zeros((1, self.slot_cardinality + 1)) + \
            np.sum(kb_results_dict['matching_all_constraints'] > 0.)
        for slot in kb_results_dict:
            if slot in self.slot_set:
                kb_binary_rep[0, self.slot_set[slot]] = np.sum(
                    kb_results_dict[slot] > 0.)

        self.final_representation = np.hstack([user_act_rep, user_inform_slots_rep, user_request_slots_rep, agent_act_rep,
                                               agent_inform_slots_rep, agent_request_slots_rep, current_slots_rep, turn_rep, turn_onehot_rep, kb_binary_rep, kb_count_rep])
        return self.final_representation



    def run_policy(self, state):
        """ epsilon-greedy policy """

        if random.random() < self.epsilon:
            return random.randint(0, self.num_actions - 1)
        else:
            if self.warm_start == 1:
                if len(self.experience_replay_pool) > self.experience_replay_pool_size:
                    self.warm_start = 2
                return self.rule_policy()
            else:
                state = torch.from_numpy(
                    state).float().unsqueeze(0).to(self.device)
                self.qnetwork_local.eval()
                with torch.no_grad():
                    action_values = self.qnetwork_local(state)
                    self.qnetwork_local.train()
                return np.argmax(action_values.cpu().data.numpy())


    def rule_policy(self):
        """ Rule Policy """

        if self.current_slot_id < len(self.request_set):
            slot = self.request_set[self.current_slot_id]
            self.current_slot_id += 1

            act_slot_response = {}
            act_slot_response['diaact'] = "request"
            act_slot_response['inform_slots'] = {}
            act_slot_response['request_slots'] = {slot: "UNK"}
        elif self.phase == 0:
            act_slot_response = {'diaact': "inform", 'inform_slots': {
                'taskcomplete': "PLACEHOLDER"}, 'request_slots': {}}
            self.phase += 1
        elif self.phase == 1:
            act_slot_response = {'diaact': "thanks",
                                 'inform_slots': {}, 'request_slots': {}}

        return self.action_index(act_slot_response)


    def action_index(self, act_slot_response):
        """ Return the index of action """

        for (i, action) in enumerate(self.feasible_actions):
            if act_slot_response == action:
                return i
        print act_slot_response
        raise Exception("action index not found")
        return None


    def register_experience_replay_tuple(self, s_t, a_t, reward, s_tplus1, episode_over):
        """ Register feedback from the environment, to be stored as future training data """

        state = self.prepare_state_representation(s_t)
        action = self.action
        reward = reward
        next_state = self.prepare_state_representation(s_tplus1)
        done = episode_over

        if self.predict_mode == False:  # Training Mode
            if self.warm_start == 1:
                self.memory.add(state, action, reward, next_state, done)
        else:  # Prediction Mode
            self.memory.add(state, action, reward, next_state, done)



    def train(self, batch_size=16, num_batches=100, gamma = 0.99):
        """ Train DQN with experience replay """
        for iter_batch in range(num_batches):
            self.cur_bellman_err = 0
            self.memory.update_memory_sampling()
            self.memory.update_parameters()
            for iter in range(int(EXPERIENCES_PER_SAMPLING/batch_size)):
                experiences = self.memory.sample()
                self.learn(experiences, gamma)


    def learn(self, sampling, gamma):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            sampling (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones, weights, indices = sampling

        ## TODO: compute and minimize the loss
        q_target = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
        expected_values = rewards + gamma * q_target * (1 - dones)
        output = self.qnetwork_local(states).gather(1, actions)
        loss = F.mse_loss(output, expected_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)

        # ------------------- update priorities ------------------- #
        delta = abs(expected_values - output.detach()).numpy()
        self.memory.update_priorities(delta, indices)


    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)


    def reinitialize_memory(self):
        self.memory = PrioritizedReplayBuffer(
            self.num_actions, BUFFER_SIZE, BATCH_SIZE, EXPERIENCES_PER_SAMPLING, self.seed, self.compute_weights)

    ################################################################################
    #    Debug Functions
    ################################################################################

    def save_experience_replay_to_file(self, path):
        """ Save the experience replay pool to a file """

        try:
            pickle.dump(self.experience_replay_pool, open(path, "wb"))
            print 'saved model in %s' % (path, )
        except Exception, e:
            print 'Error: Writing model fails: %s' % (path, )
            print e
class Agent:
    """Interacts with and learns from the environment."""
    def __init__(self, create_network, state_size, action_size, seed):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        # Q-Network
        self.qnetwork_local = create_network(state_size, action_size,
                                             seed).to(device)
        self.qnetwork_target = create_network(state_size, action_size,
                                              seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Replay memory
        self.memory = PrioritizedReplayBuffer(BUFFER_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
        # For debugging
        self.learn_step = 0

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                batch = self.memory.sample(BATCH_SIZE)
                self.learn(batch, GAMMA)

        if done:
            self.memory.increase_beta()

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy.

        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, batch, gamma):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        experiences, indices, w_is = batch
        states, actions, rewards, next_states, dones = self._vstack_experiences(
            experiences)
        w_is = torch.from_numpy(w_is).float().to(device)

        ## TODO: compute and minimize the loss
        "*** YOUR CODE HERE ***"
        max_actions = self.qnetwork_local(next_states).detach().argmax(
            dim=1, keepdim=True)
        next_Q_values = self.qnetwork_target(next_states).detach().gather(
            dim=1, index=max_actions)
        targets = rewards + gamma * next_Q_values * (1.0 - dones)
        Q_values = self.qnetwork_local.forward(states).gather(dim=1,
                                                              index=actions)

        errors = torch.abs(Q_values - targets).squeeze(1)
        self.memory.update_errors(indices, errors.data.numpy())
        loss = ((errors**2) * w_is).mean()
        # print(f"w_IS: {w_is}")
        # print(f"errors: {errors}")
        # print(f"loss: {loss}")
        # self.learn_step += 1
        # if self.learn_step > 10:
        #     raise Exception("stop here")

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)

    def _vstack_experiences(self, experiences):
        states = torch.from_numpy(
            np.vstack([e.state for e in experiences
                       if e is not None])).float().to(device)
        actions = torch.from_numpy(
            np.vstack([e.action for e in experiences
                       if e is not None])).long().to(device)
        rewards = torch.from_numpy(
            np.vstack([e.reward for e in experiences
                       if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])) \
            .float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)) \
            .float().to(device)

        return states, actions, rewards, next_states, dones

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)