Exemplo n.º 1
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.online_net.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        (weights * loss).mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        self.optimiser.step()

        mem.update_priorities(
            idxs, loss.detach())  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
Exemplo n.º 2
0
class Agent():
    def __init__(self, args, env):
        self.args = args
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount
        self.norm_clip = args.norm_clip
        self.coeff = 0.01 if args.game in [
            'pong', 'boxing', 'private_eye', 'freeway'
        ] else 1.

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        self.momentum_net = DQN(args, self.action_space).to(device=args.device)
        # self.predictor = prediction_MLP(in_dim=128, hidden_dim=128, out_dim=128)

        if args.model:  # Load pretrained model if provided
            if os.path.isfile(args.model):
                state_dict = torch.load(
                    args.model, map_location='cpu'
                )  # Always load tensors onto CPU by default, will shift to GPU if necessary
                if 'conv1.weight' in state_dict.keys():
                    for old_key, new_key in (('conv1.weight',
                                              'convs.0.weight'),
                                             ('conv1.bias', 'convs.0.bias'),
                                             ('conv2.weight',
                                              'convs.2.weight'),
                                             ('conv2.bias', 'convs.2.bias'),
                                             ('conv3.weight',
                                              'convs.4.weight'),
                                             ('conv3.bias', 'convs.4.bias')):
                        state_dict[new_key] = state_dict[
                            old_key]  # Re-map state dict for old pretrained models
                        del state_dict[
                            old_key]  # Delete old keys for strict load_state_dict
                self.online_net.load_state_dict(state_dict)
                print("Loading pretrained model: " + args.model)
            else:  # Raise error if incorrect model path provided
                raise FileNotFoundError(args.model)

        self.online_net.train()
        # self.pred.train()
        self.initialize_momentum_net()
        self.momentum_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        for param in self.momentum_net.parameters():
            param.requires_grad = False
        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.learning_rate,
                                    eps=args.adam_eps)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            a, _, _ = self.online_net(state.unsqueeze(0))
            return (a * self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return np.random.randint(
            0, self.action_space
        ) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        # print('\n\n---------------')
        # print(f'idxs: {idxs}, ')
        # print(f'states: {states.shape}, ')
        # print(f'actions: {actions.shape}, ')
        # print(f'returns: {returns.shape}, ')
        # print(f'next_states: {next_states.shape}, ')
        # print(f'nonterminals: {nonterminals.shape}, ')
        # print(f'weights: {weights.shape},')

        aug_states_1 = aug(states).to(device=self.args.device)
        aug_states_2 = aug(states).to(device=self.args.device)

        # print(f'aug_states_1: {aug_states_1.shape}')
        # print(f'aug_states_2: {aug_states_2.shape}')

        # Calculate current state probabilities (online network noise already sampled)
        log_ps, _, _ = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)

        _, z_1, p_1 = self.online_net(aug_states_1, log=True)
        _, z_2, p_2 = self.online_net(aug_states_2, log=True)
        # p_1, p_2 = self.pred(z_1), self.pred(z_2)

        # with torch.no_grad():
        #     p_2 = self.pred(z_2)

        simsiam_loss = 2 + D(p_1, z_2) / 2 + D(p_2, z_1) / 2
        # simsiam_loss = p_1.mean() + p_2.mean()
        # simsiam_loss = p_1.mean() * 128
        # simsiam_loss = - F.cosine_similarity(p_1, z_2.detach(), dim=-1).mean()
        # print(simsiam_loss)
        # simsiam_loss = 0

        # _, z_target = self.momentum_net(aug_states_2, log=True) #z_k
        # z_proj = torch.matmul(self.online_net.W, z_target.T)
        # logits = torch.matmul(z_anch, z_proj)
        # logits = (logits - torch.max(logits, 1)[0][:, None])
        # logits = logits * 0.1
        # labels = torch.arange(logits.shape[0]).long().to(device=self.args.device)
        # moco_loss = (nn.CrossEntropyLoss()(logits, labels)).to(device=self.args.device)

        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        # print(f'z_1: {z_1.shape}')
        # print(f'p_1: {p_1.shape}')
        # print('---------------\n\n')

        # 1/0

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns, _, _ = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns, _, _ = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        # loss = loss + (moco_loss * self.coeff)
        loss = loss + (simsiam_loss * self.coeff)
        self.online_net.zero_grad()
        # self.pred.zero_grad()
        curl_loss = (weights * loss).mean()
        # print(curl_loss)
        curl_loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        self.optimiser.step()

        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # Update priorities of sampled transitions

    def learn_old(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        # print('\n\n---------------')
        # print(f'idxs: {idxs}, ')
        # print(f'states: {states.shape}, ')
        # print(f'actions: {actions.shape}, ')
        # print(f'returns: {returns.shape}, ')
        # print(f'next_states: {next_states.shape}, ')
        # print(f'nonterminals: {nonterminals.shape}, ')
        # print(f'weights: {weights.shape},')

        aug_states_1 = aug(states).to(device=self.args.device)
        aug_states_2 = aug(states).to(device=self.args.device)

        # print(f'aug_states_1: {aug_states_1.shape}')
        # print(f'aug_states_2: {aug_states_2.shape}')

        # Calculate current state probabilities (online network noise already sampled)
        log_ps, _, _ = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        _, z_anch, _ = self.online_net(aug_states_1, log=True)  #z_q
        _, z_target, _ = self.momentum_net(aug_states_2, log=True)  #z_k
        z_proj = torch.matmul(self.online_net.W, z_target.T)
        logits = torch.matmul(z_anch, z_proj)
        logits = (logits - torch.max(logits, 1)[0][:, None])
        logits = logits * 0.1
        labels = torch.arange(
            logits.shape[0]).long().to(device=self.args.device)
        moco_loss = (nn.CrossEntropyLoss()(logits,
                                           labels)).to(device=self.args.device)

        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        # print(f'z_anch: {z_anch.shape}')
        # print(f'z_target: {z_target.shape}')
        # print(f'z_proj: {z_proj.shape}')
        # print(f'logits: {logits.shape}')
        # print(logits)
        # print(f'labels: {labels.shape}')
        # print(labels)
        # print('---------------\n\n')

        # 1/0

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns, _, _ = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns, _, _ = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        print(moco_loss)
        loss = loss + (moco_loss * self.coeff)
        self.online_net.zero_grad()
        curl_loss = (weights * loss).mean()
        curl_loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        self.optimiser.step()

        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def initialize_momentum_net(self):
        for param_q, param_k in zip(self.online_net.parameters(),
                                    self.momentum_net.parameters()):
            param_k.data.copy_(param_q.data)  # update
            param_k.requires_grad = False  # not update by gradient

    # Code for this function from https://github.com/facebookresearch/moco
    @torch.no_grad()
    def update_momentum_net(self, momentum=0.999):
        for param_q, param_k in zip(self.online_net.parameters(),
                                    self.momentum_net.parameters()):
            param_k.data.copy_(momentum * param_k.data +
                               (1. - momentum) * param_q.data)  # update

    # Save model parameters on current device (don't move model between devices)
    def save(self, path, name='model.pth'):
        torch.save(self.online_net.state_dict(), os.path.join(path, name))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            a, _, _ = self.online_net(state.unsqueeze(0))
            return (a * self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
Exemplo n.º 3
0
class Agent(object):
    def __init__(self, args, action_space):
        self.action_space = action_space
        self.batch_size = args.batch_size
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)
        self.loss_func = nn.MSELoss()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            return self.online_net([state]).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.05):  # High ε can reduce evaluation scores drastically
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):

        # Sample transitions
        states, actions, next_states, rewards = mem.sample(self.batch_size)

        q_eval = self.online_net(states).gather(
            1, actions.unsqueeze(1)).squeeze()
        with torch.no_grad():
            q_eval_next_a = self.online_net(next_states).argmax(1)
            q_next = self.target_net(next_states)
            q_target = rewards + self.discount * q_next.gather(
                1, q_eval_next_a.unsqueeze(1)).squeeze()

        loss = self.loss_func(q_eval, q_target)
        self.online_net.zero_grad()
        loss.backward()
        self.optimiser.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path):
        torch.save(self.online_net.state_dict(), path + '.pth')

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net([state])).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
Exemplo n.º 4
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount
        self.norm_clip = args.norm_clip

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        if args.model:  # Load pretrained model if provided
            if os.path.isfile(args.model):
                state_dict = torch.load(
                    args.model, map_location='cpu'
                )  # Always load tensors onto CPU by default, will shift to GPU if necessary
                if 'conv1.weight' in state_dict.keys():
                    for old_key, new_key in (('conv1.weight',
                                              'convs.0.weight'),
                                             ('conv1.bias', 'convs.0.bias'),
                                             ('conv2.weight',
                                              'convs.2.weight'),
                                             ('conv2.bias', 'convs.2.bias'),
                                             ('conv3.weight',
                                              'convs.4.weight'),
                                             ('conv3.bias', 'convs.4.bias')):
                        state_dict[new_key] = state_dict[
                            old_key]  # Re-map state dict for old pretrained models
                        del state_dict[
                            old_key]  # Delete old keys for strict load_state_dict
                self.online_net.load_state_dict(state_dict)
                print("Loading pretrained model: " + args.model)
            else:  # Raise error if incorrect model path provided
                raise FileNotFoundError(args.model)

        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        # self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps)
        self.convs_optimiser = optim.Adam(self.online_net.convs.parameters(),
                                          lr=args.learning_rate,
                                          eps=args.adam_eps)
        self.linear_optimiser = optim.Adam(chain(
            self.online_net.fc_h_v.parameters(),
            self.online_net.fc_h_a.parameters(),
            self.online_net.fc_z_v.parameters(),
            self.online_net.fc_z_a.parameters()),
                                           lr=args.learning_rate,
                                           eps=args.adam_eps)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):

        with torch.no_grad():
            # don't count these calls since it is accounted for after "action = dqn.act(state)" in main.py
            ret = (self.online_net(state.unsqueeze(0)) *
                   self.support).sum(2).argmax(1).item()
            return ret

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return np.random.randint(
            0, self.action_space
        ) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem, freeze=False):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights, _ = mem.sample(
            self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        # self.optimiser.step()
        if not freeze:
            self.convs_optimiser.step()
        self.linear_optimiser.step()

    def learn_with_latent(self, latent_mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights, ns = latent_mem.sample(
            self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net.forward_with_latent(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)
        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net.forward_with_latent(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution ds_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net.forward_with_latent(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # use ns instead of self.n since n is possibly different for each sequence in the batch
            ns = torch.tensor(ns, device=latent_mem.device).unsqueeze(1)
            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**ns) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        # self.optimiser.step()
        self.linear_optimiser.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path, name='model.pth'):
        torch.save(self.online_net.state_dict(), os.path.join(path, name))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
Exemplo n.º 5
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.batch_size = args.batch_size
        self.discount = args.discount
        self.max_gradient_norm = args.max_gradient_norm

        self.policy_net = DQN(args, self.action_space)
        if args.model and os.path.isfile(args.model):
            self.policy_net.load_state_dict(torch.load(args.model))
        self.policy_net.train()

        self.target_net = DQN(args, self.action_space)
        self.update_target_net()
        self.target_net.eval()

        self.optimiser = optim.Adam(self.policy_net.parameters(), lr=args.lr)

    def act(self, state, epsilon):
        if random.random() > epsilon:
            return self.policy_net(state.unsqueeze(0)).max(1)[1].data[0]
        else:
            return random.randint(0, self.action_space - 1)

    def learn(self, mem):
        transitions = mem.sample(self.batch_size)
        batch = Transition(*zip(*transitions))  # Transpose the batch

        states = Variable(torch.stack(batch.state, 0))
        actions = Variable(torch.LongTensor(batch.action).unsqueeze(1))
        rewards = Variable(torch.Tensor(batch.reward))
        non_final_mask = torch.ByteTensor(
            tuple(map(
                lambda s: s is not None,
                batch.next_state)))  # Only process non-terminal next states
        next_states = Variable(
            torch.stack(tuple(s for s in batch.next_state if s is not None),
                        0),
            volatile=True
        )  # Prevent backpropagating through expected action values

        Qs = self.policy_net(states).gather(1, actions)  # Q(s_t, a_t; θpolicy)
        next_state_argmax_indices = self.policy_net(next_states).max(
            1, keepdim=True
        )[1]  # Perform argmax action selection using policy network: argmax_a[Q(s_t+1, a; θpolicy)]
        Qns = Variable(torch.zeros(
            self.batch_size))  # Q(s_t+1, a) = 0 if s_t+1 is terminal
        Qns[non_final_mask] = self.target_net(next_states).gather(
            1, next_state_argmax_indices
        )  # Q(s_t+1, argmax_a[Q(s_t+1, a; θpolicy)]; θtarget)
        Qns.volatile = False  # Remove volatile flag to prevent propagating it through loss
        target = rewards + (
            self.discount * Qns
        )  # Double-Q target: Y = r + γ.Q(s_t+1, argmax_a[Q(s_t+1, a; θpolicy)]; θtarget)

        loss = F.smooth_l1_loss(
            Qs, target)  # Huber loss on TD-error δ: δ = Y - Q(s_t, a_t)
        # TODO: TD-error clipping?
        self.policy_net.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm(self.policy_net.parameters(),
                                self.max_gradient_norm)  # Clamp gradients
        self.optimiser.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def save(self, path):
        torch.save(self.policy_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    def evaluate_q(self, state):
        return self.policy_net(state.unsqueeze(0)).max(1)[0].data[0]

    def train(self):
        self.policy_net.train()

    def eval(self):
        self.policy_net.eval()
Exemplo n.º 6
0
class Agent():
  def __init__(self, args, env):
    self.action_space = env.action_space()
    self.atoms = args.atoms
    self.Vmin = args.V_min
    self.Vmax = args.V_max
    self.support = torch.linspace(args.V_min, args.V_max, args.atoms)  # Support (range) of z
    self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1)
    self.batch_size = args.batch_size
    self.n = args.multi_step
    self.discount = args.discount
    self.priority_exponent = args.priority_exponent
    self.max_gradient_norm = args.max_gradient_norm

    self.policy_net = DQN(args, self.action_space)
    if args.model and os.path.isfile(args.model):
      self.policy_net.load_state_dict(torch.load(args.model))
    self.policy_net.train()

    self.target_net = DQN(args, self.action_space)
    self.update_target_net()
    self.target_net.eval()

    self.optimiser = optim.Adam(self.policy_net.parameters(), lr=args.lr, eps=args.adam_eps)
    if args.cuda:
      self.policy_net.cuda()
      self.target_net.cuda()
      self.support = self.support.cuda()

  # Resets noisy weights in all linear layers (of policy and target nets)
  def reset_noise(self):
    self.policy_net.reset_noise()
    self.target_net.reset_noise()

  # Acts based on single state (no batch)
  def act(self, state):
    return (self.policy_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[1][0]

  def learn(self, mem):
    idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)
    batch_size = len(idxs)  # May return less than specified if invalid transitions sampled

    # Calculate current state probabilities
    ps = self.policy_net(states)  # Probabilities p(s_t, ·; θpolicy)
    ps_a = ps[range(batch_size), actions]  # p(s_t, a_t; θpolicy)

    # Calculate nth next state probabilities
    pns = self.policy_net(next_states).data  # Probabilities p(s_t+n, ·; θpolicy)
    dns = self.support.expand_as(pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θpolicy))
    argmax_indices_ns = dns.sum(2).max(1)[1]  # Perform argmax action selection using policy network: argmax_a[(z, p(s_t+n, a; θpolicy))]
    pns = self.target_net(next_states).data  # Probabilities p(s_t+n, ·; θtarget)
    pns_a = pns[range(batch_size), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θpolicy))]; θtarget)
    pns_a *= nonterminals  # Set p = 0 for terminal nth next states as all possible expected returns = expected reward at final transition

    # Compute Tz (Bellman operator T applied to z)
    Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
    Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  # Clamp between supported values
    # Compute L2 projection of Tz onto fixed support z
    b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
    l, u = b.floor().long(), b.ceil().long()

    # Distribute probability of Tz
    m = states.data.new(batch_size, self.atoms).zero_()
    offset = torch.linspace(0, ((batch_size - 1) * self.atoms), batch_size).long().unsqueeze(1).expand(batch_size, self.atoms).type_as(actions)
    m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
    m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

    loss = -torch.sum(Variable(m) * ps_a.log(), 1)  # Cross-entropy loss (minimises Kullback-Leibler divergence)
    self.policy_net.zero_grad()
    (weights * loss).mean().backward()  # Importance weight losses
    nn.utils.clip_grad_norm(self.policy_net.parameters(), self.max_gradient_norm)  # Clip gradients (normalising by max value of gradient L2 norm)
    self.optimiser.step()

    mem.update_priorities(idxs, loss.data.abs().pow(self.priority_exponent))  # Update priorities of sampled transitions

  def update_target_net(self):
    self.target_net.load_state_dict(self.policy_net.state_dict())

  def save(self, path):
    torch.save(self.policy_net.state_dict(), os.path.join(path, 'model.pth'))

  # Evaluates Q-value based on single state (no batch)
  def evaluate_q(self, state):
    return (self.policy_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[0][0]

  def train(self):
    self.policy_net.train()

  def eval(self):
    self.policy_net.eval()
Exemplo n.º 7
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max,
                                      args.atoms)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space)
        if args.model and os.path.isfile(args.model):
            self.online_net.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)
        if args.cuda:
            self.online_net.cuda()
            self.target_net.cuda()
            self.support = self.support.cuda()

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        return (self.online_net(state.unsqueeze(0)).data *
                self.support).sum(2).max(1)[1][0]

    # Acts with an ε-greedy policy
    def act_e_greedy(self, state, epsilon=0.001):
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)

        # Calculate current state probabilities
        self.online_net.reset_noise()  # Sample new noise for online network
        ps = self.online_net(states)  # Probabilities p(s_t, ·; θonline)
        ps_a = ps[range(self.batch_size), actions]  # p(s_t, a_t; θonline)

        # Calculate nth next state probabilities
        self.online_net.reset_noise()  # Sample new noise for action selection
        pns = self.online_net(
            next_states).data  # Probabilities p(s_t+n, ·; θonline)
        dns = self.support.expand_as(
            pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
        argmax_indices_ns = dns.sum(2).max(
            1
        )[1]  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
        self.target_net.reset_noise()  # Sample new target net noise
        pns = self.target_net(
            next_states).data  # Probabilities p(s_t+n, ·; θtarget)
        pns_a = pns[range(
            self.batch_size
        ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

        # Compute Tz (Bellman operator T applied to z)
        Tz = returns.unsqueeze(1) + nonterminals * (
            self.discount**self.n) * self.support.unsqueeze(
                0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
        Tz = Tz.clamp(min=self.Vmin,
                      max=self.Vmax)  # Clamp between supported values
        # Compute L2 projection of Tz onto fixed support z
        b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
        l, u = b.floor().long(), b.ceil().long()
        # Fix disappearing probability mass when l = b = u (b is int)
        l[(u > 0) * (l == u)] -= 1
        u[(l < (self.atoms - 1)) * (l == u)] += 1

        # Distribute probability of Tz
        m = states.data.new(self.batch_size, self.atoms).zero_()
        offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                self.batch_size).unsqueeze(1).expand(
                                    self.batch_size,
                                    self.atoms).type_as(actions)
        m.view(-1).index_add_(
            0, (l + offset).view(-1),
            (pns_a *
             (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
        m.view(-1).index_add_(
            0, (u + offset).view(-1),
            (pns_a *
             (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        ps_a = ps_a.clamp(min=1e-3)  # Clamp for numerical stability in log
        loss = -torch.sum(
            Variable(m) * ps_a.log(),
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        (weights * loss).mean().backward()  # Importance weight losses
        self.optimiser.step()

        mem.update_priorities(
            idxs, loss.data)  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        return (self.online_net(state.unsqueeze(0)).data *
                self.support).sum(2).max(1)[0][0]

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
class Agent(object):
    """ all improvments from Rainbow research work
    """
    def __init__(self, args, state_size, action_size):
        """
        Args:
           param1 (args): args
           param2 (int): args
           param3 (int): args
        """
        self.action_size = action_size
        self.state_size = state_size
        self.atoms = args.atoms
        self.V_min = args.V_min
        self.V_max = args.V_max
        self.device = args.device
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=self.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.qnetwork_local = DQN(args, self.state_size,
                                  self.action_size).to(device=args.device)
        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.qnetwork_local.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.qnetwork_local.train()

        self.target_net = DQN(args, self.state_size,
                              self.action_size).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)

    def reset_noise(self):
        """ resets noisy weights in all linear layers """
        self.qnetwork_local.reset_noise()

    def act(self, state):
        """
          acts greedy(max) based on a single state
          Args:
             param1 (int) : state
        """
        with torch.no_grad():
            return (self.qnetwork_local(state.unsqueeze(0).to(self.device)) *
                    self.support).sum(2).argmax(1).item()

    def act_e_greedy(self, state, epsilon=0.001):
        """ acts with epsilon greedy policy
            epsilon exploration vs exploitation traide off
        Args:
            param1(int): state
            param2(float): epsilon
        Return : action int number between 0 and 4
        """
        return np.random.randint(
            0, self.action_size) if np.random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        """ uses samples with the given batch size to improve the Q function
        Args:
            param1 (Experince Replay Buffer) : mem
        """
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.qnetwork_local(
            states, log=True)  # Log probabilities log p(s_t, *; theta online)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; theat online)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.qnetwork_local(
                next_states)  # Probabilities p(s_t+n, *; theta online)
            dns = self.support.expand_as(
                pns
            ) * pns  # Distribution d_t+n = (z, p(s_t+n, *; theat online))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a;  theat online))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(
                next_states)  # Probabilities p(s_t+n,  ; theata target)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; theat online))]; theat target)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n
            ) * self.support.unsqueeze(
                0)  # Tz = R^n + (discoit ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.V_min,
                          max=self.V_max)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.V_min) / self.delta_z  # b = (Tz - Vmin) / delta z
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.qnetwork_local.zero_grad()
        (weights * loss).mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        self.optimizer.step()

        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # Update priorities of sampled transitions
        self.soft_update()

    def soft_update(self, tau=1e-3):
        """ swaps the network weights from the online to the target

        Args:
           param1 (float): tau
        """
        for target_param, local_param in zip(self.target_net.parameters(),
                                             self.qnetwork_local.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def update_target_net(self):
        """ copy the model weights from the online to the target network """
        self.target_net.load_state_dict(self.qnetwork_local.state_dict())

    def save(self, path):
        """ save the model weights to a file
        Args:
           param1 (string): pathname
        """
        torch.save(self.qnetwork_local.state_dict(),
                   os.path.join(path, 'model.pth'))

    def evaluate_q(self, state):
        """ Evaluates Q-value based on single state
        """
        with torch.no_grad():
            return (self.qnetwork_local(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        """
        activates the backprob. layers for the online network
        """
        self.qnetwork_local.train()

    def eval(self):
        """ invoke the eval from the online network
            deactivates the backprob
            layers like dropout will work in eval model instead
        """
        self.qnetwork_local.eval()
Exemplo n.º 9
0
class Agent:
    def __init__(self):
        self.mode = "train"
        with open("config.yaml") as reader:
            self.config = yaml.safe_load(reader)
        print(self.config)
        self.load_config()

        self.online_net = DQN(config=self.config,
                              word_vocab=self.word_vocab,
                              char_vocab=self.char_vocab,
                              answer_type=self.answer_type)
        self.target_net = DQN(config=self.config,
                              word_vocab=self.word_vocab,
                              char_vocab=self.char_vocab,
                              answer_type=self.answer_type)
        self.online_net.train()
        self.target_net.train()
        self.update_target_net()
        for param in self.target_net.parameters():
            param.requires_grad = False

        if self.use_cuda:
            self.online_net.cuda()
            self.target_net.cuda()

        self.naozi = ObservationPool(capacity=self.naozi_capacity)
        # optimizer
        self.optimizer = torch.optim.Adam(
            self.online_net.parameters(),
            lr=self.config['training']['optimizer']['learning_rate'])
        self.clip_grad_norm = self.config['training']['optimizer'][
            'clip_grad_norm']

    def load_config(self):
        # word vocab
        with open("vocabularies/word_vocab.txt") as f:
            self.word_vocab = f.read().split("\n")
        self.word2id = {}
        for i, w in enumerate(self.word_vocab):
            self.word2id[w] = i
        # char vocab
        with open("vocabularies/char_vocab.txt") as f:
            self.char_vocab = f.read().split("\n")
        self.char2id = {}
        for i, w in enumerate(self.char_vocab):
            self.char2id[w] = i

        self.EOS_id = self.word2id["</s>"]
        self.train_data_size = self.config['general']['train_data_size']
        self.question_type = self.config['general']['question_type']
        self.random_map = self.config['general']['random_map']
        self.testset_path = self.config['general']['testset_path']
        self.naozi_capacity = self.config['general']['naozi_capacity']
        self.eval_folder = pjoin(
            self.testset_path, self.question_type,
            ("random_map" if self.random_map else "fixed_map"))
        self.eval_data_path = pjoin(self.testset_path, "data.json")

        self.batch_size = self.config['training']['batch_size']
        self.max_nb_steps_per_episode = self.config['training'][
            'max_nb_steps_per_episode']
        self.max_episode = self.config['training']['max_episode']
        self.target_net_update_frequency = self.config['training'][
            'target_net_update_frequency']
        self.learn_start_from_this_episode = self.config['training'][
            'learn_start_from_this_episode']

        self.run_eval = self.config['evaluate']['run_eval']
        self.eval_batch_size = self.config['evaluate']['batch_size']
        self.eval_max_nb_steps_per_episode = self.config['evaluate'][
            'max_nb_steps_per_episode']

        # Set the random seed manually for reproducibility.
        self.random_seed = self.config['general']['random_seed']
        np.random.seed(self.random_seed)
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            if not self.config['general']['use_cuda']:
                print(
                    "WARNING: CUDA device detected but 'use_cuda: false' found in config.yaml"
                )
                self.use_cuda = False
            else:
                torch.backends.cudnn.deterministic = True
                torch.cuda.manual_seed(self.random_seed)
                self.use_cuda = True
        else:
            self.use_cuda = False

        if self.question_type == "location":
            self.answer_type = "pointing"
        elif self.question_type in ["attribute", "existence"]:
            self.answer_type = "2 way"
        else:
            raise NotImplementedError

        self.save_checkpoint = self.config['checkpoint']['save_checkpoint']
        self.experiment_tag = self.config['checkpoint']['experiment_tag']
        self.save_frequency = self.config['checkpoint']['save_frequency']
        self.load_pretrained = self.config['checkpoint']['load_pretrained']
        self.load_from_tag = self.config['checkpoint']['load_from_tag']

        self.qa_loss_lambda = self.config['training']['qa_loss_lambda']
        self.interaction_loss_lambda = self.config['training'][
            'interaction_loss_lambda']

        # replay buffer and updates
        self.discount_gamma = self.config['replay']['discount_gamma']
        self.replay_batch_size = self.config['replay']['replay_batch_size']
        self.command_generation_replay_memory = command_generation_memory.PrioritizedReplayMemory(
            self.config['replay']['replay_memory_capacity'],
            priority_fraction=self.config['replay']
            ['replay_memory_priority_fraction'],
            discount_gamma=self.discount_gamma)
        self.qa_replay_memory = qa_memory.PrioritizedReplayMemory(
            self.config['replay']['replay_memory_capacity'],
            priority_fraction=0.0)
        self.update_per_k_game_steps = self.config['replay'][
            'update_per_k_game_steps']
        self.multi_step = self.config['replay']['multi_step']

        # distributional RL
        self.use_distributional = self.config['distributional']['enable']
        self.atoms = self.config['distributional']['atoms']
        self.v_min = self.config['distributional']['v_min']
        self.v_max = self.config['distributional']['v_max']
        self.support = torch.linspace(self.v_min, self.v_max,
                                      self.atoms)  # Support (range) of z
        if self.use_cuda:
            self.support = self.support.cuda()
        self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1)

        # dueling networks
        self.dueling_networks = self.config['dueling_networks']

        # double dqn
        self.double_dqn = self.config['double_dqn']

        # counting reward
        self.revisit_counting_lambda_anneal_episodes = self.config[
            'episodic_counting_bonus'][
                'revisit_counting_lambda_anneal_episodes']
        self.revisit_counting_lambda_anneal_from = self.config[
            'episodic_counting_bonus']['revisit_counting_lambda_anneal_from']
        self.revisit_counting_lambda_anneal_to = self.config[
            'episodic_counting_bonus']['revisit_counting_lambda_anneal_to']
        self.revisit_counting_lambda = self.revisit_counting_lambda_anneal_from

        # valid command bonus
        self.valid_command_bonus_lambda = self.config[
            'valid_command_bonus_lambda']

        # epsilon greedy
        self.epsilon_anneal_episodes = self.config['epsilon_greedy'][
            'epsilon_anneal_episodes']
        self.epsilon_anneal_from = self.config['epsilon_greedy'][
            'epsilon_anneal_from']
        self.epsilon_anneal_to = self.config['epsilon_greedy'][
            'epsilon_anneal_to']
        self.epsilon = self.epsilon_anneal_from
        self.noisy_net = self.config['epsilon_greedy']['noisy_net']
        if self.noisy_net:
            # disable epsilon greedy
            self.epsilon_anneal_episodes = -1
            self.epsilon = 0.0

        self.nlp = spacy.load('en', disable=['ner', 'parser', 'tagger'])
        self.single_word_verbs = set(["inventory", "look", "wait"])
        self.two_word_verbs = set(["go"])

    def train(self):
        """
        Tell the agent that it's training phase.
        """
        self.mode = "train"
        self.online_net.train()

    def eval(self):
        """
        Tell the agent that it's evaluation phase.
        """
        self.mode = "eval"
        self.online_net.eval()

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def reset_noise(self):
        if self.noisy_net:
            # Resets noisy weights in all linear layers (of online net only)
            self.online_net.reset_noise()

    def zero_noise(self):
        if self.noisy_net:
            self.online_net.zero_noise()
            self.target_net.zero_noise()

    def load_pretrained_model(self, load_from):
        """
        Load pretrained checkpoint from file.

        Arguments:
            load_from: File name of the pretrained model checkpoint.
        """
        print("loading model from %s\n" % (load_from))
        try:
            if self.use_cuda:
                state_dict = torch.load(load_from)
            else:
                state_dict = torch.load(load_from, map_location='cpu')
            self.online_net.load_state_dict(state_dict)
        except:
            print("Failed to load checkpoint...")

    def save_model_to_path(self, save_to):
        torch.save(self.online_net.state_dict(), save_to)
        print("Saved checkpoint to %s..." % (save_to))

    def init(self, obs, infos):
        """
        Prepare the agent for the upcoming games.

        Arguments:
            obs: Previous command's feedback for each game.
            infos: Additional information for each game.
        """
        # reset agent, get vocabulary masks for verbs / adjectives / nouns
        batch_size = len(obs)
        self.reset_binarized_counter(batch_size)
        self.not_finished_yet = np.ones((batch_size, ), dtype="float32")
        self.prev_actions = [["" for _ in range(batch_size)]]
        self.prev_step_is_still_interacting = np.ones(
            (batch_size, ), dtype="float32"
        )  # 1s and starts to be 0 when previous action is "wait"
        self.naozi.reset(batch_size=batch_size)

    def get_agent_inputs(self, string_list):
        sentence_token_list = [item.split() for item in string_list]
        sentence_id_list = [
            _words_to_ids(tokens, self.word2id)
            for tokens in sentence_token_list
        ]
        input_sentence_char = list_of_token_list_to_char_input(
            sentence_token_list, self.char2id)
        input_sentence = pad_sequences(
            sentence_id_list, maxlen=max_len(sentence_id_list)).astype('int32')
        input_sentence = to_pt(input_sentence, self.use_cuda)
        input_sentence_char = to_pt(input_sentence_char, self.use_cuda)
        return input_sentence, input_sentence_char, sentence_id_list

    def get_game_info_at_certain_step(self, obs, infos):
        """
        Get all needed info from game engine for training.
        Arguments:
            obs: Previous command's feedback for each game.
            infos: Additional information for each game.
        """
        batch_size = len(obs)
        feedback_strings = [preproc(item, tokenizer=self.nlp) for item in obs]
        description_strings = [
            preproc(item, tokenizer=self.nlp) for item in infos["description"]
        ]
        observation_strings = [
            d + " <|> " + fb if fb != d else d + " <|> hello"
            for fb, d in zip(feedback_strings, description_strings)
        ]

        inventory_strings = [
            preproc(item, tokenizer=self.nlp) for item in infos["inventory"]
        ]
        local_word_list = [
            obs.split() + inv.split()
            for obs, inv in zip(observation_strings, inventory_strings)
        ]

        directions = ["east", "west", "north", "south"]
        if self.question_type in ["location", "existence"]:
            # agents observes the env, but do not change them
            possible_verbs = [["go", "inventory", "wait", "open", "examine"]
                              for _ in range(batch_size)]
        else:
            possible_verbs = [
                list(set(item) - set(["", "look"])) for item in infos["verbs"]
            ]

        possible_adjs, possible_nouns = [], []
        for i in range(batch_size):
            object_nouns = [
                item.split()[-1] for item in infos["object_nouns"][i]
            ]
            object_adjs = [
                w for item in infos["object_adjs"][i] for w in item.split()
            ]
            possible_nouns.append(
                list(set(object_nouns) & set(local_word_list[i]) - set([""])) +
                directions)
            possible_adjs.append(
                list(set(object_adjs) & set(local_word_list[i]) - set([""])) +
                ["</s>"])

        return observation_strings, [
            possible_verbs, possible_adjs, possible_nouns
        ]

    def get_state_strings(self, infos):
        description_strings = infos["description"]
        inventory_strings = infos["inventory"]
        observation_strings = [
            _d + _i for (_d, _i) in zip(description_strings, inventory_strings)
        ]
        return observation_strings

    def get_local_word_masks(self, possible_words):
        possible_verbs, possible_adjs, possible_nouns = possible_words
        batch_size = len(possible_verbs)

        verb_mask = np.zeros((batch_size, len(self.word_vocab)),
                             dtype="float32")
        noun_mask = np.zeros((batch_size, len(self.word_vocab)),
                             dtype="float32")
        adj_mask = np.zeros((batch_size, len(self.word_vocab)),
                            dtype="float32")
        for i in range(batch_size):
            for w in possible_verbs[i]:
                if w in self.word2id:
                    verb_mask[i][self.word2id[w]] = 1.0
            for w in possible_adjs[i]:
                if w in self.word2id:
                    adj_mask[i][self.word2id[w]] = 1.0
            for w in possible_nouns[i]:
                if w in self.word2id:
                    noun_mask[i][self.word2id[w]] = 1.0
        adj_mask[:, self.EOS_id] = 1.0

        return [verb_mask, adj_mask, noun_mask]

    def get_match_representations(self,
                                  input_observation,
                                  input_observation_char,
                                  input_quest,
                                  input_quest_char,
                                  use_model="online"):
        model = self.online_net if use_model == "online" else self.target_net
        description_representation_sequence, description_mask = model.representation_generator(
            input_observation, input_observation_char)
        quest_representation_sequence, quest_mask = model.representation_generator(
            input_quest, input_quest_char)

        match_representation_sequence = model.get_match_representations(
            description_representation_sequence, description_mask,
            quest_representation_sequence, quest_mask)
        match_representation_sequence = match_representation_sequence * description_mask.unsqueeze(
            -1)
        return match_representation_sequence

    def get_ranks(self,
                  input_observation,
                  input_observation_char,
                  input_quest,
                  input_quest_char,
                  word_masks,
                  use_model="online"):
        """
        Given input observation and question tensors, to get Q values of words.
        """
        model = self.online_net if use_model == "online" else self.target_net
        match_representation_sequence = self.get_match_representations(
            input_observation,
            input_observation_char,
            input_quest,
            input_quest_char,
            use_model=use_model)
        action_ranks = model.action_scorer(match_representation_sequence,
                                           word_masks)  # list of 3 tensors
        return action_ranks

    def choose_maxQ_command(self, action_ranks, word_mask=None):
        """
        Generate a command by maximum q values, for epsilon greedy.
        """
        if self.use_distributional:
            action_ranks = [
                (item * self.support).sum(2) for item in action_ranks
            ]  # list of batch x n_vocab
        action_indices = []
        for i in range(len(action_ranks)):
            ar = action_ranks[i]
            ar = ar - torch.min(
                ar, -1, keepdim=True
            )[0] + 1e-2  # minus the min value, so that all values are non-negative
            if word_mask is not None:
                assert word_mask[i].size() == ar.size(), (
                    word_mask[i].size().shape, ar.size())
                ar = ar * word_mask[i]
            action_indices.append(torch.argmax(ar, -1))  # batch
        return action_indices

    def choose_random_command(self,
                              batch_size,
                              action_space_size,
                              possible_words=None):
        """
        Generate a command randomly, for epsilon greedy.
        """
        action_indices = []
        for i in range(3):
            if possible_words is None:
                indices = np.random.choice(action_space_size, batch_size)
            else:
                indices = []
                for j in range(batch_size):
                    mask_ids = []
                    for w in possible_words[i][j]:
                        if w in self.word2id:
                            mask_ids.append(self.word2id[w])
                    indices.append(np.random.choice(mask_ids))
                indices = np.array(indices)
            action_indices.append(to_pt(indices, self.use_cuda))  # batch
        return action_indices

    def get_chosen_strings(self, chosen_indices):
        """
        Turns list of word indices into actual command strings.
        chosen_indices: Word indices chosen by model.
        """
        chosen_indices_np = [to_np(item) for item in chosen_indices]
        res_str = []
        batch_size = chosen_indices_np[0].shape[0]
        for i in range(batch_size):
            verb, adj, noun = chosen_indices_np[0][i], chosen_indices_np[1][
                i], chosen_indices_np[2][i]
            res_str.append(self.word_ids_to_commands(verb, adj, noun))
        return res_str

    def word_ids_to_commands(self, verb, adj, noun):
        """
        Turn the 3 indices into actual command strings.

        Arguments:
            verb: Index of the guessing verb in vocabulary
            adj: Index of the guessing adjective in vocabulary
            noun: Index of the guessing noun in vocabulary
        """
        # turns 3 indices into actual command strings
        if self.word_vocab[verb] in self.single_word_verbs:
            return self.word_vocab[verb]
        if self.word_vocab[verb] in self.two_word_verbs:
            return " ".join([self.word_vocab[verb], self.word_vocab[noun]])
        if adj == self.EOS_id:
            return " ".join([self.word_vocab[verb], self.word_vocab[noun]])
        else:
            return " ".join([
                self.word_vocab[verb], self.word_vocab[adj],
                self.word_vocab[noun]
            ])

    def act_random(self, obs, infos, input_observation, input_observation_char,
                   input_quest, input_quest_char, possible_words):
        with torch.no_grad():
            batch_size = len(obs)
            word_indices_random = self.choose_random_command(
                batch_size, len(self.word_vocab), possible_words)
            chosen_indices = word_indices_random
            chosen_strings = self.get_chosen_strings(chosen_indices)

            for i in range(batch_size):
                if chosen_strings[i] == "wait":
                    self.not_finished_yet[i] = 0.0

            # info for replay memory
            for i in range(batch_size):
                if self.prev_actions[-1][i] == "wait":
                    self.prev_step_is_still_interacting[i] = 0.0
            # previous step is still interacting, this is because DQN requires one step extra computation
            replay_info = [
                chosen_indices,
                to_pt(self.prev_step_is_still_interacting, self.use_cuda,
                      "float")
            ]

            # cache new info in current game step into caches
            self.prev_actions.append(chosen_strings)
            return chosen_strings, replay_info

    def act_greedy(self, obs, infos, input_observation, input_observation_char,
                   input_quest, input_quest_char, possible_words):
        """
        Acts upon the current list of observations.
        One text command must be returned for each observation.
        """
        with torch.no_grad():
            batch_size = len(obs)
            local_word_masks_np = self.get_local_word_masks(possible_words)
            local_word_masks = [
                to_pt(item, self.use_cuda, type="float")
                for item in local_word_masks_np
            ]

            # generate commands for one game step, epsilon greedy is applied, i.e.,
            # there is epsilon of chance to generate random commands
            action_ranks = self.get_ranks(
                input_observation,
                input_observation_char,
                input_quest,
                input_quest_char,
                local_word_masks,
                use_model="online")  # list of batch x vocab
            word_indices_maxq = self.choose_maxQ_command(
                action_ranks, local_word_masks)
            chosen_indices = word_indices_maxq
            chosen_strings = self.get_chosen_strings(chosen_indices)

            for i in range(batch_size):
                if chosen_strings[i] == "wait":
                    self.not_finished_yet[i] = 0.0

            # info for replay memory
            for i in range(batch_size):
                if self.prev_actions[-1][i] == "wait":
                    self.prev_step_is_still_interacting[i] = 0.0
            # previous step is still interacting, this is because DQN requires one step extra computation
            replay_info = [
                chosen_indices,
                to_pt(self.prev_step_is_still_interacting, self.use_cuda,
                      "float")
            ]

            # cache new info in current game step into caches
            self.prev_actions.append(chosen_strings)
            return chosen_strings, replay_info

    def act(self,
            obs,
            infos,
            input_observation,
            input_observation_char,
            input_quest,
            input_quest_char,
            possible_words,
            random=False):
        """
        Acts upon the current list of observations.
        One text command must be returned for each observation.
        """
        with torch.no_grad():
            if self.mode == "eval":
                return self.act_greedy(obs, infos, input_observation,
                                       input_observation_char, input_quest,
                                       input_quest_char, possible_words)
            if random:
                return self.act_random(obs, infos, input_observation,
                                       input_observation_char, input_quest,
                                       input_quest_char, possible_words)
            batch_size = len(obs)

            local_word_masks_np = self.get_local_word_masks(possible_words)
            local_word_masks = [
                to_pt(item, self.use_cuda, type="float")
                for item in local_word_masks_np
            ]

            # generate commands for one game step, epsilon greedy is applied, i.e.,
            # there is epsilon of chance to generate random commands
            action_ranks = self.get_ranks(
                input_observation,
                input_observation_char,
                input_quest,
                input_quest_char,
                local_word_masks,
                use_model="online")  # list of batch x vocab
            word_indices_maxq = self.choose_maxQ_command(
                action_ranks, local_word_masks)
            word_indices_random = self.choose_random_command(
                batch_size, len(self.word_vocab), possible_words)

            # random number for epsilon greedy
            rand_num = np.random.uniform(low=0.0,
                                         high=1.0,
                                         size=(batch_size, ))
            less_than_epsilon = (rand_num < self.epsilon).astype(
                "float32")  # batch
            greater_than_epsilon = 1.0 - less_than_epsilon
            less_than_epsilon = to_pt(less_than_epsilon,
                                      self.use_cuda,
                                      type='long')
            greater_than_epsilon = to_pt(greater_than_epsilon,
                                         self.use_cuda,
                                         type='long')
            chosen_indices = [
                less_than_epsilon * idx_random +
                greater_than_epsilon * idx_maxq
                for idx_random, idx_maxq in zip(word_indices_random,
                                                word_indices_maxq)
            ]
            chosen_strings = self.get_chosen_strings(chosen_indices)

            for i in range(batch_size):
                if chosen_strings[i] == "wait":
                    self.not_finished_yet[i] = 0.0

            # info for replay memory
            for i in range(batch_size):
                if self.prev_actions[-1][i] == "wait":
                    self.prev_step_is_still_interacting[i] = 0.0
            # previous step is still interacting, this is because DQN requires one step extra computation
            replay_info = [
                chosen_indices,
                to_pt(self.prev_step_is_still_interacting, self.use_cuda,
                      "float")
            ]

            # cache new info in current game step into caches
            self.prev_actions.append(chosen_strings)
            return chosen_strings, replay_info

    def get_dqn_loss(self):
        """
        Update neural model in agent. In this example we follow algorithm
        of updating model in dqn with replay memory.
        """
        if len(self.command_generation_replay_memory) < self.replay_batch_size:
            return None

        data = self.command_generation_replay_memory.get_batch(
            self.replay_batch_size, self.multi_step)
        if data is None:
            return None

        obs_list, quest_list, possible_words_list, chosen_indices, rewards, next_obs_list, next_possible_words_list, actual_n_list = data
        batch_size = len(actual_n_list)

        input_quest, input_quest_char, _ = self.get_agent_inputs(quest_list)
        input_observation, input_observation_char, _ = self.get_agent_inputs(
            obs_list)
        next_input_observation, next_input_observation_char, _ = self.get_agent_inputs(
            next_obs_list)

        possible_words, next_possible_words = [], []
        for i in range(3):
            possible_words.append([item[i] for item in possible_words_list])
            next_possible_words.append(
                [item[i] for item in next_possible_words_list])

        local_word_masks = [
            to_pt(item, self.use_cuda, type="float")
            for item in self.get_local_word_masks(possible_words)
        ]
        next_local_word_masks = [
            to_pt(item, self.use_cuda, type="float")
            for item in self.get_local_word_masks(next_possible_words)
        ]

        action_ranks = self.get_ranks(
            input_observation,
            input_observation_char,
            input_quest,
            input_quest_char,
            local_word_masks,
            use_model="online"
        )  # list of batch x vocab or list of batch x vocab x atoms
        # ps_a
        word_qvalues = [
            ez_gather_dim_1(w_rank, idx.unsqueeze(-1)).squeeze(1)
            for w_rank, idx in zip(action_ranks, chosen_indices)
        ]  # list of batch or list of batch x atoms
        q_value = torch.mean(torch.stack(word_qvalues, -1),
                             -1)  # batch or batch x atoms
        # log_ps_a
        log_q_value = torch.log(q_value)  # batch or batch x atoms

        with torch.no_grad():
            if self.noisy_net:
                self.target_net.reset_noise()  # Sample new target net noise
            if self.double_dqn:
                # pns Probabilities p(s_t+n, ·; θonline)
                next_action_ranks = self.get_ranks(next_input_observation,
                                                   next_input_observation_char,
                                                   input_quest,
                                                   input_quest_char,
                                                   next_local_word_masks,
                                                   use_model="online")
                # list of batch x vocab or list of batch x vocab x atoms
                # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
                next_word_indices = self.choose_maxQ_command(
                    next_action_ranks,
                    next_local_word_masks)  # list of batch x 1
                # pns # Probabilities p(s_t+n, ·; θtarget)
                next_action_ranks = self.get_ranks(
                    next_input_observation,
                    next_input_observation_char,
                    input_quest,
                    input_quest_char,
                    next_local_word_masks,
                    use_model="target"
                )  # batch x vocab or list of batch x vocab x atoms
                # pns_a # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)
                next_word_qvalues = [
                    ez_gather_dim_1(w_rank, idx.unsqueeze(-1)).squeeze(1) for
                    w_rank, idx in zip(next_action_ranks, next_word_indices)
                ]  # list of batch or list of batch x atoms
            else:
                # pns Probabilities p(s_t+n, ·; θonline)
                next_action_ranks = self.get_ranks(next_input_observation,
                                                   next_input_observation_char,
                                                   input_quest,
                                                   input_quest_char,
                                                   next_local_word_masks,
                                                   use_model="target")
                # list of batch x vocab or list of batch x vocab x atoms
                next_word_indices = self.choose_maxQ_command(
                    next_action_ranks,
                    next_local_word_masks)  # list of batch x 1
                next_word_qvalues = [
                    ez_gather_dim_1(w_rank, idx.unsqueeze(-1)).squeeze(1) for
                    w_rank, idx in zip(next_action_ranks, next_word_indices)
                ]  # list of batch or list of batch x atoms

            next_q_value = torch.mean(torch.stack(next_word_qvalues, -1),
                                      -1)  # batch or batch x atoms
            # Compute Tz (Bellman operator T applied to z)
            discount = to_pt((np.ones_like(actual_n_list) *
                              self.discount_gamma)**actual_n_list,
                             self.use_cuda,
                             type="float")
        if not self.use_distributional:
            rewards = rewards + next_q_value * discount  # batch
            loss = F.smooth_l1_loss(q_value, rewards)
            return loss

        with torch.no_grad():
            Tz = rewards.unsqueeze(
                -1) + discount.unsqueeze(-1) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.v_min,
                          max=self.v_max)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.v_min) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = torch.zeros(batch_size, self.atoms).float()
            if self.use_cuda:
                m = m.cuda()
            offset = torch.linspace(0, ((batch_size - 1) * self.atoms),
                                    batch_size).unsqueeze(1).expand(
                                        batch_size, self.atoms).long()
            if self.use_cuda:
                offset = offset.cuda()
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (next_q_value *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (next_q_value *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_q_value,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        loss = torch.mean(loss)
        return loss

    def update_interaction(self):
        # update neural model by replaying snapshots in replay memory
        interaction_loss = self.get_dqn_loss()
        if interaction_loss is None:
            return None
        loss = interaction_loss * self.interaction_loss_lambda
        # Backpropagate
        self.online_net.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(),
                                       self.clip_grad_norm)
        self.optimizer.step()  # apply gradients
        return to_np(torch.mean(interaction_loss))

    def answer_question(self,
                        input_observation,
                        input_observation_char,
                        observation_id_list,
                        input_quest,
                        input_quest_char,
                        use_model="online"):
        # first pad answerer_input, and get the mask
        model = self.online_net if use_model == "online" else self.target_net
        batch_size = len(observation_id_list)
        max_length = input_observation.size(1)
        mask = compute_mask(input_observation)  # batch x obs_len

        # noun mask for location question
        if self.question_type in ["location"]:
            location_mask = []
            for i in range(batch_size):
                m = [1 for item in observation_id_list[i]]
                location_mask.append(m)
            location_mask = pad_sequences(location_mask,
                                          maxlen=max_length,
                                          dtype="float32")
            location_mask = to_pt(location_mask,
                                  enable_cuda=self.use_cuda,
                                  type='float')
            assert mask.size() == location_mask.size()
            mask = mask * location_mask

        match_representation_sequence = self.get_match_representations(
            input_observation,
            input_observation_char,
            input_quest,
            input_quest_char,
            use_model=use_model)
        pred = model.answer_question(match_representation_sequence,
                                     mask)  # batch x vocab or batch x 2

        # attention sum:
        # sometimes certain word appears multiple times in the observation,
        # thus we need to merge them together before doing further computations
        # ------- but
        # if answer type is not pointing, we just use a pre-defined mapping
        # that maps 0/1 to their positions in vocab
        if self.answer_type == "2 way":
            observation_id_list = []
            max_length = 2
            for i in range(batch_size):
                observation_id_list.append(
                    [self.word2id["0"], self.word2id["1"]])

        observation = to_pt(
            pad_sequences(observation_id_list,
                          maxlen=max_length).astype('int32'), self.use_cuda)
        vocab_distribution = np.zeros(
            (batch_size, len(self.word_vocab)))  # batch x vocab
        vocab_distribution = to_pt(vocab_distribution,
                                   self.use_cuda,
                                   type='float')
        vocab_distribution = vocab_distribution.scatter_add_(
            1, observation, pred)  # batch x vocab
        non_zero_words = []
        for i in range(batch_size):
            non_zero_words.append(list(set(observation_id_list[i])))
        vocab_mask = torch.ne(vocab_distribution, 0).float()
        return vocab_distribution, non_zero_words, vocab_mask

    def point_maxq_position(self, vocab_distribution, mask):
        """
        Generate a command by maximum q values, for epsilon greedy.

        Arguments:
            point_distribution: Q values for each position (mapped to vocab).
            mask: vocab masks.
        """
        vocab_distribution = vocab_distribution - torch.min(
            vocab_distribution, -1, keepdim=True
        )[0] + 1e-2  # minus the min value, so that all values are non-negative
        vocab_distribution = vocab_distribution * mask  # batch x vocab
        indices = torch.argmax(vocab_distribution, -1)  # batch
        return indices

    def answer_question_act_greedy(self, input_observation,
                                   input_observation_char, observation_id_list,
                                   input_quest, input_quest_char):

        with torch.no_grad():
            vocab_distribution, _, vocab_mask = self.answer_question(
                input_observation,
                input_observation_char,
                observation_id_list,
                input_quest,
                input_quest_char,
                use_model="online")  # batch x time
            positions_maxq = self.point_maxq_position(vocab_distribution,
                                                      vocab_mask)
            return positions_maxq  # batch

    def get_qa_loss(self):
        """
        Update neural model in agent. In this example we follow algorithm
        of updating model in dqn with replay memory.
        """
        if len(self.qa_replay_memory) < self.replay_batch_size:
            return None
        transitions = self.qa_replay_memory.sample(self.replay_batch_size)
        batch = qa_memory.qa_Transition(*zip(*transitions))

        observation_list = batch.observation_list
        quest_list = batch.quest_list
        answer_strings = batch.answer_strings
        answer_position = np.array(_words_to_ids(answer_strings, self.word2id))
        groundtruth = to_pt(answer_position, self.use_cuda)  # batch

        input_quest, input_quest_char, _ = self.get_agent_inputs(quest_list)
        input_observation, input_observation_char, observation_id_list = self.get_agent_inputs(
            observation_list)

        answer_distribution, _, _ = self.answer_question(
            input_observation,
            input_observation_char,
            observation_id_list,
            input_quest,
            input_quest_char,
            use_model="online")  # batch x vocab

        batch_loss = NegativeLogLoss(answer_distribution, groundtruth)  # batch
        return torch.mean(batch_loss)

    def update_qa(self):
        # update neural model by replaying snapshots in replay memory
        qa_loss = self.get_qa_loss()
        if qa_loss is None:
            return None
        loss = qa_loss * self.qa_loss_lambda
        # Backpropagate
        self.online_net.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(),
                                       self.clip_grad_norm)
        self.optimizer.step()  # apply gradients
        return to_np(torch.mean(qa_loss))

    def finish_of_episode(self, episode_no, batch_size):
        # Update target networt
        if (
                episode_no + batch_size
        ) % self.target_net_update_frequency <= episode_no % self.target_net_update_frequency:
            self.update_target_net()
        # decay lambdas
        if episode_no < self.learn_start_from_this_episode:
            return
        if episode_no < self.epsilon_anneal_episodes + self.learn_start_from_this_episode:
            self.epsilon -= (self.epsilon_anneal_from - self.epsilon_anneal_to
                             ) / float(self.epsilon_anneal_episodes)
            self.epsilon = max(self.epsilon, 0.0)
        if episode_no < self.revisit_counting_lambda_anneal_episodes + self.learn_start_from_this_episode:
            self.revisit_counting_lambda -= (
                self.revisit_counting_lambda_anneal_from -
                self.revisit_counting_lambda_anneal_to) / float(
                    self.revisit_counting_lambda_anneal_episodes)
            self.revisit_counting_lambda = max(self.epsilon, 0.0)

    def reset_binarized_counter(self, batch_size):
        self.binarized_counter_dict = [{} for _ in range(batch_size)]

    def get_binarized_count(self, observation_strings, update=True):
        count_rewards = []
        batch_size = len(observation_strings)
        for i in range(batch_size):
            key = observation_strings[i]
            if key not in self.binarized_counter_dict[i]:
                self.binarized_counter_dict[i][key] = 0.0
            if update:
                self.binarized_counter_dict[i][key] += 1.0
            r = self.binarized_counter_dict[i][key]
            r = float(r == 1.0)
            count_rewards.append(r)
        return count_rewards
Exemplo n.º 10
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms  # size of value distribution.
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max,
                                      self.atoms).to(device=args.device)
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space).to(
            device=args.device)  # greedily selects the action.
        if args.model and os.path.isfile(args.model):
            self.online_net.load_state_dict(
                torch.load(args.model, map_location='cpu')
            )  # state_dict: python dictionary that maps each layer to its parameters.
        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(
            device=args.device)  # use to compute target q-values.
        self.update_target_net(
        )  # sets it to the parameters of the online network.
        self.target_net.train()
        for param in self.target_net.parameters(
        ):  # not updated through backpropagation.
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)

    def reset_noise(self):
        self.online_net.reset_noise()

    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).argmax(1).item()

    def act_e_greedy(self, state, epsilon=0.001):
        return np.random.randint(
            0, self.action_space
        ) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem):
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        log_ps = self.online_net(states, log=True)
        log_ps_a = log_ps[range(self.batch_size), actions]

        with torch.no_grad():
            pns = self.online_net(next_states)
            dns = self.support.expand_as(pns) * pns
            argmax_indices_ns = dns.sum(2).argmax(1)
            self.target_net.reset_noise()
            pns = self.target_net(next_states)
            pns_a = pns[range(self.batch_size), argmax_indices_ns]
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(0)
            Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)
            b = (Tz - self.Vmin) / self.delta_z
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(0, (l + offset).view(-1),
                                  (pns_a * (u.float() - b)).view(-1))
            m.view(-1).index_add_(0, (u + offset).view(-1),
                                  (pns_a * (b - l.float())).view(-1))

        loss = -torch.sum(m * log_ps_a, 1)
        self.online_net.zero_grad()
        (weights * loss).mean().backward()
        self.optimiser.step()
        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model_all_layers.pth'))

    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
