Example #1
0
    def __init__(self,
                 meta_controller_experience_memory=None,
                 lr=0.00025,
                 alpha=0.95,
                 eps=0.01,
                 batch_size=32,
                 gamma=0.99,
                 num_options=12):
        # expereince replay memory
        self.meta_controller_experience_memory = meta_controller_experience_memory
        self.lr = lr  # learning rate
        self.alpha = alpha  # optimizer parameter
        self.eps = 0.01  # optimizer parameter
        self.gamma = 0.99
        # BUILD MODEL
        USE_CUDA = torch.cuda.is_available()
        if torch.cuda.is_available() and torch.cuda.device_count() > 1:
            self.device = torch.device("cuda:1")
        elif torch.cuda.device_count() == 1:
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")

        dfloat_cpu = torch.FloatTensor
        dfloat_gpu = torch.cuda.FloatTensor

        dlong_cpu = torch.LongTensor
        dlong_gpu = torch.cuda.LongTensor

        duint_cpu = torch.ByteTensor
        dunit_gpu = torch.cuda.ByteTensor

        dtype = torch.cuda.FloatTensor if torch.cuda.is_available(
        ) else torch.FloatTensor
        dlongtype = torch.cuda.LongTensor if torch.cuda.is_available(
        ) else torch.LongTensor
        duinttype = torch.cuda.ByteTensor if torch.cuda.is_available(
        ) else torch.ByteTensor

        self.dtype = dtype
        self.dlongtype = dlongtype
        self.duinttype = duinttype

        Q = DQN(in_channels=4, num_actions=num_options).type(dtype)
        Q_t = DQN(in_channels=4, num_actions=num_options).type(dtype)
        Q_t.load_state_dict(Q.state_dict())
        Q_t.eval()
        for param in Q_t.parameters():
            param.requires_grad = False

        Q = Q.to(self.device)
        Q_t = Q_t.to(self.device)

        self.batch_size = batch_size
        self.Q = Q
        self.Q_t = Q_t
        # optimizer
        optimizer = optim.RMSprop(Q.parameters(), lr=lr, alpha=alpha, eps=eps)
        self.optimizer = optimizer
        print('init: Meta Controller --> OK')
class ParallelNashAgent():
    def __init__(self, env, id, args):
        super(ParallelNashAgent, self).__init__()
        self.id = id
        self.current_model = DQN(env, args).to(args.device)
        self.target_model = DQN(env, args).to(args.device)
        update_target(self.current_model, self.target_model)

        if args.load_model and os.path.isfile(args.load_model):
            self.load_model(model_path)

        self.epsilon_by_frame = epsilon_scheduler(args.eps_start,
                                                  args.eps_final,
                                                  args.eps_decay)
        self.replay_buffer = ParallelReplayBuffer(args.buffer_size)
        self.rl_optimizer = optim.Adam(self.current_model.parameters(),
                                       lr=args.lr)

    def save_model(self, model_path):
        torch.save(self.current_model.state_dict(),
                   model_path + f'/{self.id}_dqn')
        torch.save(self.target_model.state_dict(),
                   model_path + f'/{self.id}_dqn_target')

    def load_model(self, model_path, eval=False, map_location=None):
        self.current_model.load_state_dict(
            torch.load(model_path + f'/{self.id}_dqn',
                       map_location=map_location))
        self.target_model.load_state_dict(
            torch.load(model_path + f'/{self.id}_dqn_target',
                       map_location=map_location))
        if eval:
            self.current_model.eval()
            self.target_model.eval()
Example #3
0
def test(env, args):
    current_model = DQN(env, args).to(args.device)
    current_model.eval()

    load_model(current_model, args)

    episode_reward = 0
    episode_length = 0

    state = env.reset()
    while True:
        if args.render:
            env.render()

        action = current_model.act(
            torch.FloatTensor(state).to(args.device), 0.)

        next_state, reward, done, _ = env.step(action)

        state = next_state
        episode_reward += reward
        episode_length += 1

        if done:
            break

    print("Test Result - Reward {} Length {}".format(episode_reward,
                                                     episode_length))
Example #4
0
File: deepq.py Project: shukon/SDC
def evaluate(env, load_path='agent.pt'):
    """ Evaluate a trained model and compute your leaderboard scores
	
	NO CHANGES SHOULD BE MADE TO THIS FUNCTION
	
    Parameters
    -------
    env: gym.Env
        environment to evaluate on
    load_path: str
        path to load the model (.pt) from
    """
    episode_rewards = [0.0]
    actions = get_action_set()
    action_size = len(actions)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # These are not the final evaluation seeds, do not overfit on these tracks!
    seeds = [
        22597174, 68545857, 75568192, 91140053, 86018367, 49636746, 66759182,
        91294619, 84274995, 31531469
    ]

    # Build & load network
    policy_net = DQN(action_size, device).to(device)
    checkpoint = torch.load(load_path, map_location=device)
    policy_net.load_state_dict(checkpoint)
    policy_net.eval()

    # Iterate over a number of evaluation episodes
    for i in range(10):
        env.seed(seeds[i])
        obs, done = env.reset(), False
        obs = get_state(obs)
        t = 0

        # Run each episode until episode has terminated or 600 time steps have been reached
        while not done and t < 600:
            env.render()
            action_id = select_greedy_action(obs, policy_net, action_size)
            action = actions[action_id]
            obs, rew, done, _ = env.step(action)
            obs = get_state(obs)
            episode_rewards[-1] += rew
            t += 1
        print('episode %d \t reward %f' % (i, episode_rewards[-1]))
        episode_rewards.append(0.0)

    print('---------------------------')
    print(' total score: %f' % np.mean(np.array(episode_rewards)))
    print('---------------------------')
def train_setting(env, device):

    init_screen = get_screen(env, device)
    _, _, screen_height, screen_width = init_screen.shape

    # Get number of actions from gym action space
    n_actions = env.action_space.n

    policy_net = DQN(screen_height, screen_width, n_actions).to(device)
    target_net = DQN(screen_height, screen_width, n_actions).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.RMSprop(policy_net.parameters())
    memory = ReplayMemory(10000)
    return n_actions, policy_net, target_net, optimizer, memory
Example #6
0
def test(env, args): 
    p1_current_model = DQN(env, args).to(args.device)
    p2_current_model = DQN(env, args).to(args.device)
    p1_policy = Policy(env).to(args.device)
    p2_policy = Policy(env).to(args.device)
    p1_current_model.eval(), p2_current_model.eval()
    p1_policy.eval(), p2_policy.eval()

    load_model(models={"p1": p1_current_model, "p2": p2_current_model},
               policies={"p1": p1_policy, "p2": p2_policy}, args=args)

    p1_reward_list = []
    p2_reward_list = []
    length_list = []

    for _ in range(30):
        (p1_state, p2_state) = env.reset()
        p1_episode_reward = 0
        p2_episode_reward = 0
        episode_length = 0
        while True:
            if args.render:
                env.render()
                sleep(0.01)

            # Agents follow average strategy
            p1_action = p1_policy.act(torch.FloatTensor(p1_state).to(args.device))
            p2_action = p2_policy.act(torch.FloatTensor(p2_state).to(args.device))

            actions = {"1": p1_action, "2": p2_action}

            (p1_next_state, p2_next_state), reward, done, _ = env.step(actions)

            (p1_state, p2_state) = (p1_next_state, p2_next_state)
            p1_episode_reward += reward[0]
            p2_episode_reward += reward[1]
            episode_length += 1

            if done:
                p1_reward_list.append(p1_episode_reward)
                p2_reward_list.append(p2_episode_reward)
                length_list.append(episode_length)
                break
    
    print("Test Result - Length {:.2f} p1/Reward {:.2f} p2/Reward {:.2f}".format(
        np.mean(length_list), np.mean(p1_reward_list), np.mean(p2_reward_list)))
    
Example #7
0
def test(env, args): 
    p1_current_model = DQN(env, args).to(args.device)
    p2_current_model = DQN(env, args).to(args.device)
    p1_current_model.eval()
    p2_current_model.eval()

    load_model(p1_current_model, args, 1)
    load_model(p2_current_model, args, 2)

    p1_reward_list = []
    p2_reward_list = []
    length_list = []

    for _ in range(30):
        (p1_state, p2_state) = env.reset()
        p1_episode_reward = 0
        p2_episode_reward = 0
        episode_length = 0
        while True:
            if args.render:
                env.render()
            from time import sleep
            sleep(0.2)

            p1_action = p1_current_model.act(torch.FloatTensor(p1_state).to(args.device), 0.0)
            p2_action = p2_current_model.act(torch.FloatTensor(p2_state).to(args.device), 0.0)

            actions = {"1": p1_action, "2": p2_action}

            (p1_next_state, p2_next_state), reward, done, _ = env.step(actions)

            (p1_state, p2_state) = (p1_next_state, p2_next_state)
            p1_episode_reward += reward[0]
            p2_episode_reward += reward[1]
            episode_length += 1

            if done:
                p1_reward_list.append(p1_episode_reward)
                p2_reward_list.append(p2_episode_reward)
                length_list.append(episode_length)
                break
    
    print("Test Result - p1/Reward {} p2/Reward Length {}".format(
        np.mean(p1_reward_list), np.mean(p2_reward_list)))
    
