# This calculates the reward from the first step thru the n-th step using the given discount factor.
    exp_source = ptan.experience.ExperienceSourceFirstLast(
        env, agent, steps_count=unroll_steps, gamma=params.gamma)
    buffer = ptan.experience.ExperienceReplayBuffer(
        exp_source, buffer_size=params.replay_size)

    optimizer = optim.Adam(net.parameters(), lr=params.learning_rate)

    def process_batch(engine: Engine, batch):
        optimizer.zero_grad()
        # As the experience source already accumulates the reward, the loss function just needs to
        # handle the maximization with the n-th power of the discount factor.
        loss_v = calc_loss_dqn(batch,
                               net,
                               target_net.target_model,
                               params.gamma**unroll_steps,
                               device=device)
        loss_v.backward()
        optimizer.step()

        epsilon_tracker.frame(engine.state.iteration)
        if engine.state.iteration % params.target_net_sync == 0:
            target_net.sync()

        return {'loss': loss_v.item(), 'epsilon': selector.epsilon}

    engine = Engine(process_batch)
    setup_ignite(engine, params, exp_source, NAME)
    engine.run(
        batch_generator(buffer, params.replay_initial, params.batch_size))
示例#2
0
        epsilon=params.epsilon_start)
    epsilon_tracker = common.EpsilonTracker(selector, params)
    agent = ptan.agent.DQNAgent(net, selector, device=device)

    exp_source = ptan.experience.ExperienceSourceFirstLast(
        env, agent, gamma=params.gamma)
    buffer = ptan.experience.ExperienceReplayBuffer(
        exp_source, buffer_size=params.replay_size)
    optimizer = optim.Adam(net.parameters(),
                           lr=params.learning_rate)

    def process_batch(engine, batch):
        optimizer.zero_grad()
        loss_v = common.calc_loss_dqn(
            batch, net, tgt_net.target_model,
            gamma=params.gamma, device=device)
        loss_v.backward()
        optimizer.step()
        epsilon_tracker.frame(engine.state.iteration)
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()
        return {
            "loss": loss_v.item(),
            "epsilon": selector.epsilon,
        }

    engine = Engine(process_batch)
    common.setup_ignite(engine, params, exp_source, NAME)
    engine.run(common.batch_generator(buffer, params.replay_initial,
                                      params.batch_size))