示例#1
0
def train(model_name):
    writer = SummaryWriter(log_dir=Settings.FULL_LOG_DIR)
    for key, value in Settings.export_settings().items():
        writer.add_text(key, str(value))

    if Settings.INIT_MODEL_NAME:
        dqn = DQN.load(Settings.INIT_MODEL_NAME)
    else:
        dqn = DQN(dropout=Settings.USE_DROPOUT)

    reward_function = get_reward_function()
    criterion = nn.SmoothL1Loss()
    optimizer = optim.Adam(dqn.parameters(), lr=Settings.LEARNING_RATE)
    target_dqn = copy.deepcopy(dqn)
    target_dqn.eval()  # No dropout for the target network

    if Settings.USE_PRIORITIZED_ER:
        history = SumTree(capacity=Settings.REPLAY_BUFFER_SIZE)
    else:
        history = deque(maxlen=Settings.REPLAY_BUFFER_SIZE)

    for iteration in tqdm(range(Settings.NUM_TRAINING_EPISODES)):

        # Decay the chance to make a random move to a minimum of 0.1
        epsilon = Settings.EPS_END + (
            Settings.EPS_START - Settings.EPS_END) * np.exp(
                -Settings.EPS_DECAY_COEFFICIENT *
                np.floor(iteration / Settings.EPS_DECAY_RATE))

        if iteration % Settings.TARGET_NET_FREEZE_PERIOD == 0 and iteration != 0:
            target_dqn = copy.deepcopy(dqn)
            target_dqn.eval()

        if iteration % Settings.EVALUATION_PERIOD == 0 and iteration != 0:
            evaluate_q_model_and_log_metrics(dqn, iteration, writer,
                                             reward_function)
            writer.add_scalar("epsilon", epsilon, iteration)
            dqn.checkpoint("{}_checkpoint_{}".format(model_name, iteration))

        control_function = partial(do_dqn_control, dqn=dqn, epsilon=epsilon)
        episode_metrics = control.run_episode(
            control_function=control_function,
            state_function=prediction.HighwayState.from_sumo,
            max_episode_length=Settings.TRAINING_EPISODE_LENGTH,
            limit_metrics=True)
        episode_history = rl.get_history(episode_metrics, reward_function)

        if Settings.USE_PRIORITIZED_ER:
            for item in episode_history:
                history.add_node(item,
                                 Settings.PER_MAX_PRIORITY**Settings.PER_ALPHA)
        else:
            history.extend(episode_history)

        if iteration % 10 == 0:
            writer.add_scalar("Length", len(episode_history), iteration)

        total_loss = 0
        for train_index in range(Settings.TRAINING_STEPS_PER_EPISODE):
            # Choose a (state, action, reward, state) tuple from some random trajectories in the replay buffer
            if Settings.USE_PRIORITIZED_ER:
                train_sars = []
                train_indices = []
                for k in range(min(len(history), Settings.BATCH_SIZE)):
                    position, sars = history.sample()
                    train_sars.append(sars)
                    train_indices.append(position)
            else:
                train_sars = random.choices(history,
                                            k=min(len(history),
                                                  Settings.BATCH_SIZE))
                train_indices = []

            # Calculate target = r + gamma * max_a q(s+, a)
            targets = get_targets(train_sars,
                                  dqn,
                                  target_dqn,
                                  gamma=Settings.DISCOUNT_FACTOR)
            target_tensor = dqn.get_target_tensor_bulk(targets)

            # Convert the states and actions to pytorch tensors
            state_tensor = dqn.get_q_tensor_bulk(
                [item.state for item in train_sars])
            action_tensor = dqn.get_action_tensor_bulk(
                [item.action for item in train_sars]).reshape((-1, 1))

            optimizer.zero_grad()

            # Calculate Q(s, a)
            outputs = dqn.forward(state_tensor)
            q_values = torch.gather(outputs, 1, action_tensor).flatten()

            # Gradient descent step
            loss = criterion(q_values, target_tensor)
            loss.backward()
            optimizer.step()

            if Settings.USE_PRIORITIZED_ER:
                td_errors = torch.abs(q_values - target_tensor)
                for error_index, error in enumerate(td_errors):
                    priority = min(
                        error + Settings.PER_MIN_PRIORITY,
                        Settings.PER_MAX_PRIORITY)**Settings.PER_ALPHA
                    history.update_weight(priority, train_indices[error_index])

            total_loss += loss

        if iteration % 10 == 0:
            writer.add_scalar("Loss",
                              total_loss / Settings.TRAINING_STEPS_PER_EPISODE,
                              iteration)

    evaluate_q_model_and_log_metrics(dqn, Settings.NUM_TRAINING_EPISODES,
                                     writer, reward_function)

    dqn.save(model_name)
    writer.close()