Example #8
0
def collect_training_rewards(policy_dir_path):
    env = gym.make("PongDeterministic-v4")
    epsilon = 0.05
    reward_tuples = []

    try:
        for f in os.listdir(policy_dir_path):
            print(f"processing {f}")
            dqn = DQN()
            dqn.load_state_dict(
                torch.load(policy_dir_path + "/" + f, map_location=torch.device("cpu"))
            )
            dqn.eval()

            obs = env.reset()
            s = TrainPongV0.prepare_state(obs)
            tot_reward = 0

            while True:
                if np.random.rand() < epsilon:
                    a = np.random.choice(range(0, 6))
                else:
                    a = dqn(s).argmax()

                prev_s = s
                obs, r, d, _ = env.step(a)
                s = TrainPongV0.prepare_state(obs, prev_s=prev_s)

                tot_reward += r

                if d:
                    break
            reward_tuples.append((tot_reward, int(f.replace("HPC_", ""))))

    finally:
        reward_tuples.sort(key=lambda s: s[1])
        ls = [s[1] for s in reward_tuples]
        rs = [s[0] for s in reward_tuples]
        np.save("reward_tuples", reward_tuples)
        plt.plot(ls, rs)
        plt.show()

    env.close()
Example #9
0
def render_model(path, for_gif=False, epsilon=0):
    """Render model from the given path

    0 <= epsilon < 1
    """
    env = gym.make("PongDeterministic-v4")
    dqn = DQN()
    dqn.load_state_dict(torch.load(path, map_location=torch.device("cpu")))
    dqn.eval()

    obs = env.reset()
    s = TrainPongV0.prepare_state(obs)
    frames = []
    tot_reward = 0
    try:
        for i in range(15000):

            if for_gif:
                frames.append(Image.fromarray(env.render(mode="rgb_array")))
            else:
                env.render()

            if np.random.rand() < epsilon:
                a = np.random.choice([1, 2, 3])
            else:
                with torch.no_grad():
                    a = dqn(torch.from_numpy(s))[0].argmax() + 1

            prev_s = s
            obs, r, d, _ = env.step(a)
            tot_reward += r
            s = TrainPongV0.prepare_state(obs, prev_s=prev_s)
            if d:
                break

    except KeyboardInterrupt:
        pass

    env.close()

    if for_gif:
        return tot_reward, frames
Example #10
0
def test(env, args): 
    current_model = DQN(env, args).to(args.device)
    current_model.eval()

    load_model(current_model, args)

    episode_reward = 0
    episode_length = 0

    state_buffer = deque(maxlen=args.action_repeat)
    states_deque = actions_deque = rewards_deque = None
    state, state_buffer = get_initial_state(env, state_buffer, args.action_repeat)
    while True:

        action = current_model.act(torch.FloatTensor(state).to(args.device), 0.)
        next_state, _, done, end = env.step(action, save_screenshots=True)
        add_state(next_state, state_buffer)
        next_state = recent_state(state_buffer)

        state = next_state

        if end:
            break
        # delete the agents that have reached the goal
        r_index = 0
        for r in range(len(done)):
            if done[r] is True:
                state_buffer, states_deque, actions_deque, rewards_deque = \
                    del_record(r_index, state_buffer, states_deque, actions_deque, rewards_deque)
                r_index -= 1
            r_index += 1
        next_state = recent_state(state_buffer)

        state = next_state
    PanicEnv.display(True)
    print("Test Result - Reward {} Length {}".format(episode_reward, episode_length))
    
