コード例 #1
0
ファイル: agent.py プロジェクト: wenxingliu/p1_navigation
class Agent():
    def __init__(self, state_size, action_size, fc_units, seed, lr,
                 buffer_size, batch_size, update_every):

        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.update_every = update_every

        self.qnetwork_local = QNetwork(state_size, action_size, seed,
                                       fc_units).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size, seed,
                                        fc_units).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=lr)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   step_size=100,
                                                   gamma=0.5)
        self.memory = ReplayBuffer(action_size, buffer_size, batch_size, seed)
        self.t_step = 0

    def step(self, state, action, reward, next_state, done, gamma, tau):
        self.memory.add(state, action, reward, next_state, done)
        self.t_step = (self.t_step + 1) % self.update_every

        if (self.t_step == 0) and (len(self.memory) > self.memory.batch_size):
            experiences = self.memory.sample()
            self.learn(experiences, gamma, tau)

    def act(self, state, eps):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        if random.random() > eps:
            action = np.argmax(action_values.cpu().data.numpy())
        else:
            action = random.choice(np.arange(self.action_size))
        return action

    def learn(self, experiences, gamma, tau):
        states, actions, rewards, next_states, dones = experiences
        Q_targets_next = self.qnetwork_target(next_states).detach().max(
            1)[0].unsqueeze(1)
        Q_targets = rewards + gamma * Q_targets_next * (1 - dones)
        Q_expected = self.qnetwork_local(states).gather(1, actions)
        loss = F.mse_loss(Q_expected, Q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.soft_update(self.qnetwork_local, self.qnetwork_target, tau)

    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
コード例 #2
0
class Agent():
    def __init__(self, state_size, action_size, seed, training, pixels, lr=LR):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.t_step = 0.
        self.pixels = pixels
        if pixels is False:
            from q_network import QNetwork
        else:
            from q_network_cnn import QNetwork
            print('loaded cnn network')
            self.loader = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        self.QN_local = QNetwork(state_size, action_size, seed,
                                 training).to(device)
        self.QN_target = QNetwork(state_size, action_size, seed,
                                  training).to(device)
        self.optimizer = optim.Adam(self.QN_local.parameters(), lr=lr)
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed,
                                   device)  # TODO

    def act(self, state, eps):
        #         if self.pixels is True:

        #state = Variable(torch.from_numpy(state).float().to(device).view(state.shape[0],3,32,32))
        if not self.pixels:
            state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.QN_local.eval()

        if torch.no_grad():
            action_values = self.QN_local(state)
        self.QN_local.train()
        if random.random() > eps:
            return int(np.argmax(action_values.cpu().data.numpy()))
        else:
            return int(random.choice(np.arange(self.action_size)))

    def step(self, state, action, reward, next_state, done, stack_size):
        self.memory.add(state, action, reward, next_state, done)
        self.t_step = (self.t_step + 1) % UPDATE_RATE
        if self.t_step == 0 and len(self.memory) > BATCH_SIZE:
            samples = self.memory.sample()
            self.learn(samples, GAMMA, stack_size)

    def learn(self, experiences, gamma, stack_size):
        states, actions, rewards, next_states, dones = experiences

        if self.pixels:
            next_states = Variable(
                next_states
            )  #next_states.view(next_states.shape[0],stack_size,3, stack_size,32,32))
            states = Variable(states)  #states.view(states.shape[0],3,64,64))
