コード例 #1
0
class AgentEval:
    def __init__(self, args, env):
        self.action_space = env.action_space
        self.atoms = args.atoms
        self.v_min = args.V_min
        self.v_max = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        for m in self.online_net.modules():
            print(m)

        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.online_net.load_state_dict(torch.load(args.model, map_location='cpu'))
        self.online_net.eval()

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(self, state, epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return self.action_space.sample() if np.random.random() < epsilon else self.act(state)

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item()
コード例 #2
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.online_net.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        (weights * loss).mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        self.optimiser.step()

        mem.update_priorities(
            idxs, loss.detach())  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
コード例 #3
0
ファイル: agent.py プロジェクト: teshnizi/curl_rainbow
class Agent():
    def __init__(self, args, env):
        self.args = args
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount
        self.norm_clip = args.norm_clip
        self.coeff = 0.01 if args.game in [
            'pong', 'boxing', 'private_eye', 'freeway'
        ] else 1.

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        self.momentum_net = DQN(args, self.action_space).to(device=args.device)
        # self.predictor = prediction_MLP(in_dim=128, hidden_dim=128, out_dim=128)

        if args.model:  # Load pretrained model if provided
            if os.path.isfile(args.model):
                state_dict = torch.load(
                    args.model, map_location='cpu'
                )  # Always load tensors onto CPU by default, will shift to GPU if necessary
                if 'conv1.weight' in state_dict.keys():
                    for old_key, new_key in (('conv1.weight',
                                              'convs.0.weight'),
                                             ('conv1.bias', 'convs.0.bias'),
                                             ('conv2.weight',
                                              'convs.2.weight'),
                                             ('conv2.bias', 'convs.2.bias'),
                                             ('conv3.weight',
                                              'convs.4.weight'),
                                             ('conv3.bias', 'convs.4.bias')):
                        state_dict[new_key] = state_dict[
                            old_key]  # Re-map state dict for old pretrained models
                        del state_dict[
                            old_key]  # Delete old keys for strict load_state_dict
                self.online_net.load_state_dict(state_dict)
                print("Loading pretrained model: " + args.model)
            else:  # Raise error if incorrect model path provided
                raise FileNotFoundError(args.model)

        self.online_net.train()
        # self.pred.train()
        self.initialize_momentum_net()
        self.momentum_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        for param in self.momentum_net.parameters():
            param.requires_grad = False
        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.learning_rate,
                                    eps=args.adam_eps)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            a, _, _ = self.online_net(state.unsqueeze(0))
            return (a * self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return np.random.randint(
            0, self.action_space
        ) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        # print('\n\n---------------')
        # print(f'idxs: {idxs}, ')
        # print(f'states: {states.shape}, ')
        # print(f'actions: {actions.shape}, ')
        # print(f'returns: {returns.shape}, ')
        # print(f'next_states: {next_states.shape}, ')
        # print(f'nonterminals: {nonterminals.shape}, ')
        # print(f'weights: {weights.shape},')

        aug_states_1 = aug(states).to(device=self.args.device)
        aug_states_2 = aug(states).to(device=self.args.device)

        # print(f'aug_states_1: {aug_states_1.shape}')
        # print(f'aug_states_2: {aug_states_2.shape}')

        # Calculate current state probabilities (online network noise already sampled)
        log_ps, _, _ = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)

        _, z_1, p_1 = self.online_net(aug_states_1, log=True)
        _, z_2, p_2 = self.online_net(aug_states_2, log=True)
        # p_1, p_2 = self.pred(z_1), self.pred(z_2)

        # with torch.no_grad():
        #     p_2 = self.pred(z_2)

        simsiam_loss = 2 + D(p_1, z_2) / 2 + D(p_2, z_1) / 2
        # simsiam_loss = p_1.mean() + p_2.mean()
        # simsiam_loss = p_1.mean() * 128
        # simsiam_loss = - F.cosine_similarity(p_1, z_2.detach(), dim=-1).mean()
        # print(simsiam_loss)
        # simsiam_loss = 0

        # _, z_target = self.momentum_net(aug_states_2, log=True) #z_k
        # z_proj = torch.matmul(self.online_net.W, z_target.T)
        # logits = torch.matmul(z_anch, z_proj)
        # logits = (logits - torch.max(logits, 1)[0][:, None])
        # logits = logits * 0.1
        # labels = torch.arange(logits.shape[0]).long().to(device=self.args.device)
        # moco_loss = (nn.CrossEntropyLoss()(logits, labels)).to(device=self.args.device)

        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        # print(f'z_1: {z_1.shape}')
        # print(f'p_1: {p_1.shape}')
        # print('---------------\n\n')

        # 1/0

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns, _, _ = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns, _, _ = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        # loss = loss + (moco_loss * self.coeff)
        loss = loss + (simsiam_loss * self.coeff)
        self.online_net.zero_grad()
        # self.pred.zero_grad()
        curl_loss = (weights * loss).mean()
        # print(curl_loss)
        curl_loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        self.optimiser.step()

        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # Update priorities of sampled transitions

    def learn_old(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        # print('\n\n---------------')
        # print(f'idxs: {idxs}, ')
        # print(f'states: {states.shape}, ')
        # print(f'actions: {actions.shape}, ')
        # print(f'returns: {returns.shape}, ')
        # print(f'next_states: {next_states.shape}, ')
        # print(f'nonterminals: {nonterminals.shape}, ')
        # print(f'weights: {weights.shape},')

        aug_states_1 = aug(states).to(device=self.args.device)
        aug_states_2 = aug(states).to(device=self.args.device)

        # print(f'aug_states_1: {aug_states_1.shape}')
        # print(f'aug_states_2: {aug_states_2.shape}')

        # Calculate current state probabilities (online network noise already sampled)
        log_ps, _, _ = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        _, z_anch, _ = self.online_net(aug_states_1, log=True)  #z_q
        _, z_target, _ = self.momentum_net(aug_states_2, log=True)  #z_k
        z_proj = torch.matmul(self.online_net.W, z_target.T)
        logits = torch.matmul(z_anch, z_proj)
        logits = (logits - torch.max(logits, 1)[0][:, None])
        logits = logits * 0.1
        labels = torch.arange(
            logits.shape[0]).long().to(device=self.args.device)
        moco_loss = (nn.CrossEntropyLoss()(logits,
                                           labels)).to(device=self.args.device)

        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        # print(f'z_anch: {z_anch.shape}')
        # print(f'z_target: {z_target.shape}')
        # print(f'z_proj: {z_proj.shape}')
        # print(f'logits: {logits.shape}')
        # print(logits)
        # print(f'labels: {labels.shape}')
        # print(labels)
        # print('---------------\n\n')

        # 1/0

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns, _, _ = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns, _, _ = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        print(moco_loss)
        loss = loss + (moco_loss * self.coeff)
        self.online_net.zero_grad()
        curl_loss = (weights * loss).mean()
        curl_loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        self.optimiser.step()

        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def initialize_momentum_net(self):
        for param_q, param_k in zip(self.online_net.parameters(),
                                    self.momentum_net.parameters()):
            param_k.data.copy_(param_q.data)  # update
            param_k.requires_grad = False  # not update by gradient

    # Code for this function from https://github.com/facebookresearch/moco
    @torch.no_grad()
    def update_momentum_net(self, momentum=0.999):
        for param_q, param_k in zip(self.online_net.parameters(),
                                    self.momentum_net.parameters()):
            param_k.data.copy_(momentum * param_k.data +
                               (1. - momentum) * param_q.data)  # update

    # Save model parameters on current device (don't move model between devices)
    def save(self, path, name='model.pth'):
        torch.save(self.online_net.state_dict(), os.path.join(path, name))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            a, _, _ = self.online_net(state.unsqueeze(0))
            return (a * self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
コード例 #4
0
ファイル: agent.py プロジェクト: lili-chen/SEER
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount
        self.norm_clip = args.norm_clip

        self.online_net = DQN(args, self.action_space).to(device=args.device)
        if args.model:  # Load pretrained model if provided
            if os.path.isfile(args.model):
                state_dict = torch.load(
                    args.model, map_location='cpu'
                )  # Always load tensors onto CPU by default, will shift to GPU if necessary
                if 'conv1.weight' in state_dict.keys():
                    for old_key, new_key in (('conv1.weight',
                                              'convs.0.weight'),
                                             ('conv1.bias', 'convs.0.bias'),
                                             ('conv2.weight',
                                              'convs.2.weight'),
                                             ('conv2.bias', 'convs.2.bias'),
                                             ('conv3.weight',
                                              'convs.4.weight'),
                                             ('conv3.bias', 'convs.4.bias')):
                        state_dict[new_key] = state_dict[
                            old_key]  # Re-map state dict for old pretrained models
                        del state_dict[
                            old_key]  # Delete old keys for strict load_state_dict
                self.online_net.load_state_dict(state_dict)
                print("Loading pretrained model: " + args.model)
            else:  # Raise error if incorrect model path provided
                raise FileNotFoundError(args.model)

        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        # self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps)
        self.convs_optimiser = optim.Adam(self.online_net.convs.parameters(),
                                          lr=args.learning_rate,
                                          eps=args.adam_eps)
        self.linear_optimiser = optim.Adam(chain(
            self.online_net.fc_h_v.parameters(),
            self.online_net.fc_h_a.parameters(),
            self.online_net.fc_z_v.parameters(),
            self.online_net.fc_z_a.parameters()),
                                           lr=args.learning_rate,
                                           eps=args.adam_eps)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):

        with torch.no_grad():
            # don't count these calls since it is accounted for after "action = dqn.act(state)" in main.py
            ret = (self.online_net(state.unsqueeze(0)) *
                   self.support).sum(2).argmax(1).item()
            return ret

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return np.random.randint(
            0, self.action_space
        ) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem, freeze=False):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights, _ = mem.sample(
            self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        # self.optimiser.step()
        if not freeze:
            self.convs_optimiser.step()
        self.linear_optimiser.step()

    def learn_with_latent(self, latent_mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights, ns = latent_mem.sample(
            self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net.forward_with_latent(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; θonline)
        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net.forward_with_latent(
                next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(
                pns) * pns  # Distribution ds_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net.forward_with_latent(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # use ns instead of self.n since n is possibly different for each sequence in the batch
            ns = torch.tensor(ns, device=latent_mem.device).unsqueeze(1)
            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**ns) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin,
                          max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        loss.mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        clip_grad_norm_(self.online_net.parameters(),
                        self.norm_clip)  # Clip gradients by L2 norm
        # self.optimiser.step()
        self.linear_optimiser.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path, name='model.pth'):
        torch.save(self.online_net.state_dict(), os.path.join(path, name))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
コード例 #5
0
ファイル: agent.py プロジェクト: Ashutosh-Adhikari/Rainbow
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max,
                                      args.atoms)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space)
        if args.model and os.path.isfile(args.model):
            self.online_net.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)
        if args.cuda:
            self.online_net.cuda()
            self.target_net.cuda()
            self.support = self.support.cuda()

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        return (self.online_net(state.unsqueeze(0)).data *
                self.support).sum(2).max(1)[1][0]

    # Acts with an ε-greedy policy
    def act_e_greedy(self, state, epsilon=0.001):
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)

        # Calculate current state probabilities
        self.online_net.reset_noise()  # Sample new noise for online network
        ps = self.online_net(states)  # Probabilities p(s_t, ·; θonline)
        ps_a = ps[range(self.batch_size), actions]  # p(s_t, a_t; θonline)

        # Calculate nth next state probabilities
        self.online_net.reset_noise()  # Sample new noise for action selection
        pns = self.online_net(
            next_states).data  # Probabilities p(s_t+n, ·; θonline)
        dns = self.support.expand_as(
            pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
        argmax_indices_ns = dns.sum(2).max(
            1
        )[1]  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
        self.target_net.reset_noise()  # Sample new target net noise
        pns = self.target_net(
            next_states).data  # Probabilities p(s_t+n, ·; θtarget)
        pns_a = pns[range(
            self.batch_size
        ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

        # Compute Tz (Bellman operator T applied to z)
        Tz = returns.unsqueeze(1) + nonterminals * (
            self.discount**self.n) * self.support.unsqueeze(
                0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
        Tz = Tz.clamp(min=self.Vmin,
                      max=self.Vmax)  # Clamp between supported values
        # Compute L2 projection of Tz onto fixed support z
        b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
        l, u = b.floor().long(), b.ceil().long()
        # Fix disappearing probability mass when l = b = u (b is int)
        l[(u > 0) * (l == u)] -= 1
        u[(l < (self.atoms - 1)) * (l == u)] += 1

        # Distribute probability of Tz
        m = states.data.new(self.batch_size, self.atoms).zero_()
        offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                self.batch_size).unsqueeze(1).expand(
                                    self.batch_size,
                                    self.atoms).type_as(actions)
        m.view(-1).index_add_(
            0, (l + offset).view(-1),
            (pns_a *
             (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
        m.view(-1).index_add_(
            0, (u + offset).view(-1),
            (pns_a *
             (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        ps_a = ps_a.clamp(min=1e-3)  # Clamp for numerical stability in log
        loss = -torch.sum(
            Variable(m) * ps_a.log(),
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        (weights * loss).mean().backward()  # Importance weight losses
        self.optimiser.step()

        mem.update_priorities(
            idxs, loss.data)  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        return (self.online_net(state.unsqueeze(0)).data *
                self.support).sum(2).max(1)[0][0]

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
コード例 #6
0
ファイル: agent.py プロジェクト: renly/Rainbow
class Agent():
  def __init__(self, args, env):
    self.action_space = env.action_space()
    self.atoms = args.atoms
    self.Vmin = args.V_min
    self.Vmax = args.V_max
    self.support = torch.linspace(args.V_min, args.V_max, args.atoms)  # Support (range) of z
    self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1)
    self.batch_size = args.batch_size
    self.n = args.multi_step
    self.discount = args.discount
    self.priority_exponent = args.priority_exponent
    self.max_gradient_norm = args.max_gradient_norm

    self.policy_net = DQN(args, self.action_space)
    if args.model and os.path.isfile(args.model):
      self.policy_net.load_state_dict(torch.load(args.model))
    self.policy_net.train()

    self.target_net = DQN(args, self.action_space)
    self.update_target_net()
    self.target_net.eval()

    self.optimiser = optim.Adam(self.policy_net.parameters(), lr=args.lr, eps=args.adam_eps)
    if args.cuda:
      self.policy_net.cuda()
      self.target_net.cuda()
      self.support = self.support.cuda()

  # Resets noisy weights in all linear layers (of policy and target nets)
  def reset_noise(self):
    self.policy_net.reset_noise()
    self.target_net.reset_noise()

  # Acts based on single state (no batch)
  def act(self, state):
    return (self.policy_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[1][0]

  def learn(self, mem):
    idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)
    batch_size = len(idxs)  # May return less than specified if invalid transitions sampled

    # Calculate current state probabilities
    ps = self.policy_net(states)  # Probabilities p(s_t, ·; θpolicy)
    ps_a = ps[range(batch_size), actions]  # p(s_t, a_t; θpolicy)

    # Calculate nth next state probabilities
    pns = self.policy_net(next_states).data  # Probabilities p(s_t+n, ·; θpolicy)
    dns = self.support.expand_as(pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θpolicy))
    argmax_indices_ns = dns.sum(2).max(1)[1]  # Perform argmax action selection using policy network: argmax_a[(z, p(s_t+n, a; θpolicy))]
    pns = self.target_net(next_states).data  # Probabilities p(s_t+n, ·; θtarget)
    pns_a = pns[range(batch_size), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θpolicy))]; θtarget)
    pns_a *= nonterminals  # Set p = 0 for terminal nth next states as all possible expected returns = expected reward at final transition

    # Compute Tz (Bellman operator T applied to z)
    Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
    Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  # Clamp between supported values
    # Compute L2 projection of Tz onto fixed support z
    b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
    l, u = b.floor().long(), b.ceil().long()

    # Distribute probability of Tz
    m = states.data.new(batch_size, self.atoms).zero_()
    offset = torch.linspace(0, ((batch_size - 1) * self.atoms), batch_size).long().unsqueeze(1).expand(batch_size, self.atoms).type_as(actions)
    m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
    m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

    loss = -torch.sum(Variable(m) * ps_a.log(), 1)  # Cross-entropy loss (minimises Kullback-Leibler divergence)
    self.policy_net.zero_grad()
    (weights * loss).mean().backward()  # Importance weight losses
    nn.utils.clip_grad_norm(self.policy_net.parameters(), self.max_gradient_norm)  # Clip gradients (normalising by max value of gradient L2 norm)
    self.optimiser.step()

    mem.update_priorities(idxs, loss.data.abs().pow(self.priority_exponent))  # Update priorities of sampled transitions

  def update_target_net(self):
    self.target_net.load_state_dict(self.policy_net.state_dict())

  def save(self, path):
    torch.save(self.policy_net.state_dict(), os.path.join(path, 'model.pth'))

  # Evaluates Q-value based on single state (no batch)
  def evaluate_q(self, state):
    return (self.policy_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[0][0]

  def train(self):
    self.policy_net.train()

  def eval(self):
    self.policy_net.eval()
コード例 #7
0
class Agent(object):
    """ all improvments from Rainbow research work
    """
    def __init__(self, args, state_size, action_size):
        """
        Args:
           param1 (args): args
           param2 (int): args
           param3 (int): args
        """
        self.action_size = action_size
        self.state_size = state_size
        self.atoms = args.atoms
        self.V_min = args.V_min
        self.V_max = args.V_max
        self.device = args.device
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(
            device=self.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.qnetwork_local = DQN(args, self.state_size,
                                  self.action_size).to(device=args.device)
        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.qnetwork_local.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.qnetwork_local.train()

        self.target_net = DQN(args, self.state_size,
                              self.action_size).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)

    def reset_noise(self):
        """ resets noisy weights in all linear layers """
        self.qnetwork_local.reset_noise()

    def act(self, state):
        """
          acts greedy(max) based on a single state
          Args:
             param1 (int) : state
        """
        with torch.no_grad():
            return (self.qnetwork_local(state.unsqueeze(0).to(self.device)) *
                    self.support).sum(2).argmax(1).item()

    def act_e_greedy(self, state, epsilon=0.001):
        """ acts with epsilon greedy policy
            epsilon exploration vs exploitation traide off
        Args:
            param1(int): state
            param2(float): epsilon
        Return : action int number between 0 and 4
        """
        return np.random.randint(
            0, self.action_size) if np.random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        """ uses samples with the given batch size to improve the Q function
        Args:
            param1 (Experince Replay Buffer) : mem
        """
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.qnetwork_local(
            states, log=True)  # Log probabilities log p(s_t, *; theta online)
        log_ps_a = log_ps[range(self.batch_size),
                          actions]  # log p(s_t, a_t; theat online)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.qnetwork_local(
                next_states)  # Probabilities p(s_t+n, *; theta online)
            dns = self.support.expand_as(
                pns
            ) * pns  # Distribution d_t+n = (z, p(s_t+n, *; theat online))
            argmax_indices_ns = dns.sum(2).argmax(
                1
            )  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a;  theat online))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(
                next_states)  # Probabilities p(s_t+n,  ; theata target)
            pns_a = pns[range(
                self.batch_size
            ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; theat online))]; theat target)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n
            ) * self.support.unsqueeze(
                0)  # Tz = R^n + (discoit ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.V_min,
                          max=self.V_max)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.V_min) / self.delta_z  # b = (Tz - Vmin) / delta z
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (pns_a *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (pns_a *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_ps_a,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.qnetwork_local.zero_grad()
        (weights * loss).mean().backward(
        )  # Backpropagate importance-weighted minibatch loss
        self.optimizer.step()

        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # Update priorities of sampled transitions
        self.soft_update()

    def soft_update(self, tau=1e-3):
        """ swaps the network weights from the online to the target

        Args:
           param1 (float): tau
        """
        for target_param, local_param in zip(self.target_net.parameters(),
                                             self.qnetwork_local.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def update_target_net(self):
        """ copy the model weights from the online to the target network """
        self.target_net.load_state_dict(self.qnetwork_local.state_dict())

    def save(self, path):
        """ save the model weights to a file
        Args:
           param1 (string): pathname
        """
        torch.save(self.qnetwork_local.state_dict(),
                   os.path.join(path, 'model.pth'))

    def evaluate_q(self, state):
        """ Evaluates Q-value based on single state
        """
        with torch.no_grad():
            return (self.qnetwork_local(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        """
        activates the backprob. layers for the online network
        """
        self.qnetwork_local.train()

    def eval(self):
        """ invoke the eval from the online network
            deactivates the backprob
            layers like dropout will work in eval model instead
        """
        self.qnetwork_local.eval()
コード例 #8
0
ファイル: agent.py プロジェクト: yyht/qait_public
class Agent:
    def __init__(self):
        self.mode = "train"
        with open("config.yaml") as reader:
            self.config = yaml.safe_load(reader)
        print(self.config)
        self.load_config()

        self.online_net = DQN(config=self.config,
                              word_vocab=self.word_vocab,
                              char_vocab=self.char_vocab,
                              answer_type=self.answer_type)
        self.target_net = DQN(config=self.config,
                              word_vocab=self.word_vocab,
                              char_vocab=self.char_vocab,
                              answer_type=self.answer_type)
        self.online_net.train()
        self.target_net.train()
        self.update_target_net()
        for param in self.target_net.parameters():
            param.requires_grad = False

        if self.use_cuda:
            self.online_net.cuda()
            self.target_net.cuda()

        self.naozi = ObservationPool(capacity=self.naozi_capacity)
        # optimizer
        self.optimizer = torch.optim.Adam(
            self.online_net.parameters(),
            lr=self.config['training']['optimizer']['learning_rate'])
        self.clip_grad_norm = self.config['training']['optimizer'][
            'clip_grad_norm']

    def load_config(self):
        # word vocab
        with open("vocabularies/word_vocab.txt") as f:
            self.word_vocab = f.read().split("\n")
        self.word2id = {}
        for i, w in enumerate(self.word_vocab):
            self.word2id[w] = i
        # char vocab
        with open("vocabularies/char_vocab.txt") as f:
            self.char_vocab = f.read().split("\n")
        self.char2id = {}
        for i, w in enumerate(self.char_vocab):
            self.char2id[w] = i

        self.EOS_id = self.word2id["</s>"]
        self.train_data_size = self.config['general']['train_data_size']
        self.question_type = self.config['general']['question_type']
        self.random_map = self.config['general']['random_map']
        self.testset_path = self.config['general']['testset_path']
        self.naozi_capacity = self.config['general']['naozi_capacity']
        self.eval_folder = pjoin(
            self.testset_path, self.question_type,
            ("random_map" if self.random_map else "fixed_map"))
        self.eval_data_path = pjoin(self.testset_path, "data.json")

        self.batch_size = self.config['training']['batch_size']
        self.max_nb_steps_per_episode = self.config['training'][
            'max_nb_steps_per_episode']
        self.max_episode = self.config['training']['max_episode']
        self.target_net_update_frequency = self.config['training'][
            'target_net_update_frequency']
        self.learn_start_from_this_episode = self.config['training'][
            'learn_start_from_this_episode']

        self.run_eval = self.config['evaluate']['run_eval']
        self.eval_batch_size = self.config['evaluate']['batch_size']
        self.eval_max_nb_steps_per_episode = self.config['evaluate'][
            'max_nb_steps_per_episode']

        # Set the random seed manually for reproducibility.
        self.random_seed = self.config['general']['random_seed']
        np.random.seed(self.random_seed)
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            if not self.config['general']['use_cuda']:
                print(
                    "WARNING: CUDA device detected but 'use_cuda: false' found in config.yaml"
                )
                self.use_cuda = False
            else:
                torch.backends.cudnn.deterministic = True
                torch.cuda.manual_seed(self.random_seed)
                self.use_cuda = True
        else:
            self.use_cuda = False

        if self.question_type == "location":
            self.answer_type = "pointing"
        elif self.question_type in ["attribute", "existence"]:
            self.answer_type = "2 way"
        else:
            raise NotImplementedError

        self.save_checkpoint = self.config['checkpoint']['save_checkpoint']
        self.experiment_tag = self.config['checkpoint']['experiment_tag']
        self.save_frequency = self.config['checkpoint']['save_frequency']
        self.load_pretrained = self.config['checkpoint']['load_pretrained']
        self.load_from_tag = self.config['checkpoint']['load_from_tag']

        self.qa_loss_lambda = self.config['training']['qa_loss_lambda']
        self.interaction_loss_lambda = self.config['training'][
            'interaction_loss_lambda']

        # replay buffer and updates
        self.discount_gamma = self.config['replay']['discount_gamma']
        self.replay_batch_size = self.config['replay']['replay_batch_size']
        self.command_generation_replay_memory = command_generation_memory.PrioritizedReplayMemory(
            self.config['replay']['replay_memory_capacity'],
            priority_fraction=self.config['replay']
            ['replay_memory_priority_fraction'],
            discount_gamma=self.discount_gamma)
        self.qa_replay_memory = qa_memory.PrioritizedReplayMemory(
            self.config['replay']['replay_memory_capacity'],
            priority_fraction=0.0)
        self.update_per_k_game_steps = self.config['replay'][
            'update_per_k_game_steps']
        self.multi_step = self.config['replay']['multi_step']

        # distributional RL
        self.use_distributional = self.config['distributional']['enable']
        self.atoms = self.config['distributional']['atoms']
        self.v_min = self.config['distributional']['v_min']
        self.v_max = self.config['distributional']['v_max']
        self.support = torch.linspace(self.v_min, self.v_max,
                                      self.atoms)  # Support (range) of z
        if self.use_cuda:
            self.support = self.support.cuda()
        self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1)

        # dueling networks
        self.dueling_networks = self.config['dueling_networks']

        # double dqn
        self.double_dqn = self.config['double_dqn']

        # counting reward
        self.revisit_counting_lambda_anneal_episodes = self.config[
            'episodic_counting_bonus'][
                'revisit_counting_lambda_anneal_episodes']
        self.revisit_counting_lambda_anneal_from = self.config[
            'episodic_counting_bonus']['revisit_counting_lambda_anneal_from']
        self.revisit_counting_lambda_anneal_to = self.config[
            'episodic_counting_bonus']['revisit_counting_lambda_anneal_to']
        self.revisit_counting_lambda = self.revisit_counting_lambda_anneal_from

        # valid command bonus
        self.valid_command_bonus_lambda = self.config[
            'valid_command_bonus_lambda']

        # epsilon greedy
        self.epsilon_anneal_episodes = self.config['epsilon_greedy'][
            'epsilon_anneal_episodes']
        self.epsilon_anneal_from = self.config['epsilon_greedy'][
            'epsilon_anneal_from']
        self.epsilon_anneal_to = self.config['epsilon_greedy'][
            'epsilon_anneal_to']
        self.epsilon = self.epsilon_anneal_from
        self.noisy_net = self.config['epsilon_greedy']['noisy_net']
        if self.noisy_net:
            # disable epsilon greedy
            self.epsilon_anneal_episodes = -1
            self.epsilon = 0.0

        self.nlp = spacy.load('en', disable=['ner', 'parser', 'tagger'])
        self.single_word_verbs = set(["inventory", "look", "wait"])
        self.two_word_verbs = set(["go"])

    def train(self):
        """
        Tell the agent that it's training phase.
        """
        self.mode = "train"
        self.online_net.train()

    def eval(self):
        """
        Tell the agent that it's evaluation phase.
        """
        self.mode = "eval"
        self.online_net.eval()

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def reset_noise(self):
        if self.noisy_net:
            # Resets noisy weights in all linear layers (of online net only)
            self.online_net.reset_noise()

    def zero_noise(self):
        if self.noisy_net:
            self.online_net.zero_noise()
            self.target_net.zero_noise()

    def load_pretrained_model(self, load_from):
        """
        Load pretrained checkpoint from file.

        Arguments:
            load_from: File name of the pretrained model checkpoint.
        """
        print("loading model from %s\n" % (load_from))
        try:
            if self.use_cuda:
                state_dict = torch.load(load_from)
            else:
                state_dict = torch.load(load_from, map_location='cpu')
            self.online_net.load_state_dict(state_dict)
        except:
            print("Failed to load checkpoint...")

    def save_model_to_path(self, save_to):
        torch.save(self.online_net.state_dict(), save_to)
        print("Saved checkpoint to %s..." % (save_to))

    def init(self, obs, infos):
        """
        Prepare the agent for the upcoming games.

        Arguments:
            obs: Previous command's feedback for each game.
            infos: Additional information for each game.
        """
        # reset agent, get vocabulary masks for verbs / adjectives / nouns
        batch_size = len(obs)
        self.reset_binarized_counter(batch_size)
        self.not_finished_yet = np.ones((batch_size, ), dtype="float32")
        self.prev_actions = [["" for _ in range(batch_size)]]
        self.prev_step_is_still_interacting = np.ones(
            (batch_size, ), dtype="float32"
        )  # 1s and starts to be 0 when previous action is "wait"
        self.naozi.reset(batch_size=batch_size)

    def get_agent_inputs(self, string_list):
        sentence_token_list = [item.split() for item in string_list]
        sentence_id_list = [
            _words_to_ids(tokens, self.word2id)
            for tokens in sentence_token_list
        ]
        input_sentence_char = list_of_token_list_to_char_input(
            sentence_token_list, self.char2id)
        input_sentence = pad_sequences(
            sentence_id_list, maxlen=max_len(sentence_id_list)).astype('int32')
        input_sentence = to_pt(input_sentence, self.use_cuda)
        input_sentence_char = to_pt(input_sentence_char, self.use_cuda)
        return input_sentence, input_sentence_char, sentence_id_list

    def get_game_info_at_certain_step(self, obs, infos):
        """
        Get all needed info from game engine for training.
        Arguments:
            obs: Previous command's feedback for each game.
            infos: Additional information for each game.
        """
        batch_size = len(obs)
        feedback_strings = [preproc(item, tokenizer=self.nlp) for item in obs]
        description_strings = [
            preproc(item, tokenizer=self.nlp) for item in infos["description"]
        ]
        observation_strings = [
            d + " <|> " + fb if fb != d else d + " <|> hello"
            for fb, d in zip(feedback_strings, description_strings)
        ]

        inventory_strings = [
            preproc(item, tokenizer=self.nlp) for item in infos["inventory"]
        ]
        local_word_list = [
            obs.split() + inv.split()
            for obs, inv in zip(observation_strings, inventory_strings)
        ]

        directions = ["east", "west", "north", "south"]
        if self.question_type in ["location", "existence"]:
            # agents observes the env, but do not change them
            possible_verbs = [["go", "inventory", "wait", "open", "examine"]
                              for _ in range(batch_size)]
        else:
            possible_verbs = [
                list(set(item) - set(["", "look"])) for item in infos["verbs"]
            ]

        possible_adjs, possible_nouns = [], []
        for i in range(batch_size):
            object_nouns = [
                item.split()[-1] for item in infos["object_nouns"][i]
            ]
            object_adjs = [
                w for item in infos["object_adjs"][i] for w in item.split()
            ]
            possible_nouns.append(
                list(set(object_nouns) & set(local_word_list[i]) - set([""])) +
                directions)
            possible_adjs.append(
                list(set(object_adjs) & set(local_word_list[i]) - set([""])) +
                ["</s>"])

        return observation_strings, [
            possible_verbs, possible_adjs, possible_nouns
        ]

    def get_state_strings(self, infos):
        description_strings = infos["description"]
        inventory_strings = infos["inventory"]
        observation_strings = [
            _d + _i for (_d, _i) in zip(description_strings, inventory_strings)
        ]
        return observation_strings

    def get_local_word_masks(self, possible_words):
        possible_verbs, possible_adjs, possible_nouns = possible_words
        batch_size = len(possible_verbs)

        verb_mask = np.zeros((batch_size, len(self.word_vocab)),
                             dtype="float32")
        noun_mask = np.zeros((batch_size, len(self.word_vocab)),
                             dtype="float32")
        adj_mask = np.zeros((batch_size, len(self.word_vocab)),
                            dtype="float32")
        for i in range(batch_size):
            for w in possible_verbs[i]:
                if w in self.word2id:
                    verb_mask[i][self.word2id[w]] = 1.0
            for w in possible_adjs[i]:
                if w in self.word2id:
                    adj_mask[i][self.word2id[w]] = 1.0
            for w in possible_nouns[i]:
                if w in self.word2id:
                    noun_mask[i][self.word2id[w]] = 1.0
        adj_mask[:, self.EOS_id] = 1.0

        return [verb_mask, adj_mask, noun_mask]

    def get_match_representations(self,
                                  input_observation,
                                  input_observation_char,
                                  input_quest,
                                  input_quest_char,
                                  use_model="online"):
        model = self.online_net if use_model == "online" else self.target_net
        description_representation_sequence, description_mask = model.representation_generator(
            input_observation, input_observation_char)
        quest_representation_sequence, quest_mask = model.representation_generator(
            input_quest, input_quest_char)

        match_representation_sequence = model.get_match_representations(
            description_representation_sequence, description_mask,
            quest_representation_sequence, quest_mask)
        match_representation_sequence = match_representation_sequence * description_mask.unsqueeze(
            -1)
        return match_representation_sequence

    def get_ranks(self,
                  input_observation,
                  input_observation_char,
                  input_quest,
                  input_quest_char,
                  word_masks,
                  use_model="online"):
        """
        Given input observation and question tensors, to get Q values of words.
        """
        model = self.online_net if use_model == "online" else self.target_net
        match_representation_sequence = self.get_match_representations(
            input_observation,
            input_observation_char,
            input_quest,
            input_quest_char,
            use_model=use_model)
        action_ranks = model.action_scorer(match_representation_sequence,
                                           word_masks)  # list of 3 tensors
        return action_ranks

    def choose_maxQ_command(self, action_ranks, word_mask=None):
        """
        Generate a command by maximum q values, for epsilon greedy.
        """
        if self.use_distributional:
            action_ranks = [
                (item * self.support).sum(2) for item in action_ranks
            ]  # list of batch x n_vocab
        action_indices = []
        for i in range(len(action_ranks)):
            ar = action_ranks[i]
            ar = ar - torch.min(
                ar, -1, keepdim=True
            )[0] + 1e-2  # minus the min value, so that all values are non-negative
            if word_mask is not None:
                assert word_mask[i].size() == ar.size(), (
                    word_mask[i].size().shape, ar.size())
                ar = ar * word_mask[i]
            action_indices.append(torch.argmax(ar, -1))  # batch
        return action_indices

    def choose_random_command(self,
                              batch_size,
                              action_space_size,
                              possible_words=None):
        """
        Generate a command randomly, for epsilon greedy.
        """
        action_indices = []
        for i in range(3):
            if possible_words is None:
                indices = np.random.choice(action_space_size, batch_size)
            else:
                indices = []
                for j in range(batch_size):
                    mask_ids = []
                    for w in possible_words[i][j]:
                        if w in self.word2id:
                            mask_ids.append(self.word2id[w])
                    indices.append(np.random.choice(mask_ids))
                indices = np.array(indices)
            action_indices.append(to_pt(indices, self.use_cuda))  # batch
        return action_indices

    def get_chosen_strings(self, chosen_indices):
        """
        Turns list of word indices into actual command strings.
        chosen_indices: Word indices chosen by model.
        """
        chosen_indices_np = [to_np(item) for item in chosen_indices]
        res_str = []
        batch_size = chosen_indices_np[0].shape[0]
        for i in range(batch_size):
            verb, adj, noun = chosen_indices_np[0][i], chosen_indices_np[1][
                i], chosen_indices_np[2][i]
            res_str.append(self.word_ids_to_commands(verb, adj, noun))
        return res_str

    def word_ids_to_commands(self, verb, adj, noun):
        """
        Turn the 3 indices into actual command strings.

        Arguments:
            verb: Index of the guessing verb in vocabulary
            adj: Index of the guessing adjective in vocabulary
            noun: Index of the guessing noun in vocabulary
        """
        # turns 3 indices into actual command strings
        if self.word_vocab[verb] in self.single_word_verbs:
            return self.word_vocab[verb]
        if self.word_vocab[verb] in self.two_word_verbs:
            return " ".join([self.word_vocab[verb], self.word_vocab[noun]])
        if adj == self.EOS_id:
            return " ".join([self.word_vocab[verb], self.word_vocab[noun]])
        else:
            return " ".join([
                self.word_vocab[verb], self.word_vocab[adj],
                self.word_vocab[noun]
            ])

    def act_random(self, obs, infos, input_observation, input_observation_char,
                   input_quest, input_quest_char, possible_words):
        with torch.no_grad():
            batch_size = len(obs)
            word_indices_random = self.choose_random_command(
                batch_size, len(self.word_vocab), possible_words)
            chosen_indices = word_indices_random
            chosen_strings = self.get_chosen_strings(chosen_indices)

            for i in range(batch_size):
                if chosen_strings[i] == "wait":
                    self.not_finished_yet[i] = 0.0

            # info for replay memory
            for i in range(batch_size):
                if self.prev_actions[-1][i] == "wait":
                    self.prev_step_is_still_interacting[i] = 0.0
            # previous step is still interacting, this is because DQN requires one step extra computation
            replay_info = [
                chosen_indices,
                to_pt(self.prev_step_is_still_interacting, self.use_cuda,
                      "float")
            ]

            # cache new info in current game step into caches
            self.prev_actions.append(chosen_strings)
            return chosen_strings, replay_info

    def act_greedy(self, obs, infos, input_observation, input_observation_char,
                   input_quest, input_quest_char, possible_words):
        """
        Acts upon the current list of observations.
        One text command must be returned for each observation.
        """
        with torch.no_grad():
            batch_size = len(obs)
            local_word_masks_np = self.get_local_word_masks(possible_words)
            local_word_masks = [
                to_pt(item, self.use_cuda, type="float")
                for item in local_word_masks_np
            ]

            # generate commands for one game step, epsilon greedy is applied, i.e.,
            # there is epsilon of chance to generate random commands
            action_ranks = self.get_ranks(
                input_observation,
                input_observation_char,
                input_quest,
                input_quest_char,
                local_word_masks,
                use_model="online")  # list of batch x vocab
            word_indices_maxq = self.choose_maxQ_command(
                action_ranks, local_word_masks)
            chosen_indices = word_indices_maxq
            chosen_strings = self.get_chosen_strings(chosen_indices)

            for i in range(batch_size):
                if chosen_strings[i] == "wait":
                    self.not_finished_yet[i] = 0.0

            # info for replay memory
            for i in range(batch_size):
                if self.prev_actions[-1][i] == "wait":
                    self.prev_step_is_still_interacting[i] = 0.0
            # previous step is still interacting, this is because DQN requires one step extra computation
            replay_info = [
                chosen_indices,
                to_pt(self.prev_step_is_still_interacting, self.use_cuda,
                      "float")
            ]

            # cache new info in current game step into caches
            self.prev_actions.append(chosen_strings)
            return chosen_strings, replay_info

    def act(self,
            obs,
            infos,
            input_observation,
            input_observation_char,
            input_quest,
            input_quest_char,
            possible_words,
            random=False):
        """
        Acts upon the current list of observations.
        One text command must be returned for each observation.
        """
        with torch.no_grad():
            if self.mode == "eval":
                return self.act_greedy(obs, infos, input_observation,
                                       input_observation_char, input_quest,
                                       input_quest_char, possible_words)
            if random:
                return self.act_random(obs, infos, input_observation,
                                       input_observation_char, input_quest,
                                       input_quest_char, possible_words)
            batch_size = len(obs)

            local_word_masks_np = self.get_local_word_masks(possible_words)
            local_word_masks = [
                to_pt(item, self.use_cuda, type="float")
                for item in local_word_masks_np
            ]

            # generate commands for one game step, epsilon greedy is applied, i.e.,
            # there is epsilon of chance to generate random commands
            action_ranks = self.get_ranks(
                input_observation,
                input_observation_char,
                input_quest,
                input_quest_char,
                local_word_masks,
                use_model="online")  # list of batch x vocab
            word_indices_maxq = self.choose_maxQ_command(
                action_ranks, local_word_masks)
            word_indices_random = self.choose_random_command(
                batch_size, len(self.word_vocab), possible_words)

            # random number for epsilon greedy
            rand_num = np.random.uniform(low=0.0,
                                         high=1.0,
                                         size=(batch_size, ))
            less_than_epsilon = (rand_num < self.epsilon).astype(
                "float32")  # batch
            greater_than_epsilon = 1.0 - less_than_epsilon
            less_than_epsilon = to_pt(less_than_epsilon,
                                      self.use_cuda,
                                      type='long')
            greater_than_epsilon = to_pt(greater_than_epsilon,
                                         self.use_cuda,
                                         type='long')
            chosen_indices = [
                less_than_epsilon * idx_random +
                greater_than_epsilon * idx_maxq
                for idx_random, idx_maxq in zip(word_indices_random,
                                                word_indices_maxq)
            ]
            chosen_strings = self.get_chosen_strings(chosen_indices)

            for i in range(batch_size):
                if chosen_strings[i] == "wait":
                    self.not_finished_yet[i] = 0.0

            # info for replay memory
            for i in range(batch_size):
                if self.prev_actions[-1][i] == "wait":
                    self.prev_step_is_still_interacting[i] = 0.0
            # previous step is still interacting, this is because DQN requires one step extra computation
            replay_info = [
                chosen_indices,
                to_pt(self.prev_step_is_still_interacting, self.use_cuda,
                      "float")
            ]

            # cache new info in current game step into caches
            self.prev_actions.append(chosen_strings)
            return chosen_strings, replay_info

    def get_dqn_loss(self):
        """
        Update neural model in agent. In this example we follow algorithm
        of updating model in dqn with replay memory.
        """
        if len(self.command_generation_replay_memory) < self.replay_batch_size:
            return None

        data = self.command_generation_replay_memory.get_batch(
            self.replay_batch_size, self.multi_step)
        if data is None:
            return None

        obs_list, quest_list, possible_words_list, chosen_indices, rewards, next_obs_list, next_possible_words_list, actual_n_list = data
        batch_size = len(actual_n_list)

        input_quest, input_quest_char, _ = self.get_agent_inputs(quest_list)
        input_observation, input_observation_char, _ = self.get_agent_inputs(
            obs_list)
        next_input_observation, next_input_observation_char, _ = self.get_agent_inputs(
            next_obs_list)

        possible_words, next_possible_words = [], []
        for i in range(3):
            possible_words.append([item[i] for item in possible_words_list])
            next_possible_words.append(
                [item[i] for item in next_possible_words_list])

        local_word_masks = [
            to_pt(item, self.use_cuda, type="float")
            for item in self.get_local_word_masks(possible_words)
        ]
        next_local_word_masks = [
            to_pt(item, self.use_cuda, type="float")
            for item in self.get_local_word_masks(next_possible_words)
        ]

        action_ranks = self.get_ranks(
            input_observation,
            input_observation_char,
            input_quest,
            input_quest_char,
            local_word_masks,
            use_model="online"
        )  # list of batch x vocab or list of batch x vocab x atoms
        # ps_a
        word_qvalues = [
            ez_gather_dim_1(w_rank, idx.unsqueeze(-1)).squeeze(1)
            for w_rank, idx in zip(action_ranks, chosen_indices)
        ]  # list of batch or list of batch x atoms
        q_value = torch.mean(torch.stack(word_qvalues, -1),
                             -1)  # batch or batch x atoms
        # log_ps_a
        log_q_value = torch.log(q_value)  # batch or batch x atoms

        with torch.no_grad():
            if self.noisy_net:
                self.target_net.reset_noise()  # Sample new target net noise
            if self.double_dqn:
                # pns Probabilities p(s_t+n, ·; θonline)
                next_action_ranks = self.get_ranks(next_input_observation,
                                                   next_input_observation_char,
                                                   input_quest,
                                                   input_quest_char,
                                                   next_local_word_masks,
                                                   use_model="online")
                # list of batch x vocab or list of batch x vocab x atoms
                # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
                next_word_indices = self.choose_maxQ_command(
                    next_action_ranks,
                    next_local_word_masks)  # list of batch x 1
                # pns # Probabilities p(s_t+n, ·; θtarget)
                next_action_ranks = self.get_ranks(
                    next_input_observation,
                    next_input_observation_char,
                    input_quest,
                    input_quest_char,
                    next_local_word_masks,
                    use_model="target"
                )  # batch x vocab or list of batch x vocab x atoms
                # pns_a # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)
                next_word_qvalues = [
                    ez_gather_dim_1(w_rank, idx.unsqueeze(-1)).squeeze(1) for
                    w_rank, idx in zip(next_action_ranks, next_word_indices)
                ]  # list of batch or list of batch x atoms
            else:
                # pns Probabilities p(s_t+n, ·; θonline)
                next_action_ranks = self.get_ranks(next_input_observation,
                                                   next_input_observation_char,
                                                   input_quest,
                                                   input_quest_char,
                                                   next_local_word_masks,
                                                   use_model="target")
                # list of batch x vocab or list of batch x vocab x atoms
                next_word_indices = self.choose_maxQ_command(
                    next_action_ranks,
                    next_local_word_masks)  # list of batch x 1
                next_word_qvalues = [
                    ez_gather_dim_1(w_rank, idx.unsqueeze(-1)).squeeze(1) for
                    w_rank, idx in zip(next_action_ranks, next_word_indices)
                ]  # list of batch or list of batch x atoms

            next_q_value = torch.mean(torch.stack(next_word_qvalues, -1),
                                      -1)  # batch or batch x atoms
            # Compute Tz (Bellman operator T applied to z)
            discount = to_pt((np.ones_like(actual_n_list) *
                              self.discount_gamma)**actual_n_list,
                             self.use_cuda,
                             type="float")
        if not self.use_distributional:
            rewards = rewards + next_q_value * discount  # batch
            loss = F.smooth_l1_loss(q_value, rewards)
            return loss

        with torch.no_grad():
            Tz = rewards.unsqueeze(
                -1) + discount.unsqueeze(-1) * self.support.unsqueeze(
                    0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.v_min,
                          max=self.v_max)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.v_min) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = torch.zeros(batch_size, self.atoms).float()
            if self.use_cuda:
                m = m.cuda()
            offset = torch.linspace(0, ((batch_size - 1) * self.atoms),
                                    batch_size).unsqueeze(1).expand(
                                        batch_size, self.atoms).long()
            if self.use_cuda:
                offset = offset.cuda()
            m.view(-1).index_add_(
                0, (l + offset).view(-1),
                (next_q_value *
                 (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(
                0, (u + offset).view(-1),
                (next_q_value *
                 (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(
            m * log_q_value,
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        loss = torch.mean(loss)
        return loss

    def update_interaction(self):
        # update neural model by replaying snapshots in replay memory
        interaction_loss = self.get_dqn_loss()
        if interaction_loss is None:
            return None
        loss = interaction_loss * self.interaction_loss_lambda
        # Backpropagate
        self.online_net.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(),
                                       self.clip_grad_norm)
        self.optimizer.step()  # apply gradients
        return to_np(torch.mean(interaction_loss))

    def answer_question(self,
                        input_observation,
                        input_observation_char,
                        observation_id_list,
                        input_quest,
                        input_quest_char,
                        use_model="online"):
        # first pad answerer_input, and get the mask
        model = self.online_net if use_model == "online" else self.target_net
        batch_size = len(observation_id_list)
        max_length = input_observation.size(1)
        mask = compute_mask(input_observation)  # batch x obs_len

        # noun mask for location question
        if self.question_type in ["location"]:
            location_mask = []
            for i in range(batch_size):
                m = [1 for item in observation_id_list[i]]
                location_mask.append(m)
            location_mask = pad_sequences(location_mask,
                                          maxlen=max_length,
                                          dtype="float32")
            location_mask = to_pt(location_mask,
                                  enable_cuda=self.use_cuda,
                                  type='float')
            assert mask.size() == location_mask.size()
            mask = mask * location_mask

        match_representation_sequence = self.get_match_representations(
            input_observation,
            input_observation_char,
            input_quest,
            input_quest_char,
            use_model=use_model)
        pred = model.answer_question(match_representation_sequence,
                                     mask)  # batch x vocab or batch x 2

        # attention sum:
        # sometimes certain word appears multiple times in the observation,
        # thus we need to merge them together before doing further computations
        # ------- but
        # if answer type is not pointing, we just use a pre-defined mapping
        # that maps 0/1 to their positions in vocab
        if self.answer_type == "2 way":
            observation_id_list = []
            max_length = 2
            for i in range(batch_size):
                observation_id_list.append(
                    [self.word2id["0"], self.word2id["1"]])

        observation = to_pt(
            pad_sequences(observation_id_list,
                          maxlen=max_length).astype('int32'), self.use_cuda)
        vocab_distribution = np.zeros(
            (batch_size, len(self.word_vocab)))  # batch x vocab
        vocab_distribution = to_pt(vocab_distribution,
                                   self.use_cuda,
                                   type='float')
        vocab_distribution = vocab_distribution.scatter_add_(
            1, observation, pred)  # batch x vocab
        non_zero_words = []
        for i in range(batch_size):
            non_zero_words.append(list(set(observation_id_list[i])))
        vocab_mask = torch.ne(vocab_distribution, 0).float()
        return vocab_distribution, non_zero_words, vocab_mask

    def point_maxq_position(self, vocab_distribution, mask):
        """
        Generate a command by maximum q values, for epsilon greedy.

        Arguments:
            point_distribution: Q values for each position (mapped to vocab).
            mask: vocab masks.
        """
        vocab_distribution = vocab_distribution - torch.min(
            vocab_distribution, -1, keepdim=True
        )[0] + 1e-2  # minus the min value, so that all values are non-negative
        vocab_distribution = vocab_distribution * mask  # batch x vocab
        indices = torch.argmax(vocab_distribution, -1)  # batch
        return indices

    def answer_question_act_greedy(self, input_observation,
                                   input_observation_char, observation_id_list,
                                   input_quest, input_quest_char):

        with torch.no_grad():
            vocab_distribution, _, vocab_mask = self.answer_question(
                input_observation,
                input_observation_char,
                observation_id_list,
                input_quest,
                input_quest_char,
                use_model="online")  # batch x time
            positions_maxq = self.point_maxq_position(vocab_distribution,
                                                      vocab_mask)
            return positions_maxq  # batch

    def get_qa_loss(self):
        """
        Update neural model in agent. In this example we follow algorithm
        of updating model in dqn with replay memory.
        """
        if len(self.qa_replay_memory) < self.replay_batch_size:
            return None
        transitions = self.qa_replay_memory.sample(self.replay_batch_size)
        batch = qa_memory.qa_Transition(*zip(*transitions))

        observation_list = batch.observation_list
        quest_list = batch.quest_list
        answer_strings = batch.answer_strings
        answer_position = np.array(_words_to_ids(answer_strings, self.word2id))
        groundtruth = to_pt(answer_position, self.use_cuda)  # batch

        input_quest, input_quest_char, _ = self.get_agent_inputs(quest_list)
        input_observation, input_observation_char, observation_id_list = self.get_agent_inputs(
            observation_list)

        answer_distribution, _, _ = self.answer_question(
            input_observation,
            input_observation_char,
            observation_id_list,
            input_quest,
            input_quest_char,
            use_model="online")  # batch x vocab

        batch_loss = NegativeLogLoss(answer_distribution, groundtruth)  # batch
        return torch.mean(batch_loss)

    def update_qa(self):
        # update neural model by replaying snapshots in replay memory
        qa_loss = self.get_qa_loss()
        if qa_loss is None:
            return None
        loss = qa_loss * self.qa_loss_lambda
        # Backpropagate
        self.online_net.zero_grad()
        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(),
                                       self.clip_grad_norm)
        self.optimizer.step()  # apply gradients
        return to_np(torch.mean(qa_loss))

    def finish_of_episode(self, episode_no, batch_size):
        # Update target networt
        if (
                episode_no + batch_size
        ) % self.target_net_update_frequency <= episode_no % self.target_net_update_frequency:
            self.update_target_net()
        # decay lambdas
        if episode_no < self.learn_start_from_this_episode:
            return
        if episode_no < self.epsilon_anneal_episodes + self.learn_start_from_this_episode:
            self.epsilon -= (self.epsilon_anneal_from - self.epsilon_anneal_to
                             ) / float(self.epsilon_anneal_episodes)
            self.epsilon = max(self.epsilon, 0.0)
        if episode_no < self.revisit_counting_lambda_anneal_episodes + self.learn_start_from_this_episode:
            self.revisit_counting_lambda -= (
                self.revisit_counting_lambda_anneal_from -
                self.revisit_counting_lambda_anneal_to) / float(
                    self.revisit_counting_lambda_anneal_episodes)
            self.revisit_counting_lambda = max(self.epsilon, 0.0)

    def reset_binarized_counter(self, batch_size):
        self.binarized_counter_dict = [{} for _ in range(batch_size)]

    def get_binarized_count(self, observation_strings, update=True):
        count_rewards = []
        batch_size = len(observation_strings)
        for i in range(batch_size):
            key = observation_strings[i]
            if key not in self.binarized_counter_dict[i]:
                self.binarized_counter_dict[i][key] = 0.0
            if update:
                self.binarized_counter_dict[i][key] += 1.0
            r = self.binarized_counter_dict[i][key]
            r = float(r == 1.0)
            count_rewards.append(r)
        return count_rewards
コード例 #9
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = np.linspace(args.V_min, args.V_max,
                                   self.atoms)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_steps
        self.discount = args.discount
        self.norm_clip = args.max_norm_clip

        self.sess = tf.Session()

        with tf.variable_scope("online_net"):
            self.online_net = DQN(args, self.action_space)

        self.online_net.train()

        with tf.variable_scope("target_net"):
            self.target_net = DQN(args, self.action_space)
        self.target_net.train()

        self.sess.run(tf.global_variables_initializer())

        self.saver = tf.train.Saver()
        if tf.gfile.Exists("./models/model.ckpt"):
            self.saver.restore(self.sess, "./models/model.ckpt")

        online_net_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                 scope="online_net")
        target_net_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                 scope="target_net")
        update_target_op = []
        for var, var_target in zip(
                sorted(online_net_func_vars, key=lambda v: v.name),
                sorted(target_net_func_vars, key=lambda v: v.name)):
            update_target_op.append(var_target.assign(var))
        self.update_target_op = tf.group(*update_target_op)

        self.update_target_net()

        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=args.learning_rate, epsilon=args.adam_eps)

    def forward(self, network, inputs, log=False):
        if log:
            output = self.sess.run(network.action_log,
                                   feed_dict={network.inputs: inputs})
            return output
        else:
            output = self.sess.run(network.action,
                                   feed_dict={network.inputs: inputs})
            return output

    def reset_noise(self):
        self.online_net.reset_noise()

    def act(self, state):
        return np.argmax(
            np.sum(
                (self.forward(self.online_net, state.reshape(1, 84, 84, 4)) *
                 self.support),
                axis=-1))

    def act_e_greedy(self, state, epsilon=0.001):
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)

        log_ps = self.forward(self.online_net, states, log=True)
        log_ps_a = []
        for i in range(self.batch_size):
            log_ps_a.append(log_ps[i][actions[i]])

        pns = self.forward(self.online_net, next_states)
        dns = np.broadcast_to(self.support, (self.action_space, self.atoms))
        dns = np.multiply(
            np.broadcast_to(dns,
                            (self.batch_size, self.action_space, self.atoms)),
            pns)
        argmax_indices_ns = np.argmax(np.sum(dns, axis=2), axis=1)
        self.target_net.reset_noise()
        pns = self.forward(self.target_net, next_states)
        pns_a = pns[range(self.batch_size), argmax_indices_ns]

        Tz = np.expand_dims(
            returns, axis=1) + (self.discount**self.n) * np.multiply(
                nonterminals, np.expand_dims(self.support, axis=0))
        Tz = np.clip(Tz, self.Vmin, self.Vmax)
        b = (Tz - self.Vmin) / self.delta_z
        l, u = np.floor(b).astype(dtype=np.int64), np.ceil(b).astype(
            dtype=np.int64)
        l[(u > 0) * (l == u)] -= 1
        u[(l < (self.atoms - 1)) * (l == u)] += 1

        m = np.zeros([self.batch_size, self.atoms], dtype=states.dtype)
        offset = np.broadcast_to(
            np.expand_dims(np.linspace(0, ((self.batch_size - 1) * self.atoms),
                                       self.batch_size),
                           axis=1),
            (self.batch_size, self.atoms)).astype(actions.dtype)
        np.add.at(m.flatten(), (l + offset).flatten(),
                  (pns_a * (u.astype(np.float32) - b)).flatten())
        np.add.at(m.flatten(), (u + offset).flatten(),
                  (pns_a * (b - l.astype(np.float32))).flatten())

        loss = -np.sum(m * log_ps_a, 1)
        loss = weights * loss

    def update_target_net(self):
        self.sess.run(self.update_target_op)

    def save(self, path):
        self.save_path = self.saver.save(self.sess, "./models/model.ckpt")

    def evaluate_q(self, state):
        return np.sum(
            (self.forward(self.online_net, state.reshape(1, 84, 84, 4)) *
             self.support),
            axis=-1).max(axis=1)[0]

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
コード例 #10
0
ファイル: agent.py プロジェクト: cg31/cule
class Agent():
  def __init__(self, args, action_space):
      self.action_space = action_space
      self.n = args.multi_step
      self.discount = args.discount
      self.target_update = args.target_update
      self.categorical = args.categorical
      self.noisy_linear = args.noisy_linear
      self.double_q = args.double_q
      self.max_grad_norm = args.max_grad_norm
      self.device = torch.device('cuda', args.gpu)
      self.num_param_updates = 0

      if args.categorical:
          self.atoms = args.atoms
          self.v_min = args.v_min
          self.v_max = args.v_max
          self.support = torch.linspace(self.v_min, args.v_max, self.atoms).to(device=self.device)  # Support (range) of z
          self.delta_z = (args.v_max - self.v_min) / (self.atoms - 1)

      self.online_net = DQN(args, self.action_space.n).to(device=self.device)

      if args.model and os.path.isfile(args.model):
          # Always load tensors onto CPU by default, will shift to GPU if necessary
          self.online_net.load_state_dict(torch.load(args.model, map_location='cpu'))
      self.online_net.train()

      self.target_net = DQN(args, self.action_space.n).to(device=self.device)
      self.update_target_net()
      self.target_net.eval()
      for param in self.target_net.parameters():
          param.requires_grad = False

      self.optimizer = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps, amsgrad=True)

      if args.distributed:
          self.online_net = DDP(self.online_net)

  # Resets noisy weights in all linear layers (of online net only)
  def reset_noise(self):
      if isinstance(self.online_net, DQN):
          self.online_net.reset_noise()
      else:
          self.online_net.module.reset_noise()

  # Acts based on single state (no batch)
  def act(self, state):
    with torch.no_grad():
      probs = self.online_net(state.to(self.device))
      if self.categorical:
          probs = self.support.expand_as(probs) * probs
      actions = probs.sum(-1).argmax(-1).to(state.device)
      return actions

  # Acts with an ε-greedy policy (used for evaluation only)
  def act_e_greedy(self, state, epsilon=0.001):  # High ε can reduce evaluation scores drastically
    actions = self.act(state)

    mask = torch.rand(state.size(0), device=state.device, dtype=torch.float32) < epsilon
    masked = mask.sum().item()
    if masked > 0:
        actions[mask] = torch.randint(0, self.action_space.n, (masked,), device=state.device, dtype=torch.long)

    return actions

  def learn(self, states, actions, returns, next_states, nonterminals, weights):

    tactions = actions.unsqueeze(-1).unsqueeze(-1)
    if self.categorical:
        tactions = tactions.expand(-1, -1, self.atoms)

    # Calculate current state probabilities (online network noise already sampled)
    nvtx.range_push('agent:online (state) probs')
    ps = self.online_net(states, log=True)  # Log probabilities log p(s_t, ·; θonline)
    ps_a = ps.gather(1, tactions)  # log p(s_t, a_t; θonline)
    nvtx.range_pop()

    with torch.no_grad():
      if isinstance(self.target_net, DQN):
          self.target_net.reset_noise()
      else:
          self.target_net.module.reset_noise()  # Sample new target net noise

      nvtx.range_push('agent:target (next state) probs')
      tns = self.target_net(next_states)  # Probabilities p(s_t+n, ·; θtarget)
      nvtx.range_pop()

      if self.double_q:
        # Calculate nth next state probabilities
        nvtx.range_push('agent:online (next state) probs')
        pns = self.online_net(next_states)  # Probabilities p(s_t+n, ·; θonline)
        nvtx.range_pop()
      else:
        pns = tns

      if self.categorical:
        pns = self.support.expand_as(pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))

      # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
      argmax_indices_ns = pns.sum(-1).argmax(-1).unsqueeze(-1).unsqueeze(-1)
      if self.categorical:
          argmax_indices_ns = argmax_indices_ns.expand(-1, -1, self.atoms)
      pns_a = tns.gather(1, argmax_indices_ns)  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

      if self.categorical:
        # Compute Tz (Bellman operator T applied to z)
        # Tz = R^n + (γ^n)z (accounting for terminal states)
        Tz = returns.unsqueeze(-1) + nonterminals.float().unsqueeze(-1) * (self.discount ** self.n) * self.support.unsqueeze(0)
        Tz = Tz.clamp(min=self.v_min, max=self.v_max)  # Clamp between supported values
        # Compute L2 projection of Tz onto fixed support z
        b = (Tz - self.v_min) / self.delta_z  # b = (Tz - Vmin) / Δz
        l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
        # Fix disappearing probability mass when l = b = u (b is int)
        l[(u > 0) * (l == u)] -= 1
        u[(l < (self.atoms - 1)) * (l == u)] += 1

        # Distribute probability of Tz
        batch_size = states.size(0)
        m = states.new_zeros(batch_size, self.atoms)
        offset = torch.linspace(0, ((batch_size - 1) * self.atoms), batch_size).unsqueeze(1).expand(batch_size, self.atoms).to(actions)
        m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a.squeeze(1) * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
        m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a.squeeze(1) * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)
      else:
        Tz = returns + nonterminals.float() * (self.discount ** self.n) * pns_a.squeeze(-1).squeeze(-1)

    if self.categorical:
        loss = -torch.sum(m * ps_a.squeeze(1), 1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        weights = weights.unsqueeze(-1)
    else:
        loss = F.mse_loss(ps_a.squeeze(-1).squeeze(-1), Tz, reduction='none')

    nvtx.range_push('agent:loss + step')
    self.optimizer.zero_grad()
    weighted_loss = (weights * loss).mean()
    weighted_loss.backward()
    torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), self.max_grad_norm)
    self.optimizer.step()
    nvtx.range_pop()

    return loss.detach()

  def update_target_net(self):
      self.target_net.load_state_dict(self.online_net.state_dict())

  # Save model parameters on current device (don't move model between devices)
  def save(self, path):
      torch.save(self.online_net.state_dict(), os.path.join(path, 'model.pth'))

  # Evaluates Q-value based on single state (no batch)
  def evaluate_q(self, state):
      with torch.no_grad():
          q = self.online_net(state.unsqueeze(0).to(self.device))
          if self.categorical:
              q *= self.support
          return q.sum(-1).max(-1)[0].item()

  def train(self):
      self.online_net.train()

  def eval(self):
      self.online_net.eval()

  def __str__(self):
      return self.online_net.__str__()
