Example #1
0
def run(args):
    env = make_env(args.env)
    option_critic = OptionCriticFeatures
    device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

    option_critic = option_critic(
        in_features=env.observation_space.shape[0],
        num_actions=env.action_space.n,
        num_options=args.num_options,
        temperature=args.temp,
        eps_start=args.epsilon_start,
        eps_min=args.epsilon_min,
        eps_decay=args.epsilon_decay,
        eps_test=args.optimal_eps,
        device=device,
        input_size = args.input_size,
        num_classes = args.num_classes,
        num_experts = args.num_experts, 
        hidden_size = args.hidden_size, 
        batch_size = args.batch_size,
        top_k = args.top_k
    )
    # Create a prime network for more stable Q values
    option_critic_prime = deepcopy(option_critic)

    optim = torch.optim.RMSprop(option_critic.parameters(), lr=args.learning_rate)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    env.seed(args.seed)

    buffer = ReplayBuffer(capacity=args.max_history, seed=args.seed)
    logger = Logger(logdir=args.logdir, run_name=f"{OptionCriticFeatures.__name__}-{args.env}-{args.exp}-{time.ctime()}")

    steps = 0 ; 
    print(f"Current goal {env.goal}")
    while steps < args.max_steps_total:

        rewards = 0 ; option_lengths = {opt:[] for opt in range(args.num_options)}

        obs   = env.reset()
        state = option_critic.get_state(to_tensor(obs))
        greedy_option  = option_critic.greedy_option(state)
        current_option = 0

        done = False ; ep_steps = 0 ; option_termination = True ; curr_op_len = 0
        while not done and ep_steps < args.max_steps_ep:
            epsilon = option_critic.epsilon

            if option_termination:
                option_lengths[current_option].append(curr_op_len)
                current_option = np.random.choice(args.num_options) if np.random.rand() < epsilon else greedy_option
                curr_op_len = 0
    
            action, logp, entropy = option_critic.get_action(state, current_option)

            next_obs, reward, done, _ = env.step(action)
            buffer.push(obs, current_option, reward, next_obs, done)

            old_state = state
            state = option_critic.get_state(to_tensor(next_obs))

            option_termination, greedy_option = option_critic.predict_option_termination(state, current_option)
            rewards += reward

            actor_loss, critic_loss = None, None
            if len(buffer) > args.batch_size:
                actor_loss = actor_loss_fn(obs, current_option, logp, entropy, \
                    reward, done, next_obs, option_critic, option_critic_prime, args)
                loss = actor_loss

                if steps % args.update_frequency == 0:
                    data_batch = buffer.sample(args.batch_size)
                    critic_loss = critic_loss_fn(option_critic, option_critic_prime, data_batch, args)
                    loss += critic_loss

                optim.zero_grad()
                loss.backward()
                optim.step()

                if steps % args.freeze_interval == 0:
                    option_critic_prime.load_state_dict(option_critic.state_dict())

            # update global steps etc
            steps += 1
            ep_steps += 1
            curr_op_len += 1
            obs = next_obs

            logger.log_data(steps, actor_loss, critic_loss, entropy.item(), epsilon)

        logger.log_episode(steps, rewards, option_lengths, ep_steps, epsilon)