class AgentEval:
    def __init__(self, args, env):
        self.action_space = env.action_space
        self.atoms = args.atoms
        self.v_min = args.V_min
        self.v_max = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        for m in self.online_net.modules():
            print(m)

        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.online_net.load_state_dict(torch.load(args.model, map_location='cpu'))
        self.online_net.eval()

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(self, state, epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return self.action_space.sample() if np.random.random() < epsilon else self.act(state)

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item()
Example #12
0
if torch.cuda.is_available():
    device0 = torch.device("cuda:0")
else:
    device0 = torch.device("cpu")

dtype = torch.cuda.FloatTensor if torch.cuda.is_available(
) else torch.FloatTensor
dlongtype = torch.cuda.LongTensor if torch.cuda.is_available(
) else torch.LongTensor
duinttype = torch.cuda.ByteTensor if torch.cuda.is_available(
) else torch.ByteTensor

Qt = DQN(in_channels=5, num_actions=18).type(dtype)
Qt_t = DQN(in_channels=5, num_actions=18).type(dtype)
Qt_t.load_state_dict(Qt.state_dict())
Qt_t.eval()
for param in Qt_t.parameters():
    param.requires_grad = False

if torch.cuda.device_count() > 0:
    Qt = nn.DataParallel(Qt).to(device0)
    Qt_t = nn.DataParallel(Qt_t).to(device0)
    batch_size = BATCH_SIZE * torch.cuda.device_count()
else:
    batch_size = BATCH_SIZE

# optimizer
optimizer = optim.RMSprop(Qt.parameters(),
                          lr=LEARNING_RATE,
                          alpha=ALPHA,
                          eps=EPS)
Example #13
0
class DQNAgent:
    def __init__(self, state_size, action_size, config=RLConfig()):
        self.seed = random.seed(config.seed)
        self.state_size = state_size
        self.action_size = action_size
        self.batch_size = config.batch_size
        self.batch_indices = torch.arange(config.batch_size).long().to(device)
        self.samples_before_learning = config.samples_before_learning
        self.learn_interval = config.learning_interval
        self.parameter_update_interval = config.parameter_update_interval
        self.per_epsilon = config.per_epsilon
        self.tau = config.tau
        self.gamma = config.gamma

        if config.useDuelingDQN:
            self.qnetwork_local = DuelingDQN(state_size, action_size,
                                             config.seed).to(device)
            self.qnetwork_target = DuelingDQN(state_size, action_size,
                                              config.seed).to(device)
        else:
            self.qnetwork_local = DQN(state_size, action_size,
                                      config.seed).to(device)
            self.qnetwork_target = DQN(state_size, action_size,
                                       config.seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(),
                                    lr=config.learning_rate)

        self.doubleDQN = config.useDoubleDQN
        self.usePER = config.usePER
        if self.usePER:
            self.memory = PrioritizedReplayBuffer(config.buffer_size,
                                                  config.per_alpha)
        else:
            self.memory = ReplayBuffer(config.buffer_size)

        self.t_step = 0

    def act(self, state, eps=0.):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        if random.random() < eps:
            return random.choice(np.arange(self.action_size))
        else:
            return np.argmax(action_values.cpu().data.numpy())

    def step(self, state, action, reward, next_state, done, beta):
        self.memory.add(state, action, reward, next_state, done)
        self.t_step += 1
        if self.t_step % self.learn_interval == 0:
            if len(self.memory) > self.samples_before_learning:
                state = torch.from_numpy(state).float().unsqueeze(0).to(device)
                next_state = torch.from_numpy(next_state).float().unsqueeze(
                    0).to(device)
                target = self.qnetwork_local(state).data
                old_val = target[0][action]
                target_val = self.qnetwork_target(next_state).data
                if done:
                    target[0][action] = reward
                else:
                    target[0][
                        action] = reward + self.gamma * torch.max(target_val)
                if self.usePER:
                    states, actions, rewards, next_states, dones, weights, indices = self.memory.sample(
                        self.batch_size, beta)
                else:
                    indices = None
                    weights = None
                    states, actions, rewards, next_states, dones = self.memory.sample(
                        self.batch_size)

                self.learn(states, actions, rewards, next_states, dones,
                           indices, weights, self.gamma)

    def learn(self, states, actions, rewards, next_states, dones, indices,
              weights, gamma):
        states = torch.from_numpy(np.vstack(states)).float().to(device)
        actions = torch.from_numpy(np.vstack(actions)).long().to(device)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(device)
        next_states = torch.from_numpy(
            np.vstack(next_states)).float().to(device)
        dones = torch.from_numpy(np.vstack(dones.astype(
            np.uint8))).float().to(device)
        Q_targets_next = self.qnetwork_target(next_states).detach()

        if self.doubleDQN:
            # choose the best action from the local network
            next_actions = self.qnetwork_local(next_states).argmax(dim=-1)
            Q_targets_next = Q_targets_next[self.batch_indices, next_actions]
        else:
            Q_targets_next = Q_targets_next.max(1)[0]

        Q_targets = rewards + gamma * Q_targets_next.reshape(
            (self.batch_size, 1)) * (1 - dones)

        pred = self.qnetwork_local(states)
        Q_expected = pred.gather(1, actions)

        if self.usePER:
            errors = torch.abs(Q_expected -
                               Q_targets).data.numpy() + self.per_epsilon
            self.memory.update_priorities(indices, errors)

        self.optimizer.zero_grad()
        loss = F.mse_loss(Q_expected, Q_targets)
        loss.backward()
        self.optimizer.step()

        if self.t_step % self.parameter_update_interval == 0:
            self.soft_update(self.qnetwork_local, self.qnetwork_target,
                             self.tau)

    def soft_update(self, qnetwork_local, qnetwork_target, tau):
        for local_param, target_param in zip(qnetwork_local.parameters(),
                                             qnetwork_target.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
Example #14
0
# ------------------------------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = "cpu"
print(f"Using device: {DEVICE}")
print(f"Settings:\n{SETTINGS}")

n_actions = len(SETTINGS["actions"])
n_episodes = SETTINGS["num_episodes"]
max_episode_len = SETTINGS["max_episode_length"]
dims = SETTINGS["world_dims"]
eps = SETTINGS["eps"]

policy_net = DQN(dims[0], dims[1], dims[2], n_actions).to(DEVICE)
target_net = DQN(dims[0], dims[1], dims[2], n_actions).to(DEVICE)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ExperienceReplay(100)

total_steps = 0

env = gameEnv(partial=False, size=SETTINGS["world_size"])

#
# Methods
# ------------------------------------------------------------------------------


def select_action(state, eps, n_actions):
    f"""
Example #15
0
class Agent():
    def __init__(self, args, env):
        self.args = args
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount
        self.norm_clip = args.norm_clip
        self.coeff = 0.01 if args.game in [
            'pong', 'boxing', 'private_eye', 'freeway'
        ] else 1.

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        self.momentum_net = DQN(args, self.action_space).to(device=args.device)
        # self.predictor = prediction_MLP(in_dim=128, hidden_dim=128, out_dim=128)

        if args.model:  # Load pretrained model if provided
            if os.path.isfile(args.model):
                state_dict = torch.load(
                    args.model, map_location='cpu'
                )  # Always load tensors onto CPU by default, will shift to GPU if necessary
                if 'conv1.weight' in state_dict.keys():
                    for old_key, new_key in (('conv1.weight',
                                              'convs.0.weight'),
                                             ('conv1.bias', 'convs.0.bias'),
                                             ('conv2.weight',
                                              'convs.2.weight'),
                                             ('conv2.bias', 'convs.2.bias'),
                                             ('conv3.weight',
                                              'convs.4.weight'),
                                             ('conv3.bias', 'convs.4.bias')):
                        state_dict[new_key] = state_dict[
                            old_key]  # Re-map state dict for old pretrained models
                        del state_dict[
                            old_key]  # Delete old keys for strict load_state_dict
                self.online_net.load_state_dict(state_dict)
                print("Loading pretrained model: " + args.model)
            else:  # Raise error if incorrect model path provided
                raise FileNotFoundError(args.model)

        self.online_net.train()
        # self.pred.train()
        self.initialize_momentum_net()
        self.momentum_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        for param in self.momentum_net.parameters():
            param.requires_grad = False
        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.learning_rate,
                                    eps=args.adam_eps)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            a, _, _ = self.online_net(state.unsqueeze(0))
            return (a * self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return np.random.randint(
            0, self.action_space
        ) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        # print('\n\n---------------')
        # print(f'idxs: {idxs}, ')
        # print(f'states: {states.shape}, ')
        # print(f'actions: {actions.shape}, ')
        # print(f'returns: {returns.shape}, ')
        # print(f'next_states: {next_states.shape}, ')
        # print(f'nonterminals: {nonterminals.shape}, ')
        # print(f'weights: {weights.shape},')

        aug_states_1 = aug(states).to(device=self.args.device)
        aug_states_2 = aug(states).to(device=self.args.device)

        # print(f'aug_states_1: {aug_states_1.shape}')
        # print(f'aug_states_2: {aug_states_2.shape}')

        # Calculate current state probabilities (online network noise already sampled)
        log_ps, _, _ = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)

        _, z_1, p_1 = self.online_net(aug_states_1, log=True)
        _, z_2, p_2 = self.online_net(aug_states_2, log=True)
        # p_1, p_2 = self.pred(z_1), self.pred(z_2)

        # with torch.no_grad():
        #     p_2 = self.pred(z_2)

        simsiam_loss = 2 + D(p_1, z_2) / 2 + D(p_2, z_1) / 2
        # simsiam_loss = p_1.mean() + p_2.mean()
        # simsiam_loss = p_1.mean() * 128
        # simsiam_loss = - F.cosine_similarity(p_1, z_2.detach(), dim=-1).mean()
        # print(simsiam_loss)
        # simsiam_loss = 0

        # _, z_target = self.momentum_net(aug_states_2, log=True) #z_k
        # z_proj = torch.matmul(self.online_net.W, z_target.T)
        # logits = torch.matmul(z_anch, z_proj)
        # logits = (logits - torch.max(logits, 1)[0][:, None])
        # logits = logits * 0.1
        # labels = torch.arange(logits.shape[0]).long().to(device=self.args.device)
        # moco_loss = (nn.CrossEntropyLoss()(logits, labels)).to(device=self.args.device)

        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        # print(f'z_1: {z_1.shape}')
        # print(f'p_1: {p_1.shape}')
        # print('---------------\n\n')

        # 1/0

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns, _, _ = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns, _, _ = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        # loss = loss + (moco_loss * self.coeff)
        loss = loss + (simsiam_loss * self.coeff)
        self.online_net.zero_grad()
        # self.pred.zero_grad()
        curl_loss = (weights * loss).mean()
        # print(curl_loss)
        curl_loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        self.optimiser.step()

        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # Update priorities of sampled transitions

    def learn_old(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        # print('\n\n---------------')
        # print(f'idxs: {idxs}, ')
        # print(f'states: {states.shape}, ')
        # print(f'actions: {actions.shape}, ')
        # print(f'returns: {returns.shape}, ')
        # print(f'next_states: {next_states.shape}, ')
        # print(f'nonterminals: {nonterminals.shape}, ')
        # print(f'weights: {weights.shape},')

        aug_states_1 = aug(states).to(device=self.args.device)
        aug_states_2 = aug(states).to(device=self.args.device)

        # print(f'aug_states_1: {aug_states_1.shape}')
        # print(f'aug_states_2: {aug_states_2.shape}')

        # Calculate current state probabilities (online network noise already sampled)
        log_ps, _, _ = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        _, z_anch, _ = self.online_net(aug_states_1, log=True)  #z_q
        _, z_target, _ = self.momentum_net(aug_states_2, log=True)  #z_k
        z_proj = torch.matmul(self.online_net.W, z_target.T)
        logits = torch.matmul(z_anch, z_proj)
        logits = (logits - torch.max(logits, 1)[0][:, None])
        logits = logits * 0.1
        labels = torch.arange(
            logits.shape[0]).long().to(device=self.args.device)
        moco_loss = (nn.CrossEntropyLoss()(logits,
                                           labels)).to(device=self.args.device)

        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        # print(f'z_anch: {z_anch.shape}')
        # print(f'z_target: {z_target.shape}')
        # print(f'z_proj: {z_proj.shape}')
        # print(f'logits: {logits.shape}')
        # print(logits)
        # print(f'labels: {labels.shape}')
        # print(labels)
        # print('---------------\n\n')

        # 1/0

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns, _, _ = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns, _, _ = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        print(moco_loss)
        loss = loss + (moco_loss * self.coeff)
        self.online_net.zero_grad()
        curl_loss = (weights * loss).mean()
        curl_loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        self.optimiser.step()

        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def initialize_momentum_net(self):
        for param_q, param_k in zip(self.online_net.parameters(),
                                    self.momentum_net.parameters()):
            param_k.data.copy_(param_q.data)  # update
            param_k.requires_grad = False  # not update by gradient

    # Code for this function from https://github.com/facebookresearch/moco
    @torch.no_grad()
    def update_momentum_net(self, momentum=0.999):
        for param_q, param_k in zip(self.online_net.parameters(),
                                    self.momentum_net.parameters()):
            param_k.data.copy_(momentum * param_k.data +
                               (1. - momentum) * param_q.data)  # update

    # Save model parameters on current device (don't move model between devices)
    def save(self, path, name='model.pth'):
        torch.save(self.online_net.state_dict(), os.path.join(path, name))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            a, _, _ = self.online_net(state.unsqueeze(0))
            return (a * self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
Example #16
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.batch_size = args.batch_size
        self.discount = args.discount
        self.max_gradient_norm = args.max_gradient_norm

        self.policy_net = DQN(args, self.action_space)
        if args.model and os.path.isfile(args.model):
            self.policy_net.load_state_dict(torch.load(args.model))
        self.policy_net.train()

        self.target_net = DQN(args, self.action_space)
        self.update_target_net()
        self.target_net.eval()

        self.optimiser = optim.Adam(self.policy_net.parameters(), lr=args.lr)

    def act(self, state, epsilon):
        if random.random() > epsilon:
            return self.policy_net(state.unsqueeze(0)).max(1)[1].data[0]
        else:
            return random.randint(0, self.action_space - 1)

    def learn(self, mem):
        transitions = mem.sample(self.batch_size)
        batch = Transition(*zip(*transitions))  # Transpose the batch

        states = Variable(torch.stack(batch.state, 0))
        actions = Variable(torch.LongTensor(batch.action).unsqueeze(1))
        rewards = Variable(torch.Tensor(batch.reward))
        non_final_mask = torch.ByteTensor(
            tuple(map(
                lambda s: s is not None,
                batch.next_state)))  # Only process non-terminal next states
        next_states = Variable(
            torch.stack(tuple(s for s in batch.next_state if s is not None),
                        0),
            volatile=True
        )  # Prevent backpropagating through expected action values

        Qs = self.policy_net(states).gather(1, actions)  # Q(s_t, a_t; θpolicy)
        next_state_argmax_indices = self.policy_net(next_states).max(
            1, keepdim=True
        )[1]  # Perform argmax action selection using policy network: argmax_a[Q(s_t+1, a; θpolicy)]
        Qns = Variable(torch.zeros(
            self.batch_size))  # Q(s_t+1, a) = 0 if s_t+1 is terminal
        Qns[non_final_mask] = self.target_net(next_states).gather(
            1, next_state_argmax_indices
        )  # Q(s_t+1, argmax_a[Q(s_t+1, a; θpolicy)]; θtarget)
        Qns.volatile = False  # Remove volatile flag to prevent propagating it through loss
        target = rewards + (
            self.discount * Qns
        )  # Double-Q target: Y = r + γ.Q(s_t+1, argmax_a[Q(s_t+1, a; θpolicy)]; θtarget)

        loss = F.smooth_l1_loss(
            Qs, target)  # Huber loss on TD-error δ: δ = Y - Q(s_t, a_t)
        # TODO: TD-error clipping?
        self.policy_net.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm(self.policy_net.parameters(),
                                self.max_gradient_norm)  # Clamp gradients
        self.optimiser.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def save(self, path):
        torch.save(self.policy_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    def evaluate_q(self, state):
        return self.policy_net(state.unsqueeze(0)).max(1)[0].data[0]

    def train(self):
        self.policy_net.train()

    def eval(self):
        self.policy_net.eval()
Example #17
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self,
                 state_size,
                 action_size,
                 seed,
                 gamma=0.99,
                 step_size=1,
                 dueling_dqn=False):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        # Q-Network
        if dueling_dqn:
            print("Use dueling dqn")
            self.qnetwork_local = NoisyDuelingDQN(state_size, action_size,
                                                  seed).to(device)
            self.qnetwork_target = NoisyDuelingDQN(state_size, action_size,
                                                   seed).to(device)
        else:
            print("Use non-dueling dqn")
            self.qnetwork_local = DQN(state_size, action_size, seed).to(device)
            self.qnetwork_target = DQN(state_size, action_size,
                                       seed).to(device)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
        self.gamma = gamma
        self.step_size = step_size

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                experiences = self.memory.sample()
                self.learn(experiences)

    def act(self, state):
        """Returns actions for given state as per current policy.

        Params
        ======
            state (array_like): current state
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        return np.argmax(action_values.cpu().data.numpy())

    def learn(self, experiences):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # Compute and minimize loss
        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.qnetwork_target(next_states).detach().max(
            1)[0].unsqueeze(1)
        # Compute Q targets for current states
        ## gamma ^ step_size for nstep dqn
        Q_targets = rewards + (pow(self.gamma, self.step_size) *
                               Q_targets_next * (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
Example #18
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.online_net.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        (weights * loss).mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        self.optimiser.step()

        mem.update_priorities(
            idxs, loss.detach())  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
Example #19
0
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Do the soft target update
                paramlist = list()
                for i, param in enumerate(model.parameters()):
                    paramlist.append(param)

                for i, tparam in enumerate(target.parameters()):
                    tparam.data.copy_(tau * paramlist[i].data +
                                      (1 - tau) * tparam.data)

            # Handle epsilon-greedy exploration
            state = torch.from_numpy(state).float().unsqueeze(0)
            model.eval()
            with torch.no_grad():
                Qsa = model(state)

            model.train()

            # Handle exploration/exploitation
            rand = random.uniform(0, 1)
            if rand < epsilon:  # Explore
                action = random.choice(np.arange(total_actions))  #TODO: change
            else:  # Exploit
                action = np.argmax(Qsa.data.numpy())

            # Get the next state
            next_state, reward, done, info = env.step(action)
            score += reward
class Agent:
    """
    The intelligent agent of the simulation. Set the model of the neural network used and general parameters.
    It is responsible to select the actions, optimize the neural network and manage the models.
    """

    def __init__(self, action_set, train=True, load_path=None):
        #1. Initialize agent params
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.action_set = action_set
        self.action_number = len(action_set)
        self.steps_done = 0
        self.epsilon = Config.EPS_START
        self.episode_durations = []

        print('LOAD PATH    --  agent.init:', load_path)
        time.sleep(2)

        #2. Build networks
        self.policy_net = DQN().to(self.device)
        self.target_net = DQN().to(self.device)
        
        self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=Config.LEARNING_RATE)

        if not train:
            print('entrou no not train')        
            self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=0)    
            self.policy_net.load(load_path, optimizer=self.optimizer)
            self.policy_net.eval()

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.memory = ReplayMemory(1000)

        


    def select_action(self, state, train=True):
        """
        Selet the best action according to the Q-values outputed from the neural network

        Parameters
        ----------
            state: float ndarray
                The current state on the simulation
            train: bool
                Define if we are evaluating or trainning the model
        Returns
        -------
            a.max(1)[1]: int
                The action with the highest Q-value
            a.max(0): float
                The Q-value of the action taken
        """
        global steps_done
        sample = random.random()
        #1. Perform a epsilon-greedy algorithm
        #a. set the value for epsilon
        self.epsilon = Config.EPS_END + (Config.EPS_START - Config.EPS_END) * \
            math.exp(-1. * self.steps_done / Config.EPS_DECAY)
            
        self.steps_done += 1

        #b. make the decision for selecting a random action or selecting an action from the neural network
        if sample > self.epsilon or (not train):
            # select an action from the neural network
            with torch.no_grad():
                # a <- argmax Q(s, theta)
                a = self.policy_net(state)
                return a.max(1)[1].view(1, 1), a.max(0)
        else:
            # select a random action
            print('random action')
            return torch.tensor([[random.randrange(2)]], device=self.device, dtype=torch.long), None

    def optimize_model(self):
        """
        Perform one step of optimization on the neural network
        """

        if len(self.memory) < Config.BATCH_SIZE:
            return
        transitions = self.memory.sample(Config.BATCH_SIZE)

        # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for detailed explanation).
        batch = Transition(*zip(*transitions))

        # Compute a mask of non-final states and concatenate the batch elements
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                              batch.next_state)), device=self.device, dtype=torch.uint8)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                                    if s is not None])
        
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the columns of actions taken
        state_action_values = self.policy_net(state_batch).gather(1, action_batch)
        
    
        # Compute argmax Q(s', a; θ)        
        next_state_actions = self.policy_net(non_final_next_states).max(1)[1].detach().unsqueeze(1)

        # Compute Q(s', argmax Q(s', a; θ), θ-)
        next_state_values = torch.zeros(Config.BATCH_SIZE, device=self.device)
        next_state_values[non_final_mask] = self.target_net(non_final_next_states).gather(1, next_state_actions).squeeze(1).detach()

        # Compute the expected Q values
        expected_state_action_values = (next_state_values * Config.GAMMA) + reward_batch


        # Compute Huber loss
        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
        
        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

    def save(self, step, logs_path, label):
        """
        Save the model on hard disc

        Parameters
        ----------
            step: int
                current step on the simulation
            logs_path: string
                path to where we will store the model
            label: string
                label that will be used to store the model
        """

        os.makedirs(logs_path + label, exist_ok=True)

        full_label = label + str(step) + '.pth'
        logs_path = os.path.join(logs_path, label, full_label)

        self.policy_net.save(logs_path, step=step, optimizer=self.optimizer)
    
    def restore(self, logs_path):
        """
        Load the model from hard disc

        Parameters
        ----------
            logs_path: string
                path to where we will store the model
        """
        self.policy_net.load(logs_path)
        self.target_net.load(logs_path)