コード例 #11
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms  # size of value distribution.
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max,
                                      self.atoms).to(device=args.device)
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space).to(
            device=args.device)  # greedily selects the action.
        if args.model and os.path.isfile(args.model):
            self.online_net.load_state_dict(
                torch.load(args.model, map_location='cpu')
            )  # state_dict: python dictionary that maps each layer to its parameters.
        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(
            device=args.device)  # use to compute target q-values.
        self.update_target_net(
        )  # sets it to the parameters of the online network.
        self.target_net.train()
        for param in self.target_net.parameters(
        ):  # not updated through backpropagation.
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)

    def reset_noise(self):
        self.online_net.reset_noise()

    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).argmax(1).item()

    def act_e_greedy(self, state, epsilon=0.001):
        return np.random.randint(
            0, self.action_space
        ) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem):
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)
        log_ps = self.online_net(states, log=True)
        log_ps_a = log_ps[range(self.batch_size), actions]

        with torch.no_grad():
            pns = self.online_net(next_states)
            dns = self.support.expand_as(pns) * pns
            argmax_indices_ns = dns.sum(2).argmax(1)
            self.target_net.reset_noise()
            pns = self.target_net(next_states)
            pns_a = pns[range(self.batch_size), argmax_indices_ns]
            Tz = returns.unsqueeze(1) + nonterminals * (
                self.discount**self.n) * self.support.unsqueeze(0)
            Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)
            b = (Tz - self.Vmin) / self.delta_z
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                    self.batch_size).unsqueeze(1).expand(
                                        self.batch_size,
                                        self.atoms).to(actions)
            m.view(-1).index_add_(0, (l + offset).view(-1),
                                  (pns_a * (u.float() - b)).view(-1))
            m.view(-1).index_add_(0, (u + offset).view(-1),
                                  (pns_a * (b - l.float())).view(-1))

        loss = -torch.sum(m * log_ps_a, 1)
        self.online_net.zero_grad()
        (weights * loss).mean().backward()
        self.optimiser.step()
        mem.update_priorities(idxs,
                              loss.detach().cpu().numpy()
                              )  # update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model_all_layers.pth'))

    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) *
                    self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()
