def train_shared_model(agent_params,env_params,training_params,id,actor,critic): actor.train() critic.train() pid = os.getpid() print(f"Intiantiating process PID {pid}") env = Poker(env_params) nS = env.state_space nO = env.observation_space nA = env.action_space nB = env.betsize_space nC = nA - 2 + nB print_every = (training_params['epochs']+1) // 5 seed = 154 agent = ParallelAgent(nS,nO,nA,nB,seed,agent_params,actor,critic) training_data = copy.deepcopy(training_params['training_data']) for e in range(1,training_params['epochs']+1): last_state,state,obs,done,mask,betsize_mask = env.reset() while not done: if env.game == pdt.GameTypes.HISTORICALKUHN: actor_outputs = agent(state,mask,betsize_mask) if env.rules.betsize == True else agent(state,mask) else: actor_outputs = agent(last_state,mask,betsize_mask) if env.rules.betsize == True else agent(last_state,mask) last_state,state,obs,done,mask,betsize_mask = env.step(actor_outputs) ml_inputs = env.ml_inputs() agent.learn(ml_inputs) ml_inputs = detach_ml(ml_inputs) for position in ml_inputs.keys(): training_data[position].append(ml_inputs[position]) if id == 0 and e % print_every == 0: print(f'PID {pid}, Epoch {e}') mongo = MongoDB() mongo.clean_db() mongo.store_data(training_data,env.db_mapping,training_params['training_round'],env.game,id,training_params['epochs'])
# return_dict = manager.dict() # online training seed = 123 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") actor = env_networks['actor'](seed, nS, nA, nB, agent_params) critic = env_networks['critic'][args.critic](seed, nS, nA, nB, agent_params) del env actor.share_memory() #.to(device) critic.share_memory() #.to(device) processes = [] num_processes = mp.cpu_count() if args.clean: print('Cleaning db') mongo = MongoDB() mongo.clean_db() del mongo for i in range(num_processes): # No. of processes p = mp.Process(target=train_shared_model, args=(agent_params, env_params, training_params, i, actor, critic)) p.start() processes.append(p) for p in processes: p.join() basepath = os.path.abspath(sys.argv[0]) torch.save( actor.state_dict(), os.path.join(basepath, 'checkpoints/Historical_kuhn' + '_actor')) torch.save( critic.state_dict(),