def create_dataset(agent, sad, device): # use it in "vdn" mode so that trajecoties from the same game are # grouped together agent = agent.clone(device, {"vdn": True}) runner = rela.BatchRunner(agent, device, 100, ["act", "compute_priority"]) dataset_size = 1000 replay_buffer = rela.RNNPrioritizedReplay( dataset_size, # args.dataset_size, 1, # args.seed, 0, # args.priority_exponent, uniform sampling 1, # args.priority_weight, 0, # args.prefetch, ) num_thread = 100 num_game_per_thread = 1 max_len = 80 actors = [] for i in range(num_thread): # thread_actors = [] actor = rela.R2D2Actor( runner, 1, # multi_step, num_game_per_thread, 0.99, # gamma, 0.9, # eta max_len, # max_len, 2, # num_player replay_buffer, ) actors.append(actor) eps = [0] # for _ in range(num_game_per_thread)] num_game = num_thread * num_game_per_thread games = create.create_envs(num_game, 1, 2, 5, 0, [0], max_len, sad, False, False) context, threads = create.create_threads(num_thread, num_game_per_thread, actors, games) runner.start() context.start() while replay_buffer.size() < dataset_size: print("collecting data from replay buffer:", replay_buffer.size()) time.sleep(0.2) context.pause() # remove extra data for _ in range(2): data, unif = replay_buffer.sample(10, "cpu") replay_buffer.update_priority(unif.detach().cpu()) time.sleep(0.2) print("dataset size:", replay_buffer.size()) print("done about to return") return replay_buffer, agent, context
if args.load_model: print("*****loading pretrained model*****") utils.load_weight(agent.online_net, args.load_model, args.train_device) print("*****done*****") agent = agent.to(args.train_device) optim = torch.optim.Adam(agent.online_net.parameters(), lr=args.lr, eps=args.eps) print(agent) eval_agent = agent.clone(args.train_device, {"vdn": False}) replay_buffer = rela.RNNPrioritizedReplay( args.replay_buffer_size, args.seed, args.priority_exponent, args.priority_weight, args.prefetch, ) act_group = ActGroup( args.method, args.act_device, agent, args.num_thread, args.num_game_per_thread, args.multi_step, args.gamma, args.eta, args.max_len, args.num_player,