コード例 #12
0
ファイル: agent.py プロジェクト: cocobol/Rainbow-1
class Agent():
  def __init__(self, args, env):
    self.action_space = env.action_space()
    self.quantile = args.quantile
    self.atoms = args.quantiles if args.quantile else args.atoms
    if args.quantile:
      self.cumulative_density = (2 * torch.arange(self.atoms).to(device=args.device) + 1) / (2 * self.atoms)  # Quantile cumulative probability weights τ
    else:
      self.Vmin = args.V_min
      self.Vmax = args.V_max
      self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device)  # Support (range) of z
      self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
    self.batch_size = args.batch_size
    self.n = args.multi_step
    self.discount = args.discount
    self.norm_clip = args.norm_clip

    self.online_net = DQN(args, self.action_space, args.quantile).to(device=args.device)
    if args.model and os.path.isfile(args.model):
      # Always load tensors onto CPU by default, will shift to GPU if necessary
      self.online_net.load_state_dict(torch.load(args.model, map_location='cpu'))
    self.online_net.train()

    self.target_net = DQN(args, self.action_space, args.quantile).to(device=args.device)
    self.update_target_net()
    self.target_net.train()
    for param in self.target_net.parameters():
      param.requires_grad = False

    self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps)

  # Resets noisy weights in all linear layers (of online net only)
  def reset_noise(self):
    self.online_net.reset_noise()

  # Acts based on single state (no batch)
  def act(self, state):
    with torch.no_grad():
      return (self.online_net(state.unsqueeze(0)) * ((1 / self.atoms) if self.quantile else self.support)).sum(2).argmax(1).item()

  # Acts with an ε-greedy policy
  def act_e_greedy(self, state, epsilon=0.05):
    return random.randrange(self.action_space) if random.random() < epsilon else self.act(state)

  def learn(self, mem):
    # Sample transitions
    idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)

    # Calculate current state probabilities (online network noise already sampled)
    ps = self.online_net(states)  # Probabilities p(s_t, ·; θonline)/quantile probabilities θ(s_t, ·; θonline)
    ps_a = ps[range(self.batch_size), actions]  # p(s_t, a_t; θonline)

    with torch.no_grad():
      # Calculate nth next state probabilities
      pns = self.online_net(next_states)  # Probabilities p(s_t+n, ·; θonline)
      dns = ((1 / self.atoms) if self.quantile else self.support.expand_as(pns)) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
      argmax_indices_ns = dns.sum(2).argmax(1)  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
      self.target_net.reset_noise()  # Sample new target net noise
      pns = self.target_net(next_states)  # Probabilities p(s_t+n, ·; θtarget)
      pns_a = pns[range(self.batch_size), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

      if self.quantile:
        # Compute distributional Bellman target Tθ = R^n + (γ^n)p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)
        Ttheta = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * pns_a  # (accounting for terminal states)
      else:
        # Compute Tz (Bellman operator T applied to z)
        Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
        Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  # Clamp between supported values
        # Compute L2 projection of Tz onto fixed support z
        b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
        l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
        # Fix disappearing probability mass when l = b = u (b is int)
        l[(u > 0) * (l == u)] -= 1
        u[(l < (self.atoms - 1)) * (l == u)] += 1

        # Distribute probability of Tz
        m = states.new_zeros(self.batch_size, self.atoms)
        offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions)
        m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
        m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

    if self.quantile:
      u = Ttheta - ps_a  # Residual u
      kappa_cond = (u < 1).to(torch.float32)  # |u| ≤ κ
      huber_loss = 0.5 * u ** 2 * kappa_cond + (u.abs() - 0.5) * (1 - kappa_cond)  # Huber loss Lκ(u)
      loss = torch.sum(torch.abs(self.cumulative_density - (u < 0).to(torch.float32)) * huber_loss, 1)  # Quantile Huber loss ρκτ(u) = |τ − δ{u<0}|Lκ(u)
    else:
      loss = -torch.sum(m * ps_a.log(), 1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
    loss = weights * loss  # Importance weight losses
    self.online_net.zero_grad()
    loss.mean().backward()  # Backpropagate minibatch loss
    self.optimiser.step()
    nn.utils.clip_grad_norm_(self.online_net.parameters(), self.norm_clip)  # Clip gradients by L2 norm
    if self.quantile:
      loss = (self.atoms * loss).clamp(max=5)  # Heuristic for prioritised replay

    mem.update_priorities(idxs, loss.detach())  # Update priorities of sampled transitions

  def update_target_net(self):
    self.target_net.load_state_dict(self.online_net.state_dict())

  # Save model parameters on current device (don't move model between devices)
  def save(self, path):
    torch.save(self.online_net.state_dict(), os.path.join(path, 'model.pth'))

  # Evaluates Q-value based on single state (no batch)
  def evaluate_q(self, state):
    with torch.no_grad():
      return (self.online_net(state.unsqueeze(0)) * ((1 / self.atoms) if self.quantile else self.support)).sum(2).max(1)[0].item()

  def train(self):
    self.online_net.train()

  def eval(self):
    self.online_net.eval()
コード例 #13
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self, args, state_size, action_size):
        """Initialize an Agent object.
        
        Params
        ======
            args (class defined on the notebook): A set of parameters that will define the agent hyperparameters
            state_size (int): dimension of each state
            action_size (int): dimension of each action
        """
        self.state_size = state_size
        self.action_size = action_size
        self.params = args

        # Deep Q-Network
        if args.use_NoisyNet:
            self.DQN_local = DQN_NoisyNet(args, state_size,
                                          action_size).to(args.device)
            self.DQN_target = DQN_NoisyNet(args, state_size,
                                           action_size).to(args.device)
        else:
            self.DQN_local = DQN(args, state_size, action_size).to(args.device)
            self.DQN_target = DQN(args, state_size,
                                  action_size).to(args.device)

        self.optimizer = optim.Adam(self.DQN_local.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)

        # Replay memory
        self.memory = ReplayBuffer(args, action_size)
        # Initialize time step (for updating every args.target_update steps)
        self.t_step = 0

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.DQN_local.reset_noise()

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every args.target_update time steps.
        self.t_step = (self.t_step + 1) % self.params.target_update
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > self.params.batch_size:
                experiences = self.memory.sample()
                self.learn(experiences)

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy.
        
        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(
            self.params.device)
        self.DQN_local.eval()
        with torch.no_grad():
            action_values = self.DQN_local(state)
        self.DQN_local.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            args.discount (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.DQN_target(next_states).detach().max(
            1)[0].unsqueeze(1)
        # Compute Q targets for current states
        Q_targets = rewards + (self.params.discount * Q_targets_next *
                               (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.DQN_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.DQN_local, self.DQN_target, self.params.tau)

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            args.tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
コード例 #14
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount
        self.saved_model_path = args.saved_model_path
        self.experiment = args.experiment
        self.plots_path = args.plots_path
        self.data_save_path = args.data_save_path


        self.online_net = DQN(args, self.action_space).to(device=args.device)
        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.online_net.load_state_dict(torch.load(args.model, map_location='cpu'))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space).to(device=args.device)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps)

        # list of layers:
        self.online_net_layers = [self.online_net.conv1,
                                  self.online_net.conv2,
                                  self.online_net.conv3,
                                  self.online_net.fc_h_v,
                                  self.online_net.fc_h_a,
                                  self.online_net.fc_z_v,
                                  self.online_net.fc_z_a
                                  ]

        self.target_net_layers = [self.target_net.conv1,
                                  self.target_net.conv2,
                                  self.target_net.conv3,
                                  self.target_net.fc_h_v,
                                  self.target_net.fc_h_a,
                                  self.target_net.fc_z_v,
                                  self.target_net.fc_z_a
                                  ]

        # freeze all layers except the last, and reinitialize last
        if args.freeze_layers > 0:
            self.freeze_layers(args.freeze_layers)

        if args.reinitialize_layers > 0:
            self.reinit_layers(args.reinitialize_layers)


    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item()

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(self, state, epsilon=0.001):  # High ε can reduce evaluation scores drastically
        return np.random.randint(0, self.action_space) if np.random.random() < epsilon else self.act(state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)

        # Calculate current state probabilities (online network noise already sampled)
        log_ps = self.online_net(states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        log_ps_a = log_ps[range(self.batch_size), actions]  # log p(s_t, a_t; θonline)

        with torch.no_grad():
            # Calculate nth next state probabilities
            pns = self.online_net(next_states)  # Probabilities p(s_t+n, ·; θonline)
            dns = self.support.expand_as(pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
            argmax_indices_ns = dns.sum(2).argmax(1)  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            self.target_net.reset_noise()  # Sample new target net noise
            pns = self.target_net(next_states)  # Probabilities p(s_t+n, ·; θtarget)
            pns_a = pns[range(self.batch_size), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            # Compute Tz (Bellman operator T applied to z)
            Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(
                0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
            Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  # Clamp between supported values
            # Compute L2 projection of Tz onto fixed support z
            b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1

            # Distribute probability of Tz
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(
                self.batch_size, self.atoms).to(actions)
            m.view(-1).index_add_(0, (l + offset).view(-1),
                                  (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(0, (u + offset).view(-1),
                                  (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        loss = -torch.sum(m * log_ps_a, 1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        (weights * loss).mean().backward()  # Backpropagate importance-weighted minibatch loss
        self.optimiser.step()

        mem.update_priorities(idxs, loss.detach().cpu().numpy())  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path):
        torch.save(self.online_net.state_dict(), os.path.join(path, self.experiment + '_model.pth'))  # 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()

    def freeze_layers(self, num_frozen_layers):

        # reinitialize the proper layers (all that were not frozen
        self.reinit_layers(5 - num_frozen_layers)

        for i in range(num_frozen_layers):
            if i == 0:
                # freeze last layer (two in list)
                self.online_net_layers[0].weight.requires_grad = False
                self.online_net_layers[0].bias.requires_grad = False
            elif i == 1:
                self.online_net_layers[1].weight.requires_grad = False
                self.online_net_layers[1].bias.requires_grad = False
            elif i == 2:
                self.online_net_layers[2].weight.requires_grad = False
                self.online_net_layers[2].bias.requires_grad = False
            elif i == 3:
                self.online_net_layers[3].weight_mu.requires_grad = False
                self.online_net_layers[3].weight_sigma.requires_grad = False
                self.online_net_layers[3].bias_mu.requires_grad = False
                self.online_net_layers[3].bias_sigma.requires_grad = False
                # self.online_net_layers[3].weight.requires_grad = False
                self.online_net_layers[4].bias_mu.requires_grad = False
                self.online_net_layers[4].bias_sigma.requires_grad = False
                self.online_net_layers[4].weight_mu.requires_grad = False
                self.online_net_layers[4].weight_sigma.requires_grad = False
                # self.online_net_layers[4].bias.requires_grad = False
            # elif i == 4:
            #     self.online_net_layers[0].reset_parameters()
            #     self.target_net_layers[0].reset_parameters()

        # freeze the proper layers - complicated work around for dueling architecture
        # ct = 0
        # fourth_layer_first_time = True
        # for child in self.online_net.children():
        #     if ct < num_frozen_layers and ct < 3:
        #         for param in child.parameters():
        #             print('something1')
        #             param.required_grad = False
        #     if ct < num_frozen_layers and ct == 3:
        #         for param in child.parameters():
        #             print('something2')
        #             param.required_grad = False
        #         if fourth_layer_first_time:
        #             fourth_layer_first_time = False
        #             ct -= 1
        #     ct += 1
        #
        # ct = 0
        # fourth_layer_first_time = True
        # for child in self.target_net.children():
        #     if ct < num_frozen_layers and ct < 3:
        #         for param in child.parameters():
        #             print('something3')
        #             param.required_grad = False
        #     if ct < num_frozen_layers and ct == 3:
        #         for param in child.parameters():
        #             print('something4')
        #             param.required_grad = False
        #         if fourth_layer_first_time:
        #             fourth_layer_first_time = False
        #             ct -= 1
        #     ct += 1

        print(self.online_net)
        print(list(i.requires_grad for i in self.online_net.parameters()))
        print(self.target_net)
        print(list(i.requires_grad for i in self.target_net.parameters()))


    def reinit_layers(self, num_layers):
        for i in range(num_layers):
            if i == 0:
                # freeze last layer (two in list)
                self.online_net_layers[6].reset_parameters()
                self.online_net_layers[5].reset_parameters()
                self.target_net_layers[6].reset_parameters()
                self.target_net_layers[5].reset_parameters()
            elif i == 1:
                self.online_net_layers[4].reset_parameters()
                self.online_net_layers[3].reset_parameters()
                self.target_net_layers[4].reset_parameters()
                self.target_net_layers[3].reset_parameters()
            elif i == 2:
                self.online_net_layers[2].reset_parameters()
                self.target_net_layers[2].reset_parameters()
            elif i == 3:
                self.online_net_layers[1].reset_parameters()
                self.target_net_layers[1].reset_parameters()
            elif i == 4:
                self.online_net_layers[0].reset_parameters()
                self.target_net_layers[0].reset_parameters()