Example #21
0
class Agent():
  def __init__(self, args, env):
    self.action_space = env.action_space()
    self.atoms = args.atoms
    self.Vmin = args.V_min
    self.Vmax = args.V_max
    self.support = torch.linspace(args.V_min, args.V_max, args.atoms)  # Support (range) of z
    self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1)
    self.batch_size = args.batch_size
    self.n = args.multi_step
    self.discount = args.discount
    self.priority_exponent = args.priority_exponent
    self.max_gradient_norm = args.max_gradient_norm

    self.policy_net = DQN(args, self.action_space)
    if args.model and os.path.isfile(args.model):
      self.policy_net.load_state_dict(torch.load(args.model))
    self.policy_net.train()

    self.target_net = DQN(args, self.action_space)
    self.update_target_net()
    self.target_net.eval()

    self.optimiser = optim.Adam(self.policy_net.parameters(), lr=args.lr, eps=args.adam_eps)
    if args.cuda:
      self.policy_net.cuda()
      self.target_net.cuda()
      self.support = self.support.cuda()

  # Resets noisy weights in all linear layers (of policy and target nets)
  def reset_noise(self):
    self.policy_net.reset_noise()
    self.target_net.reset_noise()

  # Acts based on single state (no batch)
  def act(self, state):
    return (self.policy_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[1][0]

  def learn(self, mem):
    idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)
    batch_size = len(idxs)  # May return less than specified if invalid transitions sampled

    # Calculate current state probabilities
    ps = self.policy_net(states)  # Probabilities p(s_t, ·; θpolicy)
    ps_a = ps[range(batch_size), actions]  # p(s_t, a_t; θpolicy)

    # Calculate nth next state probabilities
    pns = self.policy_net(next_states).data  # Probabilities p(s_t+n, ·; θpolicy)
    dns = self.support.expand_as(pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θpolicy))
    argmax_indices_ns = dns.sum(2).max(1)[1]  # Perform argmax action selection using policy network: argmax_a[(z, p(s_t+n, a; θpolicy))]
    pns = self.target_net(next_states).data  # Probabilities p(s_t+n, ·; θtarget)
    pns_a = pns[range(batch_size), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θpolicy))]; θtarget)
    pns_a *= nonterminals  # Set p = 0 for terminal nth next states as all possible expected returns = expected reward at final transition

    # Compute Tz (Bellman operator T applied to z)
    Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
    Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  # Clamp between supported values
    # Compute L2 projection of Tz onto fixed support z
    b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
    l, u = b.floor().long(), b.ceil().long()

    # Distribute probability of Tz
    m = states.data.new(batch_size, self.atoms).zero_()
    offset = torch.linspace(0, ((batch_size - 1) * self.atoms), batch_size).long().unsqueeze(1).expand(batch_size, self.atoms).type_as(actions)
    m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
    m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

    loss = -torch.sum(Variable(m) * ps_a.log(), 1)  # Cross-entropy loss (minimises Kullback-Leibler divergence)
    self.policy_net.zero_grad()
    (weights * loss).mean().backward()  # Importance weight losses
    nn.utils.clip_grad_norm(self.policy_net.parameters(), self.max_gradient_norm)  # Clip gradients (normalising by max value of gradient L2 norm)
    self.optimiser.step()

    mem.update_priorities(idxs, loss.data.abs().pow(self.priority_exponent))  # Update priorities of sampled transitions

  def update_target_net(self):
    self.target_net.load_state_dict(self.policy_net.state_dict())

  def save(self, path):
    torch.save(self.policy_net.state_dict(), os.path.join(path, 'model.pth'))

  # Evaluates Q-value based on single state (no batch)
  def evaluate_q(self, state):
    return (self.policy_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[0][0]

  def train(self):
    self.policy_net.train()

  def eval(self):
    self.policy_net.eval()
Example #22
0
class DQNAgent:
    """
    Interacts with and learns from the environment.
    Vanilla DQN.
    """
    def __init__(self, state_size: int, action_size: int, seed: int):
        """
        Initialize an Agent object.

        :param state_size: dimension of each state;
        :param action_size: dimension of each action;
        :param seed: random seed.
        """

        self.state_size = state_size
        self.action_size = action_size
        random.seed(seed)

        # Q-Network
        self.network_local = DQN(state_size, action_size, seed).to(DEVICE)
        self.network_target = DQN(state_size, action_size, seed).to(DEVICE)
        self.optimizer = optim.Adam(self.network_local.parameters(), lr=LR)

        # Replay memory
        self.memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE, seed)

        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

    def step(self, state, action: int, reward: float, next_state, done):
        """
        Save experiences in the replay memory and check if it's time to learn.

        :param state: (array_like) current state;
        :param action: action taken;
        :param reward: reward received;
        :param next_state: (array_like) next state;
        :param done: terminal state indicator; int or bool.
        """

        # Save experience in replay memory
        self.memory.push(state, action, reward, next_state, done)

        # Increment time step and compare it to the network update frequency
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            # Check if there is enough samples in the memory to learn
            if len(self.memory) > BATCH_SIZE:
                # sample experiences from memory
                experiences = self.memory.sample()
                # learn from sampled experiences
                self.learn(experiences, GAMMA)

    def act(self, state, eps: float = 0.):
        """
        Returns actions for given state as per current policy.

        :param state: (array_like) current state
        :param eps: epsilon, for epsilon-greedy action selection
        """

        state = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
        self.network_local.eval()
        with torch.no_grad():
            action_values = self.network_local(state)
        self.network_local.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences, gamma: float):
        """
        Update value parameters using given batch of experience tuples.

        :param experiences: (Tuple[torch.Tensor]) tuple of (s, a, r, s', done) tuples;
        :param gamma: discount factor.
        """

        states, actions, rewards, next_states, dones = experiences

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.network_target(next_states).detach().max(
            1)[0].unsqueeze(1)
        # Compute Q targets for current states
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.network_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.network_local, self.network_target, TAU)

    @staticmethod
    def soft_update(local_model, target_model, tau: float):
        """
        Soft update model parameters,
        θ_target = τ*θ_local + (1 - τ)*θ_target.

        :param local_model: (PyTorch model) weights will be copied from;
        :param target_model: (PyTorch model) weights will be copied to;
        :param tau: interpolation parameter.
        """

        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
