Exemple #1
0
class DDPG(object):
    def __init__(self,
                 state_dim,
                 action_dim,
                 action_bounds,
                 gamma=0.99,
                 sess=None):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma

        self.action_mean = (action_bounds[0] + action_bounds[1]) * 0.5
        self.action_scale = (action_bounds[1] - action_bounds[0]) * 0.5

        self.batch_size = 5

        self.replay_buffer = ReplayBuffer(1000000,
                                          state_dim=state_dim,
                                          action_dim=action_dim)

        if sess == None:
            self.sess = tf.InteractiveSession()
        else:
            self.sess = sess

        self.actor = ActorModel(state_dim, action_dim, self.action_mean,
                                self.action_scale, self.sess)
        self.critic = CriticModel(state_dim, action_dim, self.sess)

        self.reset_policy()

        writer = tf.summary.FileWriter('logs', self.sess.graph)
        writer.close()

    def reset_policy(self):
        tf.global_variables_initializer().run()

        self.actor.reset_target_model()
        self.critic.reset_target_model()

        self.train_idx = 0
        self.replay_buffer.clear()

    def curr_policy(self):
        return self.actor.get_action

    def save_model(self, filename='/tmp/model.ckpt'):
        saver = tf.train.Saver()
        save_path = saver.save(self.sess, filename)
        print("Model saved in file: %s" % filename)

    def load_model(self, filename='/tmp/model.ckpt'):
        saver = tf.train.Saver()
        saver.restore(self.sess, filename)
        print("Model loaded from file: %s" % filename)

    def update(self, env, get_state, max_iter=1000):
        state = env.reset()

        total_reward = 0
        rand_process = OrnsteinUhlenbeckProcess(dt=1.0,
                                                theta=0.15,
                                                sigma=0.2,
                                                mu=np.zeros(self.action_dim),
                                                x0=np.zeros(self.action_dim))
        for i in range(max_iter):
            # get action
            action = self.actor.get_action(state)
            # generate random noise for action
            action_noise = rand_process.get_next()
            action += action_noise
            action = np.clip(action, self.action_mean - self.action_scale,
                             self.action_mean + self.action_scale)
            # action = np.array([action.squeeze()])

            [new_state, reward, done, _] = env.step(action)
            new_state = np.reshape(new_state, (1, self.state_dim))
            self.replay_buffer.insert(state, action, reward, new_state, done)

            total_reward += reward
            state = new_state

            if self.train_idx >= (self.batch_size * 3):
                sample = self.replay_buffer.sample(self.batch_size)

                # get target actions
                target_actions = self.actor.get_target_action(
                    sample['next_state'])
                target_q_vals = self.critic.get_target_q_val(
                    sample['next_state'], target_actions)

                disc_return = sample['reward'] + \
                    self.gamma * target_q_vals.squeeze() * (1.0 - sample['terminal'])

                # update critic network
                loss = self.critic.train(sample['state'], sample['action'],
                                         disc_return)

                # get actions grads from critic network
                action_grads = self.critic.get_action_grads(
                    sample['state'], sample['action'])[0]

                # update actor network
                self.actor.train(sample['state'], action_grads)

                # # update target networks
                self.actor.update_target_model()
                self.critic.update_target_model()

            if done:
                break

            self.train_idx += 1

        return total_reward
