Exemple #1
0
    selector = EpsilonGreedyActionSelector(epsilon=params.epsilon_start)
    epsilon_tracker = common.EpsilonTracker(selector, params)
    agent = DQNAgent(net, selector, device=device)
    exp_source = ExperienceSourceFirstLast(env,
                                           agent,
                                           gamma=params.gamma,
                                           steps_count=args.n)
    buffer = ExperienceReplayBuffer(exp_source, buffer_size=params.replay_size)
    optimizer = Adam(net.parameters(), lr=params.learning_rate)

    def process_batch(engine, batch):
        optimizer.zero_grad()
        loss = common.calc_loss_dqn(batch,
                                    net,
                                    tgt_net.target_model,
                                    gamma=params.gamma**args.n,
                                    device=device)
        loss.backward()
        optimizer.step()
        epsilon_tracker.frame(engine.state.iteration)
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()
        return {"loss": loss.item(), "epsilon": selector.epsilon}

    engine = Engine(process_batch)
    common.setup_ignite(engine, params, exp_source, f"{NAME}={args.n}")
    engine.run(
        common.batch_generator(buffer, params.replay_initial,
                               params.batch_size))
Exemple #2
0
    env = make(params.env_name)
    env = wrap_dqn(env)
    env.seed(123)
    net = dqn_extra.RainbowDQN(env.observation_space.shape, env.action_space.n).to(device)
    tgt_net = TargetNet(net)

    selector = ArgmaxActionSelector()
    agent = DQNAgent(net, selector, device=device)
    exp_source = ExperienceSourceFirstLast(env, agent, gamma=params.gamma, steps_count=N_STEPS)
    buffer = dqn_extra.PrioReplayBuffer(exp_source, params.replay_size, PRIO_REPLAY_ALPHA)
    optimizer = Adam(net.parameters(), lr=params.learning_rate)


    def process_batch(engine, batch_data):
        batch, batch_indices, batch_weights = batch_data
        optimizer.zero_grad()
        loss, sample_prios = calc_loss_prio(batch, batch_weights, net, tgt_net.target_model,
                                            gamma=params.gamma ** N_STEPS)
        loss.backward()
        optimizer.step()
        buffer.update_priorities(batch_indices, sample_prios)
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()
        return {"loss": loss.item(), "beta": buffer.update_beta(engine.state.iteration)}


    engine = Engine(process_batch)
    common.setup_ignite(engine, params, exp_source, NAME)
    engine.run(common.batch_generator(buffer, params.replay_initial, params.batch_size))
    selector = EpsilonGreedyActionSelector(epsilon=params.epsilon_start)
    epsilon_tracker = common.EpsilonTracker(selector, params)
    agent = DQNAgent(net, selector, device=device)
    exp_source = ExperienceSourceFirstLast(env, agent, gamma=params.gamma)
    buffer = ExperienceReplayBuffer(exp_source, buffer_size=params.replay_size)
    optimizer = Adam(net.parameters(), lr=params.learning_rate)


    def process_batch(engine, batch):
        optimizer.zero_grad()
        loss = common.calc_loss_dqn(batch, net, tgt_net.target_model, gamma=params.gamma, device=device)
        loss.backward()
        optimizer.step()
        epsilon_tracker.frame(engine.state.iteration)
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()
        if engine.state.iteration % EVAL_EVER_FRAME == 0:
            eval_states = getattr(engine.state, "eval_states", None)
            if eval_states is None:
                eval_states = buffer.sample(STATES_TO_EVALUATE)
                eval_states = [np.array(transition.state, copy=False) for transition in eval_states]
                eval_states = np.array(eval_states, copy=False)
                engine.state.eval_states = eval_states
            evaluate_states(eval_states, net, device, engine)
        return {"loss": loss.item(), "epsilon": selector.epsilon}


    engine = Engine(process_batch)
    common.setup_ignite(engine, params, exp_source, NAME, extra_metrics=("adv", "val"))
    engine.run(common.batch_generator(buffer, params.replay_initial, params.batch_size))
Exemple #4
0
    exp_source = ExperienceSourceFirstLast(env, agent, gamma=params.gamma)
    buffer = PrioReplayBuffer(exp_source,
                              buffer_size=params.replay_size,
                              prob_alpha=PRIO_REPLAY_ALPHA)
    optimizer = Adam(net.parameters(), lr=params.learning_rate)

    def process_batch(engine, batch_data):
        batch, batch_indices, batch_weights = batch_data
        optimizer.zero_grad()
        loss, sample_prios = calc_loss(batch,
                                       batch_weights,
                                       net,
                                       tgt_net.target_model,
                                       gamma=params.gamma)
        loss.backward()
        optimizer.step()
        buffer.update_priorities(batch_indices, sample_prios)
        epsilon_tracker.frame(engine.state.iteration)
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()
        return {
            "loss": loss.item(),
            "epsilon": selector.epsilon,
            "beta": buffer.update_beta(engine.state.iteration)
        }

    engine = Engine(process_batch)
    setup_ignite(engine, params, exp_source, NAME)
    engine.run(
        batch_generator(buffer, params.replay_initial, params.batch_size))
Exemple #5
0
    exp_source = ExperienceSourceFirstLast(env, agent, gamma=params.gamma)
    buffer = ExperienceReplayBuffer(exp_source, buffer_size=params.replay_size)
    optimizer = Adam(net.parameters(), lr=params.learning_rate)

    def process_batch(engine, batch):
        optimizer.zero_grad()
        loss = common.calc_loss_dqn(batch,
                                    net,
                                    tgt_net.target_model,
                                    gamma=params.gamma,
                                    device=device)
        loss.backward()
        optimizer.step()
        epsilon_tracker.frame(engine.state.iteration)
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()
        if engine.state.iteration % NOISY_SNR_EVERY_ITERS == 0:
            for layer_idx, sigma_l2 in enumerate(net.noisy_layers_sigma_snr()):
                engine.state.metrics[f"snr_{layer_idx + 1}"] = sigma_l2
        return {"loss": loss.item()}

    engine = Engine(process_batch)
    common.setup_ignite(engine,
                        params,
                        exp_source,
                        NAME,
                        extra_metrics=("snr_1", "snr_2"))
    engine.run(
        common.batch_generator(buffer, params.replay_initial,
                               params.batch_size))
Exemple #6
0
                                    gamma=params.gamma,
                                    device=device)
        loss.backward()
        optimizer.step()
        epsilon_tracker.frame(engine.state.iteration)
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()
        if engine.state.iteration % EVAL_EVER_FRAME == 0:
            eval_states = getattr(engine.state, "eval_states", None)
            if eval_states is None:
                eval_states = buffer.sample(STATES_TO_EVALUATE)
                eval_states = [
                    np.array(transition.state, copy=False)
                    for transition in eval_states
                ]
                eval_states = np.array(eval_states, copy=False)
                engine.state.eval_states = eval_states
            engine.state.metrics["values"] = common.calc_values_of_states(
                eval_states, net, device)
        return {"loss": loss.item(), "epsilon": selector.epsilon}

    engine = Engine(process_batch)
    common.setup_ignite(engine,
                        params,
                        exp_source,
                        f"{NAME}={args.double}",
                        extra_metrics=("values", ))
    engine.run(
        common.batch_generator(buffer, params.replay_initial,
                               params.batch_size))