Example #23
0
class Agent:
  state: int
  actions: int
  history: int = 4
  atoms: int = 5 #51
  Vmin: float = -10
  Vmax: float = 10
  
  lr: float = 1e-5
  batch_size: int = 32
  discount: float = 0.99
  norm_clip: float = 10.

  def __post_init__(self):
    self.support = torch.linspace(self.Vmin, self.Vmax, self.atoms)
    self.delta_z = (self.Vmax - self.Vmin) / (self.atoms - 1)

    self.online_net = DQN(self.state, self.actions, self.history, self.atoms)
    self.online_net.train()

    self.target_net = DQN(self.state, self.actions, self.history, self.atoms)
    self.update_target_net()
    self.target_net.train()
    for param in self.target_net.parameters(): param.requires_grad = False

    self.optimiser = optim.Adam(self.online_net.parameters(), lr=self.lr)

  def act(self, state):
    state = torch.FloatTensor(state).unsqueeze(0)
    with torch.no_grad():
      return (self.online_net(state) * self.support).sum(2).argmax(1).item()

  def act_e_greedy(self, state, epsilon=0.001):
    return random.randrange(self.actions) if random.random() < epsilon else self.act(state)

  def learn(self, buffer):
    state, action, reward, next_state, terminal, weights, idx = buffer.sample(self.batch_size)
    state = torch.FloatTensor(state)
    action = torch.LongTensor(action)
    reward = torch.FloatTensor(reward)
    next_state = torch.FloatTensor(next_state)
    terminal = torch.FloatTensor(terminal)
    weights = torch.FloatTensor(weights)

    log_ps = self.online_net(state, log=True)
    log_ps_a = log_ps[range(self.batch_size), action]

    with torch.no_grad():
      # Calculate nth next state probabilities
      pns = self.online_net(next_state)
      dns = self.support.expand_as(pns) * pns
      argmax_indices_ns = dns.sum(2).argmax(1)
      self.target_net.sample_noise()
      pns = self.target_net(next_state)
      pns_a = pns[range(self.batch_size), argmax_indices_ns]

      # Compute Bellman operator T applied to z
      Tz = reward.unsqueeze(1) + (1 - terminal).unsqueeze(1) * self.discount * self.support.unsqueeze(0) # -10 ... 10 + reward
      Tz.clamp_(min=self.Vmin, max=self.Vmax)
      
      # Compute L2 projection of Tz onto fixed support z
      b = (Tz - self.Vmin) / self.delta_z # 0 ... 4
      l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
      # Fix disappearing probability mass when l = b = u (b is int)
      l[(u > 0) * (l == u)] -= 1
      u[(l < (self.atoms - 1)) * (l == u)] += 1

      # Distribute probability of Tz
      m = state.new_zeros(self.batch_size, self.atoms)
      offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(action)
      m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
      m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

    loss = -torch.sum(m * log_ps_a, 1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
    loss = weights * loss

#     q_values = self.online_net(state)
#     q_value = q_values[range(self.batch_size), action]

#     next_q_values = self.target_net(next_state)
#     next_q_value = next_q_values.max(1)[0]

#     expected_q_value = reward + self.discount * next_q_value * (1 - terminal)
#     loss = weights * (q_value - expected_q_value).pow(2)

    self.optimiser.zero_grad()
    loss.mean().backward()
    self.optimiser.step()
    nn.utils.clip_grad_norm_(self.online_net.parameters(), self.norm_clip)

    buffer.update_priorities(idx, loss.tolist())

  def update_target_net(self):
    self.target_net.load_state_dict(self.online_net.state_dict())

  def sample_noise(self):
    self.online_net.sample_noise()

  def save(self, path):
    torch.save(self.online_net.state_dict(), path)

  # Evaluates Q-value based on single state (no batch)
  def evaluate_q(self, state):
    with torch.no_grad():
      return self.online_net(state.unsqueeze(0)).max(1)[0].item()

  def train(self):
    self.online_net.train()

  def eval(self):
    self.online_net.eval()
Example #24
0
    while True:
        state = get_state(env.reset()).to(device)
        while True:
            with torch.no_grad():
                action = policy_net(state).max(1)[1].view(1, 1)
            next_state, _, done, _ = env.step(action)

            if done:
                break
            next_state = get_state(next_state).to(device)
            state = next_state

            
if __name__ == '__main__': 
    # enable cuda is available 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialise the game
    env = gym.make('ChromeDino-v0')
    # env = gym.make('ChromeDinoNoBrowser-v0')
    env = make_dino(env, timer=True, frame_stack=True)
    # Get the number of actions and the dimension of input
    n_actions = env.action_space.n
    
    # initialise networks 
    policy_net = DQN(n_actions=n_actions).to(device)
    trained_model = torch.load('checkpoints/model_2000.pth')
    policy_net.load_state_dict(trained_model)
    policy_net.eval() 
    
    test(env)
Example #25
0
class Agent(object):
    def __init__(self, args, action_space):
        self.action_space = action_space
        self.batch_size = args.batch_size
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)
        self.loss_func = nn.MSELoss()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            return self.online_net([state]).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.05):  # High ε can reduce evaluation scores drastically
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):

        # Sample transitions
        states, actions, next_states, rewards = mem.sample(self.batch_size)

        q_eval = self.online_net(states).gather(
            1, actions.unsqueeze(1)).squeeze()
        with torch.no_grad():
            q_eval_next_a = self.online_net(next_states).argmax(1)
            q_next = self.target_net(next_states)
            q_target = rewards + self.discount * q_next.gather(
                1, q_eval_next_a.unsqueeze(1)).squeeze()

        loss = self.loss_func(q_eval, q_target)
        self.online_net.zero_grad()
        loss.backward()
        self.optimiser.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path):
        torch.save(self.online_net.state_dict(), path + '.pth')

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net([state])).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
Example #26
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount
        self.norm_clip = args.norm_clip

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        if args.model:  # Load pretrained model if provided
            if os.path.isfile(args.model):
                state_dict = torch.load(
                    args.model, map_location='cpu'
                )  # Always load tensors onto CPU by default, will shift to GPU if necessary
                if 'conv1.weight' in state_dict.keys():
                    for old_key, new_key in (('conv1.weight',
                                              'convs.0.weight'),
                                             ('conv1.bias', 'convs.0.bias'),
                                             ('conv2.weight',
                                              'convs.2.weight'),
                                             ('conv2.bias', 'convs.2.bias'),
                                             ('conv3.weight',
                                              'convs.4.weight'),
                                             ('conv3.bias', 'convs.4.bias')):
                        state_dict[new_key] = state_dict[
                            old_key]  # Re-map state dict for old pretrained models
                        del state_dict[
                            old_key]  # Delete old keys for strict load_state_dict
                self.online_net.load_state_dict(state_dict)
                print("Loading pretrained model: " + args.model)
            else:  # Raise error if incorrect model path provided
                raise FileNotFoundError(args.model)

        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        # self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps)
        self.convs_optimiser = optim.Adam(self.online_net.convs.parameters(),
                                          lr=args.learning_rate,
                                          eps=args.adam_eps)
        self.linear_optimiser = optim.Adam(chain(
            self.online_net.fc_h_v.parameters(),
            self.online_net.fc_h_a.parameters(),
            self.online_net.fc_z_v.parameters(),
            self.online_net.fc_z_a.parameters()),
                                           lr=args.learning_rate,
                                           eps=args.adam_eps)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):

        with torch.no_grad():
            # don't count these calls since it is accounted for after "action = dqn.act(state)" in main.py
            ret = (self.online_net(state.unsqueeze(0)) *
                   self.support).sum(2).argmax(1).item()
            return ret

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return np.random.randint(
            0, self.action_space
        ) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem, freeze=False):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights, _ = mem.sample(
            self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        # self.optimiser.step()
        if not freeze:
            self.convs_optimiser.step()
        self.linear_optimiser.step()

    def learn_with_latent(self, latent_mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights, ns = latent_mem.sample(
            self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net.forward_with_latent(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)
        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net.forward_with_latent(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution ds_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net.forward_with_latent(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # use ns instead of self.n since n is possibly different for each sequence in the batch
            ns = torch.tensor(ns, device=latent_mem.device).unsqueeze(1)
            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**ns) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        # self.optimiser.step()
        self.linear_optimiser.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path, name='model.pth'):
        torch.save(self.online_net.state_dict(), os.path.join(path, name))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
Example #27
0
class DQNAgent:
    """
        初始化
        @:param env_id : gym环境id
    """

    def __init__(self, env_id, config):
        # gym
        self._env_id = env_id
        self._env = gym.make(env_id)
        self._state_size = self._env.observation_space.shape[0]
        self._action_size = self._env.action_space.n
        # 参数
        self._gamma = config.gamma
        self._learning_rate = config.lr
        self._reward_boundary = config.reward_boundary
        self._device = torch.device("cuda" if config.cuda and torch.cuda.is_available() else "cpu")
        # model
        self._model = DQN(self._state_size, self._action_size).to(self._device)
        self._optimizer = torch.optim.Adam(self._model.parameters(), lr=self._learning_rate)
        # 经验池
        self._replay_buffer = deque(maxlen=config.buffer_size)
        self._mini_batch = config.mini_batch
        # epsilon
        self._epsilon = config.epsilon
        self._epsilon_min = config.epsilon_min
        self._epsilon_decay = config.epsilon_decay

    """
        将observation放入双向队列中,队列满时自动删除最旧的元素
    """

    def remember(self, state, action, next_state, reward, done):
        self._replay_buffer.append((state, action, next_state, reward, done))

        # epsilon幂指数下降
        if len(self._replay_buffer) > self._mini_batch:
            if self._epsilon > self._epsilon_min:
                self._epsilon *= self._epsilon_decay
        pass

    """
        epsilon-greedy action
    """

    def act(self, state):
        # 类似模拟退火,random返回[0,1]
        if np.random.random() <= self._epsilon:
            return random.randrange(self._action_size)
        else:
            # numpy转成tensor,unsqueeze在下标0处新增一个维度
            state = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(self._device)
            # 模型预测
            predict = self._model(state)
            # max在第1维处取最大,[1]为下标,[0]为值, [512*2]-> [521]
            return predict.max(1)[1].item()
        pass

    """
        训练
        1、从双向队列中采样mini_batch
        2、预测next_state
        3、更新优化器
    """

    def replay(self):
        if len(self._replay_buffer) < self._mini_batch:
            return
        # 1、从双向队列中采样mini_batch
        mini_batch = random.sample(self._replay_buffer, self._mini_batch)

        # 载入方式一
        # state = np.zeros((self._mini_batch, self._state_size))
        # next_state = np.zeros((self._mini_batch, self._state_size))
        # action, reward, done = [], [], []
        #
        # for i in range(self._mini_batch):
        #     state[i] = mini_batch[i][0]
        #     action.append(mini_batch[i][1])
        #     next_state[i] = mini_batch[i][2]
        #     reward.append(mini_batch[i][3])
        #     done.append(mini_batch[i][4])

        # 载入方式二
        state, action, next_state, reward, done = zip(*mini_batch)
        state = torch.tensor(state, dtype=torch.float).to(self._device)
        action = torch.tensor(action, dtype=torch.long).to(self._device)
        next_state = torch.tensor(next_state, dtype=torch.float).to(self._device)
        reward = torch.tensor(reward, dtype=torch.float).to(self._device)
        done = torch.tensor(done, dtype=torch.float).to(self._device)

        # 2、预测next_state
        q_target = reward + \
                   self._gamma * self._model(next_state).to(self._device).max(1)[0] * (1 - done)

        q_values = self._model(state).to(self._device).gather(1, action.unsqueeze(1)).squeeze(1)
        loss_func = nn.MSELoss()
        loss = loss_func(q_values, q_target)
        # loss = (q_values - q_target.detach()).pow(2).mean()

        # 3、更新优化器
        self._optimizer.zero_grad()
        loss.backward()
        self._optimizer.step()

        return loss.item()

    """
        1、渲染gym环境开始交互
        2、训练模型
    """

    def training(self):
        writer = SummaryWriter(comment="-train-" + self._env_id)
        print(self._model)

        # 参数
        frame_index = 0
        episode_index = 1
        best_mean_reward = None
        mean_reward = 0
        total_rewards = []

        while mean_reward < self._reward_boundary:

            state = self._env.reset()
            # 一轮结束,reward置零
            episode_reward = 0

            while True:
                # 1、渲染gym环境开始交互
                self._env.render()

                # 选择action进行交互
                action = self.act(state)
                next_state, reward, done, _ = self._env.step(action)
                self.remember(state, action, next_state, reward, done)
                state = next_state
                frame_index += 1
                episode_reward += reward

                # 2、训练模型
                loss = self.replay()

                # 游戏结束,开始训练模型
                if done:
                    if loss is not None:
                        print("episode: %4d, frames: %5d, reward: %5f, loss: %4f, epsilon: %4f" % (
                            episode_index, frame_index, np.mean(total_rewards[-10:]), loss, self._epsilon))

                    episode_index += 1
                    total_rewards.append(episode_reward)
                    mean_reward = np.mean(total_rewards[-10:])

                    writer.add_scalar("epsilon", self._epsilon, frame_index)
                    writer.add_scalar("episode_reward", episode_reward, frame_index)
                    writer.add_scalar("mean_reward", mean_reward, frame_index)
                    if best_mean_reward is None or best_mean_reward < mean_reward:
                        torch.save(self._model.state_dict(), "training-best.dat")
                    break

        self._env.close()
        pass

    def test(self, model_path):
        if model_path is None:
            return
        self._model.load_state_dict(torch.load(model_path))
        self._model.eval()

        total_rewards = []

        for episode_index in range(10):
            episode_reward = 0
            done = False
            state = self._env.reset()

            while not done:
                action = self.act(state)
                next_state, reward, done, _ = self._env.step(action)

                state = next_state
                episode_reward += reward

            total_rewards.append(episode_reward)
            print("episode: %4d, reward: %5f" % (episode_index, np.mean(total_rewards[-10:])))
class Agent:
	"""
	The intelligent agent of the simulation. Set the model of the neural network used and general parameters.
	It is responsible to select the actions, optimize the neural network and manage the models.
	"""

	def __init__(self, action_set, train=True, load_path=None):
		#1. Initialize agent params
		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		self.action_set = action_set
		self.action_number = len(action_set)
		self.steps_done = 0
		self.epsilon = Config.EPS_START
		self.episode_durations = []

		#2. Build networks
		self.policy_net = DQN().to(self.device)
		self.target_net = DQN().to(self.device)
		
		self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=Config.LEARNING_RATE)

		if not train:		
			self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=0)	
			self.policy_net.load(load_path, optimizer=self.optimizer)
			self.policy_net.eval()

		self.target_net.load_state_dict(self.policy_net.state_dict())
		self.target_net.eval()

		#3. Create Prioritized Experience Replay Memory
		self.memory = Memory(Config.MEMORY_SIZE)


	 
	def append_sample(self, state, action, next_state, reward):
		"""
		save sample (error,<s,a,s',r>) to the replay memory
		"""

		# Define if is the end of the simulation
		done = True if next_state is None else False

		# Compute Q(s_t, a) - the model computes Q(s_t), then we select the columns of actions taken
		state_action_values = self.policy_net(state)
		state_action_values = state_action_values.gather(1, action.view(-1,1))

		
		if not done:
			# Compute argmax Q(s', a; θ)		
			next_state_actions = self.policy_net(next_state).max(1)[1].detach().unsqueeze(1)

			# Compute Q(s', argmax Q(s', a; θ), θ-)
			next_state_values = self.target_net(next_state).gather(1, next_state_actions).squeeze(1).detach()

			# Compute the expected Q values
			expected_state_action_values = (next_state_values * Config.GAMMA) + reward
		else:
			expected_state_action_values = reward


		error = abs(state_action_values - expected_state_action_values).data.cpu().numpy()


		self.memory.add(error, state, action, next_state, reward)

	def select_action(self, state, train=True):
		"""
		Selet the best action according to the Q-values outputed from the neural network

		Parameters
		----------
			state: float ndarray
				The current state on the simulation
			train: bool
				Define if we are evaluating or trainning the model
		Returns
		-------
			a.max(1)[1]: int
				The action with the highest Q-value
			a.max(0): float
				The Q-value of the action taken
		"""
		global steps_done
		sample = random.random()
		#1. Perform a epsilon-greedy algorithm
		#a. set the value for epsilon
		self.epsilon = Config.EPS_END + (Config.EPS_START - Config.EPS_END) * \
			math.exp(-1. * self.steps_done / Config.EPS_DECAY)
			
		self.steps_done += 1

		#b. make the decision for selecting a random action or selecting an action from the neural network
		if sample > self.epsilon or (not train):
			# select an action from the neural network
			with torch.no_grad():
				# a <- argmax Q(s, theta)
				a = self.policy_net(state)
				return a.max(1)[1].view(1, 1), a.max(0)
		else:
			# select a random action
			print('random action')
			return torch.tensor([[random.randrange(2)]], device=self.device, dtype=torch.long), None

	"""
	def select_action(self, state, train=True):
		
		Selet the best action according to the Q-values outputed from the neural network

		Parameters
		----------
			state: float ndarray
				The current state on the simulation
			train: bool
				Define if we are evaluating or trainning the model
		Returns
		-------
			a.max(1)[1]: int
				The action with the highest Q-value
			a.max(0): float
				The Q-value of the action taken
		
		global steps_done
		sample = random.random()
		#1. Perform a epsilon-greedy algorithm
		#a. set the value for epsilon
		self.epsilon = Config.EPS_END + (Config.EPS_START - Config.EPS_END) * \
			math.exp(-1. * self.steps_done / Config.EPS_DECAY)
			
		self.steps_done += 1

		#b. make the decision for selecting a random action or selecting an action from the neural network
		if sample > self.epsilon or (not train):
			# select an action from the neural network
			with torch.no_grad():
				# a <- argmax Q(s, theta)
				#set the network to train mode is important to enable dropout
				self.policy_net.train()
				output_list = []
				# Retrieve the outputs from neural network feedfoward n times to build a statistic model
				for i in range(Config.STOCHASTIC_PASSES):
					#print(agent.policy_net(data))
					output_list.append(torch.unsqueeze(F.softmax(self.policy_net(state)), 0))
					#print(output_list[i])

				self.policy_net.eval()
				# The result of the network is the mean of n passes
				output_mean = torch.cat(output_list, 0).mean(0)
				q_value = output_mean.data.cpu().numpy().max()
				action = output_mean.max(1)[1].view(1, 1)

				uncertainty = torch.cat(output_list, 0).var(0).mean().item()
				
				return action, q_value, uncertainty
				
		else:
			# select a random action
			print('random action')
			return torch.tensor([[random.randrange(2)]], device=self.device, dtype=torch.long), None, None

	"""
	def optimize_model(self):
		"""
		Perform one step of optimization on the neural network
		"""

		if self.memory.tree.n_entries < Config.BATCH_SIZE:
			return
		transitions, idxs, is_weights = self.memory.sample(Config.BATCH_SIZE)

		# Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for detailed explanation).
		batch = Transition(*zip(*transitions))

		# Compute a mask of non-final states and concatenate the batch elements
		non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
											  batch.next_state)), device=self.device, dtype=torch.uint8)
		non_final_next_states = torch.cat([s for s in batch.next_state
													if s is not None])
		
		state_batch = torch.cat(batch.state)
		action_batch = torch.cat(batch.action)
		reward_batch = torch.cat(batch.reward)
		
		# Compute Q(s_t, a) - the model computes Q(s_t), then we select the columns of actions taken
		state_action_values = self.policy_net(state_batch).gather(1, action_batch)
		
	
		# Compute argmax Q(s', a; θ)		
		next_state_actions = self.policy_net(non_final_next_states).max(1)[1].detach().unsqueeze(1)

		# Compute Q(s', argmax Q(s', a; θ), θ-)
		next_state_values = torch.zeros(Config.BATCH_SIZE, device=self.device)
		next_state_values[non_final_mask] = self.target_net(non_final_next_states).gather(1, next_state_actions).squeeze(1).detach()

		# Compute the expected Q values
		expected_state_action_values = (next_state_values * Config.GAMMA) + reward_batch

		# Update priorities
		errors = torch.abs(state_action_values.squeeze() - expected_state_action_values).data.cpu().numpy()
		
		# update priority
		for i in range(Config.BATCH_SIZE):
			idx = idxs[i]
			self.memory.update(idx, errors[i])


		# Compute Huber loss
		loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
		loss_return = loss.item()

		# Optimize the model
		self.optimizer.zero_grad()
		loss.backward()
		for param in self.policy_net.parameters():
			param.grad.data.clamp_(-1, 1)
		self.optimizer.step()

		return loss_return

	def save(self, step, logs_path, label):
		"""
		Save the model on hard disc

		Parameters
		----------
			step: int
				current step on the simulation
			logs_path: string
				path to where we will store the model
			label: string
				label that will be used to store the model
		"""

		os.makedirs(logs_path + label, exist_ok=True)

		full_label = label + str(step) + '.pth'
		logs_path = os.path.join(logs_path, label, full_label)

		self.policy_net.save(logs_path, step=step, optimizer=self.optimizer)
	
	def restore(self, logs_path):
		"""
		Load the model from hard disc

		Parameters
		----------
			logs_path: string
				path to where we will store the model
		"""
		self.policy_net.load(logs_path)
		self.target_net.load(logs_path)