Exemple #2
0
class Agent:
    def __init__(self, tau, gamma, batch_size, lr, state_size, actions_size, kappa, N, device, double=True,
                 visual=False, prioritized=False):
        self.tau = tau
        self.gamma = gamma
        self.batch_size = batch_size
        self.actions_size = actions_size
        self.N = N
        self.tau_q = torch.linspace(0, 1, N + 1)
        self.tau_q = (self.tau_q[1:] + self.tau_q[:-1]) / 2
        self.tau_q = self.tau_q.to(device).unsqueeze(0)  # (1,N)
        self.kappa = kappa
        self.device = device
        self.double = double
        self.prioritized = prioritized

        if visual:
            self.Q_target = Visual_Q_Networks(state_size, actions_size, N).to(self.device)
            self.Q_local = Visual_Q_Networks(state_size, actions_size, N).to(self.device)
        else:
            self.Q_target = Q_Network(state_size, actions_size, N).to(self.device)
            self.Q_local = Q_Network(state_size, actions_size, N).to(self.device)

        self.optimizer = optim.Adam(self.Q_local.parameters(), lr=lr)
        self.soft_update()

        if self.prioritized:
            self.memory = ProportionalReplayBuffer(int(1e5), batch_size)
            # self.memory = RankedReplayBuffer(int(1e5), batch_size)
        else:
            self.memory = ReplayBuffer(int(1e5), batch_size)

    def act(self, state, epsilon=0.1):
        """
        :param state: 输入的input
        :param epsilon:  以epsilon的概率进行exploration, 以 1 - epsilon的概率进行exploitation
        :return:
        """
        if random.random() > epsilon:
            state = torch.tensor(state, dtype=torch.float32).to(self.device)
            with torch.no_grad():
                actions_value = self.Q_local(state)  # (batch_size,action_size,N)
                actions_value = actions_value.sum(dim=2,keepdims=False).view(-1)
            return np.argmax(actions_value.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.actions_size))

    def learn(self):
        if self.prioritized:
            index_list, states, actions, rewards, next_states, dones, probs = self.memory.sample(self.batch_size)
            w = 1 / len(self.memory) / probs
            w = w / torch.max(w)
            w = w.to(self.device)
        else:
            states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)

            w = torch.ones(actions.size())
            w = w.to(self.device)

        states = states.to(self.device)
        actions = actions.to(self.device)  # (batch_size,1)
        rewards = rewards.to(self.device)
        next_states = next_states.to(self.device)
        dones = dones.to(self.device)

        quantiles_local = self.Q_local(states)
        actions = actions.unsqueeze(1).repeat(1, 1, self.N)  # (batch_size,1,N)
        quantiles_local = torch.gather(input=quantiles_local, dim=1, index=actions)  # (batch_size,1,N)

        with torch.no_grad():
            quantiles_target = self.Q_target(next_states)
            _, actions_target = torch.max(input=quantiles_target.sum(dim=2, keepdims=True), dim=1,
                                          keepdim=True)  # (batch_size,1,1)
            actions_target = actions_target.repeat(1, 1, self.N)  # (batch_size,1,1)
            quantiles_target = torch.gather(input=quantiles_target, dim=1, index=actions_target)  # (batch_size,1,N)
            quantiles_target = rewards.unsqueeze(1).repeat(1, 1, self.N) + self.gamma * \
                               (1 - (dones.unsqueeze(1).repeat(1, 1, self.N))) * quantiles_target  # (batch_size,1,N)

        diff = quantiles_target.permute(0, 2, 1) - quantiles_local  # (batch_size,N,N)
        loss = self.huber_loss(diff, self.kappa, self.tau_q)  # (batch_size,N,N)
        loss = loss.mean(dim=2, keepdim=False).sum(dim=1, keepdim=False).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.prioritized:
            deltas = quantiles_target.sum(dim=2, keepdim=False) - quantiles_local.sum(dim=2, keepdim=False)
            deltas = np.abs(deltas.detach().cpu().numpy().reshape(-1))
            for i in range(self.batch_size):
                self.memory.insert(deltas[i], index_list[i])

    def huber_loss(self, u, kappa, tau):
        if kappa > 0:
            flag = (u.abs() < kappa).float()
            huber = 0.5 * u.pow(2) * flag + kappa * (u.abs() - 0.5 * kappa) * (1 - flag)
        else:
            huber = u.abs()
        loss = (tau - (u < 0).float()).abs() * huber
        return loss

    def soft_update(self):
        for Q_target_param, Q_local_param in zip(self.Q_target.parameters(), self.Q_local.parameters()):
            Q_target_param.data.copy_(self.tau * Q_local_param.data + (1.0 - self.tau) * Q_target_param.data)
