Пример #1
0
def train(env_name, seed=42, timesteps=1, epsilon_decay_last_step=1000,
            er_capacity=1e4, batch_size=16, lr=1e-3, gamma=1.0,  update_target=16,
            exp_name='test', init_timesteps=100, save_every_steps=1e4, arch='nature',
            dueling=False, play_steps=2, n_jobs=2):
    """
        Main training function. Calls the subprocesses to get experience and
        train the network.
    """
    # Multiprocessing method
    mp.set_start_method('spawn')

    # Get PyTorch device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Set random seed for PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # Create logger
    logger = Logger(exp_name, loggers=['tensorboard'])

    # Create the Q network
    _env = make_env(env_name, seed)
    net = QNetwork(_env.observation_space, _env.action_space, arch=arch, dueling=dueling).to(device)
    # Create the target network as a copy of the Q network
    target_net = copy.deepcopy(net)
    # Create buffer and optimizer
    buffer = ExperienceReplay(capacity=int(er_capacity))
    optimizer = optim.Adam(net.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=LR_STEPS, gamma=0.99)

    # Multiprocessing queue
    obs_queue = mp.Queue(maxsize=n_jobs)
    transition_queue = mp.Queue(maxsize=n_jobs)
    workers, action_queues = [], []
    for i in range(n_jobs):
        action_queue = mp.Queue(maxsize=1)
        _seed = seed + i * 1000
        play_proc = mp.Process(target=play_func, args=(i, env_name, obs_queue, transition_queue, action_queue, _seed))
        play_proc.start()
        workers.append(play_proc)
        action_queues.append(action_queue)

    # Vars to keep track of performances and time
    timestep = 0
    current_reward, current_len = np.zeros(play_steps), np.zeros(play_steps, dtype=np.int64)
    current_time = [time.time() for _ in range(play_steps)]
    # Training loop
    while timestep < timesteps:
        # Compute the current epsilon
        epsilon = EPSILON_STOP + max(0, (EPSILON_START - EPSILON_STOP)*(epsilon_decay_last_step-timestep)/epsilon_decay_last_step)
        logger.log_kv('internals/epsilon', epsilon, timestep)
        # Gather observation N_STEPS
        ids, obs_batch = zip(*[obs_queue.get() for _ in range(play_steps)])
        # Pre-process observation_batch for PyTorch
        obs_batch = torch.from_numpy(np.array(obs_batch)).to(device)
        # Select greedy action from policy, apply epsilon-greedy selection
        greedy_actions = net(obs_batch).argmax(dim=1).cpu().detach().numpy()
        probs = torch.rand(greedy_actions.shape)
        actions = np.where(probs < epsilon, _env.action_space.sample(), greedy_actions)
        # Send actions
        for id, action in zip(ids, actions):
            action_queues[id].put(action)
        # Add transitions to experience replay
        transitions = [transition_queue.get() for _ in range(play_steps)]
        buffer.pushTransitions(transitions)
        # Check if we need to update rewards, time and lengths
        _, _, _, reward, done, _ = zip(*transitions)
        current_reward += reward
        current_len += 1
        for i, done in enumerate(done):
            if done:
                # Log quantities
                logger.log_kv('performance/return', current_reward[i], timestep)
                logger.log_kv('performance/length', current_len[i], timestep)
                logger.log_kv('performance/speed', current_len[i] / (time.time() - current_time[i]), timestep)
                # Reset counters
                current_reward[i] = 0.0
                current_len[i] = 0
                current_time[i] = time.time()

        # Update number of steps
        timestep += play_steps

        # Check if we are in the warm-up phase, otherwise go on with policy update
        if timestep < init_timesteps:
            continue
        # Learning rate upddate and log
        scheduler.step()
        logger.log_kv('internals/lr', scheduler.get_lr()[0], timestep)
        # Clear grads
        optimizer.zero_grad()
        # Get a batch from experience replay
        batch = buffer.sampleTransitions(batch_size)
        def batch_preprocess(batch_item):
            return torch.tensor(batch_item, dtype=(torch.long if isinstance(batch_item[0], np.int64) else None)).to(device)
        ids, states_batch, actions_batch, rewards_batch, done_batch, next_states_batch = map(batch_preprocess, zip(*batch))
        # Compute the loss function
        state_action_values = net(states_batch).gather(1, actions_batch.unsqueeze(-1)).squeeze(-1)
        next_state_values = target_net(next_states_batch).max(1)[0]
        next_state_values[done_batch] = 0.0
        expected_state_action_values = next_state_values.detach() * gamma + rewards_batch
        loss = F.mse_loss(state_action_values, expected_state_action_values)
        logger.log_kv('internals/loss', loss.item(), timestep)
        loss.backward()
        # Clip the gradients to avoid to abrupt changes (this is equivalent to Huber Loss)
        for param in net.parameters():
            param.grad.data.clamp_(-1, 1)
        optimizer.step()

        if timestep % update_target == 0:
            target_net.load_state_dict(net.state_dict())

        # Check if we need to save a checkpoint
        if timestep % save_every_steps == 0:
            torch.save(net.get_extended_state(), exp_name + '.pth')

    # Ending
    for i, worker in enumerate(workers):
        action_queues[i].put(None)
        worker.join()
