def train(model_name): writer = SummaryWriter(log_dir=Settings.FULL_LOG_DIR) for key, value in Settings.export_settings().items(): writer.add_text(key, str(value)) if Settings.INIT_MODEL_NAME: dqn = DQN.load(Settings.INIT_MODEL_NAME) else: dqn = DQN(dropout=Settings.USE_DROPOUT) reward_function = get_reward_function() criterion = nn.SmoothL1Loss() optimizer = optim.Adam(dqn.parameters(), lr=Settings.LEARNING_RATE) target_dqn = copy.deepcopy(dqn) target_dqn.eval() # No dropout for the target network if Settings.USE_PRIORITIZED_ER: history = SumTree(capacity=Settings.REPLAY_BUFFER_SIZE) else: history = deque(maxlen=Settings.REPLAY_BUFFER_SIZE) for iteration in tqdm(range(Settings.NUM_TRAINING_EPISODES)): # Decay the chance to make a random move to a minimum of 0.1 epsilon = Settings.EPS_END + ( Settings.EPS_START - Settings.EPS_END) * np.exp( -Settings.EPS_DECAY_COEFFICIENT * np.floor(iteration / Settings.EPS_DECAY_RATE)) if iteration % Settings.TARGET_NET_FREEZE_PERIOD == 0 and iteration != 0: target_dqn = copy.deepcopy(dqn) target_dqn.eval() if iteration % Settings.EVALUATION_PERIOD == 0 and iteration != 0: evaluate_q_model_and_log_metrics(dqn, iteration, writer, reward_function) writer.add_scalar("epsilon", epsilon, iteration) dqn.checkpoint("{}_checkpoint_{}".format(model_name, iteration)) control_function = partial(do_dqn_control, dqn=dqn, epsilon=epsilon) episode_metrics = control.run_episode( control_function=control_function, state_function=prediction.HighwayState.from_sumo, max_episode_length=Settings.TRAINING_EPISODE_LENGTH, limit_metrics=True) episode_history = rl.get_history(episode_metrics, reward_function) if Settings.USE_PRIORITIZED_ER: for item in episode_history: history.add_node(item, Settings.PER_MAX_PRIORITY**Settings.PER_ALPHA) else: history.extend(episode_history) if iteration % 10 == 0: writer.add_scalar("Length", len(episode_history), iteration) total_loss = 0 for train_index in range(Settings.TRAINING_STEPS_PER_EPISODE): # Choose a (state, action, reward, state) tuple from some random trajectories in the replay buffer if Settings.USE_PRIORITIZED_ER: train_sars = [] train_indices = [] for k in range(min(len(history), Settings.BATCH_SIZE)): position, sars = history.sample() train_sars.append(sars) train_indices.append(position) else: train_sars = random.choices(history, k=min(len(history), Settings.BATCH_SIZE)) train_indices = [] # Calculate target = r + gamma * max_a q(s+, a) targets = get_targets(train_sars, dqn, target_dqn, gamma=Settings.DISCOUNT_FACTOR) target_tensor = dqn.get_target_tensor_bulk(targets) # Convert the states and actions to pytorch tensors state_tensor = dqn.get_q_tensor_bulk( [item.state for item in train_sars]) action_tensor = dqn.get_action_tensor_bulk( [item.action for item in train_sars]).reshape((-1, 1)) optimizer.zero_grad() # Calculate Q(s, a) outputs = dqn.forward(state_tensor) q_values = torch.gather(outputs, 1, action_tensor).flatten() # Gradient descent step loss = criterion(q_values, target_tensor) loss.backward() optimizer.step() if Settings.USE_PRIORITIZED_ER: td_errors = torch.abs(q_values - target_tensor) for error_index, error in enumerate(td_errors): priority = min( error + Settings.PER_MIN_PRIORITY, Settings.PER_MAX_PRIORITY)**Settings.PER_ALPHA history.update_weight(priority, train_indices[error_index]) total_loss += loss if iteration % 10 == 0: writer.add_scalar("Loss", total_loss / Settings.TRAINING_STEPS_PER_EPISODE, iteration) evaluate_q_model_and_log_metrics(dqn, Settings.NUM_TRAINING_EPISODES, writer, reward_function) dqn.save(model_name) writer.close()