def main(test=False,
         checkpoint=None,
         device='cuda',
         project_name='drqn',
         run_name='example'):
    if not test:
        wandb.init(project=project_name, name=run_name)

    ## HYPERPARAMETERS
    memory_size = 500000  # DARQN paper - 500k, 400k DRQN paper
    min_rb_size = 50000  # ? z dupy wyciagniete, po ilu iteracjach zaczynamy trening
    sample_size = 32  # ? z dupy, 32 DARQN
    lr = 0.1  # DARQN paper - 0.01
    lr_min = 0.00025
    lr_decay = (lr - lr_min) / 1e6
    boltzmann_exploration = False  # nie bylo w papierach
    eps = 1
    eps_min = 0.1  # DARQN - 0.1
    eps_decay = (eps - eps_min) / 1e6  # powinien byc liniowy
    train_interval = 4  # DARQN - 4
    update_interval = 10000  # wszystkie papiery
    test_interval = 5000  # z dupy, bez znaczenia do zbieznosci
    episode_reward = 0
    episode_rewards = []
    screen_flicker_probability = 0.5

    # replay buffer
    replay = ReplayBuffer(memory_size, truncate_batch=True, guaranteed_size=2)
    step_num = -1 * min_rb_size

    # environment creation
    env = gym.make('Frostbite-v0')
    env = BreakoutEnv(env, 84, 84)
    test_env = gym.make('Frostbite-v0')
    test_env = BreakoutEnv(test_env, 84, 84)
    last_observation = env.reset()

    # model creation
    model = DRQN(env.observation_space.shape, env.action_space.n,
                 lr=lr).to(device)
    if checkpoint is not None:
        model.load_state_dict(torch.load(checkpoint))
    target = DRQN(env.observation_space.shape, env.action_space.n).to(device)
    update_target_model(model, target)

    hidden = None
    # training loop
    tq = tqdm()
    while True:
        if test:
            env.render()
            time.sleep(0.05)
        tq.update(1)
        eps = max(eps_min, eps - eps_decay)
        #lr = max(lr_min, lr - lr_decay)
        if test:
            eps = 0
        if boltzmann_exploration:
            x = torch.Tensor(last_observation).unsqueeze(0).to(device)
            logits, hidden = model(x, hidden)
            action = torch.distributions.Categorical(
                logits=logits[0]).sample().item()
        else:
            # epsilon-greedy
            x = torch.Tensor(last_observation).unsqueeze(0).to(device)
            qvals, hidden = model(x, hidden)
            if random() < eps:
                action = env.action_space.sample()
            else:
                action = qvals.max(-1)[-1].item()

        # screen flickering
        # if random() < screen_flicker_probability:
        #    last_observation = np.zeros_like(last_observation)

        # observe and obtain reward
        observation, reward, done, info = env.step(action)
        episode_reward += reward

        # add to replay buffer
        replay.insert(
            Sarsd(last_observation, action, reward, observation, done))
        last_observation = observation

        # episode end logic
        if done:
            hidden = None
            episode_rewards.append(episode_reward)
            if len(episode_rewards) > 100:
                del episode_rewards[0]
            wandb.log({
                "reward_ep":
                episode_reward,
                "avg_reward_100ep":
                0 if len(episode_rewards) != 100 else np.mean(episode_rewards)
            })
            episode_reward = 0
            last_observation = env.reset()
        step_num += 1

        # testing, model updating and checkpointing
        if (not test) and (replay.idx > min_rb_size):
            if step_num % train_interval == 0:
                loss = train_step(model, replay.sample(sample_size), target,
                                  env.action_space.n, lr, device)
                wandb.log({
                    "loss": loss.detach().cpu().item(),
                    "step": step_num,
                    "lr": lr
                })
                if not boltzmann_exploration:
                    wandb.log({"eps": eps})
            if step_num % update_interval == 0:
                print('updating target model')
                update_target_model(model, target)
                torch.save(target.state_dict(), f'target.model')
                model_artifact = wandb.Artifact("model_checkpoint",
                                                type="raw_data")

                model_artifact.add_file('target.model')
                wandb.log_artifact(model_artifact)
            if step_num % test_interval == 0:
                print('running test')
                avg_reward, best_reward, frames = policy_evaluation(
                    model,
                    test_env,
                    device,
                    boltzmann_exploration=boltzmann_exploration
                )  # model or target?
                wandb.log({
                    'test_avg_reward':
                    avg_reward,
                    'test_best_reward':
                    best_reward,
                    'test_best_video':
                    wandb.Video(frames.transpose(0, 3, 1, 2),
                                str(best_reward),
                                fps=24)
                })
    env.close()
