def __init__(self, state_space, action_space, seed, opts):

        self.state_space = state_space
        self.action_space = action_space
        self.seed = random.seed(seed)
        self.opts = opts
        self.batch_size = opts.batch

        '''DQNetwork'''

        self.local_model = DQNetwork(state_space, action_space, seed).to(device)
        self.target_model = DQNetwork(state_space, action_space, seed).to(device)
        self.optimizer = Adam(self.local_model.parameters(), lr=opts.lr)

        '''Replay Memory'''

        self.memory = replayMemory(action_space, opts.memory_size, self.batch_size, seed)

        '''How often to update the model'''

        self.update_every = opts.update_freq
class Agent():
    def __init__(self, state_space, action_space, seed, opts):

        self.state_space = state_space
        self.action_space = action_space
        self.seed = random.seed(seed)
        self.opts = opts
        self.batch_size = opts.batch

        '''DQNetwork'''

        self.local_model = DQNetwork(state_space, action_space, seed).to(device)
        self.target_model = DQNetwork(state_space, action_space, seed).to(device)
        self.optimizer = Adam(self.local_model.parameters(), lr=opts.lr)

        '''Replay Memory'''

        self.memory = replayMemory(action_space, opts.memory_size, self.batch_size, seed)

        '''How often to update the model'''

        self.update_every = opts.update_freq

    def step(self, state, action, reward, next_state, done):
        '''
        :param state:
        :param action:
        :param reward:
        :param next_state:
        :param done:
        :return:
        '''

        '''save experience to memory'''
        self.memory.add(state, action, reward, next_state, done)

        self.update_every += 1
        if(self.update_every % self.update_every == 0):
            if(len(self.memory) > self.batch_size):
                experience = self.memory.sample()
                self.learn(experience, self.opts.discount_rate)

    def learn(self, experience, gamma):
        '''
        :param experience:
        :param gamma:
        :return:
        '''

        sampled_state, sampled_action, sampled_reward, sampled_next_state, sampled_done = experience

        next_value = self.target_model(sampled_next_state).detach().max(1)[0].unsqueeze(1)

        DQN_target = sampled_reward + (gamma * next_value * (1 - sampled_done))

        DQN_estimation = self.local_model(sampled_state).gather(1, sampled_action)

        loss = F.mse_loss(DQN_estimation, DQN_target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.soft_update(self.local_model, self.target_model, self.opts.transfer_rate)

    def soft_update(self, local_model, target_model, transfer_rate):
        '''
        :param local_model:
        :param target_model:
        :param transfer_rate:
        :return:
        '''

        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(transfer_rate * local_param.data + (1.0 - transfer_rate) * target_param.data)

    def act(self, state, epsilon=0.):
        '''
        :param state:
        :param epsilon:
        :return:
        '''
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.local_model.eval()
        with torch.no_grad():
            action_value = self.local_model(state)
        self.local_model.train()

        if(np.random.uniform(0,1,1) < epsilon):
            action = np.random.choice(np.arange(self.action_space))
        else:
            action = np.argmax(action_value.cpu().data.numpy())

        return action