#         else:
#todo bring back the old version stuff here

        _target = self.QN_target(next_states).detach().max(1)[0].unsqueeze(1)
        action_values_target = rewards + gamma * _target * (1 - dones)
        action_values_expected = self.QN_local(states).gather(1, actions)

        loss = F.mse_loss(action_values_expected, action_values_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # update target Qnetwork
        for target_param, local_param in zip(self.QN_target.parameters(),
                                             self.QN_local.parameters()):
            target_param.data.copy_(TAU * local_param.data +
                                    (1.0 - TAU) * target_param.data)
コード例 #3
0
ファイル: ddqn.py プロジェクト: adreena/DQN_DDQN
class Agent:
    def __init__(self, env, state_space, action_space, device, learning_rate, buffer_size, \
                 batch_size, gamma, in_channels, train_freq = 4, target_update_freq=1e4, is_ddqn = False):
        self.env = env
        self.state_space = state_space
        self.action_space = action_space
        
        self.QNetwork_local = QNetwork(in_channels, self.state_space, self.action_space.n, device = device).to(device)
        self.QNetwork_local.init_weights()
        self.QNetwork_target = QNetwork(in_channels, self.state_space,self.action_space.n , device = device).to(device)
        self.QNetwork_target.load_state_dict(self.QNetwork_local.state_dict())
        
        self.optimizer = torch.optim.RMSprop(self.QNetwork_local.parameters(), lr=learning_rate, alpha=0.95, eps=0.01, centered=True)
        self.criterion = torch.nn.MSELoss()
        self.memory = ReplayBuffer(capacity=int(buffer_size), batch_size=batch_size)
        self.step_count = 0.
        self.batch_size = batch_size
        self.gamma = gamma
        
        self.device = device
        self.buffer_size = buffer_size
        self.num_train_updates = 0
        self.train_freq = train_freq
        self.target_update_freq = target_update_freq

        self.is_ddqn = is_ddqn
        
        print('Agent initialized with memory size:{}'.format(buffer_size))
        
    def act(self, state, epsilon):
        if random.random() > epsilon:
            #switch to evaluation mode, evaluate state, switch back to train mode
            self.QNetwork_local.eval()
            if torch.no_grad():
                actions = self.QNetwork_local(state)
            self.QNetwork_local.train()
            
            actions_np = actions.data.cpu().numpy()[0]
            best_action_idx = int(np.argmax(actions_np))
            return best_action_idx
        else:
            rand_action = self.action_space.sample()
            return rand_action
            
        
    def step(self, state, action, reward, next_state , done, add_memory=True):
        #TODO calculate priority here?
        if add_memory: 
            priority = 1.
            reward_clip = np.sign(reward)
            self.memory.add(state=state, action=action, next_state=next_state, reward=reward_clip, done=done, priority=priority)
            self.step_count = (self.step_count+1) % self.train_freq  #self.update_rate
            
            #if self.step_count == 0 and len(self.memory) >= self.batch_size:
            self.network_is_updated = False
            if self.step_count == 0 and len(self.memory) == self.buffer_size:
                samples = self.memory.random_sample(self.device)
                self.learn(samples)
                self.num_train_updates +=1
                self.network_is_updated = True
            
    def learn(self, samples):
        states, actions, rewards, next_states, dones = samples
        

        if self.is_ddqn is True:
            # DDQN: find max action using local network & gather the values of actions from target network
            next_actions = torch.argmax(self.QNetwork_local(next_states).detach(), dim=1).unsqueeze(1)
            q_target_next = self.QNetwork_target(next_states).gather(1,next_actions)
        else:
            # DQN: find the max action from target network
            q_target_next = self.QNetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
    
        # expected actions
        q_local_current = self.QNetwork_local(states).gather(1,actions)

        self.optimizer.zero_grad() #cleans up previous values

        # TD Error 
        TD_target = rewards + (self.gamma*q_target_next * (1-dones))
        TD_error = self.criterion(q_local_current, TD_target)
        TD_error.backward()
        torch.nn.utils.clip_grad_norm_(self.QNetwork_local.parameters(), 5.)
        self.optimizer.step()
        
        if (self.num_train_updates/self.train_freq) % self.target_update_freq == 0:
            self.QNetwork_target.load_state_dict(self.QNetwork_local.state_dict())
コード例 #4
0
class DeepQ_agent:
    """
    Represents the DQN agent.
    """
    def __init__(self, env, hidden_units = None, network_LR=0.01, batch_size=1024, update_every=5, gamma=0.95):
        """
        Creates a DQN agent.

        :param env: game environment.
        :type env: Class Snake_Env().
        :param hidden_units: number of neurons in each layer.
        :type hidden_units: tupple with dimension (1, 3).
        :param network_LR: learning rate of the action-value neural network.
        :type network_LR: float.
        :param batch_size: size of the minibatch taken from the replay buffer.
        :type batch_size: int.
        :param update_every: number of iterations for updating the target qnetwork. 
        :type update_every: int
        :param gamma: discount factor.
        :type gamma: float.
        """
        self.env = env
        self.BATCH_SIZE = batch_size
        self.GAMMA = gamma          
        self.NETWORK_LR = network_LR
        self.MEMORY_CAPACITY = int(1e5)   
        self.ACTION_SIZE = env.ACTION_SPACE           
        self.HIDDEN_UNITS = hidden_units
        self.UPDATE_EVERY = update_every
       
        self.qnetwork_local = QNetwork(input_shape = self.env.STATE_SPACE,
                                        hidden_units = self.HIDDEN_UNITS,
                                        output_size = self.ACTION_SIZE,
                                        learning_rate = self.NETWORK_LR)
        
        self.qnetwork_target = QNetwork(input_shape = self.env.STATE_SPACE,
                                        hidden_units = self.HIDDEN_UNITS,
                                        output_size = self.ACTION_SIZE,
                                        learning_rate = self.NETWORK_LR)

        self.memory = ReplayMemory(self.MEMORY_CAPACITY, self.BATCH_SIZE) 

        #Temp variable
        self.t = 0


    def learn(self):
        """
        Learn from memorized experience.
        """
        if self.memory.__len__() > self.BATCH_SIZE:
            states, actions, rewards, next_states, dones = self.memory.sample(self.env.STATE_SPACE)
            
            #Calculating action-values using local network
            target = self.qnetwork_local.predict(states, self.BATCH_SIZE)
            
            #Future action-values using target network
            target_val = self.qnetwork_target.predict(next_states, self.BATCH_SIZE)
            
            #Future action-values using local network
            target_next = self.qnetwork_local.predict(next_states, self.BATCH_SIZE)
        
            max_action_values = np.argmax(target_next, axis=1)   #action selection
            
            for i in range(self.BATCH_SIZE):
                if dones[i]:
                    target[i][actions[i]] = rewards[i]
                else:
                    target[i][actions[i]] = rewards[i] + self.GAMMA*target_val[i][max_action_values[i]]   #action evaluation
            
            self.qnetwork_local.train(states, target, batch_size = self.BATCH_SIZE)

            if self.t == self.UPDATE_EVERY:
                self.update_target_weights()
                self.t = 0
            else:
                self.t += 1


    def act(self, state, epsilon=0.0):
        """
        Chooses an action using an epsilon-greedy policy.
        
        :param state: current state.
        :type state: NumPy array with dimension (1, 18).
        :param epsilon: epsilon used in epsilon-greedy policy.
        :type epsilon: float
        :return action: action chosen by the agent.
        :rtype: int
        """    
        state = state.reshape((1,)+state.shape)
        action_values = self.qnetwork_local.predict(state)    #returns a vector of size = self.ACTION_SIZE
        if random() > epsilon:
            action = np.argmax(action_values)                 #choose best action - Exploitation
        else:
            action = randint(0, self.ACTION_SIZE-1)           #choose random action - Exploration
        return action


    def add_experience(self, state, action, reward, next_state, done):
        """
        Add experience to agent's memory.
        """
        self.memory.add(state, action, reward, next_state, done)

    
    def update_target_weights(self):
        """
        Updates values of the Target network.
        """
        self.qnetwork_target.model.set_weights(self.qnetwork_local.model.get_weights())
コード例 #5
0
class Agent:
    def __init__(self,
                 device,
                 state_size,
                 action_size,
                 buffer_size=10,
                 batch_size=10,
                 learning_rate=0.1,
                 discount_rate=0.99,
                 eps_decay=0.9,
                 tau=0.1,
                 steps_per_update=4):
        self.device = device
        self.state_size = state_size
        self.action_size = action_size

        self.q_network_control = QNetwork(state_size, action_size).to(device)
        self.q_network_target = QNetwork(state_size, action_size).to(device)
        self.optimizer = torch.optim.Adam(self.q_network_control.parameters(),
                                          lr=learning_rate)

        self.batch_size = batch_size
        self.replay_buffer = ReplayBuffer(device, state_size, action_size,
                                          buffer_size)

        self.discount_rate = discount_rate

        self.eps = 1.0
        self.eps_decay = eps_decay

        self.tau = tau

        self.step_count = 0
        self.steps_per_update = steps_per_update

    def policy(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        return self.epsilon_greedy_policy(self.eps, state)

    def epsilon_greedy_policy(self, eps, state):
        self.q_network_control.eval()
        with torch.no_grad():
            action_values = self.q_network_control(state)
        self.q_network_control.train()

        if random.random() > eps:
            greedy_choice = np.argmax(action_values.cpu().data.numpy())
            return greedy_choice
        else:
            return random.choice(np.arange(self.action_size))

    def step(self, state, action, reward, next_state, done):
        p = self.calculate_p(state, action, reward, next_state, done)
        self.replay_buffer.add(state, action, reward, next_state, done, p)
        if self.step_count % self.steps_per_update == 0:
            self.learn()
        self.step_count += 1

    def learn(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        states, actions, rewards, next_states, dones, p = \
            self.replay_buffer.sample(self.batch_size)

        error = self.bellman_eqn_error(states, actions, rewards, next_states,
                                       dones)
        importance_scaling = (self.replay_buffer.buffer_size * p)**-1
        loss = (importance_scaling * (error**2)).sum() / self.batch_size
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_target()

    def bellman_eqn_error(self, states, actions, rewards, next_states, dones):
        """Double DQN error - use the control network to get the best action
        and apply the target network to it to get the target reward which is
        used for the bellman eqn error.
        """
        self.q_network_control.eval()
        with torch.no_grad():
            a_max = self.q_network_control(next_states).argmax(1).unsqueeze(1)

        target_action_values = self.q_network_target(next_states).gather(
            1, a_max)
        target_rewards = rewards + self.discount_rate * (1 - dones) \
                         * target_action_values

        self.q_network_control.train()
        current_rewards = self.q_network_control(states).gather(1, actions)
        error = current_rewards - target_rewards
        return error

    def calculate_p(self, state, action, reward, next_state, done):
        next_state = torch.from_numpy(next_state[np.newaxis, :]).float().to(
            self.device)
        state = torch.from_numpy(state[np.newaxis, :]).float().to(self.device)
        action = torch.from_numpy(np.array([[action]])).long().to(self.device)
        reward = torch.from_numpy(np.array([reward])).float().to(self.device)
        done = torch.from_numpy(np.array([[done]], dtype=np.uint8)).float().to(
            self.device)

        return abs(
            self.bellman_eqn_error(state, action, reward, next_state,
                                   done)) + 1e-3

    def update_target(self):
        for target_param, control_param in zip(
                self.q_network_target.parameters(),
                self.q_network_control.parameters()):
            target_param.data.copy_(self.tau * control_param.data +
                                    (1.0 - self.tau) * target_param.data)

    def end_of_episode(self):
        self.eps *= self.eps_decay
        self.step_count = 0

    def save(self, path):
        torch.save(self.q_network_control.state_dict(), path)

    def restore(self, path):
        self.q_network_control.load_state_dict(torch.load(path))
コード例 #6
0
ファイル: agent.py プロジェクト: SIakovlev/Navigation
class Agent():
    def __init__(self, params):
        action_size = params['action_size']
        state_size = params['state_size']
        buf_params = params['buf_params']
        nn_params = params['nn_params']
        nn_params['l1'][0] = state_size
        nn_params['l5'][1] = action_size

        self.__learning_mode = params['learning_mode']

        if self.__learning_mode['DuelingDDQN']:
            self.__qnetwork_local = DuelingQNetwork(nn_params).to(device)
            self.__qnetwork_target = DuelingQNetwork(nn_params).to(device)
        else:
            self.__qnetwork_local = QNetwork(nn_params).to(device)
            self.__qnetwork_target = QNetwork(nn_params).to(device)

        self.__action_size = action_size
        self.__state_size = state_size
        self.__memory = ReplayBuffer(buf_params)
        self.__t = 0

        self.eps = params['eps_initial']
        self.gamma = params['gamma']
        self.learning_rate = params['learning_rate']
        self.update_period = params['update_period']
        self.a = params['a']
        self.b = params['b']
        self.e = params['e']
        self.tau = params['tau']

        self.__optimiser = optim.Adam(self.__qnetwork_local.parameters(),
                                      self.learning_rate)

        # other parameters
        self.agent_loss = 0.0

    # Set methods
    def set_learning_rate(self, lr):
        self.learning_rate = lr
        for param_group in self.__optimiser.param_groups:
            param_group['lr'] = lr

    # Get methods
    def get_qlocal(self):
        return self.__qnetwork_local

    # Other methods
    def step(self, state, action, reward, next_state, done):
        # add experience to memory
        self.__memory.add(state, action, reward, next_state, done)

        self.__t = (self.__t + 1) % self.update_period
        if not self.__t:
            if self.__memory.is_ready():
                experiences = self.__memory.sample()
                self.__update(experiences)

    def choose_action(self, state, mode='train'):
        # state should be transformed to a tensor
        if mode == 'train':
            if random.random() > self.eps:
                state = torch.from_numpy(state).float().unsqueeze(0).to(device)
                self.__qnetwork_local.eval()
                with torch.no_grad():
                    actions = self.__qnetwork_local(state)
                self.__qnetwork_local.train()
                return np.argmax(actions.cpu().data.numpy())
            else:
                return np.random.choice(np.arange(self.__action_size))
        elif mode == 'test':
            state = torch.from_numpy(state).float().unsqueeze(0).to(device)
            self.__qnetwork_local.eval()
            with torch.no_grad():
                actions = self.__qnetwork_local(state)
            self.__qnetwork_local.train()
            return np.argmax(actions.cpu().data.numpy())
        else:
            print("Invalid mode value")

    def __update(self, experiences):
        states, actions, rewards, next_states, dones, indices, probs = experiences
        # Compute and minimise the loss
        self.__optimiser.zero_grad()

        loss_fn = nn.MSELoss(reduce=False)

        if self.__learning_mode['DQN']:
            Q_target_next = self.__qnetwork_target.forward(next_states).max(
                1)[0].unsqueeze(1).detach()
        else:
            Q_target_next = self.__qnetwork_target.forward(next_states). \
                gather(1, self.__qnetwork_local.forward(next_states).max(1)[1].unsqueeze(1)).detach()

        targets = rewards + self.gamma * Q_target_next * (1 - dones)
        outputs = self.__qnetwork_local.forward(states).gather(1, actions)
        loss = loss_fn(outputs, targets)

        # Calculate weights and normalise
        if probs:
            weights = [(prob * len(self.__memory))**(-self.b)
                       for prob in probs]
            weights = np.array([w / max(weights) for w in weights]).reshape(
                (-1, 1))
        else:
            weights = np.ones(loss.shape, dtype=np.float)

        # Calculate weighted loss
        weighted_loss = torch.mean(torch.from_numpy(weights).float() * loss)
        weighted_loss.backward()

        self.__optimiser.step()

        if indices:
            self.__memory.update(
                indices,
                list(loss.detach().numpy().squeeze()**self.a + self.e))

        self.__soft_update(self.__qnetwork_local, self.__qnetwork_target,
                           self.tau)

        self.agent_loss = weighted_loss.detach().numpy().squeeze()

    def __soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
コード例 #7
0
class Agent():
    def __init__(self, state_size, action_size, seed):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size,
                                       seed).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size,
                                        seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                experiences = self.memory.sample()
                self.learn(experiences, GAMMA)

    def act(self, state, eps=0.0):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences, gamma):
        states, actions, rewards, next_states, dones = experiences

        # Get argmax(QLocal) for action.
        # Get values from QTarget(action) for local indexes
        # Should be 64x1
        q_local_idx = self.qnetwork_local(next_states).detach().argmax(
            1).unsqueeze(1)
        Q_targets_next = self.qnetwork_target(next_states).detach().gather(
            1, q_local_idx)
        # Compute Q targets for current states
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)

    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