Exemple #4
0
def main(test=False,
         checkpoint=None,
         device='cuda',
         project_name='drqn',
         run_name='example'):
    if not test:
        wandb.init(project=project_name, name=run_name)

    ## HYPERPARAMETERS
    memory_size = 500000
    min_rb_size = 100000
    sample_size = 64
    lr = 0.005
    boltzmann_exploration = False
    eps_min = 0.05
    eps_decay = 0.999995
    train_interval = 8
    update_interval = 10000
    test_interval = 5000
    episode_reward = 0
    episode_rewards = []
    screen_flicker_probability = 0.5

    # additional hparams
    living_reward = -0.01
    same_frame_ctr = 0
    same_frame_limit = 200

    # replay buffer
    replay = ReplayBuffer(memory_size, truncate_batch=True, guaranteed_size=32)
    step_num = -1 * min_rb_size

    # environment creation
    env = gym.make('BreakoutDeterministic-v4')
    env = BreakoutEnv(env, 84, 84)
    test_env = gym.make('BreakoutDeterministic-v4')
    test_env = BreakoutEnv(test_env, 84, 84)
    last_observation = env.reset()

    # model creation
    model = DRQN(env.observation_space.shape, env.action_space.n,
                 lr=lr).to(device)
    if checkpoint is not None:
        model.load_state_dict(torch.load(checkpoint))
    target = DRQN(env.observation_space.shape, env.action_space.n).to(device)
    update_target_model(model, target)

    # lstm hidden state
    hidden = None

    # training loop
    tq = tqdm()
    while True:
        if test:
            env.render()
            time.sleep(0.05)
        tq.update(1)
        eps = max(eps_min, eps_decay**(step_num))
        if test:
            eps = 0
        if boltzmann_exploration:
            x = torch.Tensor(last_observation).unsqueeze(0).to(device)
            logits, next_hidden = model(x, hidden)
            action = torch.distributions.Categorical(
                logits=logits[0]).sample().item()
            hidden = next_hidden
        else:
            # epsilon-greedy
            if random() < eps:
                action = env.action_space.sample()
            else:
                x = torch.Tensor(last_observation).unsqueeze(0).to(device)
                qvals, next_hidden = model(x, hidden)
                action = qvals.max(-1)[-1].item()
                hidden = next_hidden

        # screen flickering
        # if random() < screen_flicker_probability:
        #    last_observation = np.zeros_like(last_observation)

        # observe and obtain reward
        observation, reward, done, info = env.step(action)
        episode_reward += reward

        # add to replay buffer
        replay.insert(Sarsd(last_observation, action, reward, done))
        last_observation = observation

        # episode end logic
        if done:
            hidden = None
            episode_rewards.append(episode_reward)
            if len(episode_rewards) > 100:
                del episode_rewards[0]
            wandb.log({
                "reward_ep": episode_reward,
                "avg_reward_100ep": np.mean(episode_rewards)
            })
            episode_reward = 0
            last_observation = env.reset()
        step_num += 1

        # testing, model updating and checkpointing
        if (not test) and (replay.idx > min_rb_size):
            if step_num % train_interval == 0:
                loss = train_step(model, replay.sample(sample_size), target,
                                  env.action_space.n, device)
                wandb.log({
                    "loss": loss.detach().cpu().item(),
                    "step": step_num
                })
                if not boltzmann_exploration:
                    wandb.log({"eps": eps})
            if step_num % update_interval == 0:
                print('updating target model')
                update_target_model(model, target)
                torch.save(target.state_dict(), f'target.model')
                model_artifact = wandb.Artifact("model_checkpoint",
                                                type="raw_data")
                model_artifact.add_file('target.model')
                wandb.log_artifact(model_artifact)
            if step_num % test_interval == 0:
                print('running test')
                avg_reward, best_reward, frames = policy_evaluation(
                    model, test_env, device)  # model or target?
                wandb.log({
                    'test_avg_reward':
                    avg_reward,
                    'test_best_reward':
                    best_reward,
                    'test_best_video':
                    wandb.Video(frames.transpose(0, 3, 1, 2),
                                str(best_reward),
                                fps=24)
                })
    env.close()
