def train(env_name, seed=42, timesteps=1, epsilon_decay_last_step=1000, er_capacity=1e4, batch_size=16, lr=1e-3, gamma=1.0, update_target=16, exp_name='test', init_timesteps=100, save_every_steps=1e4, arch='nature', dueling=False, play_steps=2, n_jobs=2): """ Main training function. Calls the subprocesses to get experience and train the network. """ # Multiprocessing method mp.set_start_method('spawn') # Get PyTorch device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Set random seed for PyTorch torch.manual_seed(seed) torch.cuda.manual_seed(seed) # Create logger logger = Logger(exp_name, loggers=['tensorboard']) # Create the Q network _env = make_env(env_name, seed) net = QNetwork(_env.observation_space, _env.action_space, arch=arch, dueling=dueling).to(device) # Create the target network as a copy of the Q network target_net = copy.deepcopy(net) # Create buffer and optimizer buffer = ExperienceReplay(capacity=int(er_capacity)) optimizer = optim.Adam(net.parameters(), lr=lr) scheduler = StepLR(optimizer, step_size=LR_STEPS, gamma=0.99) # Multiprocessing queue obs_queue = mp.Queue(maxsize=n_jobs) transition_queue = mp.Queue(maxsize=n_jobs) workers, action_queues = [], [] for i in range(n_jobs): action_queue = mp.Queue(maxsize=1) _seed = seed + i * 1000 play_proc = mp.Process(target=play_func, args=(i, env_name, obs_queue, transition_queue, action_queue, _seed)) play_proc.start() workers.append(play_proc) action_queues.append(action_queue) # Vars to keep track of performances and time timestep = 0 current_reward, current_len = np.zeros(play_steps), np.zeros(play_steps, dtype=np.int64) current_time = [time.time() for _ in range(play_steps)] # Training loop while timestep < timesteps: # Compute the current epsilon epsilon = EPSILON_STOP + max(0, (EPSILON_START - EPSILON_STOP)*(epsilon_decay_last_step-timestep)/epsilon_decay_last_step) logger.log_kv('internals/epsilon', epsilon, timestep) # Gather observation N_STEPS ids, obs_batch = zip(*[obs_queue.get() for _ in range(play_steps)]) # Pre-process observation_batch for PyTorch obs_batch = torch.from_numpy(np.array(obs_batch)).to(device) # Select greedy action from policy, apply epsilon-greedy selection greedy_actions = net(obs_batch).argmax(dim=1).cpu().detach().numpy() probs = torch.rand(greedy_actions.shape) actions = np.where(probs < epsilon, _env.action_space.sample(), greedy_actions) # Send actions for id, action in zip(ids, actions): action_queues[id].put(action) # Add transitions to experience replay transitions = [transition_queue.get() for _ in range(play_steps)] buffer.pushTransitions(transitions) # Check if we need to update rewards, time and lengths _, _, _, reward, done, _ = zip(*transitions) current_reward += reward current_len += 1 for i, done in enumerate(done): if done: # Log quantities logger.log_kv('performance/return', current_reward[i], timestep) logger.log_kv('performance/length', current_len[i], timestep) logger.log_kv('performance/speed', current_len[i] / (time.time() - current_time[i]), timestep) # Reset counters current_reward[i] = 0.0 current_len[i] = 0 current_time[i] = time.time() # Update number of steps timestep += play_steps # Check if we are in the warm-up phase, otherwise go on with policy update if timestep < init_timesteps: continue # Learning rate upddate and log scheduler.step() logger.log_kv('internals/lr', scheduler.get_lr()[0], timestep) # Clear grads optimizer.zero_grad() # Get a batch from experience replay batch = buffer.sampleTransitions(batch_size) def batch_preprocess(batch_item): return torch.tensor(batch_item, dtype=(torch.long if isinstance(batch_item[0], np.int64) else None)).to(device) ids, states_batch, actions_batch, rewards_batch, done_batch, next_states_batch = map(batch_preprocess, zip(*batch)) # Compute the loss function state_action_values = net(states_batch).gather(1, actions_batch.unsqueeze(-1)).squeeze(-1) next_state_values = target_net(next_states_batch).max(1)[0] next_state_values[done_batch] = 0.0 expected_state_action_values = next_state_values.detach() * gamma + rewards_batch loss = F.mse_loss(state_action_values, expected_state_action_values) logger.log_kv('internals/loss', loss.item(), timestep) loss.backward() # Clip the gradients to avoid to abrupt changes (this is equivalent to Huber Loss) for param in net.parameters(): param.grad.data.clamp_(-1, 1) optimizer.step() if timestep % update_target == 0: target_net.load_state_dict(net.state_dict()) # Check if we need to save a checkpoint if timestep % save_every_steps == 0: torch.save(net.get_extended_state(), exp_name + '.pth') # Ending for i, worker in enumerate(workers): action_queues[i].put(None) worker.join()
def train(env_name, seed=42, timesteps=1, epsilon_decay_last_step=1000, er_capacity=1e4, batch_size=16, lr=1e-3, gamma=1.0, update_target=16, exp_name='test', init_timesteps=100, save_every_steps=1e4, arch='nature', dueling=False): """ Main training function. Calls the subprocesses to get experience and train the network. """ # Multiprocessing method mp.set_start_method('spawn') # Get PyTorch device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create logger logger = Logger(exp_name, loggers=['tensorboard']) # Create the Q network _env = make_env(env_name, seed) net = QNetwork(_env.observation_space, _env.action_space, arch=arch, dueling=dueling).to(device) # Create the target network as a copy of the Q network tgt_net = ptan.agent.TargetNet(net) # Create buffer and optimizer buffer = ptan.experience.ExperienceReplayBuffer(experience_source=None, buffer_size=er_capacity) optimizer = optim.Adam(net.parameters(), lr=lr) scheduler = StepLR(optimizer, step_size=LR_STEPS, gamma=0.99) # Multiprocessing queue exp_queue = mp.Queue(maxsize=PLAY_STEPS * 2) play_proc = mp.Process(target=play_func, args=(env_name, net, exp_queue, seed, timesteps, epsilon_decay_last_step, gamma)) play_proc.start() # Main training loop timestep = 0 while play_proc.is_alive() and timestep < timesteps: timestep += PLAY_STEPS # Query the environments and log results if the episode has ended for _ in range(PLAY_STEPS): exp, info = exp_queue.get() if exp is None: play_proc.join() break buffer._add(exp) logger.log_kv('internals/epsilon', info['epsilon'][0], info['epsilon'][1]) if 'ep_reward' in info.keys(): logger.log_kv('performance/return', info['ep_reward'], timestep) logger.log_kv('performance/length', info['ep_length'], timestep) logger.log_kv('performance/speed', info['speed'], timestep) # Check if we are in the starting phase if len(buffer) < init_timesteps: continue scheduler.step() logger.log_kv('internals/lr', scheduler.get_lr()[0], timestep) # Get a batch from experience replay optimizer.zero_grad() batch = buffer.sample(batch_size * PLAY_STEPS) # Unpack the batch states, actions, rewards, dones, next_states = unpack_batch(batch) states_v = torch.tensor(states).to(device) next_states_v = torch.tensor(next_states).to(device) actions_v = torch.tensor(actions).to(device) rewards_v = torch.tensor(rewards).to(device) done_mask = torch.ByteTensor(dones).to(device) # Optimize defining the loss function state_action_values = net(states_v).gather( 1, actions_v.unsqueeze(-1)).squeeze(-1) next_state_values = tgt_net.target_model(next_states_v).max(1)[0] next_state_values[done_mask] = 0.0 expected_state_action_values = next_state_values.detach( ) * gamma + rewards_v loss = F.mse_loss(state_action_values, expected_state_action_values) logger.log_kv('internals/loss', loss.item(), timestep) loss.backward() # Clip the gradients to avoid to abrupt changes (this is equivalent to Huber Loss) for param in net.parameters(): param.grad.data.clamp_(-1, 1) optimizer.step() # Check if the target network need to be synched if timestep % update_target == 0: tgt_net.sync() # Check if we need to save a checkpoint if timestep % save_every_steps == 0: torch.save(net.get_extended_state(), exp_name + '.pth')