class Agent(): def __init__(self, args, env): self.action_space = env.action_space() self.atoms = args.atoms self.Vmin = args.V_min self.Vmax = args.V_max self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to( device=args.device) # Support (range) of z self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) self.batch_size = args.batch_size self.n = args.multi_step self.discount = args.discount self.online_net = DQN(args, self.action_space).to(device=args.device) if args.model and os.path.isfile(args.model): # Always load tensors onto CPU by default, will shift to GPU if necessary self.online_net.load_state_dict( torch.load(args.model, map_location='cpu')) self.online_net.train() self.target_net = DQN(args, self.action_space).to(device=args.device) self.update_target_net() self.target_net.train() for param in self.target_net.parameters(): param.requires_grad = False self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps) # Resets noisy weights in all linear layers (of online net only) def reset_noise(self): self.online_net.reset_noise() # Acts based on single state (no batch) def act(self, state): with torch.no_grad(): return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item() # Acts with an ε-greedy policy (used for evaluation only) def act_e_greedy( self, state, epsilon=0.001): # High ε can reduce evaluation scores drastically return random.randrange( self.action_space) if random.random() < epsilon else self.act( state) def learn(self, mem): # Sample transitions idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample( self.batch_size) # Calculate current state probabilities (online network noise already sampled) log_ps = self.online_net( states, log=True) # Log probabilities log p(s_t, ·; θonline) log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; θonline) with torch.no_grad(): # Calculate nth next state probabilities pns = self.online_net( next_states) # Probabilities p(s_t+n, ·; θonline) dns = self.support.expand_as( pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) argmax_indices_ns = dns.sum(2).argmax( 1 ) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] self.target_net.reset_noise() # Sample new target net noise pns = self.target_net( next_states) # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range( self.batch_size ), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * ( self.discount**self.n) * self.support.unsqueeze( 0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( self.batch_size, self.atoms).to(actions) m.view(-1).index_add_( 0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_( 0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum( m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) self.online_net.zero_grad() (weights * loss).mean().backward( ) # Backpropagate importance-weighted minibatch loss self.optimiser.step() mem.update_priorities( idxs, loss.detach()) # Update priorities of sampled transitions def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) # Save model parameters on current device (don't move model between devices) def save(self, path): torch.save(self.online_net.state_dict(), os.path.join(path, 'model.pth')) # Evaluates Q-value based on single state (no batch) def evaluate_q(self, state): with torch.no_grad(): return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item() def train(self): self.online_net.train() def eval(self): self.online_net.eval()
class Agent(): def __init__(self, args, env): self.args = args self.action_space = env.action_space() self.atoms = args.atoms self.Vmin = args.V_min self.Vmax = args.V_max self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to( device=args.device) # Support (range) of z self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) self.batch_size = args.batch_size self.n = args.multi_step self.discount = args.discount self.norm_clip = args.norm_clip self.coeff = 0.01 if args.game in [ 'pong', 'boxing', 'private_eye', 'freeway' ] else 1. self.online_net = DQN(args, self.action_space).to(device=args.device) self.momentum_net = DQN(args, self.action_space).to(device=args.device) # self.predictor = prediction_MLP(in_dim=128, hidden_dim=128, out_dim=128) if args.model: # Load pretrained model if provided if os.path.isfile(args.model): state_dict = torch.load( args.model, map_location='cpu' ) # Always load tensors onto CPU by default, will shift to GPU if necessary if 'conv1.weight' in state_dict.keys(): for old_key, new_key in (('conv1.weight', 'convs.0.weight'), ('conv1.bias', 'convs.0.bias'), ('conv2.weight', 'convs.2.weight'), ('conv2.bias', 'convs.2.bias'), ('conv3.weight', 'convs.4.weight'), ('conv3.bias', 'convs.4.bias')): state_dict[new_key] = state_dict[ old_key] # Re-map state dict for old pretrained models del state_dict[ old_key] # Delete old keys for strict load_state_dict self.online_net.load_state_dict(state_dict) print("Loading pretrained model: " + args.model) else: # Raise error if incorrect model path provided raise FileNotFoundError(args.model) self.online_net.train() # self.pred.train() self.initialize_momentum_net() self.momentum_net.train() self.target_net = DQN(args, self.action_space).to(device=args.device) self.update_target_net() self.target_net.train() for param in self.target_net.parameters(): param.requires_grad = False for param in self.momentum_net.parameters(): param.requires_grad = False self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps) # Resets noisy weights in all linear layers (of online net only) def reset_noise(self): self.online_net.reset_noise() # Acts based on single state (no batch) def act(self, state): with torch.no_grad(): a, _, _ = self.online_net(state.unsqueeze(0)) return (a * self.support).sum(2).argmax(1).item() # Acts with an ε-greedy policy (used for evaluation only) def act_e_greedy( self, state, epsilon=0.001): # High ε can reduce evaluation scores drastically return np.random.randint( 0, self.action_space ) if np.random.random() < epsilon else self.act(state) def learn(self, mem): # Sample transitions idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample( self.batch_size) # print('\n\n---------------') # print(f'idxs: {idxs}, ') # print(f'states: {states.shape}, ') # print(f'actions: {actions.shape}, ') # print(f'returns: {returns.shape}, ') # print(f'next_states: {next_states.shape}, ') # print(f'nonterminals: {nonterminals.shape}, ') # print(f'weights: {weights.shape},') aug_states_1 = aug(states).to(device=self.args.device) aug_states_2 = aug(states).to(device=self.args.device) # print(f'aug_states_1: {aug_states_1.shape}') # print(f'aug_states_2: {aug_states_2.shape}') # Calculate current state probabilities (online network noise already sampled) log_ps, _, _ = self.online_net( states, log=True) # Log probabilities log p(s_t, ·; θonline) _, z_1, p_1 = self.online_net(aug_states_1, log=True) _, z_2, p_2 = self.online_net(aug_states_2, log=True) # p_1, p_2 = self.pred(z_1), self.pred(z_2) # with torch.no_grad(): # p_2 = self.pred(z_2) simsiam_loss = 2 + D(p_1, z_2) / 2 + D(p_2, z_1) / 2 # simsiam_loss = p_1.mean() + p_2.mean() # simsiam_loss = p_1.mean() * 128 # simsiam_loss = - F.cosine_similarity(p_1, z_2.detach(), dim=-1).mean() # print(simsiam_loss) # simsiam_loss = 0 # _, z_target = self.momentum_net(aug_states_2, log=True) #z_k # z_proj = torch.matmul(self.online_net.W, z_target.T) # logits = torch.matmul(z_anch, z_proj) # logits = (logits - torch.max(logits, 1)[0][:, None]) # logits = logits * 0.1 # labels = torch.arange(logits.shape[0]).long().to(device=self.args.device) # moco_loss = (nn.CrossEntropyLoss()(logits, labels)).to(device=self.args.device) log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; θonline) # print(f'z_1: {z_1.shape}') # print(f'p_1: {p_1.shape}') # print('---------------\n\n') # 1/0 with torch.no_grad(): # Calculate nth next state probabilities pns, _, _ = self.online_net( next_states) # Probabilities p(s_t+n, ·; θonline) dns = self.support.expand_as( pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) argmax_indices_ns = dns.sum(2).argmax( 1 ) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] self.target_net.reset_noise() # Sample new target net noise pns, _, _ = self.target_net( next_states) # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range( self.batch_size ), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * ( self.discount**self.n) * self.support.unsqueeze( 0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( self.batch_size, self.atoms).to(actions) m.view(-1).index_add_( 0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_( 0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum( m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) # loss = loss + (moco_loss * self.coeff) loss = loss + (simsiam_loss * self.coeff) self.online_net.zero_grad() # self.pred.zero_grad() curl_loss = (weights * loss).mean() # print(curl_loss) curl_loss.mean().backward( ) # Backpropagate importance-weighted minibatch loss clip_grad_norm_(self.online_net.parameters(), self.norm_clip) # Clip gradients by L2 norm self.optimiser.step() mem.update_priorities(idxs, loss.detach().cpu().numpy() ) # Update priorities of sampled transitions def learn_old(self, mem): # Sample transitions idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample( self.batch_size) # print('\n\n---------------') # print(f'idxs: {idxs}, ') # print(f'states: {states.shape}, ') # print(f'actions: {actions.shape}, ') # print(f'returns: {returns.shape}, ') # print(f'next_states: {next_states.shape}, ') # print(f'nonterminals: {nonterminals.shape}, ') # print(f'weights: {weights.shape},') aug_states_1 = aug(states).to(device=self.args.device) aug_states_2 = aug(states).to(device=self.args.device) # print(f'aug_states_1: {aug_states_1.shape}') # print(f'aug_states_2: {aug_states_2.shape}') # Calculate current state probabilities (online network noise already sampled) log_ps, _, _ = self.online_net( states, log=True) # Log probabilities log p(s_t, ·; θonline) _, z_anch, _ = self.online_net(aug_states_1, log=True) #z_q _, z_target, _ = self.momentum_net(aug_states_2, log=True) #z_k z_proj = torch.matmul(self.online_net.W, z_target.T) logits = torch.matmul(z_anch, z_proj) logits = (logits - torch.max(logits, 1)[0][:, None]) logits = logits * 0.1 labels = torch.arange( logits.shape[0]).long().to(device=self.args.device) moco_loss = (nn.CrossEntropyLoss()(logits, labels)).to(device=self.args.device) log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; θonline) # print(f'z_anch: {z_anch.shape}') # print(f'z_target: {z_target.shape}') # print(f'z_proj: {z_proj.shape}') # print(f'logits: {logits.shape}') # print(logits) # print(f'labels: {labels.shape}') # print(labels) # print('---------------\n\n') # 1/0 with torch.no_grad(): # Calculate nth next state probabilities pns, _, _ = self.online_net( next_states) # Probabilities p(s_t+n, ·; θonline) dns = self.support.expand_as( pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) argmax_indices_ns = dns.sum(2).argmax( 1 ) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] self.target_net.reset_noise() # Sample new target net noise pns, _, _ = self.target_net( next_states) # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range( self.batch_size ), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * ( self.discount**self.n) * self.support.unsqueeze( 0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( self.batch_size, self.atoms).to(actions) m.view(-1).index_add_( 0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_( 0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum( m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) print(moco_loss) loss = loss + (moco_loss * self.coeff) self.online_net.zero_grad() curl_loss = (weights * loss).mean() curl_loss.mean().backward( ) # Backpropagate importance-weighted minibatch loss clip_grad_norm_(self.online_net.parameters(), self.norm_clip) # Clip gradients by L2 norm self.optimiser.step() mem.update_priorities(idxs, loss.detach().cpu().numpy() ) # Update priorities of sampled transitions def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) def initialize_momentum_net(self): for param_q, param_k in zip(self.online_net.parameters(), self.momentum_net.parameters()): param_k.data.copy_(param_q.data) # update param_k.requires_grad = False # not update by gradient # Code for this function from https://github.com/facebookresearch/moco @torch.no_grad() def update_momentum_net(self, momentum=0.999): for param_q, param_k in zip(self.online_net.parameters(), self.momentum_net.parameters()): param_k.data.copy_(momentum * param_k.data + (1. - momentum) * param_q.data) # update # Save model parameters on current device (don't move model between devices) def save(self, path, name='model.pth'): torch.save(self.online_net.state_dict(), os.path.join(path, name)) # Evaluates Q-value based on single state (no batch) def evaluate_q(self, state): with torch.no_grad(): a, _, _ = self.online_net(state.unsqueeze(0)) return (a * self.support).sum(2).max(1)[0].item() def train(self): self.online_net.train() def eval(self): self.online_net.eval()
class Agent(object): def __init__(self, args, action_space): self.action_space = action_space self.batch_size = args.batch_size self.discount = args.discount self.online_net = DQN(args, self.action_space).to(device=args.device) self.online_net.train() self.target_net = DQN(args, self.action_space).to(device=args.device) self.update_target_net() self.target_net.train() for param in self.target_net.parameters(): param.requires_grad = False self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps) self.loss_func = nn.MSELoss() # Acts based on single state (no batch) def act(self, state): with torch.no_grad(): return self.online_net([state]).argmax(1).item() # Acts with an ε-greedy policy (used for evaluation only) def act_e_greedy( self, state, epsilon=0.05): # High ε can reduce evaluation scores drastically return random.randrange( self.action_space) if random.random() < epsilon else self.act( state) def learn(self, mem): # Sample transitions states, actions, next_states, rewards = mem.sample(self.batch_size) q_eval = self.online_net(states).gather( 1, actions.unsqueeze(1)).squeeze() with torch.no_grad(): q_eval_next_a = self.online_net(next_states).argmax(1) q_next = self.target_net(next_states) q_target = rewards + self.discount * q_next.gather( 1, q_eval_next_a.unsqueeze(1)).squeeze() loss = self.loss_func(q_eval, q_target) self.online_net.zero_grad() loss.backward() self.optimiser.step() def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) # Save model parameters on current device (don't move model between devices) def save(self, path): torch.save(self.online_net.state_dict(), path + '.pth') # Evaluates Q-value based on single state (no batch) def evaluate_q(self, state): with torch.no_grad(): return (self.online_net([state])).max(1)[0].item() def train(self): self.online_net.train() def eval(self): self.online_net.eval()
class Agent(): def __init__(self, args, env): self.action_space = env.action_space() self.atoms = args.atoms self.Vmin = args.V_min self.Vmax = args.V_max self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to( device=args.device) # Support (range) of z self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) self.batch_size = args.batch_size self.n = args.multi_step self.discount = args.discount self.norm_clip = args.norm_clip self.online_net = DQN(args, self.action_space).to(device=args.device) if args.model: # Load pretrained model if provided if os.path.isfile(args.model): state_dict = torch.load( args.model, map_location='cpu' ) # Always load tensors onto CPU by default, will shift to GPU if necessary if 'conv1.weight' in state_dict.keys(): for old_key, new_key in (('conv1.weight', 'convs.0.weight'), ('conv1.bias', 'convs.0.bias'), ('conv2.weight', 'convs.2.weight'), ('conv2.bias', 'convs.2.bias'), ('conv3.weight', 'convs.4.weight'), ('conv3.bias', 'convs.4.bias')): state_dict[new_key] = state_dict[ old_key] # Re-map state dict for old pretrained models del state_dict[ old_key] # Delete old keys for strict load_state_dict self.online_net.load_state_dict(state_dict) print("Loading pretrained model: " + args.model) else: # Raise error if incorrect model path provided raise FileNotFoundError(args.model) self.online_net.train() self.target_net = DQN(args, self.action_space).to(device=args.device) self.update_target_net() self.target_net.train() for param in self.target_net.parameters(): param.requires_grad = False # self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps) self.convs_optimiser = optim.Adam(self.online_net.convs.parameters(), lr=args.learning_rate, eps=args.adam_eps) self.linear_optimiser = optim.Adam(chain( self.online_net.fc_h_v.parameters(), self.online_net.fc_h_a.parameters(), self.online_net.fc_z_v.parameters(), self.online_net.fc_z_a.parameters()), lr=args.learning_rate, eps=args.adam_eps) # Resets noisy weights in all linear layers (of online net only) def reset_noise(self): self.online_net.reset_noise() # Acts based on single state (no batch) def act(self, state): with torch.no_grad(): # don't count these calls since it is accounted for after "action = dqn.act(state)" in main.py ret = (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item() return ret # Acts with an ε-greedy policy (used for evaluation only) def act_e_greedy( self, state, epsilon=0.001): # High ε can reduce evaluation scores drastically return np.random.randint( 0, self.action_space ) if np.random.random() < epsilon else self.act(state) def learn(self, mem, freeze=False): # Sample transitions idxs, states, actions, returns, next_states, nonterminals, weights, _ = mem.sample( self.batch_size) # Calculate current state probabilities (online network noise already sampled) log_ps = self.online_net( states, log=True) # Log probabilities log p(s_t, ·; θonline) log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; θonline) with torch.no_grad(): # Calculate nth next state probabilities pns = self.online_net( next_states) # Probabilities p(s_t+n, ·; θonline) dns = self.support.expand_as( pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) argmax_indices_ns = dns.sum(2).argmax( 1 ) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] self.target_net.reset_noise() # Sample new target net noise pns = self.target_net( next_states) # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range( self.batch_size ), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * ( self.discount**self.n) * self.support.unsqueeze( 0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( self.batch_size, self.atoms).to(actions) m.view(-1).index_add_( 0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_( 0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum( m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) self.online_net.zero_grad() loss.mean().backward( ) # Backpropagate importance-weighted minibatch loss clip_grad_norm_(self.online_net.parameters(), self.norm_clip) # Clip gradients by L2 norm # self.optimiser.step() if not freeze: self.convs_optimiser.step() self.linear_optimiser.step() def learn_with_latent(self, latent_mem): # Sample transitions idxs, states, actions, returns, next_states, nonterminals, weights, ns = latent_mem.sample( self.batch_size) # Calculate current state probabilities (online network noise already sampled) log_ps = self.online_net.forward_with_latent( states, log=True) # Log probabilities log p(s_t, ·; θonline) log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; θonline) with torch.no_grad(): # Calculate nth next state probabilities pns = self.online_net.forward_with_latent( next_states) # Probabilities p(s_t+n, ·; θonline) dns = self.support.expand_as( pns) * pns # Distribution ds_t+n = (z, p(s_t+n, ·; θonline)) argmax_indices_ns = dns.sum(2).argmax( 1 ) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] self.target_net.reset_noise() # Sample new target net noise pns = self.target_net.forward_with_latent( next_states) # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range( self.batch_size ), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) # use ns instead of self.n since n is possibly different for each sequence in the batch ns = torch.tensor(ns, device=latent_mem.device).unsqueeze(1) # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * ( self.discount**ns) * self.support.unsqueeze( 0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( self.batch_size, self.atoms).to(actions) m.view(-1).index_add_( 0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_( 0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum( m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) self.online_net.zero_grad() loss.mean().backward( ) # Backpropagate importance-weighted minibatch loss clip_grad_norm_(self.online_net.parameters(), self.norm_clip) # Clip gradients by L2 norm # self.optimiser.step() self.linear_optimiser.step() def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) # Save model parameters on current device (don't move model between devices) def save(self, path, name='model.pth'): torch.save(self.online_net.state_dict(), os.path.join(path, name)) # Evaluates Q-value based on single state (no batch) def evaluate_q(self, state): with torch.no_grad(): return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item() def train(self): self.online_net.train() def eval(self): self.online_net.eval()
class Agent(): def __init__(self, args, env): self.action_space = env.action_space() self.batch_size = args.batch_size self.discount = args.discount self.max_gradient_norm = args.max_gradient_norm self.policy_net = DQN(args, self.action_space) if args.model and os.path.isfile(args.model): self.policy_net.load_state_dict(torch.load(args.model)) self.policy_net.train() self.target_net = DQN(args, self.action_space) self.update_target_net() self.target_net.eval() self.optimiser = optim.Adam(self.policy_net.parameters(), lr=args.lr) def act(self, state, epsilon): if random.random() > epsilon: return self.policy_net(state.unsqueeze(0)).max(1)[1].data[0] else: return random.randint(0, self.action_space - 1) def learn(self, mem): transitions = mem.sample(self.batch_size) batch = Transition(*zip(*transitions)) # Transpose the batch states = Variable(torch.stack(batch.state, 0)) actions = Variable(torch.LongTensor(batch.action).unsqueeze(1)) rewards = Variable(torch.Tensor(batch.reward)) non_final_mask = torch.ByteTensor( tuple(map( lambda s: s is not None, batch.next_state))) # Only process non-terminal next states next_states = Variable( torch.stack(tuple(s for s in batch.next_state if s is not None), 0), volatile=True ) # Prevent backpropagating through expected action values Qs = self.policy_net(states).gather(1, actions) # Q(s_t, a_t; θpolicy) next_state_argmax_indices = self.policy_net(next_states).max( 1, keepdim=True )[1] # Perform argmax action selection using policy network: argmax_a[Q(s_t+1, a; θpolicy)] Qns = Variable(torch.zeros( self.batch_size)) # Q(s_t+1, a) = 0 if s_t+1 is terminal Qns[non_final_mask] = self.target_net(next_states).gather( 1, next_state_argmax_indices ) # Q(s_t+1, argmax_a[Q(s_t+1, a; θpolicy)]; θtarget) Qns.volatile = False # Remove volatile flag to prevent propagating it through loss target = rewards + ( self.discount * Qns ) # Double-Q target: Y = r + γ.Q(s_t+1, argmax_a[Q(s_t+1, a; θpolicy)]; θtarget) loss = F.smooth_l1_loss( Qs, target) # Huber loss on TD-error δ: δ = Y - Q(s_t, a_t) # TODO: TD-error clipping? self.policy_net.zero_grad() loss.backward() nn.utils.clip_grad_norm(self.policy_net.parameters(), self.max_gradient_norm) # Clamp gradients self.optimiser.step() def update_target_net(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def save(self, path): torch.save(self.policy_net.state_dict(), os.path.join(path, 'model.pth')) def evaluate_q(self, state): return self.policy_net(state.unsqueeze(0)).max(1)[0].data[0] def train(self): self.policy_net.train() def eval(self): self.policy_net.eval()
class Agent(): def __init__(self, args, env): self.action_space = env.action_space() self.atoms = args.atoms self.Vmin = args.V_min self.Vmax = args.V_max self.support = torch.linspace(args.V_min, args.V_max, args.atoms) # Support (range) of z self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1) self.batch_size = args.batch_size self.n = args.multi_step self.discount = args.discount self.priority_exponent = args.priority_exponent self.max_gradient_norm = args.max_gradient_norm self.policy_net = DQN(args, self.action_space) if args.model and os.path.isfile(args.model): self.policy_net.load_state_dict(torch.load(args.model)) self.policy_net.train() self.target_net = DQN(args, self.action_space) self.update_target_net() self.target_net.eval() self.optimiser = optim.Adam(self.policy_net.parameters(), lr=args.lr, eps=args.adam_eps) if args.cuda: self.policy_net.cuda() self.target_net.cuda() self.support = self.support.cuda() # Resets noisy weights in all linear layers (of policy and target nets) def reset_noise(self): self.policy_net.reset_noise() self.target_net.reset_noise() # Acts based on single state (no batch) def act(self, state): return (self.policy_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[1][0] def learn(self, mem): idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size) batch_size = len(idxs) # May return less than specified if invalid transitions sampled # Calculate current state probabilities ps = self.policy_net(states) # Probabilities p(s_t, ·; θpolicy) ps_a = ps[range(batch_size), actions] # p(s_t, a_t; θpolicy) # Calculate nth next state probabilities pns = self.policy_net(next_states).data # Probabilities p(s_t+n, ·; θpolicy) dns = self.support.expand_as(pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θpolicy)) argmax_indices_ns = dns.sum(2).max(1)[1] # Perform argmax action selection using policy network: argmax_a[(z, p(s_t+n, a; θpolicy))] pns = self.target_net(next_states).data # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range(batch_size), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θpolicy))]; θtarget) pns_a *= nonterminals # Set p = 0 for terminal nth next states as all possible expected returns = expected reward at final transition # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().long(), b.ceil().long() # Distribute probability of Tz m = states.data.new(batch_size, self.atoms).zero_() offset = torch.linspace(0, ((batch_size - 1) * self.atoms), batch_size).long().unsqueeze(1).expand(batch_size, self.atoms).type_as(actions) m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum(Variable(m) * ps_a.log(), 1) # Cross-entropy loss (minimises Kullback-Leibler divergence) self.policy_net.zero_grad() (weights * loss).mean().backward() # Importance weight losses nn.utils.clip_grad_norm(self.policy_net.parameters(), self.max_gradient_norm) # Clip gradients (normalising by max value of gradient L2 norm) self.optimiser.step() mem.update_priorities(idxs, loss.data.abs().pow(self.priority_exponent)) # Update priorities of sampled transitions def update_target_net(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def save(self, path): torch.save(self.policy_net.state_dict(), os.path.join(path, 'model.pth')) # Evaluates Q-value based on single state (no batch) def evaluate_q(self, state): return (self.policy_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[0][0] def train(self): self.policy_net.train() def eval(self): self.policy_net.eval()
class Agent(): def __init__(self, args, env): self.action_space = env.action_space() self.atoms = args.atoms self.Vmin = args.V_min self.Vmax = args.V_max self.support = torch.linspace(args.V_min, args.V_max, args.atoms) # Support (range) of z self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1) self.batch_size = args.batch_size self.n = args.multi_step self.discount = args.discount self.online_net = DQN(args, self.action_space) if args.model and os.path.isfile(args.model): self.online_net.load_state_dict( torch.load(args.model, map_location='cpu')) self.online_net.train() self.target_net = DQN(args, self.action_space) self.update_target_net() self.target_net.train() for param in self.target_net.parameters(): param.requires_grad = False self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps) if args.cuda: self.online_net.cuda() self.target_net.cuda() self.support = self.support.cuda() # Resets noisy weights in all linear layers (of online net only) def reset_noise(self): self.online_net.reset_noise() # Acts based on single state (no batch) def act(self, state): return (self.online_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[1][0] # Acts with an ε-greedy policy def act_e_greedy(self, state, epsilon=0.001): return random.randrange( self.action_space) if random.random() < epsilon else self.act( state) def learn(self, mem): # Sample transitions idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample( self.batch_size) # Calculate current state probabilities self.online_net.reset_noise() # Sample new noise for online network ps = self.online_net(states) # Probabilities p(s_t, ·; θonline) ps_a = ps[range(self.batch_size), actions] # p(s_t, a_t; θonline) # Calculate nth next state probabilities self.online_net.reset_noise() # Sample new noise for action selection pns = self.online_net( next_states).data # Probabilities p(s_t+n, ·; θonline) dns = self.support.expand_as( pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) argmax_indices_ns = dns.sum(2).max( 1 )[1] # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] self.target_net.reset_noise() # Sample new target net noise pns = self.target_net( next_states).data # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range( self.batch_size ), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * ( self.discount**self.n) * self.support.unsqueeze( 0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().long(), b.ceil().long() # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.data.new(self.batch_size, self.atoms).zero_() offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( self.batch_size, self.atoms).type_as(actions) m.view(-1).index_add_( 0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_( 0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) ps_a = ps_a.clamp(min=1e-3) # Clamp for numerical stability in log loss = -torch.sum( Variable(m) * ps_a.log(), 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) self.online_net.zero_grad() (weights * loss).mean().backward() # Importance weight losses self.optimiser.step() mem.update_priorities( idxs, loss.data) # Update priorities of sampled transitions def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) def save(self, path): torch.save(self.online_net.state_dict(), os.path.join(path, 'model.pth')) # Evaluates Q-value based on single state (no batch) def evaluate_q(self, state): return (self.online_net(state.unsqueeze(0)).data * self.support).sum(2).max(1)[0][0] def train(self): self.online_net.train() def eval(self): self.online_net.eval()
class Agent(object): """ all improvments from Rainbow research work """ def __init__(self, args, state_size, action_size): """ Args: param1 (args): args param2 (int): args param3 (int): args """ self.action_size = action_size self.state_size = state_size self.atoms = args.atoms self.V_min = args.V_min self.V_max = args.V_max self.device = args.device self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to( device=self.device) # Support (range) of z self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) self.batch_size = args.batch_size self.n = args.multi_step self.discount = args.discount self.qnetwork_local = DQN(args, self.state_size, self.action_size).to(device=args.device) if args.model and os.path.isfile(args.model): # Always load tensors onto CPU by default, will shift to GPU if necessary self.qnetwork_local.load_state_dict( torch.load(args.model, map_location='cpu')) self.qnetwork_local.train() self.target_net = DQN(args, self.state_size, self.action_size).to(device=args.device) self.update_target_net() self.target_net.train() for param in self.target_net.parameters(): param.requires_grad = False self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=args.lr, eps=args.adam_eps) def reset_noise(self): """ resets noisy weights in all linear layers """ self.qnetwork_local.reset_noise() def act(self, state): """ acts greedy(max) based on a single state Args: param1 (int) : state """ with torch.no_grad(): return (self.qnetwork_local(state.unsqueeze(0).to(self.device)) * self.support).sum(2).argmax(1).item() def act_e_greedy(self, state, epsilon=0.001): """ acts with epsilon greedy policy epsilon exploration vs exploitation traide off Args: param1(int): state param2(float): epsilon Return : action int number between 0 and 4 """ return np.random.randint( 0, self.action_size) if np.random.random() < epsilon else self.act( state) def learn(self, mem): """ uses samples with the given batch size to improve the Q function Args: param1 (Experince Replay Buffer) : mem """ # Sample transitions idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample( self.batch_size) # Calculate current state probabilities (online network noise already sampled) log_ps = self.qnetwork_local( states, log=True) # Log probabilities log p(s_t, *; theta online) log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; theat online) with torch.no_grad(): # Calculate nth next state probabilities pns = self.qnetwork_local( next_states) # Probabilities p(s_t+n, *; theta online) dns = self.support.expand_as( pns ) * pns # Distribution d_t+n = (z, p(s_t+n, *; theat online)) argmax_indices_ns = dns.sum(2).argmax( 1 ) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; theat online))] self.target_net.reset_noise() # Sample new target net noise pns = self.target_net( next_states) # Probabilities p(s_t+n, ; theata target) pns_a = pns[range( self.batch_size ), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; theat online))]; theat target) # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * ( self.discount**self.n ) * self.support.unsqueeze( 0) # Tz = R^n + (discoit ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.V_min, max=self.V_max) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.V_min) / self.delta_z # b = (Tz - Vmin) / delta z l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( self.batch_size, self.atoms).to(actions) m.view(-1).index_add_( 0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_( 0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum( m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) self.qnetwork_local.zero_grad() (weights * loss).mean().backward( ) # Backpropagate importance-weighted minibatch loss self.optimizer.step() mem.update_priorities(idxs, loss.detach().cpu().numpy() ) # Update priorities of sampled transitions self.soft_update() def soft_update(self, tau=1e-3): """ swaps the network weights from the online to the target Args: param1 (float): tau """ for target_param, local_param in zip(self.target_net.parameters(), self.qnetwork_local.parameters()): target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) def update_target_net(self): """ copy the model weights from the online to the target network """ self.target_net.load_state_dict(self.qnetwork_local.state_dict()) def save(self, path): """ save the model weights to a file Args: param1 (string): pathname """ torch.save(self.qnetwork_local.state_dict(), os.path.join(path, 'model.pth')) def evaluate_q(self, state): """ Evaluates Q-value based on single state """ with torch.no_grad(): return (self.qnetwork_local(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item() def train(self): """ activates the backprob. layers for the online network """ self.qnetwork_local.train() def eval(self): """ invoke the eval from the online network deactivates the backprob layers like dropout will work in eval model instead """ self.qnetwork_local.eval()
class Agent: def __init__(self): self.mode = "train" with open("config.yaml") as reader: self.config = yaml.safe_load(reader) print(self.config) self.load_config() self.online_net = DQN(config=self.config, word_vocab=self.word_vocab, char_vocab=self.char_vocab, answer_type=self.answer_type) self.target_net = DQN(config=self.config, word_vocab=self.word_vocab, char_vocab=self.char_vocab, answer_type=self.answer_type) self.online_net.train() self.target_net.train() self.update_target_net() for param in self.target_net.parameters(): param.requires_grad = False if self.use_cuda: self.online_net.cuda() self.target_net.cuda() self.naozi = ObservationPool(capacity=self.naozi_capacity) # optimizer self.optimizer = torch.optim.Adam( self.online_net.parameters(), lr=self.config['training']['optimizer']['learning_rate']) self.clip_grad_norm = self.config['training']['optimizer'][ 'clip_grad_norm'] def load_config(self): # word vocab with open("vocabularies/word_vocab.txt") as f: self.word_vocab = f.read().split("\n") self.word2id = {} for i, w in enumerate(self.word_vocab): self.word2id[w] = i # char vocab with open("vocabularies/char_vocab.txt") as f: self.char_vocab = f.read().split("\n") self.char2id = {} for i, w in enumerate(self.char_vocab): self.char2id[w] = i self.EOS_id = self.word2id["</s>"] self.train_data_size = self.config['general']['train_data_size'] self.question_type = self.config['general']['question_type'] self.random_map = self.config['general']['random_map'] self.testset_path = self.config['general']['testset_path'] self.naozi_capacity = self.config['general']['naozi_capacity'] self.eval_folder = pjoin( self.testset_path, self.question_type, ("random_map" if self.random_map else "fixed_map")) self.eval_data_path = pjoin(self.testset_path, "data.json") self.batch_size = self.config['training']['batch_size'] self.max_nb_steps_per_episode = self.config['training'][ 'max_nb_steps_per_episode'] self.max_episode = self.config['training']['max_episode'] self.target_net_update_frequency = self.config['training'][ 'target_net_update_frequency'] self.learn_start_from_this_episode = self.config['training'][ 'learn_start_from_this_episode'] self.run_eval = self.config['evaluate']['run_eval'] self.eval_batch_size = self.config['evaluate']['batch_size'] self.eval_max_nb_steps_per_episode = self.config['evaluate'][ 'max_nb_steps_per_episode'] # Set the random seed manually for reproducibility. self.random_seed = self.config['general']['random_seed'] np.random.seed(self.random_seed) torch.manual_seed(self.random_seed) if torch.cuda.is_available(): if not self.config['general']['use_cuda']: print( "WARNING: CUDA device detected but 'use_cuda: false' found in config.yaml" ) self.use_cuda = False else: torch.backends.cudnn.deterministic = True torch.cuda.manual_seed(self.random_seed) self.use_cuda = True else: self.use_cuda = False if self.question_type == "location": self.answer_type = "pointing" elif self.question_type in ["attribute", "existence"]: self.answer_type = "2 way" else: raise NotImplementedError self.save_checkpoint = self.config['checkpoint']['save_checkpoint'] self.experiment_tag = self.config['checkpoint']['experiment_tag'] self.save_frequency = self.config['checkpoint']['save_frequency'] self.load_pretrained = self.config['checkpoint']['load_pretrained'] self.load_from_tag = self.config['checkpoint']['load_from_tag'] self.qa_loss_lambda = self.config['training']['qa_loss_lambda'] self.interaction_loss_lambda = self.config['training'][ 'interaction_loss_lambda'] # replay buffer and updates self.discount_gamma = self.config['replay']['discount_gamma'] self.replay_batch_size = self.config['replay']['replay_batch_size'] self.command_generation_replay_memory = command_generation_memory.PrioritizedReplayMemory( self.config['replay']['replay_memory_capacity'], priority_fraction=self.config['replay'] ['replay_memory_priority_fraction'], discount_gamma=self.discount_gamma) self.qa_replay_memory = qa_memory.PrioritizedReplayMemory( self.config['replay']['replay_memory_capacity'], priority_fraction=0.0) self.update_per_k_game_steps = self.config['replay'][ 'update_per_k_game_steps'] self.multi_step = self.config['replay']['multi_step'] # distributional RL self.use_distributional = self.config['distributional']['enable'] self.atoms = self.config['distributional']['atoms'] self.v_min = self.config['distributional']['v_min'] self.v_max = self.config['distributional']['v_max'] self.support = torch.linspace(self.v_min, self.v_max, self.atoms) # Support (range) of z if self.use_cuda: self.support = self.support.cuda() self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1) # dueling networks self.dueling_networks = self.config['dueling_networks'] # double dqn self.double_dqn = self.config['double_dqn'] # counting reward self.revisit_counting_lambda_anneal_episodes = self.config[ 'episodic_counting_bonus'][ 'revisit_counting_lambda_anneal_episodes'] self.revisit_counting_lambda_anneal_from = self.config[ 'episodic_counting_bonus']['revisit_counting_lambda_anneal_from'] self.revisit_counting_lambda_anneal_to = self.config[ 'episodic_counting_bonus']['revisit_counting_lambda_anneal_to'] self.revisit_counting_lambda = self.revisit_counting_lambda_anneal_from # valid command bonus self.valid_command_bonus_lambda = self.config[ 'valid_command_bonus_lambda'] # epsilon greedy self.epsilon_anneal_episodes = self.config['epsilon_greedy'][ 'epsilon_anneal_episodes'] self.epsilon_anneal_from = self.config['epsilon_greedy'][ 'epsilon_anneal_from'] self.epsilon_anneal_to = self.config['epsilon_greedy'][ 'epsilon_anneal_to'] self.epsilon = self.epsilon_anneal_from self.noisy_net = self.config['epsilon_greedy']['noisy_net'] if self.noisy_net: # disable epsilon greedy self.epsilon_anneal_episodes = -1 self.epsilon = 0.0 self.nlp = spacy.load('en', disable=['ner', 'parser', 'tagger']) self.single_word_verbs = set(["inventory", "look", "wait"]) self.two_word_verbs = set(["go"]) def train(self): """ Tell the agent that it's training phase. """ self.mode = "train" self.online_net.train() def eval(self): """ Tell the agent that it's evaluation phase. """ self.mode = "eval" self.online_net.eval() def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) def reset_noise(self): if self.noisy_net: # Resets noisy weights in all linear layers (of online net only) self.online_net.reset_noise() def zero_noise(self): if self.noisy_net: self.online_net.zero_noise() self.target_net.zero_noise() def load_pretrained_model(self, load_from): """ Load pretrained checkpoint from file. Arguments: load_from: File name of the pretrained model checkpoint. """ print("loading model from %s\n" % (load_from)) try: if self.use_cuda: state_dict = torch.load(load_from) else: state_dict = torch.load(load_from, map_location='cpu') self.online_net.load_state_dict(state_dict) except: print("Failed to load checkpoint...") def save_model_to_path(self, save_to): torch.save(self.online_net.state_dict(), save_to) print("Saved checkpoint to %s..." % (save_to)) def init(self, obs, infos): """ Prepare the agent for the upcoming games. Arguments: obs: Previous command's feedback for each game. infos: Additional information for each game. """ # reset agent, get vocabulary masks for verbs / adjectives / nouns batch_size = len(obs) self.reset_binarized_counter(batch_size) self.not_finished_yet = np.ones((batch_size, ), dtype="float32") self.prev_actions = [["" for _ in range(batch_size)]] self.prev_step_is_still_interacting = np.ones( (batch_size, ), dtype="float32" ) # 1s and starts to be 0 when previous action is "wait" self.naozi.reset(batch_size=batch_size) def get_agent_inputs(self, string_list): sentence_token_list = [item.split() for item in string_list] sentence_id_list = [ _words_to_ids(tokens, self.word2id) for tokens in sentence_token_list ] input_sentence_char = list_of_token_list_to_char_input( sentence_token_list, self.char2id) input_sentence = pad_sequences( sentence_id_list, maxlen=max_len(sentence_id_list)).astype('int32') input_sentence = to_pt(input_sentence, self.use_cuda) input_sentence_char = to_pt(input_sentence_char, self.use_cuda) return input_sentence, input_sentence_char, sentence_id_list def get_game_info_at_certain_step(self, obs, infos): """ Get all needed info from game engine for training. Arguments: obs: Previous command's feedback for each game. infos: Additional information for each game. """ batch_size = len(obs) feedback_strings = [preproc(item, tokenizer=self.nlp) for item in obs] description_strings = [ preproc(item, tokenizer=self.nlp) for item in infos["description"] ] observation_strings = [ d + " <|> " + fb if fb != d else d + " <|> hello" for fb, d in zip(feedback_strings, description_strings) ] inventory_strings = [ preproc(item, tokenizer=self.nlp) for item in infos["inventory"] ] local_word_list = [ obs.split() + inv.split() for obs, inv in zip(observation_strings, inventory_strings) ] directions = ["east", "west", "north", "south"] if self.question_type in ["location", "existence"]: # agents observes the env, but do not change them possible_verbs = [["go", "inventory", "wait", "open", "examine"] for _ in range(batch_size)] else: possible_verbs = [ list(set(item) - set(["", "look"])) for item in infos["verbs"] ] possible_adjs, possible_nouns = [], [] for i in range(batch_size): object_nouns = [ item.split()[-1] for item in infos["object_nouns"][i] ] object_adjs = [ w for item in infos["object_adjs"][i] for w in item.split() ] possible_nouns.append( list(set(object_nouns) & set(local_word_list[i]) - set([""])) + directions) possible_adjs.append( list(set(object_adjs) & set(local_word_list[i]) - set([""])) + ["</s>"]) return observation_strings, [ possible_verbs, possible_adjs, possible_nouns ] def get_state_strings(self, infos): description_strings = infos["description"] inventory_strings = infos["inventory"] observation_strings = [ _d + _i for (_d, _i) in zip(description_strings, inventory_strings) ] return observation_strings def get_local_word_masks(self, possible_words): possible_verbs, possible_adjs, possible_nouns = possible_words batch_size = len(possible_verbs) verb_mask = np.zeros((batch_size, len(self.word_vocab)), dtype="float32") noun_mask = np.zeros((batch_size, len(self.word_vocab)), dtype="float32") adj_mask = np.zeros((batch_size, len(self.word_vocab)), dtype="float32") for i in range(batch_size): for w in possible_verbs[i]: if w in self.word2id: verb_mask[i][self.word2id[w]] = 1.0 for w in possible_adjs[i]: if w in self.word2id: adj_mask[i][self.word2id[w]] = 1.0 for w in possible_nouns[i]: if w in self.word2id: noun_mask[i][self.word2id[w]] = 1.0 adj_mask[:, self.EOS_id] = 1.0 return [verb_mask, adj_mask, noun_mask] def get_match_representations(self, input_observation, input_observation_char, input_quest, input_quest_char, use_model="online"): model = self.online_net if use_model == "online" else self.target_net description_representation_sequence, description_mask = model.representation_generator( input_observation, input_observation_char) quest_representation_sequence, quest_mask = model.representation_generator( input_quest, input_quest_char) match_representation_sequence = model.get_match_representations( description_representation_sequence, description_mask, quest_representation_sequence, quest_mask) match_representation_sequence = match_representation_sequence * description_mask.unsqueeze( -1) return match_representation_sequence def get_ranks(self, input_observation, input_observation_char, input_quest, input_quest_char, word_masks, use_model="online"): """ Given input observation and question tensors, to get Q values of words. """ model = self.online_net if use_model == "online" else self.target_net match_representation_sequence = self.get_match_representations( input_observation, input_observation_char, input_quest, input_quest_char, use_model=use_model) action_ranks = model.action_scorer(match_representation_sequence, word_masks) # list of 3 tensors return action_ranks def choose_maxQ_command(self, action_ranks, word_mask=None): """ Generate a command by maximum q values, for epsilon greedy. """ if self.use_distributional: action_ranks = [ (item * self.support).sum(2) for item in action_ranks ] # list of batch x n_vocab action_indices = [] for i in range(len(action_ranks)): ar = action_ranks[i] ar = ar - torch.min( ar, -1, keepdim=True )[0] + 1e-2 # minus the min value, so that all values are non-negative if word_mask is not None: assert word_mask[i].size() == ar.size(), ( word_mask[i].size().shape, ar.size()) ar = ar * word_mask[i] action_indices.append(torch.argmax(ar, -1)) # batch return action_indices def choose_random_command(self, batch_size, action_space_size, possible_words=None): """ Generate a command randomly, for epsilon greedy. """ action_indices = [] for i in range(3): if possible_words is None: indices = np.random.choice(action_space_size, batch_size) else: indices = [] for j in range(batch_size): mask_ids = [] for w in possible_words[i][j]: if w in self.word2id: mask_ids.append(self.word2id[w]) indices.append(np.random.choice(mask_ids)) indices = np.array(indices) action_indices.append(to_pt(indices, self.use_cuda)) # batch return action_indices def get_chosen_strings(self, chosen_indices): """ Turns list of word indices into actual command strings. chosen_indices: Word indices chosen by model. """ chosen_indices_np = [to_np(item) for item in chosen_indices] res_str = [] batch_size = chosen_indices_np[0].shape[0] for i in range(batch_size): verb, adj, noun = chosen_indices_np[0][i], chosen_indices_np[1][ i], chosen_indices_np[2][i] res_str.append(self.word_ids_to_commands(verb, adj, noun)) return res_str def word_ids_to_commands(self, verb, adj, noun): """ Turn the 3 indices into actual command strings. Arguments: verb: Index of the guessing verb in vocabulary adj: Index of the guessing adjective in vocabulary noun: Index of the guessing noun in vocabulary """ # turns 3 indices into actual command strings if self.word_vocab[verb] in self.single_word_verbs: return self.word_vocab[verb] if self.word_vocab[verb] in self.two_word_verbs: return " ".join([self.word_vocab[verb], self.word_vocab[noun]]) if adj == self.EOS_id: return " ".join([self.word_vocab[verb], self.word_vocab[noun]]) else: return " ".join([ self.word_vocab[verb], self.word_vocab[adj], self.word_vocab[noun] ]) def act_random(self, obs, infos, input_observation, input_observation_char, input_quest, input_quest_char, possible_words): with torch.no_grad(): batch_size = len(obs) word_indices_random = self.choose_random_command( batch_size, len(self.word_vocab), possible_words) chosen_indices = word_indices_random chosen_strings = self.get_chosen_strings(chosen_indices) for i in range(batch_size): if chosen_strings[i] == "wait": self.not_finished_yet[i] = 0.0 # info for replay memory for i in range(batch_size): if self.prev_actions[-1][i] == "wait": self.prev_step_is_still_interacting[i] = 0.0 # previous step is still interacting, this is because DQN requires one step extra computation replay_info = [ chosen_indices, to_pt(self.prev_step_is_still_interacting, self.use_cuda, "float") ] # cache new info in current game step into caches self.prev_actions.append(chosen_strings) return chosen_strings, replay_info def act_greedy(self, obs, infos, input_observation, input_observation_char, input_quest, input_quest_char, possible_words): """ Acts upon the current list of observations. One text command must be returned for each observation. """ with torch.no_grad(): batch_size = len(obs) local_word_masks_np = self.get_local_word_masks(possible_words) local_word_masks = [ to_pt(item, self.use_cuda, type="float") for item in local_word_masks_np ] # generate commands for one game step, epsilon greedy is applied, i.e., # there is epsilon of chance to generate random commands action_ranks = self.get_ranks( input_observation, input_observation_char, input_quest, input_quest_char, local_word_masks, use_model="online") # list of batch x vocab word_indices_maxq = self.choose_maxQ_command( action_ranks, local_word_masks) chosen_indices = word_indices_maxq chosen_strings = self.get_chosen_strings(chosen_indices) for i in range(batch_size): if chosen_strings[i] == "wait": self.not_finished_yet[i] = 0.0 # info for replay memory for i in range(batch_size): if self.prev_actions[-1][i] == "wait": self.prev_step_is_still_interacting[i] = 0.0 # previous step is still interacting, this is because DQN requires one step extra computation replay_info = [ chosen_indices, to_pt(self.prev_step_is_still_interacting, self.use_cuda, "float") ] # cache new info in current game step into caches self.prev_actions.append(chosen_strings) return chosen_strings, replay_info def act(self, obs, infos, input_observation, input_observation_char, input_quest, input_quest_char, possible_words, random=False): """ Acts upon the current list of observations. One text command must be returned for each observation. """ with torch.no_grad(): if self.mode == "eval": return self.act_greedy(obs, infos, input_observation, input_observation_char, input_quest, input_quest_char, possible_words) if random: return self.act_random(obs, infos, input_observation, input_observation_char, input_quest, input_quest_char, possible_words) batch_size = len(obs) local_word_masks_np = self.get_local_word_masks(possible_words) local_word_masks = [ to_pt(item, self.use_cuda, type="float") for item in local_word_masks_np ] # generate commands for one game step, epsilon greedy is applied, i.e., # there is epsilon of chance to generate random commands action_ranks = self.get_ranks( input_observation, input_observation_char, input_quest, input_quest_char, local_word_masks, use_model="online") # list of batch x vocab word_indices_maxq = self.choose_maxQ_command( action_ranks, local_word_masks) word_indices_random = self.choose_random_command( batch_size, len(self.word_vocab), possible_words) # random number for epsilon greedy rand_num = np.random.uniform(low=0.0, high=1.0, size=(batch_size, )) less_than_epsilon = (rand_num < self.epsilon).astype( "float32") # batch greater_than_epsilon = 1.0 - less_than_epsilon less_than_epsilon = to_pt(less_than_epsilon, self.use_cuda, type='long') greater_than_epsilon = to_pt(greater_than_epsilon, self.use_cuda, type='long') chosen_indices = [ less_than_epsilon * idx_random + greater_than_epsilon * idx_maxq for idx_random, idx_maxq in zip(word_indices_random, word_indices_maxq) ] chosen_strings = self.get_chosen_strings(chosen_indices) for i in range(batch_size): if chosen_strings[i] == "wait": self.not_finished_yet[i] = 0.0 # info for replay memory for i in range(batch_size): if self.prev_actions[-1][i] == "wait": self.prev_step_is_still_interacting[i] = 0.0 # previous step is still interacting, this is because DQN requires one step extra computation replay_info = [ chosen_indices, to_pt(self.prev_step_is_still_interacting, self.use_cuda, "float") ] # cache new info in current game step into caches self.prev_actions.append(chosen_strings) return chosen_strings, replay_info def get_dqn_loss(self): """ Update neural model in agent. In this example we follow algorithm of updating model in dqn with replay memory. """ if len(self.command_generation_replay_memory) < self.replay_batch_size: return None data = self.command_generation_replay_memory.get_batch( self.replay_batch_size, self.multi_step) if data is None: return None obs_list, quest_list, possible_words_list, chosen_indices, rewards, next_obs_list, next_possible_words_list, actual_n_list = data batch_size = len(actual_n_list) input_quest, input_quest_char, _ = self.get_agent_inputs(quest_list) input_observation, input_observation_char, _ = self.get_agent_inputs( obs_list) next_input_observation, next_input_observation_char, _ = self.get_agent_inputs( next_obs_list) possible_words, next_possible_words = [], [] for i in range(3): possible_words.append([item[i] for item in possible_words_list]) next_possible_words.append( [item[i] for item in next_possible_words_list]) local_word_masks = [ to_pt(item, self.use_cuda, type="float") for item in self.get_local_word_masks(possible_words) ] next_local_word_masks = [ to_pt(item, self.use_cuda, type="float") for item in self.get_local_word_masks(next_possible_words) ] action_ranks = self.get_ranks( input_observation, input_observation_char, input_quest, input_quest_char, local_word_masks, use_model="online" ) # list of batch x vocab or list of batch x vocab x atoms # ps_a word_qvalues = [ ez_gather_dim_1(w_rank, idx.unsqueeze(-1)).squeeze(1) for w_rank, idx in zip(action_ranks, chosen_indices) ] # list of batch or list of batch x atoms q_value = torch.mean(torch.stack(word_qvalues, -1), -1) # batch or batch x atoms # log_ps_a log_q_value = torch.log(q_value) # batch or batch x atoms with torch.no_grad(): if self.noisy_net: self.target_net.reset_noise() # Sample new target net noise if self.double_dqn: # pns Probabilities p(s_t+n, ·; θonline) next_action_ranks = self.get_ranks(next_input_observation, next_input_observation_char, input_quest, input_quest_char, next_local_word_masks, use_model="online") # list of batch x vocab or list of batch x vocab x atoms # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] next_word_indices = self.choose_maxQ_command( next_action_ranks, next_local_word_masks) # list of batch x 1 # pns # Probabilities p(s_t+n, ·; θtarget) next_action_ranks = self.get_ranks( next_input_observation, next_input_observation_char, input_quest, input_quest_char, next_local_word_masks, use_model="target" ) # batch x vocab or list of batch x vocab x atoms # pns_a # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) next_word_qvalues = [ ez_gather_dim_1(w_rank, idx.unsqueeze(-1)).squeeze(1) for w_rank, idx in zip(next_action_ranks, next_word_indices) ] # list of batch or list of batch x atoms else: # pns Probabilities p(s_t+n, ·; θonline) next_action_ranks = self.get_ranks(next_input_observation, next_input_observation_char, input_quest, input_quest_char, next_local_word_masks, use_model="target") # list of batch x vocab or list of batch x vocab x atoms next_word_indices = self.choose_maxQ_command( next_action_ranks, next_local_word_masks) # list of batch x 1 next_word_qvalues = [ ez_gather_dim_1(w_rank, idx.unsqueeze(-1)).squeeze(1) for w_rank, idx in zip(next_action_ranks, next_word_indices) ] # list of batch or list of batch x atoms next_q_value = torch.mean(torch.stack(next_word_qvalues, -1), -1) # batch or batch x atoms # Compute Tz (Bellman operator T applied to z) discount = to_pt((np.ones_like(actual_n_list) * self.discount_gamma)**actual_n_list, self.use_cuda, type="float") if not self.use_distributional: rewards = rewards + next_q_value * discount # batch loss = F.smooth_l1_loss(q_value, rewards) return loss with torch.no_grad(): Tz = rewards.unsqueeze( -1) + discount.unsqueeze(-1) * self.support.unsqueeze( 0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.v_min, max=self.v_max) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.v_min) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = torch.zeros(batch_size, self.atoms).float() if self.use_cuda: m = m.cuda() offset = torch.linspace(0, ((batch_size - 1) * self.atoms), batch_size).unsqueeze(1).expand( batch_size, self.atoms).long() if self.use_cuda: offset = offset.cuda() m.view(-1).index_add_( 0, (l + offset).view(-1), (next_q_value * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_( 0, (u + offset).view(-1), (next_q_value * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum( m * log_q_value, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) loss = torch.mean(loss) return loss def update_interaction(self): # update neural model by replaying snapshots in replay memory interaction_loss = self.get_dqn_loss() if interaction_loss is None: return None loss = interaction_loss * self.interaction_loss_lambda # Backpropagate self.online_net.zero_grad() self.optimizer.zero_grad() loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), self.clip_grad_norm) self.optimizer.step() # apply gradients return to_np(torch.mean(interaction_loss)) def answer_question(self, input_observation, input_observation_char, observation_id_list, input_quest, input_quest_char, use_model="online"): # first pad answerer_input, and get the mask model = self.online_net if use_model == "online" else self.target_net batch_size = len(observation_id_list) max_length = input_observation.size(1) mask = compute_mask(input_observation) # batch x obs_len # noun mask for location question if self.question_type in ["location"]: location_mask = [] for i in range(batch_size): m = [1 for item in observation_id_list[i]] location_mask.append(m) location_mask = pad_sequences(location_mask, maxlen=max_length, dtype="float32") location_mask = to_pt(location_mask, enable_cuda=self.use_cuda, type='float') assert mask.size() == location_mask.size() mask = mask * location_mask match_representation_sequence = self.get_match_representations( input_observation, input_observation_char, input_quest, input_quest_char, use_model=use_model) pred = model.answer_question(match_representation_sequence, mask) # batch x vocab or batch x 2 # attention sum: # sometimes certain word appears multiple times in the observation, # thus we need to merge them together before doing further computations # ------- but # if answer type is not pointing, we just use a pre-defined mapping # that maps 0/1 to their positions in vocab if self.answer_type == "2 way": observation_id_list = [] max_length = 2 for i in range(batch_size): observation_id_list.append( [self.word2id["0"], self.word2id["1"]]) observation = to_pt( pad_sequences(observation_id_list, maxlen=max_length).astype('int32'), self.use_cuda) vocab_distribution = np.zeros( (batch_size, len(self.word_vocab))) # batch x vocab vocab_distribution = to_pt(vocab_distribution, self.use_cuda, type='float') vocab_distribution = vocab_distribution.scatter_add_( 1, observation, pred) # batch x vocab non_zero_words = [] for i in range(batch_size): non_zero_words.append(list(set(observation_id_list[i]))) vocab_mask = torch.ne(vocab_distribution, 0).float() return vocab_distribution, non_zero_words, vocab_mask def point_maxq_position(self, vocab_distribution, mask): """ Generate a command by maximum q values, for epsilon greedy. Arguments: point_distribution: Q values for each position (mapped to vocab). mask: vocab masks. """ vocab_distribution = vocab_distribution - torch.min( vocab_distribution, -1, keepdim=True )[0] + 1e-2 # minus the min value, so that all values are non-negative vocab_distribution = vocab_distribution * mask # batch x vocab indices = torch.argmax(vocab_distribution, -1) # batch return indices def answer_question_act_greedy(self, input_observation, input_observation_char, observation_id_list, input_quest, input_quest_char): with torch.no_grad(): vocab_distribution, _, vocab_mask = self.answer_question( input_observation, input_observation_char, observation_id_list, input_quest, input_quest_char, use_model="online") # batch x time positions_maxq = self.point_maxq_position(vocab_distribution, vocab_mask) return positions_maxq # batch def get_qa_loss(self): """ Update neural model in agent. In this example we follow algorithm of updating model in dqn with replay memory. """ if len(self.qa_replay_memory) < self.replay_batch_size: return None transitions = self.qa_replay_memory.sample(self.replay_batch_size) batch = qa_memory.qa_Transition(*zip(*transitions)) observation_list = batch.observation_list quest_list = batch.quest_list answer_strings = batch.answer_strings answer_position = np.array(_words_to_ids(answer_strings, self.word2id)) groundtruth = to_pt(answer_position, self.use_cuda) # batch input_quest, input_quest_char, _ = self.get_agent_inputs(quest_list) input_observation, input_observation_char, observation_id_list = self.get_agent_inputs( observation_list) answer_distribution, _, _ = self.answer_question( input_observation, input_observation_char, observation_id_list, input_quest, input_quest_char, use_model="online") # batch x vocab batch_loss = NegativeLogLoss(answer_distribution, groundtruth) # batch return torch.mean(batch_loss) def update_qa(self): # update neural model by replaying snapshots in replay memory qa_loss = self.get_qa_loss() if qa_loss is None: return None loss = qa_loss * self.qa_loss_lambda # Backpropagate self.online_net.zero_grad() self.optimizer.zero_grad() loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), self.clip_grad_norm) self.optimizer.step() # apply gradients return to_np(torch.mean(qa_loss)) def finish_of_episode(self, episode_no, batch_size): # Update target networt if ( episode_no + batch_size ) % self.target_net_update_frequency <= episode_no % self.target_net_update_frequency: self.update_target_net() # decay lambdas if episode_no < self.learn_start_from_this_episode: return if episode_no < self.epsilon_anneal_episodes + self.learn_start_from_this_episode: self.epsilon -= (self.epsilon_anneal_from - self.epsilon_anneal_to ) / float(self.epsilon_anneal_episodes) self.epsilon = max(self.epsilon, 0.0) if episode_no < self.revisit_counting_lambda_anneal_episodes + self.learn_start_from_this_episode: self.revisit_counting_lambda -= ( self.revisit_counting_lambda_anneal_from - self.revisit_counting_lambda_anneal_to) / float( self.revisit_counting_lambda_anneal_episodes) self.revisit_counting_lambda = max(self.epsilon, 0.0) def reset_binarized_counter(self, batch_size): self.binarized_counter_dict = [{} for _ in range(batch_size)] def get_binarized_count(self, observation_strings, update=True): count_rewards = [] batch_size = len(observation_strings) for i in range(batch_size): key = observation_strings[i] if key not in self.binarized_counter_dict[i]: self.binarized_counter_dict[i][key] = 0.0 if update: self.binarized_counter_dict[i][key] += 1.0 r = self.binarized_counter_dict[i][key] r = float(r == 1.0) count_rewards.append(r) return count_rewards
class Agent(): def __init__(self, args, env): self.action_space = env.action_space() self.atoms = args.atoms # size of value distribution. self.Vmin = args.V_min self.Vmax = args.V_max self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device) self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) self.batch_size = args.batch_size self.n = args.multi_step self.discount = args.discount self.online_net = DQN(args, self.action_space).to( device=args.device) # greedily selects the action. if args.model and os.path.isfile(args.model): self.online_net.load_state_dict( torch.load(args.model, map_location='cpu') ) # state_dict: python dictionary that maps each layer to its parameters. self.online_net.train() self.target_net = DQN(args, self.action_space).to( device=args.device) # use to compute target q-values. self.update_target_net( ) # sets it to the parameters of the online network. self.target_net.train() for param in self.target_net.parameters( ): # not updated through backpropagation. param.requires_grad = False self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps) def reset_noise(self): self.online_net.reset_noise() def act(self, state): with torch.no_grad(): return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item() def act_e_greedy(self, state, epsilon=0.001): return np.random.randint( 0, self.action_space ) if np.random.random() < epsilon else self.act(state) def learn(self, mem): idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample( self.batch_size) log_ps = self.online_net(states, log=True) log_ps_a = log_ps[range(self.batch_size), actions] with torch.no_grad(): pns = self.online_net(next_states) dns = self.support.expand_as(pns) * pns argmax_indices_ns = dns.sum(2).argmax(1) self.target_net.reset_noise() pns = self.target_net(next_states) pns_a = pns[range(self.batch_size), argmax_indices_ns] Tz = returns.unsqueeze(1) + nonterminals * ( self.discount**self.n) * self.support.unsqueeze(0) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) b = (Tz - self.Vmin) / self.delta_z l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( self.batch_size, self.atoms).to(actions) m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) loss = -torch.sum(m * log_ps_a, 1) self.online_net.zero_grad() (weights * loss).mean().backward() self.optimiser.step() mem.update_priorities(idxs, loss.detach().cpu().numpy() ) # update priorities of sampled transitions def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) def save(self, path): torch.save(self.online_net.state_dict(), os.path.join(path, 'model_all_layers.pth')) def evaluate_q(self, state): with torch.no_grad(): return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item() def train(self): self.online_net.train() def eval(self): self.online_net.eval()
class Agent(): def __init__(self, args, env): self.action_space = env.action_space() self.quantile = args.quantile self.atoms = args.quantiles if args.quantile else args.atoms if args.quantile: self.cumulative_density = (2 * torch.arange(self.atoms).to(device=args.device) + 1) / (2 * self.atoms) # Quantile cumulative probability weights τ else: self.Vmin = args.V_min self.Vmax = args.V_max self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device) # Support (range) of z self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) self.batch_size = args.batch_size self.n = args.multi_step self.discount = args.discount self.norm_clip = args.norm_clip self.online_net = DQN(args, self.action_space, args.quantile).to(device=args.device) if args.model and os.path.isfile(args.model): # Always load tensors onto CPU by default, will shift to GPU if necessary self.online_net.load_state_dict(torch.load(args.model, map_location='cpu')) self.online_net.train() self.target_net = DQN(args, self.action_space, args.quantile).to(device=args.device) self.update_target_net() self.target_net.train() for param in self.target_net.parameters(): param.requires_grad = False self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps) # Resets noisy weights in all linear layers (of online net only) def reset_noise(self): self.online_net.reset_noise() # Acts based on single state (no batch) def act(self, state): with torch.no_grad(): return (self.online_net(state.unsqueeze(0)) * ((1 / self.atoms) if self.quantile else self.support)).sum(2).argmax(1).item() # Acts with an ε-greedy policy def act_e_greedy(self, state, epsilon=0.05): return random.randrange(self.action_space) if random.random() < epsilon else self.act(state) def learn(self, mem): # Sample transitions idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size) # Calculate current state probabilities (online network noise already sampled) ps = self.online_net(states) # Probabilities p(s_t, ·; θonline)/quantile probabilities θ(s_t, ·; θonline) ps_a = ps[range(self.batch_size), actions] # p(s_t, a_t; θonline) with torch.no_grad(): # Calculate nth next state probabilities pns = self.online_net(next_states) # Probabilities p(s_t+n, ·; θonline) dns = ((1 / self.atoms) if self.quantile else self.support.expand_as(pns)) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) argmax_indices_ns = dns.sum(2).argmax(1) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] self.target_net.reset_noise() # Sample new target net noise pns = self.target_net(next_states) # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range(self.batch_size), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) if self.quantile: # Compute distributional Bellman target Tθ = R^n + (γ^n)p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) Ttheta = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * pns_a # (accounting for terminal states) else: # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions) m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) if self.quantile: u = Ttheta - ps_a # Residual u kappa_cond = (u < 1).to(torch.float32) # |u| ≤ κ huber_loss = 0.5 * u ** 2 * kappa_cond + (u.abs() - 0.5) * (1 - kappa_cond) # Huber loss Lκ(u) loss = torch.sum(torch.abs(self.cumulative_density - (u < 0).to(torch.float32)) * huber_loss, 1) # Quantile Huber loss ρκτ(u) = |τ − δ{u<0}|Lκ(u) else: loss = -torch.sum(m * ps_a.log(), 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) loss = weights * loss # Importance weight losses self.online_net.zero_grad() loss.mean().backward() # Backpropagate minibatch loss self.optimiser.step() nn.utils.clip_grad_norm_(self.online_net.parameters(), self.norm_clip) # Clip gradients by L2 norm if self.quantile: loss = (self.atoms * loss).clamp(max=5) # Heuristic for prioritised replay mem.update_priorities(idxs, loss.detach()) # Update priorities of sampled transitions def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) # Save model parameters on current device (don't move model between devices) def save(self, path): torch.save(self.online_net.state_dict(), os.path.join(path, 'model.pth')) # Evaluates Q-value based on single state (no batch) def evaluate_q(self, state): with torch.no_grad(): return (self.online_net(state.unsqueeze(0)) * ((1 / self.atoms) if self.quantile else self.support)).sum(2).max(1)[0].item() def train(self): self.online_net.train() def eval(self): self.online_net.eval()
class Agent(): def __init__(self, args, env): self.action_space = env.action_space() self.atoms = args.atoms self.Vmin = args.V_min self.Vmax = args.V_max self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device) # Support (range) of z self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1) self.batch_size = args.batch_size self.n = args.multi_step self.discount = args.discount self.saved_model_path = args.saved_model_path self.experiment = args.experiment self.plots_path = args.plots_path self.data_save_path = args.data_save_path self.online_net = DQN(args, self.action_space).to(device=args.device) if args.model and os.path.isfile(args.model): # Always load tensors onto CPU by default, will shift to GPU if necessary self.online_net.load_state_dict(torch.load(args.model, map_location='cpu')) self.online_net.train() self.target_net = DQN(args, self.action_space).to(device=args.device) self.update_target_net() self.target_net.train() for param in self.target_net.parameters(): param.requires_grad = False self.optimiser = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps) # list of layers: self.online_net_layers = [self.online_net.conv1, self.online_net.conv2, self.online_net.conv3, self.online_net.fc_h_v, self.online_net.fc_h_a, self.online_net.fc_z_v, self.online_net.fc_z_a ] self.target_net_layers = [self.target_net.conv1, self.target_net.conv2, self.target_net.conv3, self.target_net.fc_h_v, self.target_net.fc_h_a, self.target_net.fc_z_v, self.target_net.fc_z_a ] # freeze all layers except the last, and reinitialize last if args.freeze_layers > 0: self.freeze_layers(args.freeze_layers) if args.reinitialize_layers > 0: self.reinit_layers(args.reinitialize_layers) # Resets noisy weights in all linear layers (of online net only) def reset_noise(self): self.online_net.reset_noise() # Acts based on single state (no batch) def act(self, state): with torch.no_grad(): return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item() # Acts with an ε-greedy policy (used for evaluation only) def act_e_greedy(self, state, epsilon=0.001): # High ε can reduce evaluation scores drastically return np.random.randint(0, self.action_space) if np.random.random() < epsilon else self.act(state) def learn(self, mem): # Sample transitions idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size) # Calculate current state probabilities (online network noise already sampled) log_ps = self.online_net(states, log=True) # Log probabilities log p(s_t, ·; θonline) log_ps_a = log_ps[range(self.batch_size), actions] # log p(s_t, a_t; θonline) with torch.no_grad(): # Calculate nth next state probabilities pns = self.online_net(next_states) # Probabilities p(s_t+n, ·; θonline) dns = self.support.expand_as(pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) argmax_indices_ns = dns.sum(2).argmax(1) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] self.target_net.reset_noise() # Sample new target net noise pns = self.target_net(next_states) # Probabilities p(s_t+n, ·; θtarget) pns_a = pns[range(self.batch_size), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) # Compute Tz (Bellman operator T applied to z) Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze( 0) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.Vmin) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz m = states.new_zeros(self.batch_size, self.atoms) offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand( self.batch_size, self.atoms).to(actions) m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) loss = -torch.sum(m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) self.online_net.zero_grad() (weights * loss).mean().backward() # Backpropagate importance-weighted minibatch loss self.optimiser.step() mem.update_priorities(idxs, loss.detach().cpu().numpy()) # Update priorities of sampled transitions def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) # Save model parameters on current device (don't move model between devices) def save(self, path): torch.save(self.online_net.state_dict(), os.path.join(path, self.experiment + '_model.pth')) # 'model.pth')) # Evaluates Q-value based on single state (no batch) def evaluate_q(self, state): with torch.no_grad(): return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item() def train(self): self.online_net.train() def eval(self): self.online_net.eval() def freeze_layers(self, num_frozen_layers): # reinitialize the proper layers (all that were not frozen self.reinit_layers(5 - num_frozen_layers) for i in range(num_frozen_layers): if i == 0: # freeze last layer (two in list) self.online_net_layers[0].weight.requires_grad = False self.online_net_layers[0].bias.requires_grad = False elif i == 1: self.online_net_layers[1].weight.requires_grad = False self.online_net_layers[1].bias.requires_grad = False elif i == 2: self.online_net_layers[2].weight.requires_grad = False self.online_net_layers[2].bias.requires_grad = False elif i == 3: self.online_net_layers[3].weight_mu.requires_grad = False self.online_net_layers[3].weight_sigma.requires_grad = False self.online_net_layers[3].bias_mu.requires_grad = False self.online_net_layers[3].bias_sigma.requires_grad = False # self.online_net_layers[3].weight.requires_grad = False self.online_net_layers[4].bias_mu.requires_grad = False self.online_net_layers[4].bias_sigma.requires_grad = False self.online_net_layers[4].weight_mu.requires_grad = False self.online_net_layers[4].weight_sigma.requires_grad = False # self.online_net_layers[4].bias.requires_grad = False # elif i == 4: # self.online_net_layers[0].reset_parameters() # self.target_net_layers[0].reset_parameters() # freeze the proper layers - complicated work around for dueling architecture # ct = 0 # fourth_layer_first_time = True # for child in self.online_net.children(): # if ct < num_frozen_layers and ct < 3: # for param in child.parameters(): # print('something1') # param.required_grad = False # if ct < num_frozen_layers and ct == 3: # for param in child.parameters(): # print('something2') # param.required_grad = False # if fourth_layer_first_time: # fourth_layer_first_time = False # ct -= 1 # ct += 1 # # ct = 0 # fourth_layer_first_time = True # for child in self.target_net.children(): # if ct < num_frozen_layers and ct < 3: # for param in child.parameters(): # print('something3') # param.required_grad = False # if ct < num_frozen_layers and ct == 3: # for param in child.parameters(): # print('something4') # param.required_grad = False # if fourth_layer_first_time: # fourth_layer_first_time = False # ct -= 1 # ct += 1 print(self.online_net) print(list(i.requires_grad for i in self.online_net.parameters())) print(self.target_net) print(list(i.requires_grad for i in self.target_net.parameters())) def reinit_layers(self, num_layers): for i in range(num_layers): if i == 0: # freeze last layer (two in list) self.online_net_layers[6].reset_parameters() self.online_net_layers[5].reset_parameters() self.target_net_layers[6].reset_parameters() self.target_net_layers[5].reset_parameters() elif i == 1: self.online_net_layers[4].reset_parameters() self.online_net_layers[3].reset_parameters() self.target_net_layers[4].reset_parameters() self.target_net_layers[3].reset_parameters() elif i == 2: self.online_net_layers[2].reset_parameters() self.target_net_layers[2].reset_parameters() elif i == 3: self.online_net_layers[1].reset_parameters() self.target_net_layers[1].reset_parameters() elif i == 4: self.online_net_layers[0].reset_parameters() self.target_net_layers[0].reset_parameters()