Exemplo n.º 11
0
class Agent():
  def __init__(self, args, env):
    self.action_space = env.action_space()
    self.quantile = args.quantile
    self.atoms = args.quantiles if args.quantile else args.atoms
    if args.quantile:
      self.cumulative_density = (2 * torch.arange(self.atoms).to(device=args.device) + 1) / (2 * self.atoms)  # Quantile cumulative probability weights τ
    else:
      self.Vmin = args.V_min
      self.Vmax = args.V_max
      self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device)  # Support (range) of z
      self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
    self.batch_size = args.batch_size
    self.n = args.multi_step
    self.discount = args.discount
    self.norm_clip = args.norm_clip

    self.online_net = DQN(args, self.action_space, args.quantile).to(device=args.device)
    if args.model and os.path.isfile(args.model):
      # Always load tensors onto CPU by default, will shift to GPU if necessary
      self.online_net.load_state_dict(torch.load(args.model, map_location='cpu'))
    self.online_net.train()

    self.target_net = DQN(args, self.action_space, args.quantile).to(device=args.device)
    self.update_target_net()
    self.target_net.train()
    for param in self.target_net.parameters():
      param.requires_grad = False

    self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps)

  # Resets noisy weights in all linear layers (of online net only)
  def reset_noise(self):
    self.online_net.reset_noise()

  # Acts based on single state (no batch)
  def act(self, state):
    with torch.no_grad():
      return (self.online_net(state.unsqueeze(0)) * ((1 / self.atoms) if self.quantile else self.support)).sum(2).argmax(1).item()

  # Acts with an ε-greedy policy
  def act_e_greedy(self, state, epsilon=0.05):
    return random.randrange(self.action_space) if random.random() < epsilon else self.act(state)

  def learn(self, mem):
    # Sample transitions
    idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)

    # Calculate current state probabilities (online network noise already sampled)
    ps = self.online_net(states)  # Probabilities p(s_t, ·; θonline)/quantile probabilities θ(s_t, ·; θonline)
    ps_a = ps[range(self.batch_size), actions]  # p(s_t, a_t; θonline)

    with torch.no_grad():
      # Calculate nth next state probabilities
      pns = self.online_net(next_states)  # Probabilities p(s_t+n, ·; θonline)
      dns = ((1 / self.atoms) if self.quantile else self.support.expand_as(pns)) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
      argmax_indices_ns = dns.sum(2).argmax(1)  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
      self.target_net.reset_noise()  # Sample new target net noise
      pns = self.target_net(next_states)  # Probabilities p(s_t+n, ·; θtarget)
      pns_a = pns[range(self.batch_size), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

      if self.quantile:
        # Compute distributional Bellman target Tθ = R^n + (γ^n)p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)
        Ttheta = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * pns_a  # (accounting for terminal states)
      else:
        # Compute Tz (Bellman operator T applied to z)
        Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
        Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  # Clamp between supported values
        # Compute L2 projection of Tz onto fixed support z
        b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
        l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
        # Fix disappearing probability mass when l = b = u (b is int)
        l[(u > 0) * (l == u)] -= 1
        u[(l < (self.atoms - 1)) * (l == u)] += 1

        # Distribute probability of Tz
        m = states.new_zeros(self.batch_size, self.atoms)
        offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions)
        m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
        m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

    if self.quantile:
      u = Ttheta - ps_a  # Residual u
      kappa_cond = (u < 1).to(torch.float32)  # |u| ≤ κ
      huber_loss = 0.5 * u ** 2 * kappa_cond + (u.abs() - 0.5) * (1 - kappa_cond)  # Huber loss Lκ(u)
      loss = torch.sum(torch.abs(self.cumulative_density - (u < 0).to(torch.float32)) * huber_loss, 1)  # Quantile Huber loss ρκτ(u) = |τ − δ{u<0}|Lκ(u)
    else:
      loss = -torch.sum(m * ps_a.log(), 1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
    loss = weights * loss  # Importance weight losses
    self.online_net.zero_grad()
    loss.mean().backward()  # Backpropagate minibatch loss
    self.optimiser.step()
    nn.utils.clip_grad_norm_(self.online_net.parameters(), self.norm_clip)  # Clip gradients by L2 norm
    if self.quantile:
      loss = (self.atoms * loss).clamp(max=5)  # Heuristic for prioritised replay

    mem.update_priorities(idxs, loss.detach())  # Update priorities of sampled transitions

  def update_target_net(self):
    self.target_net.load_state_dict(self.online_net.state_dict())

  # Save model parameters on current device (don't move model between devices)
  def save(self, path):
    torch.save(self.online_net.state_dict(), os.path.join(path, 'model.pth'))

  # Evaluates Q-value based on single state (no batch)
  def evaluate_q(self, state):
    with torch.no_grad():
      return (self.online_net(state.unsqueeze(0)) * ((1 / self.atoms) if self.quantile else self.support)).sum(2).max(1)[0].item()

  def train(self):
    self.online_net.train()

  def eval(self):
    self.online_net.eval()
Exemplo n.º 12
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount
        self.saved_model_path = args.saved_model_path
        self.experiment = args.experiment
        self.plots_path = args.plots_path
        self.data_save_path = args.data_save_path


        self.online_net = DQN(args, self.action_space).to(device=args.device)
        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.online_net.load_state_dict(torch.load(args.model, map_location='cpu'))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps)

        # list of layers:
        self.online_net_layers = [self.online_net.conv1,
                                  self.online_net.conv2,
                                  self.online_net.conv3,
                                  self.online_net.fc_h_v,
                                  self.online_net.fc_h_a,
                                  self.online_net.fc_z_v,
                                  self.online_net.fc_z_a
                                  ]

        self.target_net_layers = [self.target_net.conv1,
                                  self.target_net.conv2,
                                  self.target_net.conv3,
                                  self.target_net.fc_h_v,
                                  self.target_net.fc_h_a,
                                  self.target_net.fc_z_v,
                                  self.target_net.fc_z_a
                                  ]

        # freeze all layers except the last, and reinitialize last
        if args.freeze_layers > 0:
            self.freeze_layers(args.freeze_layers)

        if args.reinitialize_layers > 0:
            self.reinit_layers(args.reinitialize_layers)


    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(self, state, epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return np.random.randint(0, self.action_space) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net(states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size), actions]  # log p(s_t, a_t; θonline)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net(next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(1)  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(self.batch_size), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(
                0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(
                self.batch_size, self.atoms).to(actions)
            m.view(-1).index_add_(0, (l + offset).view(-1),
                                  (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(0, (u + offset).view(-1),
                                  (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(m * log_ps_a, 1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        (weights * loss).mean().backward()  # Backpropagate importance-weighted minibatch loss
        self.optimiser.step()

        mem.update_priorities(idxs, loss.detach().cpu().numpy())  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path):
        torch.save(self.online_net.state_dict(), os.path.join(path, self.experiment + '_model.pth'))  # 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()

    def freeze_layers(self, num_frozen_layers):

        # reinitialize the proper layers (all that were not frozen
        self.reinit_layers(5 - num_frozen_layers)

        for i in range(num_frozen_layers):
            if i == 0:
                # freeze last layer (two in list)
                self.online_net_layers[0].weight.requires_grad = False
                self.online_net_layers[0].bias.requires_grad = False
            elif i == 1:
                self.online_net_layers[1].weight.requires_grad = False
                self.online_net_layers[1].bias.requires_grad = False
            elif i == 2:
                self.online_net_layers[2].weight.requires_grad = False
                self.online_net_layers[2].bias.requires_grad = False
            elif i == 3:
                self.online_net_layers[3].weight_mu.requires_grad = False
                self.online_net_layers[3].weight_sigma.requires_grad = False
                self.online_net_layers[3].bias_mu.requires_grad = False
                self.online_net_layers[3].bias_sigma.requires_grad = False
                # self.online_net_layers[3].weight.requires_grad = False
                self.online_net_layers[4].bias_mu.requires_grad = False
                self.online_net_layers[4].bias_sigma.requires_grad = False
                self.online_net_layers[4].weight_mu.requires_grad = False
                self.online_net_layers[4].weight_sigma.requires_grad = False
                # self.online_net_layers[4].bias.requires_grad = False
            # elif i == 4:
            #     self.online_net_layers[0].reset_parameters()
            #     self.target_net_layers[0].reset_parameters()

        # freeze the proper layers - complicated work around for dueling architecture
        # ct = 0
        # fourth_layer_first_time = True
        # for child in self.online_net.children():
        #     if ct < num_frozen_layers and ct < 3:
        #         for param in child.parameters():
        #             print('something1')
        #             param.required_grad = False
        #     if ct < num_frozen_layers and ct == 3:
        #         for param in child.parameters():
        #             print('something2')
        #             param.required_grad = False
        #         if fourth_layer_first_time:
        #             fourth_layer_first_time = False
        #             ct -= 1
        #     ct += 1
        #
        # ct = 0
        # fourth_layer_first_time = True
        # for child in self.target_net.children():
        #     if ct < num_frozen_layers and ct < 3:
        #         for param in child.parameters():
        #             print('something3')
        #             param.required_grad = False
        #     if ct < num_frozen_layers and ct == 3:
        #         for param in child.parameters():
        #             print('something4')
        #             param.required_grad = False
        #         if fourth_layer_first_time:
        #             fourth_layer_first_time = False
        #             ct -= 1
        #     ct += 1

        print(self.online_net)
        print(list(i.requires_grad for i in self.online_net.parameters()))
        print(self.target_net)
        print(list(i.requires_grad for i in self.target_net.parameters()))


    def reinit_layers(self, num_layers):
        for i in range(num_layers):
            if i == 0:
                # freeze last layer (two in list)
                self.online_net_layers[6].reset_parameters()
                self.online_net_layers[5].reset_parameters()
                self.target_net_layers[6].reset_parameters()
                self.target_net_layers[5].reset_parameters()
            elif i == 1:
                self.online_net_layers[4].reset_parameters()
                self.online_net_layers[3].reset_parameters()
                self.target_net_layers[4].reset_parameters()
                self.target_net_layers[3].reset_parameters()
            elif i == 2:
                self.online_net_layers[2].reset_parameters()
                self.target_net_layers[2].reset_parameters()
            elif i == 3:
                self.online_net_layers[1].reset_parameters()
                self.target_net_layers[1].reset_parameters()
            elif i == 4:
                self.online_net_layers[0].reset_parameters()
                self.target_net_layers[0].reset_parameters()