Пример #2
0
def train(env_name,
          arch,
          timesteps=1,
          init_timesteps=0,
          seed=42,
          er_capacity=1,
          epsilon_start=1.0,
          epsilon_stop=0.05,
          epsilon_decay_stop=1,
          batch_size=16,
          target_sync=16,
          lr=1e-3,
          gamma=1.0,
          dueling=False,
          play_steps=1,
          lr_steps=1e4,
          lr_gamma=0.99,
          save_steps=5e4,
          logger=None,
          experiment_name='test'):
    """
        Main training function. Calls the subprocesses to get experience and
        train the network.
    """

    # Casting params which are expressable in scientific notation
    def int_scientific(x):
        return int(float(x))

    timesteps, init_timesteps = map(int_scientific,
                                    [timesteps, init_timesteps])
    lr_steps, epsilon_decay_stop = map(int_scientific,
                                       [lr_steps, epsilon_decay_stop])
    er_capacity, target_sync, save_steps = map(
        int_scientific, [er_capacity, target_sync, save_steps])
    lr = float(lr)

    # Multiprocessing method
    mp.set_start_method('spawn')

    # Get PyTorch device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create the Q network
    _env = make_env(env_name, seed)
    net = QNetwork(_env.observation_space,
                   _env.action_space,
                   arch=arch,
                   dueling=dueling).to(device)
    # Create the target network as a copy of the Q network
    tgt_net = ptan.agent.TargetNet(net)
    # Create buffer and optimizer
    buffer = ptan.experience.ExperienceReplayBuffer(experience_source=None,
                                                    buffer_size=er_capacity)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=lr_steps, gamma=0.99)

    # Multiprocessing queue
    epsilon_schedule = (epsilon_start, epsilon_stop, epsilon_decay_stop)
    exp_queue = mp.Queue(maxsize=play_steps * 2)
    play_proc = mp.Process(target=play_func,
                           args=(env_name, net, exp_queue, seed, timesteps,
                                 epsilon_schedule, gamma))
    play_proc.start()

    # Main training loop
    timestep = 0
    while play_proc.is_alive() and timestep < timesteps:
        timestep += play_steps
        # Query the environments and log results if the episode has ended
        for _ in range(play_steps):
            exp, info = exp_queue.get()
            if exp is None:
                play_proc.join()
                break
            buffer._add(exp)
            logger.log_kv('internals/epsilon', info['epsilon'][0],
                          info['epsilon'][1])
            if 'ep_reward' in info.keys():
                logger.log_kv('performance/return', info['ep_reward'],
                              timestep)
                logger.log_kv('performance/length', info['ep_length'],
                              timestep)
                logger.log_kv('performance/speed', info['speed'], timestep)

        # Check if we are in the starting phase
        if len(buffer) < init_timesteps:
            continue

        scheduler.step()
        logger.log_kv('internals/lr', scheduler.get_lr()[0], timestep)
        # Get a batch from experience replay
        optimizer.zero_grad()
        batch = buffer.sample(batch_size * play_steps)
        # Unpack the batch
        states, actions, rewards, dones, next_states = unpack_batch(batch)
        states_v = torch.tensor(states).to(device)
        next_states_v = torch.tensor(next_states).to(device)
        actions_v = torch.tensor(actions).to(device)
        rewards_v = torch.tensor(rewards).to(device)
        done_mask = torch.ByteTensor(dones).to(device)
        # Optimize defining the loss function
        state_action_values = net(states_v).gather(
            1, actions_v.unsqueeze(-1)).squeeze(-1)
        next_state_values = tgt_net.target_model(next_states_v).max(1)[0]
        next_state_values[done_mask] = 0.0
        expected_state_action_values = next_state_values.detach(
        ) * gamma + rewards_v
        loss = F.mse_loss(state_action_values, expected_state_action_values)
        logger.log_kv('internals/loss', loss.item(), timestep)
        loss.backward()
        # Clip the gradients to avoid to abrupt changes (this is equivalent to Huber Loss)
        for param in net.parameters():
            param.grad.data.clamp_(-1, 1)
        optimizer.step()

        # Check if the target network need to be synched
        if timestep % target_sync == 0:
            tgt_net.sync()

        # Check if we need to save a checkpoint
        if timestep % save_steps == 0:
            torch.save(net.get_extended_state(), experiment_name + '.pth')