Exemple #5
0
class DQN(object):
    def __init__(self,
                 state_dim,
                 num_actions,
                 eps_anneal,
                 gamma=0.99,
                 update_freq=100,
                 sess=None):
        self.state_dim = state_dim
        self.num_actions = num_actions
        self.gamma = gamma
        self.eps_anneal = eps_anneal
        self.update_freq = update_freq

        self.batch_size = 64

        self.replay_buffer = ReplayBuffer(3000,
                                          state_dim=state_dim,
                                          action_dim=1)
        self.__build_model()

        if sess == None:
            self.sess = tf.InteractiveSession()
        else:
            self.sess = sess

        self.reset_policy()

        writer = tf.summary.FileWriter('logs', self.sess.graph)
        writer.close()

    def reset_policy(self):
        tf.global_variables_initializer().run()
        self.train_idx = 0
        self.replay_buffer.clear()
        self.eps_anneal.reset()

    def __build_q_func(self, input_var, name, reuse=False):
        with tf.variable_scope(name, reuse=reuse) as scope:
            layer1 = tf.contrib.layers.fully_connected(
                input_var, 32, activation_fn=tf.nn.relu, scope='layer1')
            layer2 = tf.contrib.layers.fully_connected(
                layer1, 16, activation_fn=tf.nn.relu, scope='layer2')
            q_vals = tf.contrib.layers.fully_connected(layer2,
                                                       self.num_actions,
                                                       activation_fn=None,
                                                       scope='q_vals')
        return q_vals

    def __build_model(self):
        # forward model
        self.states = tf.placeholder(tf.float32, [None, self.state_dim],
                                     name='states')
        self.actions = tf.placeholder(tf.int32, [None], name='actions')
        self.action_q_vals = self.__build_q_func(self.states,
                                                 name='action_q_func')
        self.output_actions = tf.argmax(self.action_q_vals,
                                        axis=1,
                                        name='output_actions')
        self.sampled_q_vals = tf.reduce_sum(tf.multiply(
            self.action_q_vals, tf.one_hot(self.actions, self.num_actions)),
                                            1,
                                            name='sampled_q_vals')

        self.target_q_vals = self.__build_q_func(self.states,
                                                 name='target_q_func')
        self.max_q_vals = tf.reduce_max(self.target_q_vals,
                                        axis=1,
                                        name='max_q_vals')

        # loss
        self.rewards = tf.placeholder(tf.float32, [None], name='rewards')
        self.terminal = tf.placeholder(tf.float32, [None], name='terminal')
        self.q_vals_next_state = tf.placeholder(tf.float32, [None],
                                                name='q_vals_next_state')

        self.terminal_mask = tf.subtract(1.0, self.terminal)

        self.disc_return = tf.add(self.rewards,
                                  tf.multiply(
                                      self.terminal_mask,
                                      tf.multiply(self.gamma,
                                                  self.q_vals_next_state)),
                                  name='disc_return')

        self.td_error = tf.subtract(self.disc_return,
                                    self.sampled_q_vals,
                                    name='td_error')
        self.loss = tf.reduce_mean(tf.square(self.td_error), name='loss')
        self.optimizer = tf.train.RMSPropOptimizer(0.00025).minimize(self.loss)

        # updating target network
        var_sort_lambd = lambda x: x.name
        self.action_q_vars = sorted(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope='action_q_func'),
                                    key=var_sort_lambd)
        self.target_q_vars = sorted(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope='target_q_func'),
                                    key=var_sort_lambd)

        update_target_ops = []
        for action_q, target_q in zip(self.action_q_vars, self.target_q_vars):
            update_target_ops.append(target_q.assign(action_q))
        self.update_target_ops = tf.group(*update_target_ops,
                                          name='update_target_ops')

    def __update_target_network(self):
        self.sess.run(self.update_target_ops)

    def get_action(self, state):
        sample = np.random.random_sample()
        if sample > self.eps_anneal.eps:
            fd = {self.states: np.array([state])}
            output_action = self.sess.run(self.output_actions, feed_dict=fd)
            action = np.asscalar(output_action)
        else:
            action = np.random.randint(self.num_actions)

        return action

    def curr_policy(self):
        return partial(DQN.get_action, self)

    def save_model(self, filename='/tmp/model.ckpt'):
        saver = tf.train.Saver()
        save_path = saver.save(self.sess, filename)
        print("Model saved in file: %s" % filename)

    def load_model(self, filename='/tmp/model.ckpt'):
        saver = tf.train.Saver()
        saver.restore(self.sess, filename)
        print("Model loaded from file: %s" % filename)

    def update(self, env, get_state, max_iter=1000):
        state = env.reset()

        action = self.get_action(state)

        total_reward = 0
        for i in range(max_iter):
            [new_state, reward, done, _] = env.step(action)
            total_reward += reward

            self.replay_buffer.insert(state, action, reward, new_state, done)

            state = new_state

            if self.train_idx >= self.batch_size:
                sample = self.replay_buffer.sample(self.batch_size)

                # get max q values of next state
                fd = {self.states: sample['next_state']}
                max_q_vals = self.sess.run(self.max_q_vals, feed_dict=fd)

                fd = {
                    self.states: sample['state'],
                    self.actions: sample['action'].squeeze(),
                    self.rewards: sample['reward'],
                    self.terminal: sample['terminal'],
                    self.q_vals_next_state: max_q_vals
                }

                loss, _ = self.sess.run([self.loss, self.optimizer],
                                        feed_dict=fd)

                if self.train_idx % self.update_freq == 0:
                    self.__update_target_network()
            if done:
                break

            action = self.get_action(state)
            self.train_idx += 1

        self.eps_anneal.update()
        return total_reward
