class RLSVIIncrementalTDAgent(Agent): def __init__(self, action_set, reward_function, prior_variance, noise_variance, feature_extractor, prior_network, num_ensemble, hidden_dims=[10, 10], learning_rate=5e-4, buffer_size=50000, batch_size=64, num_batches=100, starts_learning=5000, discount=0.99, target_freq=10, verbose=False, print_every=1, test_model_path=None): Agent.__init__(self, action_set, reward_function) self.prior_variance = prior_variance self.noise_variance = noise_variance self.feature_extractor = feature_extractor self.feature_dim = self.feature_extractor.dimension dims = [self.feature_dim] + hidden_dims + [len(self.action_set)] self.prior_network = prior_network self.num_ensemble = num_ensemble # number of models in ensemble self.index = np.random.randint(self.num_ensemble) # build Q network # we use a multilayer perceptron if test_model_path is None: self.test_mode = False self.learning_rate = learning_rate self.buffer_size = buffer_size self.batch_size = batch_size self.num_batches = num_batches self.starts_learning = starts_learning self.discount = discount self.timestep = 0 self.buffer = Buffer(self.buffer_size) self.models = [] for i in range(self.num_ensemble): if self.prior_network: ''' Second network is a prior network whose weights are fixed and first network is difference network learned i.e, weights are mutable ''' self.models.append( DQNWithPrior(dims, scale=np.sqrt( self.prior_variance)).to(device)) else: self.models.append(MLP(dims).to(device)) self.models[i].initialize() ''' prior networks weights are immutable so enough to keep difference network ''' self.target_nets = [] for i in range(self.num_ensemble): if self.prior_network: self.target_nets.append( DQNWithPrior(dims, scale=np.sqrt( self.prior_variance)).to(device)) else: self.target_nets.append(MLP(dims).to(device)) self.target_nets[i].load_state_dict( self.models[i].state_dict()) self.target_nets[i].eval() self.target_freq = target_freq # target nn updated every target_freq episodes self.num_episodes = 0 self.optimizer = [] for i in range(self.num_ensemble): self.optimizer.append( torch.optim.Adam(self.models[i].parameters(), lr=self.learning_rate)) # for debugging purposes self.verbose = verbose self.running_loss = 1. self.print_every = print_every else: self.models = [] self.test_mode = True if self.prior_network: self.models.append( DQNWithPrior(dims, scale=self.prior_variance)) else: self.models.append(MLP(dims)) self.models[0].load_state_dict(torch.load(test_model_path)) self.models[0].eval() self.index = 0 def __str__(self): return 'rlsvi_incremental_TD_' + str(self.num_ensemble) + 'models' def update_buffer(self, observation_history, action_history): """ update buffer with data collected from current episode """ reward_history = self.get_episode_reward(observation_history, action_history) self.cummulative_reward += np.sum(reward_history) tau = len(action_history) feature_history = np.zeros((tau + 1, self.feature_extractor.dimension)) for t in range(tau + 1): feature_history[t] = self.feature_extractor.get_feature( observation_history[:t + 1]) for t in range(tau - 1): perturbations = np.random.randn(self.num_ensemble) * np.sqrt( self.noise_variance) self.buffer.add( (feature_history[t], action_history[t], reward_history[t], feature_history[t + 1], perturbations)) done = observation_history[tau][1] if done: feat_next = None else: feat_next = feature_history[tau] perturbations = np.random.randn(self.num_ensemble) * np.sqrt( self.noise_variance) self.buffer.add((feature_history[tau - 1], action_history[tau - 1], reward_history[tau - 1], feat_next, perturbations)) def learn_from_buffer(self): """ update Q network by applying TD steps """ if self.timestep < self.starts_learning: pass loss_ensemble = 0 for _ in range(self.num_batches): for sample_num in range(self.num_ensemble): minibatch = self.buffer.sample(batch_size=self.batch_size) feature_batch = torch.zeros(self.batch_size, self.feature_dim, device=device) action_batch = torch.zeros(self.batch_size, 1, dtype=torch.long, device=device) reward_batch = torch.zeros(self.batch_size, 1, device=device) perturb_batch = torch.zeros(self.batch_size, self.num_ensemble, device=device) non_terminal_idxs = [] next_feature_batch = [] for i, d in enumerate(minibatch): s, a, r, s_next, perturb = d feature_batch[i] = torch.from_numpy(s) action_batch[i] = torch.tensor(a, dtype=torch.long) reward_batch[i] = r perturb_batch[i] = torch.from_numpy(perturb) if s_next is not None: non_terminal_idxs.append(i) next_feature_batch.append(s_next) model_estimates = ( self.models[sample_num](feature_batch)).gather( 1, action_batch).float() future_values = torch.zeros(self.batch_size, device=device) if non_terminal_idxs != []: next_feature_batch = torch.tensor(next_feature_batch, dtype=torch.float, device=device) future_values[non_terminal_idxs] = ( self.target_nets[sample_num](next_feature_batch) ).max(1)[0].detach() future_values = future_values.unsqueeze(1) temp = perturb_batch[:, sample_num].unsqueeze(1) target_values = reward_batch + self.discount * future_values \ + perturb_batch[:,sample_num].unsqueeze(1) assert (model_estimates.shape == target_values.shape) loss = nn.functional.mse_loss(model_estimates, target_values) self.optimizer[sample_num].zero_grad() loss.backward() self.optimizer[sample_num].step() loss_ensemble += loss.item() self.running_loss = 0.99 * self.running_loss + 0.01 * loss_ensemble self.num_episodes += 1 self.index = np.random.randint(self.num_ensemble) if self.verbose and (self.num_episodes % self.print_every == 0): print("rlsvi ep %d, running loss %.2f, reward %.3f, index %d" % (self.num_episodes, self.running_loss, self.cummulative_reward, self.index)) if self.num_episodes % self.target_freq == 0: for sample_num in range(self.num_ensemble): self.target_nets[sample_num].load_state_dict( self.models[sample_num].state_dict()) # if self.verbose: # print("rlsvi via ensemble sampling ep %d update target network" % self.num_episodes) def act(self, observation_history, action_history): """ select action according to an epsilon greedy policy with respect to the Q network """ feature = self.feature_extractor.get_feature(observation_history) with torch.no_grad(): if str(device) == "cpu": action_values = (self.models[self.index]( torch.tensor(feature).float())).numpy() else: out = (self.models[self.index]( torch.tensor(feature).float().to(device))) action_values = (out.to("cpu")).numpy() action = self._random_argmax(action_values) return action def save(self, path=None): if path is None: path = './' + self.__str__() + '.pt' torch.save(self.models[self.index].state_dict(), path)
def main(args): torch.manual_seed(args.seed) # start simulators mp.set_start_method('spawn') episode_q = Queue() player_qs = [] simulators = [] for si in range(args.n_simulators): player_qs.append(Queue()) simulators.append( mp.Process(target=simulator, args=( si, player_qs[-1], episode_q, args, False, ))) simulators[-1].start() return_q = Queue() valid_q = Queue() valid_simulator = mp.Process(target=simulator, args=( args.n_simulators, valid_q, return_q, args, True, )) valid_simulator.start() env = gym.make(args.env) # env = gym.make('Assault-ram-v0') n_frames = args.n_frames # initialize replay buffer replay_buffer = Buffer(max_items=args.buffer_size, n_frames=n_frames, priority_ratio=args.priority_ratio, store_ratio=args.store_ratio) n_iter = args.n_iter init_collect = args.init_collect n_collect = args.n_collect n_value = args.n_value n_policy = args.n_policy n_hid = args.n_hid critic_aware = args.critic_aware update_every = args.update_every disp_iter = args.disp_iter val_iter = args.val_iter save_iter = args.save_iter max_len = args.max_len batch_size = args.batch_size max_collected_frames = args.max_collected_frames clip_coeff = args.grad_clip ent_coeff = args.ent_coeff discount_factor = args.discount_factor value_loss = -numpy.Inf entropy = -numpy.Inf valid_ret = -numpy.Inf ess = -numpy.Inf n_collected_frames = 0 offset = 0 return_history = [] if args.nn == "ff": # create a policy player = ff.Player(n_in=128 * n_frames, n_hid=args.n_hid, n_out=6).to(args.device) if args.player_coeff > 0.: player_old = ff.Player(n_in=128 * n_frames, n_hid=args.n_hid, n_out=6).to(args.device) player_copy = ff.Player(n_in=128 * n_frames, n_hid=args.n_hid, n_out=6).to('cpu') # create a value estimator value = ff.Value(n_in=128 * n_frames, n_hid=args.n_hid).to(args.device) value_old = ff.Value(n_in=128 * n_frames, n_hid=args.n_hid).to(args.device) for m in player.parameters(): m.data.normal_(0., 0.01) for m in value.parameters(): m.data.normal_(0., 0.01) elif args.nn == "conv": # create a policy player = conv.Player(n_frames=n_frames, n_hid=args.n_hid).to(args.device) if args.player_coeff > 0.: player_old = conv.Player(n_frames=n_frames, n_hid=args.n_hid).to(args.device) player_copy = conv.Player(n_frames=n_frames, n_hid=args.n_hid).to('cpu') # create a value estimator value = conv.Value(n_frames, n_hid=args.n_hid).to(args.device) value_old = conv.Value(n_frames, n_hid=args.n_hid).to(args.device) else: raise Exception('Unknown type') if args.cont: files = glob.glob("{}*th".format(args.saveto)) iterations = [ int(".".join(f.split('.')[:-1]).split('_')[-1].strip()) for f in files ] last_iter = numpy.max(iterations) offset = last_iter - 1 print('Reloading from {}_{}.th'.format(args.saveto, last_iter)) checkpoint = torch.load("{}_{}.th".format(args.saveto, last_iter)) player.load_state_dict(checkpoint['player']) value.load_state_dict(checkpoint['value']) return_history = checkpoint['return_history'] n_collected_frames = checkpoint['n_collected_frames'] copy_params(value, value_old) if args.player_coeff > 0.: copy_params(player, player_old) # start simulators player.to('cpu') copy_params(player, player_copy) for si in range(args.n_simulators): player_qs[si].put( [copy.deepcopy(p.data.numpy()) for p in player_copy.parameters()] + [copy.deepcopy(p.data.numpy()) for p in player_copy.buffers()]) valid_q.put( [copy.deepcopy(p.data.numpy()) for p in player_copy.parameters()] + [copy.deepcopy(p.data.numpy()) for p in player_copy.buffers()]) player.to(args.device) if args.device == 'cuda': torch.set_num_threads(1) initial = True pre_filled = 0 for ni in range(n_iter): # re-initialize optimizers opt_player = eval(args.optimizer_player)(player.parameters(), lr=args.lr, weight_decay=args.l2) opt_value = eval(args.optimizer_value)(value.parameters(), lr=args.lr, weight_decay=args.l2) try: if not initial: lr = args.lr / (1 + (ni - pre_filled + 1) * args.lr_factor) ent_coeff = args.ent_coeff / ( 1 + (ni - pre_filled + 1) * args.ent_factor) print('lr', lr, 'ent_coeff', ent_coeff) for param_group in opt_player.param_groups: param_group['lr'] = lr for param_group in opt_value.param_groups: param_group['lr'] = lr if numpy.mod((ni - pre_filled + 1), save_iter) == 0: torch.save( { 'n_iter': n_iter, 'n_collect': n_collect, 'n_value': n_value, 'n_policy': n_policy, 'max_len': max_len, 'n_hid': n_hid, 'batch_size': batch_size, 'player': player.state_dict(), 'value': value.state_dict(), 'return_history': return_history, 'n_collected_frames': n_collected_frames, }, '{}_{}.th'.format(args.saveto, (ni - pre_filled + 1) + offset + 1)) player.eval() ret_ = -numpy.Inf while True: try: ret_ = return_q.get_nowait() except queue.Empty: break if ret_ != -numpy.Inf: return_history.append(ret_) if valid_ret == -numpy.Inf: valid_ret = ret_ else: valid_ret = 0.9 * valid_ret + 0.1 * ret_ print('Valid run', ret_, valid_ret) #st = time.time() player.to('cpu') copy_params(player, player_copy) for si in range(args.n_simulators): while True: try: # empty the queue, as the new one has arrived player_qs[si].get_nowait() except queue.Empty: break player_qs[si].put([ copy.deepcopy(p.data.numpy()) for p in player_copy.parameters() ] + [ copy.deepcopy(p.data.numpy()) for p in player_copy.buffers() ]) while True: try: # empty the queue, as the new one has arrived valid_q.get_nowait() except queue.Empty: break valid_q.put([ copy.deepcopy(p.data.numpy()) for p in player_copy.parameters() ] + [copy.deepcopy(p.data.numpy()) for p in player_copy.buffers()]) player.to(args.device) #print('model push took', time.time()-st) #st = time.time() n_collected_frames_ = 0 while True: try: epi = episode_q.get_nowait() replay_buffer.add(epi[0], epi[1], epi[2], epi[3]) n_collected_frames_ = n_collected_frames_ + len(epi[0]) except queue.Empty: break if n_collected_frames_ >= max_collected_frames \ and (len(replay_buffer.buffer) + len(replay_buffer.priority_buffer)) > 0: break n_collected_frames = n_collected_frames + n_collected_frames_ if len(replay_buffer.buffer) + len( replay_buffer.priority_buffer) < 1: continue if len(replay_buffer.buffer) + len( replay_buffer.priority_buffer) < args.initial_buffer: if initial: print( 'Pre-filling the buffer...', len(replay_buffer.buffer) + len(replay_buffer.priority_buffer)) continue else: if initial: pre_filled = ni initial = False #print('collection took', time.time()-st) #print('Buffer size', len(replay_buffer.buffer) + len(replay_buffer.priority_buffer)) # fit a value function # TD(0) #st = time.time() value.train() for vi in range(n_value): if numpy.mod(vi, update_every) == 0: #print(vi, 'zeroing gradient') opt_player.zero_grad() opt_value.zero_grad() batch = replay_buffer.sample(batch_size) batch_x = torch.from_numpy( numpy.stack([ex.current_['obs'] for ex in batch ]).astype('float32')).to(args.device) batch_r = torch.from_numpy( numpy.stack([ex.current_['rew'] for ex in batch ]).astype('float32')).to(args.device) batch_xn = torch.from_numpy( numpy.stack([ex.next_['obs'] for ex in batch ]).astype('float32')).to(args.device) pred_y = value(batch_x) pred_next = value_old(batch_xn).clone().detach() batch_pi = player(batch_x) loss_ = ((batch_r + discount_factor * pred_next.squeeze() - pred_y.squeeze())**2) batch_a = torch.from_numpy( numpy.stack([ex.current_['act'] for ex in batch ]).astype('float32')[:, None]).to(args.device) batch_q = torch.from_numpy( numpy.stack([ex.current_['prob'] for ex in batch ]).astype('float32')).to(args.device) logp = torch.log(batch_pi.gather(1, batch_a.long()) + 1e-8) # (clipped) importance weight: # because the policy may have changed since the tuple was collected. log_iw = logp.squeeze().clone().detach() - torch.log( batch_q.squeeze() + 1e-8) ess_ = torch.exp(-torch.logsumexp(2 * log_iw, dim=0)).item() iw = torch.exp(log_iw.clamp(max=0.)) if args.iw: loss = iw * loss_ else: loss = loss_ loss = loss.mean() loss.backward() if numpy.mod(vi, update_every) == (update_every - 1): #print(vi, 'making an update') if clip_coeff > 0.: nn.utils.clip_grad_norm_(value.parameters(), clip_coeff) opt_value.step() copy_params(value, value_old) if value_loss < 0.: value_loss = loss_.mean().item() else: value_loss = 0.9 * value_loss + 0.1 * loss_.mean().item() if numpy.mod((ni - pre_filled + 1), disp_iter) == 0: print('# frames', n_collected_frames, 'value_loss', value_loss, 'entropy', -entropy, 'ess', ess) #print('value update took', time.time()-st) # fit a policy #st = time.time() value.eval() player.train() if args.player_coeff > 0.: player_old.eval() for pi in range(n_policy): if numpy.mod(pi, update_every) == 0: opt_player.zero_grad() opt_value.zero_grad() #st = time.time() batch = replay_buffer.sample(batch_size) #print('batch collection took', time.time()-st) #st = time.time() #batch_x = [ex.current_['obs'] for ex in batch] #batch_xn = [ex.next_['obs'] for ex in batch] #batch_r = [ex.current_['rew'] for ex in batch] #print('list construction took', time.time()-st) #st = time.time() batch_x = numpy.zeros( tuple([len(batch)] + list(batch[0].current_['obs'].shape)), dtype='float32') batch_xn = numpy.zeros( tuple([len(batch)] + list(batch[0].current_['obs'].shape)), dtype='float32') batch_r = numpy.zeros((len(batch)), dtype='float32')[:, None] for ei, ex in enumerate(batch): batch_x[ei, :] = ex.current_['obs'] batch_xn[ei, :] = ex.next_['obs'] batch_r[ei, 0] = ex.current_['rew'] #batch_x = numpy.stack(batch_x).astype('float32') #batch_xn = numpy.stack(batch_xn).astype('float32') #batch_r = numpy.stack(batch_r).astype('float32')[:,None] #print('batch stack for value took', time.time()-st) #st = time.time() batch_x = torch.from_numpy(batch_x).to(args.device) batch_xn = torch.from_numpy(batch_xn).to(args.device) batch_r = torch.from_numpy(batch_r).to(args.device) #print('batch push for value took', time.time()-st) #st = time.time() batch_v = value(batch_x).clone().detach() batch_vn = value(batch_xn).clone().detach() #print('value forward pass took', time.time()-st) #st = time.time() batch_a = torch.from_numpy( numpy.stack([ex.current_['act'] for ex in batch ]).astype('float32')[:, None]).to(args.device) batch_q = torch.from_numpy( numpy.stack([ex.current_['prob'] for ex in batch ]).astype('float32')).to(args.device) batch_pi = player(batch_x) logp = torch.log(batch_pi.gather(1, batch_a.long()) + 1e-8) if args.player_coeff > 0.: batch_pi_old = player_old(batch_x).clone().detach() #print('policy computation took', time.time()-st) #st = time.time() # entropy regularization ent = -(batch_pi * torch.log(batch_pi + 1e-8)).sum(1) if entropy == -numpy.Inf: entropy = ent.mean().item() else: entropy = 0.9 * entropy + 0.1 * ent.mean().item() #print('entropy computation took', time.time()-st) #st = time.time() # advantage: r(s,a) + \gamma * V(s') - V(s) adv = batch_r + discount_factor * batch_vn - batch_v #adv = adv / adv.abs().max().clamp(min=1.) loss = -(adv * logp).squeeze() loss = loss - ent_coeff * ent #print('basic loss computation took', time.time()-st) #st = time.time() # (clipped) importance weight: log_iw = logp.squeeze().clone().detach() - torch.log(batch_q + 1e-8) iw = torch.exp(log_iw.clamp(max=0.)) ess_ = torch.exp(-torch.logsumexp(2 * log_iw, dim=0)).item() if ess == -numpy.Inf: ess = ess_ else: ess = 0.9 * ess + 0.1 * ess_ if args.iw: loss = iw * loss else: loss = loss #print('importance weighting took', time.time()-st) if critic_aware: #st = time.time() pred_y = value(batch_x).squeeze() pred_next = value(batch_xn).squeeze() critic_loss_ = -( (batch_r.squeeze() + discount_factor * pred_next - pred_y)**2).clone().detach() critic_loss_ = torch.exp(critic_loss_) loss = loss * critic_loss_ #print('critic aware weighting took', time.time()-st) loss = loss.mean() if args.player_coeff > 0.: #st = time.time() loss_old = -(batch_pi_old * torch.log(batch_pi + 1e-8)).sum(1).mean() loss = (1. - args.player_coeff ) * loss + args.player_coeff * loss_old #print('player interpolation took', time.time()-st) #st = time.time() loss.backward() if numpy.mod(pi, update_every) == (update_every - 1): if clip_coeff > 0.: nn.utils.clip_grad_norm_(player.parameters(), clip_coeff) opt_player.step() #print('backward computation and update took', time.time()-st) if args.player_coeff > 0.: copy_params(player, player_old) ##print('policy update took', time.time()-st) except KeyboardInterrupt: print('Terminating...') break for si in range(args.n_simulators): player_qs[si].put("END") print('Waiting for the simulators...') for si in range(args.n_simulators): simulators[-1].join() print('Done')
class DQNAgent(Agent): def __init__(self, action_set, reward_function, feature_extractor, hidden_dims=[50, 50], learning_rate=5e-4, buffer_size=50000, batch_size=64, num_batches=100, starts_learning=5000, final_epsilon=0.02, discount=0.99, target_freq=10, verbose=False, print_every=1, test_model_path=None): Agent.__init__(self, action_set, reward_function) self.feature_extractor = feature_extractor self.feature_dim = self.feature_extractor.dimension # build Q network # we use a multilayer perceptron dims = [self.feature_dim] + hidden_dims + [len(self.action_set)] self.model = MLP(dims) if test_model_path is None: self.test_mode = False self.learning_rate = learning_rate self.buffer_size = buffer_size self.batch_size = batch_size self.num_batches = num_batches self.starts_learning = starts_learning self.epsilon = 1.0 # anneals starts_learning/(starts_learning + t) self.final_epsilon = 0.02 self.timestep = 0 self.discount = discount self.buffer = Buffer(self.buffer_size) self.target_net = MLP(dims) self.target_net.load_state_dict(self.model.state_dict()) self.target_net.eval() self.target_freq = target_freq # target nn updated every target_freq episodes self.num_episodes = 0 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) # for debugging purposes self.verbose = verbose self.running_loss = 1. self.print_every = print_every else: self.test_mode = True self.model.load_state_dict(torch.load(test_model_path)) self.model.eval() def __str__(self): return "dqn" def update_buffer(self, observation_history, action_history): """ update buffer with data collected from current episode """ reward_history = self.get_episode_reward(observation_history, action_history) self.cummulative_reward += np.sum(reward_history) tau = len(action_history) feature_history = np.zeros((tau+1, self.feature_extractor.dimension)) for t in range(tau+1): feature_history[t] = self.feature_extractor.get_feature(observation_history[:t+1]) for t in range(tau-1): self.buffer.add((feature_history[t], action_history[t], reward_history[t], feature_history[t+1])) done = observation_history[tau][1] if done: feat_next = None else: feat_next = feature_history[tau] self.buffer.add((feature_history[tau-1], action_history[tau-1], reward_history[tau-1], feat_next)) def learn_from_buffer(self): """ update Q network by applying TD steps """ if self.timestep < self.starts_learning: pass for _ in range(self.num_batches): minibatch = self.buffer.sample(batch_size=self.batch_size) feature_batch = torch.zeros(self.batch_size, self.feature_dim) action_batch = torch.zeros(self.batch_size, 1, dtype=torch.long) reward_batch = torch.zeros(self.batch_size, 1) non_terminal_idxs = [] next_feature_batch = [] for i, d in enumerate(minibatch): x, a, r, x_next = d feature_batch[i] = torch.from_numpy(x) action_batch[i] = torch.tensor(a, dtype=torch.long) reward_batch[i] = r if x_next is not None: non_terminal_idxs.append(i) next_feature_batch.append(x_next) model_estimates = self.model(feature_batch).gather(1, action_batch) future_values = torch.zeros(self.batch_size) if next_feature_batch != []: next_feature_batch = torch.tensor(next_feature_batch, dtype=torch.float) future_values[non_terminal_idxs] = self.target_net(next_feature_batch).max(1)[0].detach() future_values = future_values.unsqueeze(1) target_values = reward_batch + self.discount * future_values loss = nn.functional.mse_loss(model_estimates, target_values) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.running_loss = 0.99 * self.running_loss + 0.01 * loss.item() self.epsilon = self.starts_learning / (self.starts_learning + self.timestep) self.epsilon = max(self.final_epsilon, self.epsilon) self.num_episodes += 1 if self.verbose and (self.num_episodes % self.print_every == 0): print("dqn ep %d, running loss %.2f" % (self.num_episodes, self.running_loss)) if self.num_episodes % self.target_freq == 0: self.target_net.load_state_dict(self.model.state_dict()) if self.verbose: print("dqn ep %d update target network" % self.num_episodes) def act(self, observation_history, action_history): """ select action according to an epsilon greedy policy with respect to the Q network """ feature = self.feature_extractor.get_feature(observation_history) with torch.no_grad(): action_values = self.model(torch.from_numpy(feature).float()).numpy() if not self.test_mode: action = self._epsilon_greedy_action(action_values, self.epsilon) self.timestep += 1 else: action = self._random_argmax(action_values) return action def save(self, path=None): if path is None: path = './dqn.pt' torch.save(self.model.state_dict(), path)