Example #29
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max,
                                      args.atoms)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space)
        if args.model and os.path.isfile(args.model):
            self.online_net.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)
        if args.cuda:
            self.online_net.cuda()
            self.target_net.cuda()
            self.support = self.support.cuda()

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        return (self.online_net(state.unsqueeze(0)).data *
                self.support).sum(2).max(1)[1][0]

    # Acts with an ε-greedy policy
    def act_e_greedy(self, state, epsilon=0.001):
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)

        # Calculate current state probabilities
        self.online_net.reset_noise()  # Sample new noise for online network
        ps = self.online_net(states)  # Probabilities p(s_t, ·; θonline)
        ps_a = ps[range(self.batch_size), actions]  # p(s_t, a_t; θonline)

        # Calculate nth next state probabilities
        self.online_net.reset_noise()  # Sample new noise for action selection
        pns = self.online_net(
            next_states).data  # Probabilities p(s_t+n, ·; θonline)
        dns = self.support.expand_as(
            pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
        argmax_indices_ns = dns.sum(2).max(
            1
        )[1]  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
        self.target_net.reset_noise()  # Sample new target net noise
        pns = self.target_net(
            next_states).data  # Probabilities p(s_t+n, ·; θtarget)
        pns_a = pns[range(
            self.batch_size
        ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

        # Compute Tz (Bellman operator T applied to z)
        Tz = returns.unsqueeze(1) + nonterminals * (
            self.discount**self.n) * self.support.unsqueeze(
                0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
        Tz = Tz.clamp(min=self.Vmin,
                      max=self.Vmax)  # Clamp between supported values
        # Compute L2 projection of Tz onto fixed support z
        b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
        l, u = b.floor().long(), b.ceil().long()
        # Fix disappearing probability mass when l = b = u (b is int)
        l[(u > 0) * (l == u)] -= 1
        u[(l < (self.atoms - 1)) * (l == u)] += 1

        # Distribute probability of Tz
        m = states.data.new(self.batch_size, self.atoms).zero_()
        offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                self.batch_size).unsqueeze(1).expand(
                                    self.batch_size,
                                    self.atoms).type_as(actions)
        m.view(-1).index_add_(
            0, (l + offset).view(-1),
            (pns_a *
             (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
        m.view(-1).index_add_(
            0, (u + offset).view(-1),
            (pns_a *
             (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        ps_a = ps_a.clamp(min=1e-3)  # Clamp for numerical stability in log
        loss = -torch.sum(
            Variable(m) * ps_a.log(),
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        (weights * loss).mean().backward()  # Importance weight losses
        self.optimiser.step()

        mem.update_priorities(
            idxs, loss.data)  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        return (self.online_net(state.unsqueeze(0)).data *
                self.support).sum(2).max(1)[0][0]

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
class Agent(object):
    """ all improvments from Rainbow research work
    """
    def __init__(self, args, state_size, action_size):
        """
        Args:
           param1 (args): args
           param2 (int): args
           param3 (int): args
        """
        self.action_size = action_size
        self.state_size = state_size
        self.atoms = args.atoms
        self.V_min = args.V_min
        self.V_max = args.V_max
        self.device = args.device
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=self.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.qnetwork_local = DQN(args, self.state_size,
                                  self.action_size).to(device=args.device)
        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.qnetwork_local.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.qnetwork_local.train()

        self.target_net = DQN(args, self.state_size,
                              self.action_size).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)

    def reset_noise(self):
        """ resets noisy weights in all linear layers """
        self.qnetwork_local.reset_noise()

    def act(self, state):
        """
          acts greedy(max) based on a single state
          Args:
             param1 (int) : state
        """
        with torch.no_grad():
            return (self.qnetwork_local(state.unsqueeze(0).to(self.device)) *
                    self.support).sum(2).argmax(1).item()

    def act_e_greedy(self, state, epsilon=0.001):
        """ acts with epsilon greedy policy
            epsilon exploration vs exploitation traide off
        Args:
            param1(int): state
            param2(float): epsilon
        Return : action int number between 0 and 4
        """
        return np.random.randint(
            0, self.action_size) if np.random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        """ uses samples with the given batch size to improve the Q function
        Args:
            param1 (Experince Replay Buffer) : mem
        """
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.qnetwork_local(
            states, log=True)  # Log probabilities log p(s_t, *; theta online)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; theat online)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.qnetwork_local(
                next_states)  # Probabilities p(s_t+n, *; theta online)
            dns = self.support.expand_as(
                pns
            ) * pns  # Distribution d_t+n = (z, p(s_t+n, *; theat online))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a;  theat online))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(
                next_states)  # Probabilities p(s_t+n,  ; theata target)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; theat online))]; theat target)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n
            ) * self.support.unsqueeze(
                0)  # Tz = R^n + (discoit ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.V_min,
                          max=self.V_max)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.V_min) / self.delta_z  # b = (Tz - Vmin) / delta z
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.qnetwork_local.zero_grad()
        (weights * loss).mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        self.optimizer.step()

        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # Update priorities of sampled transitions
        self.soft_update()

    def soft_update(self, tau=1e-3):
        """ swaps the network weights from the online to the target

        Args:
           param1 (float): tau
        """
        for target_param, local_param in zip(self.target_net.parameters(),
                                             self.qnetwork_local.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def update_target_net(self):
        """ copy the model weights from the online to the target network """
        self.target_net.load_state_dict(self.qnetwork_local.state_dict())

    def save(self, path):
        """ save the model weights to a file
        Args:
           param1 (string): pathname
        """
        torch.save(self.qnetwork_local.state_dict(),
                   os.path.join(path, 'model.pth'))

    def evaluate_q(self, state):
        """ Evaluates Q-value based on single state
        """
        with torch.no_grad():
            return (self.qnetwork_local(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        """
        activates the backprob. layers for the online network
        """
        self.qnetwork_local.train()

    def eval(self):
        """ invoke the eval from the online network
            deactivates the backprob
            layers like dropout will work in eval model instead
        """
        self.qnetwork_local.eval()