Exemple #6
0
class Agent:
    def __init__(self,
                 tau,
                 gamma,
                 batch_size,
                 lr,
                 state_size,
                 actions_size,
                 v_min,
                 v_max,
                 N,
                 device,
                 double=True,
                 visual=False,
                 prioritized=False):
        self.tau = tau
        self.gamma = gamma
        self.batch_size = batch_size
        self.actions_size = actions_size
        self.v_min = v_min
        self.v_max = v_max
        self.N = N
        self.vals = torch.linspace(v_min, v_max,
                                   N).to(device)  # (batch_size ,N)
        self.unit = (v_max - v_min) / (N - 1)
        self.device = device
        self.double = double
        self.prioritized = prioritized

        if visual:
            self.Q_target = Visual_Q_Networks(state_size, actions_size, v_min,
                                              v_max, N).to(self.device)
            self.Q_local = Visual_Q_Networks(state_size, actions_size, v_min,
                                             v_max, N).to(self.device)
        else:
            self.Q_target = Q_Network(state_size, actions_size, v_min, v_max,
                                      N).to(self.device)
            self.Q_local = Q_Network(state_size, actions_size, v_min, v_max,
                                     N).to(self.device)

        self.optimizer = optim.Adam(self.Q_local.parameters(), lr=lr)
        self.soft_update()

        if self.prioritized:
            self.memory = ProportionalReplayBuffer(int(1e5), batch_size)
            # self.memory = RankedReplayBuffer(int(1e5), batch_size)
        else:
            self.memory = ReplayBuffer(int(1e5), batch_size)

    def act(self, state, epsilon=0.1):
        """
        :param state: 输入的input
        :param epsilon:  以epsilon的概率进行exploration, 以 1 - epsilon的概率进行exploitation
        :return:
        """
        if random.random() > epsilon:
            state = torch.tensor(state, dtype=torch.float32).to(self.device)
            with torch.no_grad():
                _, actions_value = self.Q_local(state)
            return np.argmax(actions_value.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.actions_size))

    def learn(self):
        if self.prioritized:
            index_list, states, actions, rewards, next_states, dones, probs = self.memory.sample(
                self.batch_size)
            w = 1 / len(self.memory) / probs
            w = w / torch.max(w)
            w = w.to(self.device)
        else:
            states, actions, rewards, next_states, dones = self.memory.sample(
                self.batch_size)

            w = torch.ones(actions.size())
            w = w.to(self.device)

        states = states.to(self.device)
        actions = actions.to(self.device)  # (batch_size,1)
        rewards = rewards.to(self.device)
        next_states = next_states.to(self.device)
        dones = dones.to(self.device)

        local_log_prob, Q_local_value = self.Q_local(states)
        actions = actions.unsqueeze(1).repeat(1, 1, self.N)  # (batch_size,1,N)
        local_log_prob = torch.gather(input=local_log_prob,
                                      dim=1,
                                      index=actions)  # (batch_size,1,N)

        with torch.no_grad():
            target_log_prob, Q_target_value = self.Q_target(next_states)
            _, actions_target = torch.max(input=Q_target_value,
                                          dim=1,
                                          keepdim=True)  # (batch_size,1)
            actions_target = actions_target.unsqueeze(1).repeat(
                1, 1, self.N)  # (batch_size,1,1)
            target_log_prob = torch.gather(
                input=target_log_prob, dim=1,
                index=actions_target)  # (bath_size,1,N)
            target_log_prob = self.update_distribution(target_log_prob.exp(),
                                                       rewards, self.gamma,
                                                       dones)
            # (batch_size,1,N)

        loss = -local_log_prob * target_log_prob
        loss = loss.sum(dim=2, keepdim=False).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.prioritized:
            deltas = local_log_prob.sum(
                dim=2, keepdim=False) - target_log_prob.sum(dim=2,
                                                            keepdim=False)
            deltas = np.abs(deltas.detach().cpu().numpy().reshape(-1))
            for i in range(self.batch_size):
                self.memory.insert(deltas[i], index_list[i])

    def update_distribution(self, prev_distribution, reward, gamma, dones):
        """
        :param prev_distribution: Q_target(X_t+1,a*)
        :param reward: 奖励
        :param gamma:  gamma
        :param dones: 是否结束
        :return:  更新话的分布
        """
        with torch.no_grad():
            reward = reward.view(-1, 1)  # (batch_size,1)
            batch_size = reward.size(0)
            assert prev_distribution.size(0) == batch_size
            new_vals = self.vals.view(
                1, -1) * gamma * (1 - dones) + reward  # (batch_size,N)
            new_vals = torch.clamp(new_vals, self.v_min,
                                   self.v_max).to(self.device)
            lower_indexes = torch.floor(
                (new_vals - self.v_min) / self.unit).long().to(
                    self.device)  # (batch_size,N)
            upper_indexes = torch.min(
                lower_indexes + 1,
                other=torch.tensor(self.N - 1).to(self.device)).to(
                    self.device)  # (batch_size,N)
            lower_vals = self.vals[lower_indexes].to(
                self.device)  # (batch_size,N)
            lower_distances = 1 - torch.min(
                (new_vals - lower_vals) / self.unit,
                other=torch.tensor(1, dtype=torch.float32).to(self.device)).to(
                    self.device)  # (batch_size,N)
            transition = torch.zeros(
                (batch_size, self.N, self.N)).to(self.device)
            first_dim = torch.tensor(range(batch_size), dtype=torch.long).view(
                -1,
                1).repeat(1,
                          self.N).view(-1).to(self.device)  # (bath_size * N)

            second_dim = torch.tensor(range(
                self.N), dtype=torch.long).repeat(batch_size).view(-1).to(
                    self.device)  # (batch_size * N)
            transition[first_dim, second_dim,
                       lower_indexes.view(-1)] += lower_distances.view(-1)
            transition[first_dim, second_dim,
                       upper_indexes.view(-1)] += 1 - lower_distances.view(-1)
            if len(prev_distribution.size()) == 2:
                prev_distribution = prev_distribution.unsqueeze(
                    1)  # (batch_size,action_size,N)
            return torch.bmm(prev_distribution,
                             transition)  # (batch_size,action_size,N)

    def soft_update(self):
        for Q_target_param, Q_local_param in zip(self.Q_target.parameters(),
                                                 self.Q_local.parameters()):
            Q_target_param.data.copy_(self.tau * Q_local_param.data +
                                      (1.0 - self.tau) * Q_target_param.data)