コード例 #8
0
class DeepQ_agent:
    def __init__(self,
                 env,
                 hidden_units=None,
                 network_LR=0.001,
                 batch_size=64,
                 update_every=4,
                 gamma=1.0):
        self.env = env
        self.BATCH_SIZE = batch_size
        self.GAMMA = gamma
        self.NETWORK_LR = network_LR
        self.MEMORY_CAPACITY = int(1e5)  #this is pythonic

        self.nA = env.ACTION_SPACE  #number of actions agent can perform
        self.HIDDEN_UNITS = hidden_units
        self.UPDATE_EVERY = update_every

        #let's give it some brains
        self.qnetwork_local = QNetwork(input_shape=self.env.STATE_SPACE,
                                       hidden_units=self.HIDDEN_UNITS,
                                       output_size=self.nA,
                                       learning_rate=self.NETWORK_LR)
        print(self.qnetwork_local.model.summary())

        #I call the target network as the PC
        # Where our agent stores all the concrete and important stuff
        self.qnetwork_target = QNetwork(input_shape=self.env.STATE_SPACE,
                                        hidden_units=self.HIDDEN_UNITS,
                                        output_size=self.nA,
                                        learning_rate=self.NETWORK_LR)

        #and the memory of course
        self.memory = ReplayMemory(self.MEMORY_CAPACITY, self.BATCH_SIZE)

        #handy temp variable
        self.t = 0

