def create_dataset(agent, sad, device):
    # use it in "vdn" mode so that trajecoties from the same game are
    # grouped together
    agent = agent.clone(device, {"vdn": True})
    runner = rela.BatchRunner(agent, device, 100, ["act", "compute_priority"])

    dataset_size = 1000
    replay_buffer = rela.RNNPrioritizedReplay(
        dataset_size,  # args.dataset_size,
        1,  # args.seed,
        0,  # args.priority_exponent, uniform sampling
        1,  # args.priority_weight,
        0,  # args.prefetch,
    )

    num_thread = 100
    num_game_per_thread = 1
    max_len = 80
    actors = []
    for i in range(num_thread):
        # thread_actors = []
        actor = rela.R2D2Actor(
            runner,
            1,  # multi_step,
            num_game_per_thread,
            0.99,  # gamma,
            0.9,  # eta
            max_len,  # max_len,
            2,  # num_player
            replay_buffer,
        )
        actors.append(actor)

    eps = [0]  # for _ in range(num_game_per_thread)]
    num_game = num_thread * num_game_per_thread
    games = create.create_envs(num_game, 1, 2, 5, 0, [0], max_len, sad, False,
                               False)
    context, threads = create.create_threads(num_thread, num_game_per_thread,
                                             actors, games)

    runner.start()
    context.start()
    while replay_buffer.size() < dataset_size:
        print("collecting data from replay buffer:", replay_buffer.size())
        time.sleep(0.2)

    context.pause()

    # remove extra data
    for _ in range(2):
        data, unif = replay_buffer.sample(10, "cpu")
        replay_buffer.update_priority(unif.detach().cpu())
        time.sleep(0.2)

    print("dataset size:", replay_buffer.size())
    print("done about to return")
    return replay_buffer, agent, context
Esempio n. 2
0
    if args.load_model:
        print("*****loading pretrained model*****")
        utils.load_weight(agent.online_net, args.load_model, args.train_device)
        print("*****done*****")

    agent = agent.to(args.train_device)
    optim = torch.optim.Adam(agent.online_net.parameters(),
                             lr=args.lr,
                             eps=args.eps)
    print(agent)
    eval_agent = agent.clone(args.train_device, {"vdn": False})

    replay_buffer = rela.RNNPrioritizedReplay(
        args.replay_buffer_size,
        args.seed,
        args.priority_exponent,
        args.priority_weight,
        args.prefetch,
    )

    act_group = ActGroup(
        args.method,
        args.act_device,
        agent,
        args.num_thread,
        args.num_game_per_thread,
        args.multi_step,
        args.gamma,
        args.eta,
        args.max_len,
        args.num_player,