Exemple #7
0
class Agent:
    def __init__(self,
                 tau,
                 gamma,
                 batch_size,
                 lr,
                 state_size,
                 actions_size,
                 device,
                 double=True,
                 duel=False,
                 visual=False,
                 prioritized=False):
        self.tau = tau
        self.gamma = gamma
        self.batch_size = batch_size
        self.actions_size = actions_size
        self.device = device
        self.double = double
        self.duel = duel
        self.prioritized = prioritized

        if visual:
            self.Q_target = Visual_Q_Networks(state_size,
                                              actions_size,
                                              duel=duel).to(self.device)
            self.Q_local = Visual_Q_Networks(state_size, actions_size,
                                             duel).to(self.device)
        else:
            self.Q_target = Q_Network(state_size, actions_size,
                                      duel=duel).to(self.device)
            self.Q_local = Q_Network(state_size, actions_size,
                                     duel=duel).to(self.device)

        self.optimizer = optim.Adam(self.Q_local.parameters(), lr=lr)
        self.soft_update()

        if self.prioritized:
            self.memory = ProportionalReplayBuffer(int(1e5), batch_size)
            # self.memory = RankedReplayBuffer(int(1e5), batch_size)
        else:
            self.memory = ReplayBuffer(int(1e5), batch_size)

    def act(self, state, epsilon=0.1):
        """
        :param state: 输入的input
        :param epsilon:  以epsilon的概率进行exploration, 以 1 - epsilon的概率进行exploitation
        :return:
        """
        if random.random() > epsilon:
            state = torch.tensor(state, dtype=torch.float32).to(self.device)
            with torch.no_grad():
                actions_value = self.Q_local(state)
            return np.argmax(actions_value.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.actions_size))

    def learn(self):
        if self.prioritized:
            index_list, states, actions, rewards, next_states, dones, probs = self.memory.sample(
                self.batch_size)
            w = 1 / len(self.memory) / probs
            w = w / torch.max(w)
            w = w.to(self.device)
        else:
            states, actions, rewards, next_states, dones = self.memory.sample(
                self.batch_size)
            w = torch.ones(actions.size())
            w = w.to(self.device)

        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        next_states = next_states.to(self.device)
        dones = dones.to(self.device)

        Q_local_values = self.Q_local(states)
        Q_local_values = torch.gather(Q_local_values, dim=-1, index=actions)

        with torch.no_grad():
            Q_targets_values = self.Q_target(next_states)
            if self.double:
                max_actions = torch.max(input=self.Q_local(next_states),
                                        dim=1,
                                        keepdim=True)[1]
                Q_targets_values = torch.gather(Q_targets_values,
                                                dim=1,
                                                index=max_actions)
            else:
                Q_targets_values = torch.max(input=Q_targets_values,
                                             dim=1,
                                             keepdim=True)[0]

            Q_targets_values = rewards + self.gamma * (
                1 - dones) * Q_targets_values

        deltas = Q_local_values - Q_targets_values

        loss = (w * deltas).pow(2).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.prioritized:
            deltas = np.abs(deltas.detach().cpu().numpy().reshape(-1))
            for i in range(self.batch_size):
                self.memory.insert(deltas[i], index_list[i])

    def soft_update(self):
        for Q_target_param, Q_local_param in zip(self.Q_target.parameters(),
                                                 self.Q_local.parameters()):
            Q_target_param.data.copy_(self.tau * Q_local_param.data +
                                      (1.0 - self.tau) * Q_target_param.data)