#----------------------Learn from experience-----------------------------------#

    def learn(self):
        '''
            hell yeah   
        '''

        if self.memory.__len__() > self.BATCH_SIZE:
            states, actions, rewards, next_states, dones = self.memory.sample(
                self.env.STATE_SPACE)

            #calculating action-values using local network
            target = self.qnetwork_local.predict(states, self.BATCH_SIZE)

            #future action-values using target network
            target_val = self.qnetwork_target.predict(next_states,
                                                      self.BATCH_SIZE)

            #future action-values using local network
            target_next = self.qnetwork_local.predict(next_states,
                                                      self.BATCH_SIZE)

            #The main point of Double DQN is selection of action from local network
            #while the update si from target network
            max_action_values = np.argmax(target_next,
                                          axis=1)  #action selection

            for i in range(self.BATCH_SIZE):
                if dones[i]:
                    target[i][actions[i]] = rewards[i]
                else:
                    target[i][
                        actions[i]] = rewards[i] + self.GAMMA * target_val[i][
                            max_action_values[i]]  #action evaluation

            self.qnetwork_local.train(states,
                                      target,
                                      batch_size=self.BATCH_SIZE)

            if self.t == self.UPDATE_EVERY:
                self.update_target_weights()
                self.t = 0
            else:
                self.t += 1

#-----------------------Time to act-----------------------------------------------#

    def act(self, state, epsilon=0):  #set to NO exploration by default
        state = state.reshape((1, ) + state.shape)
        action_values = self.qnetwork_local.predict(
            state)  #returns a vector of size = self.nA
        if random.random() > epsilon:
            action = np.argmax(
                action_values)  #choose best action - Exploitation
        else:
            action = random.randint(0, self.nA -
                                    1)  #choose random action - Exploration

        return action

#-----------------------------Add experience to agent's memory------------------------#

    def add_experience(self, state, action, reward, next_state, done):
        self.memory.add(state, action, reward, next_state, done)

#----------------------Updates values of Target network----------------------------#

    def update_target_weights(self):
        #well now we are doing hard update, but we can do soft update also
        self.qnetwork_target.model.set_weights(
            self.qnetwork_local.model.get_weights())

#---------------------helpful save function-------------------------------------#

    def save(self, model_num, directory):
        self.qnetwork_local.model.save(
            f'{directory}/snake_dqn_{model_num}_{time.asctime()}.h5')