def run(args):
    env, is_atari = make_env(args.env)
    option_critic = OptionCriticConv if is_atari else OptionCriticFeatures
    device = torch.device(
        'cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

    option_critic = option_critic(in_features=env.observation_space.shape[0],
                                  num_actions=env.action_space.n,
                                  num_options=args.num_options,
                                  temperature=args.temp,
                                  eps_start=args.epsilon_start,
                                  eps_min=args.epsilon_min,
                                  eps_decay=args.epsilon_decay,
                                  eps_test=args.optimal_eps,
                                  device=device)
    # Create a prime network for more stable Q values
    option_critic_prime = deepcopy(option_critic)

    optim = torch.optim.RMSprop(option_critic.parameters(),
                                lr=args.learning_rate)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    env.seed(args.seed)

    buffer = ReplayBuffer(capacity=args.max_history, seed=args.seed)
    logger = Logger(
        logdir=args.logdir,
        run_name=
        f"{OptionCriticFeatures.__name__}-{args.env}-{args.exp}-{time.ctime()}"
    )

    steps = 0
    if args.switch_goal: print(f"Current goal {env.goal}")
    while steps < args.max_steps_total:

        rewards = 0
        option_lengths = {opt: [] for opt in range(args.num_options)}

        obs = env.reset()
        state = option_critic.get_state(to_tensor(obs))
        greedy_option = option_critic.greedy_option(state)
        current_option = 0

        # Goal switching experiment: run for 1k episodes in fourrooms, switch goals and run for another
        # 2k episodes. In option-critic, if the options have some meaning, only the policy-over-options
        # should be finedtuned (this is what we would hope).
        if args.switch_goal and logger.n_eps == 1000:
            torch.save(
                {
                    'model_params': option_critic.state_dict(),
                    'goal_state': env.goal
                }, 'models/option_critic_{args.seed}_1k')
            env.switch_goal()
            print(f"New goal {env.goal}")

        if args.switch_goal and logger.n_eps > 2000:
            torch.save(
                {
                    'model_params': option_critic.state_dict(),
                    'goal_state': env.goal
                }, 'models/option_critic_{args.seed}_2k')
            break

        done = False
        ep_steps = 0
        option_termination = True
        curr_op_len = 0
        while not done and ep_steps < args.max_steps_ep:
            epsilon = option_critic.epsilon

            if option_termination:
                option_lengths[current_option].append(curr_op_len)
                current_option = np.random.choice(
                    args.num_options
                ) if np.random.rand() < epsilon else greedy_option
                curr_op_len = 0

            action, logp, entropy = option_critic.get_action(
                state, current_option)

            next_obs, reward, done, _ = env.step(action)
            buffer.push(obs, current_option, reward, next_obs, done)

            old_state = state
            state = option_critic.get_state(to_tensor(next_obs))

            option_termination, greedy_option = option_critic.predict_option_termination(
                state, current_option)
            rewards += reward

            actor_loss, critic_loss = None, None
            if len(buffer) > args.batch_size:
                actor_loss = actor_loss_fn(obs, current_option, logp, entropy, \
                    reward, done, next_obs, option_critic, option_critic_prime, args)
                loss = actor_loss

                if steps % args.update_frequency == 0:
                    data_batch = buffer.sample(args.batch_size)
                    critic_loss = critic_loss_fn(option_critic,
                                                 option_critic_prime,
                                                 data_batch, args)
                    loss += critic_loss

                optim.zero_grad()
                loss.backward()
                optim.step()

                if steps % args.freeze_interval == 0:
                    option_critic_prime.load_state_dict(
                        option_critic.state_dict())

            # update global steps etc
            steps += 1
            ep_steps += 1
            curr_op_len += 1
            obs = next_obs

            logger.log_data(steps, actor_loss, critic_loss, entropy.item(),
                            epsilon)

        logger.log_episode(steps, rewards, option_lengths, ep_steps, epsilon)
Example #3
0
class DQNAgent:
    def __init__(self,
                 state_size,
                 action_size,
                 buffer_size=int(1e5),
                 batch_size=64,
                 gamma=.99,
                 tau=1e-3,
                 lr=5e-4,
                 update_every=4,
                 use_double=False,
                 use_dueling=False,
                 use_priority=False,
                 use_noise=False,
                 seed=42):
        """Deep Q-Network Agent
        
        Args:
            state_size (int)
            action_size (int)
            buffer_size (int): Experience Replay buffer size
            batch_size (int)
            gamma (float): 
                discount factor, used to balance immediate and future reward
            tau (float): interpolation parameter for soft update target network
            lr (float): neural Network learning rate, 
            update_every (int): how ofter we're gonna learn, 
            use_double (bool): whether or not to use double networks improvement
            use_dueling (bool): whether or not to use dueling network improvement
            use_priority (bool): whether or not to use priority experience replay
            use_noise (bool): whether or not to use noisy nets for exploration
            seed (int)
        """

        self.state_size = state_size
        self.action_size = action_size
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.lr = lr
        self.update_every = update_every
        self.use_double = use_double
        self.use_dueling = use_dueling
        self.use_priority = use_priority
        self.use_noise = use_noise

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        # Q-Network
        if use_dueling:
            self.qn_local = DuelingQNetwork(state_size,
                                            action_size,
                                            noisy=use_noise).to(device)
        else:
            self.qn_local = QNetwork(state_size, action_size,
                                     noisy=use_noise).to(device)

        if use_dueling:
            self.qn_target = DuelingQNetwork(state_size,
                                             action_size,
                                             noisy=use_noise).to(device)
        else:
            self.qn_target = QNetwork(state_size, action_size,
                                      noisy=use_noise).to(device)

        # Initialize target model parameters with local model parameters
        self.soft_update(1.0)

        # TODO: make the optimizer configurable
        self.optimizer = optim.Adam(self.qn_local.parameters(), lr=lr)

        if use_priority:
            self.memory = PrioritizedReplayBuffer(buffer_size, batch_size)
        else:
            self.memory = ReplayBuffer(buffer_size, batch_size)

        # Initialize time step (for updating every update_every steps)
        self.t_step = 0

    def step(self, state, action, reward, next_state, done):
        """Step performed by the agent 
        after interacting with the environment and receiving feedback
        
        Args:
            state (int)
            action (int)
            reward (float)
            next_state (int)
            done (bool)
        """

        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every update_every time steps.
        self.t_step = (self.t_step + 1) % self.update_every

        if self.t_step == 0:

            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > self.batch_size:

                if self.use_priority:
                    experiences, indices, weights = self.memory.sample()
                    self.learn(experiences, indices, weights)
                else:
                    experiences = self.memory.sample()
                    self.learn(experiences)

    def act(self, state, eps=0.):
        """Given a state what's the next action to take
        
        Args:
            state (int)
            eps (flost): 
                controls how often we explore before taking the greedy action
        
        Returns:
            int: action to take
        """

        state = torch.from_numpy(state).float().unsqueeze(0).to(device)

        self.qn_local.eval()
        with torch.no_grad():
            action_values = self.qn_local(state)
        self.qn_local.train()

        if self.use_noise:
            return np.argmax(action_values.cpu().numpy())
        else:
            # Epsilon-greedy action selection
            if random.random() > eps:
                return np.argmax(action_values.cpu().numpy())
            else:
                return random.choice(np.arange(self.action_size))

    def learn(self, experiences, indices=None, weights=None):
        """Use a batch of experiences to calculate TD errors and update Q networks
        
        Args:
            experiences: tuple with state, action, reward, next_state and done
            indices (Numpy array): 
                array of indices to update priorities (only used with PER)
            weights (Numpy array): 
                importance-sampling weights (only used with PER)
        """

        states = torch.from_numpy(
                np.vstack([e.state for e in experiences if e is not None]))\
                .float().to(device)
        actions = torch.from_numpy(
                np.vstack([e.action for e in experiences if e is not None]))\
                .long().to(device)
        rewards = torch.from_numpy(
                np.vstack([e.reward for e in experiences if e is not None]))\
                .float().to(device)
        next_states = torch.from_numpy(
                np.vstack([e.next_state for e in experiences if e is not None]))\
                .float().to(device)
        dones = torch.from_numpy(
                np.vstack([e.done for e in experiences if e is not None])\
                .astype(np.uint8)).float().to(device)

        if self.use_priority:
            weights = torch.from_numpy(np.vstack(weights)).float().to(device)

        if self.use_double:  # uses Double Deep Q-Network

            # Get the best action using local model
            best_action = self.qn_local(next_states).argmax(-1, keepdim=True)

            # Evaluate the action using target model
            max_q = self.qn_target(next_states).detach().gather(
                -1, best_action)

        else:  # normal Deep Q-Network

            # Get max predicted Q value (for next states) from target model
            max_q = self.qn_target(next_states).detach().max(-1,
                                                             keepdim=True)[0]

        # Compute Q targets for current states
        q_targets = rewards + (self.gamma * max_q * (1 - dones))

        # Get expected Q values from local model
        q_expected = self.qn_local(states).gather(-1, actions)

        # Compute loss...
        if self.use_priority:
            # Calculate TD error to update priorities
            weighted_td_errors = weights * (q_targets - q_expected)**2
            loss = weighted_td_errors.mean()
        else:
            loss = F.mse_loss(q_expected, q_targets)

        # ...and minimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.use_priority:
            self.memory.update(indices,
                               weighted_td_errors.detach().cpu().numpy())

        # Update target network
        self.soft_update(self.tau)

    def soft_update(self, tau):
        """Soft update model parameters:
            θ_target = τ*θ_local + (1 - τ)*θ_target

        Args:
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter 
        """

        for target_param, local_param in zip(self.qn_target.parameters(),
                                             self.qn_local.parameters()):
            target_param.data.copy_(tau * local_param +
                                    (1.0 - tau) * target_param)

    def make_filename(self, filename):
        filename = 'noisy_' + filename if self.use_noise else filename
        filename = 'dueling_' + filename if self.use_dueling else filename
        filename = 'double_' + filename if self.use_double else filename
        filename = 'prioritized_' + filename if self.use_priority else filename

        return filename

    def save_weights(self, filename='local_weights.pth', path='weights'):
        filename = self.make_filename(filename)
        torch.save(self.qn_local.state_dict(), '{}/{}'.format(path, filename))

    def load_weights(self, filename='local_weights.pth', path='weights'):
        self.qn_local.load_state_dict(
            torch.load('{}/{}'.format(path, filename)))

    def summary(self):
        print('DQNAgent:')
        print('========')
        print('')
        print('Using Double:', self.use_double)
        print('Using Dueling:', self.use_dueling)
        print('Using Priority:', self.use_priority)
        print('Using Noise:', self.use_noise)
        print('')
        print(self.qn_local)
Example #4
0
class DDPGAgent:
    def __init__(self,
                 state_size,
                 action_size,
                 seed,
                 n_hidden_units=128,
                 n_layers=3):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        # actor
        self.actor = Actor(state_size, action_size, seed).to(device)
        self.actor_target = Actor(state_size, action_size, seed).to(device)
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=1e-4)

        # critic
        self.critic = Critic(state_size, action_size, seed).to(device)
        self.critic_target = Critic(state_size, action_size, seed).to(device)
        self.critic_opt = optim.Adam(self.critic.parameters(),
                                     lr=3e-4,
                                     weight_decay=0.0001)

        # will add noise
        self.noise = OUNoise(action_size, seed)

        # experience replay
        self.replay = ReplayBuffer(seed)

    def act(self, state, noise=True):
        '''
            Returns actions taken.
        '''
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.actor.eval()
        with torch.no_grad():
            action = self.actor(state).cpu().data.numpy()
        self.actor.train()
        if noise:
            action += self.noise.sample()
        return np.clip(action, -1, 1)

    def reset(self):
        self.noise.reset()

    def step(self, state, action, reward, next_state, done):
        '''
            Save experiences into replay and sample if replay contains enough experiences
        '''
        self.replay.add(state, action, reward, next_state, done)

        if self.replay.len() > self.replay.batch_size:
            experiences = self.replay.sample()
            self.learn(experiences, GAMMA)

    def learn(self, experiences, gamma):
        '''
            Update policy and value parameters using given batch of experience tuples.
            Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
            where:
                actor_target(state) -> action
                critic_target(state, action) -> Q-value
            Params: experiences (Tuple[torch.Tensor]): tuple of (s, a, r, n_s, done) tuples
                    gamma (float): discount factor
        '''
        states, actions, rewards, next_states, dones = experiences
        # update critic:
        #   get predicted next state actions and Qvalues from targets
        next_actions = self.actor_target(next_states)
        next_Q_targets = self.critic_target(next_states, next_actions)
        #   get current state Qvalues
        Q_targets = rewards + (GAMMA * next_Q_targets * (1 - dones))
        #   compute citic loss
        Q_expected = self.critic(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        #   minimize loss
        self.critic_opt.zero_grad()
        critic_loss.backward(retain_graph=True)
        self.critic_opt.step()

        # update actor:
        #   compute actor loss
        action_predictions = self.actor(states)
        actor_loss = -self.critic(states, action_predictions).mean()
        #   minimize actor loss
        self.actor_opt.zero_grad()
        actor_loss.backward(retain_graph=True)
        self.actor_opt.step()

        # update target networks
        self.soft_update(self.critic, self.critic_target, TAU)
        self.soft_update(self.actor, self.actor_target, TAU)

    def soft_update(self, local, target, tau):
        '''
            Soft update model parameters.
            θ_target = τ*θ_local + (1 - τ)*θ_target
            Params: local: PyTorch model (weights will be copied from)
                    target: PyTorch model (weights will be copied to)
                    tau (float): interpolation parameter
        '''
        for target_param, local_param in zip(target.parameters(),
                                             local.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
Example #5
0
class Runner(object):
    def __init__(self, args):
        # Set the random seed during training and deterministic cudnn

        self.args = args
        torch.manual_seed(self.args.seed)
        np.random.seed(self.args.seed)
        random.seed(self.args.seed)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.autograd.set_detect_anomaly(True)
        device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

        # Check if we are running evaluation alone
        self.evaling_checkpoint = args.eval_checkpoint != ""
        # Create Logger
        self.logger = Logger(logdir=args.logdir,
                             run_name=f"{OptionCriticFeatures.__name__}-{args.env}-{args.exp}-{time.ctime()}")

        # Load Env
        # self.env = gym.make('beaverkitchen-v0')
        self.env.set_args(self.env.Args(
            human_play=False,
            level=args.level,
            headless=args.render_train and not args.render_eval,
            render_slowly=False))

        self.env.seed(self.args.seed)

        self.num_t_steps = 0
        self.num_e_steps = 0
        self.epsilon = self.args.epsilon_start
        self.state = self.env.reset()
        self.num_states = len(self.state)
        self.num_actions = len(self.env.sample_action())

        # Create Model
        self.option_critic = OptionCriticFeatures(
            env_name=self.args.env,
            in_features=self.num_states,
            num_actions=self.num_actions,
            num_options=self.args.num_options,
            temperature=self.args.temp,
            eps_start=self.args.epsilon_start,
            eps_min=self.args.epsilon_min,
            eps_decay=self.args.epsilon_decay,
            eps_test=self.args.optimal_eps,
            device=device
        )

        # Create a prime network for more stable Q values
        self.option_critic_prime = deepcopy(self.option_critic)

        self.optim = torch.optim.RMSprop(self.option_critic.parameters(), lr=self.args.learning_rate)
        torch.nn.utils.clip_grad_norm(self.option_critic.parameters(), self.args.clip_value)

        self.replay_buffer = ReplayBuffer(capacity=self.args.max_history, seed=self.args.seed)

    def run(self):
        self.n_episodes = self.args.n_episodes if not self.args.eval_checkpoint else self.args.n_eval_episodes

        for ep in range(self.n_episodes):
            if not self.args.eval_checkpoint:
                train_return = self.run_episode(ep, train=True)

                timestamp = str(datetime.now())
                print("[{}] Episode: {}, Train Return: {}".format(timestamp, ep, train_return))
                self.logger.log_return("reward/train_return", train_return, ep)

            if ep % self.args.eval_freq == 0:
                # Output and plot eval episode results
                eval_returns = []
                for _ in range(self.args.n_eval_samples):
                    eval_return = self.run_episode(ep, train=False)
                    eval_returns.append(eval_return)

                eval_return = np.array(eval_returns).mean()

                timestamp = str(datetime.now())
                print("[{}] Episode: {}, Eval Return: {}".format(timestamp, ep, eval_return))
                self.logger.log_return("reward/eval_return", eval_return, ep)

            if ep % self.args.checkpoint_freq == 0:
                if not os.path.exists(self.args.modeldir):
                    os.makedirs(self.args.modeldir)

                model_dir = os.path.join(self.args.modeldir, "episode_{:05d}".format(ep))
                self.option_critic.save(model_dir)

        print("Done running...")

    def run_episode(self, ep, train=False):
        option_lengths = {opt: [] for opt in range(self.args.num_options)}
        obs = self.env.reset()
        state = self.option_critic.get_state(to_tensor(obs))

        greedy_option = self.option_critic.greedy_option(state)
        current_option = 0

        done = False
        ep_steps = 0
        option_termination = True
        curr_op_len = 0
        rewards = 0
        cum_reward = 0

        while not done and ep_steps < self.args.max_steps_ep:
            epsilon = self.epsilon if train else 0.0

            if train:
                self.num_t_steps += 1
            else:
                self.num_e_steps += 1

            if option_termination:
                option_lengths[current_option].append(curr_op_len)
                current_option = np.random.choice(
                    self.args.num_options) if np.random.rand() < epsilon else greedy_option
                curr_op_len = 0

            action_idx, logp, entropy = self.option_critic.get_action(state, current_option)

            action = np.zeros(self.num_actions)
            action[int(action_idx)] = 1.0
            next_obs, reward, done, info = self.env.step(action)
            rewards += reward

            if train:
                self.replay_buffer.push(obs, current_option, reward, next_obs, done)

                old_state = state
                state = self.option_critic.get_state(to_tensor(next_obs))

                option_termination, greedy_option = self.option_critic.predict_option_termination(state, current_option)

                # Render domain
                if (self.args.render_train and train) or (self.args.render_eval and not train):
                    self.env.render()

                actor_loss, critic_loss = None, None
                if len(self.replay_buffer) > self.args.batch_size:
                    actor_loss = actor_loss_fn(obs, current_option, logp, entropy, \
                                               reward, done, next_obs, self.option_critic, self.option_critic_prime,
                                               self.args)
                    loss = actor_loss

                    if ep % self.args.update_frequency == 0:
                        data_batch = self.replay_buffer.sample(self.args.batch_size)
                        critic_loss = critic_loss_fn(self.option_critic, self.option_critic_prime, data_batch,
                                                     self.args)
                        loss += critic_loss

                    self.optim.zero_grad()
                    torch.autograd.set_detect_anomaly(True)
                    loss.backward()
                    self.optim.step()

                    if ep % self.args.freeze_interval == 0:
                        self.option_critic_prime.load_state_dict(self.option_critic.state_dict())

                self.logger.log_return("train/cum_reward", cum_reward, self.num_t_steps)
                self.logger.log_data(self.num_t_steps, actor_loss, critic_loss, entropy.item(), self.epsilon)

            # update global steps etc
            ep_steps += 1
            curr_op_len += 1
            obs = next_obs
            cum_reward += (self.args.gamma ** ep_steps) * reward
            self.epsilon = max(self.args.epsilon_min, self.epsilon * self.args.epsilon_decay)

            self.logger.log_return("error/volume_error_{}".format("train" if train else "eval"),
                                   float(info['volume_error']), self.num_t_steps if train else self.num_e_steps)

        self.logger.log_episode(self.num_t_steps, rewards, option_lengths, ep_steps, self.epsilon)

        return rewards
class Agent:
    """Interacts with and learns from the environment."""
    def __init__(self, config):
        """Initialize an Agent object"""
        self.seed = random.seed(config["general"]["seed"])
        self.config = config

        # Q-Network
        self.q = DuelingDQN(config).to(DEVICE)
        self.q_target = DuelingDQN(config).to(DEVICE)

        self.optimizer = optim.RMSprop(self.q.parameters(),
                                       lr=config["agent"]["learning_rate"])
        self.criterion = F.mse_loss

        self.memory = ReplayBuffer(config)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

    def save_experiences(self, state, action, reward, next_state, done):
        """Prepare and save experience in replay memory"""
        reward = np.clip(reward, -1.0, 1.0)
        self.memory.add(state, action, reward, next_state, done)

    def _current_step_is_a_learning_step(self):
        """Check if the current step is an update step"""
        self.t_step = (self.t_step + 1) % self.config["agent"]["update_rate"]
        return self.t_step == 0

    def _enough_samples_in_memory(self):
        """Check if minimum amount of samples are in memory"""
        return len(self.memory) > self.config["train"]["batch_size"]

    def epsilon_greedy_action_selection(self, action_values, eps):
        """Epsilon-greedy action selection"""
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(
                np.arange(self.config["general"]["action_size"]))

    def act(self, state, eps=0.0):
        """Returns actions for given state as per current policy"""
        state = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
        self.q.eval()
        with torch.no_grad():
            action_values = self.q(state)
        self.q.train()

        return self.epsilon_greedy_action_selection(action_values, eps)

    def _calc_loss(self, states, actions, rewards, next_states, dones):
        """Calculates loss for a given experience batch"""
        q_eval = self.q(states).gather(1, actions)
        q_eval_next = self.q(next_states)
        _, q_argmax = q_eval_next.detach().max(1)
        q_next = self.q_target(next_states)
        q_next = q_next.gather(1, q_argmax.unsqueeze(1))
        q_target = rewards + (self.config["agent"]["gamma"] * q_next *
                              (1 - dones))
        loss = self.criterion(q_eval, q_target)
        return loss

    def _update_weights(self, loss):
        """update the q network weights"""
        torch.nn.utils.clip_grad.clip_grad_value_(self.q.parameters(), 1.0)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def learn(self):
        """Update network using one sample of experience from memory"""
        if self._current_step_is_a_learning_step(
        ) and self._enough_samples_in_memory():
            states, actions, rewards, next_states, dones = self.memory.sample(
                self.config["train"]["batch_size"])
            loss = self._calc_loss(states, actions, rewards, next_states,
                                   dones)
            self._update_weights(loss)
            self._soft_update(self.q, self.q_target)

    def _soft_update(self, local_model, target_model):
        """Soft update target network parameters: θ_target = τ*θ_local + (1 - τ)*θ_target"""
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(
                self.config["agent"]["tau"] * local_param.data +
                (1.0 - self.config["agent"]["tau"]) * target_param.data)

    def save(self):
        """Save the network weights"""
        helper.mkdir(
            os.path.join(".", *self.config["general"]["checkpoint_dir"],
                         self.config["general"]["env_name"]))
        current_date_time = helper.get_current_date_time()
        current_date_time = current_date_time.replace(" ", "__").replace(
            "/", "_").replace(":", "_")

        torch.save(
            self.q.state_dict(),
            os.path.join(".", *self.config["general"]["checkpoint_dir"],
                         self.config["general"]["env_name"],
                         "ckpt_" + current_date_time))

    def load(self):
        """Load latest available network weights"""
        list_of_files = glob.glob(
            os.path.join(".", *self.config["general"]["checkpoint_dir"],
                         self.config["general"]["env_name"], "*"))
        latest_file = max(list_of_files, key=os.path.getctime)
        self.q.load_state_dict(torch.load(latest_file))
        self.q_target.load_state_dict(torch.load(latest_file))