def run_distributed(): size = config.NB_PROCESSES processes = [] for rank in range(size): p = Process(target=init_process, args=(rank, size, train_model)) p.start() processes.append(p) while all(p.is_alive() for p in processes): time.sleep(5) for p in processes: p.kill() p.join() logging.info("Main process exit")
def train(): np.random.seed(random_seed) torch.manual_seed(random_seed) writer = SummaryWriter() ac = AC(latent_num, cnn_chanel_num, stat_dim) writer.add_graph(ac, (torch.zeros([1, 1, img_shape[0], img_shape[1] ]), torch.zeros([1, stat_dim]))) optim = GlobalAdam([{ 'params': ac.encode_img.parameters(), 'lr': 2.5e-5 }, { 'params': ac.encode_stat.parameters(), 'lr': 2.5e-5 }, { 'params': ac.pi.parameters(), 'lr': 2.5e-5 }, { 'params': ac.actor.parameters(), 'lr': 2.5e-5 }, { 'params': ac.f.parameters() }, { 'params': ac.V.parameters() }], lr=5e-3, weight_decay=weight_decay) if os.path.exists('S3_state_dict.pt'): ac.load_state_dict(torch.load('S3_state_dict.pt')) optim.load_state_dict(torch.load('S3_Optim_state_dict.pt')) else: ac.load_state_dict(torch.load('../stage2/S2_state_dict.pt'), strict=False) result_queue = Queue() validate_queue = Queue() gradient_queue = Queue() loss_queue = Queue() ep_cnt = Value('i', 0) optimizer_lock = Lock() processes = [] ac.share_memory() optimizer_worker = Process(target=update_shared_model, args=(gradient_queue, optimizer_lock, optim, ac)) optimizer_worker.start() for no in range(mp.cpu_count() - 3): worker = Worker(no, ac, ep_cnt, optimizer_lock, result_queue, gradient_queue, loss_queue) worker.start() processes.append(worker) validater = Validate(ac, ep_cnt, optimizer_lock, validate_queue) validater.start() best_reward = 0 while True: with ep_cnt.get_lock(): if not result_queue.empty(): ep_cnt.value += 1 reward, money, win_rate = result_queue.get() objective_actor, loss_critic, loss_f = loss_queue.get() writer.add_scalar('Interaction/Reward', reward, ep_cnt.value) writer.add_scalar('Interaction/Money', money, ep_cnt.value) writer.add_scalar('Interaction/win_rate', win_rate, ep_cnt.value) writer.add_scalar('Update/objective_actor', objective_actor, ep_cnt.value) writer.add_scalar('Update/loss_critic', loss_critic, ep_cnt.value) writer.add_scalar('Update/loss_f', loss_f, ep_cnt.value) with optimizer_lock: if reward > best_reward: best_reward = reward torch.save(ac.state_dict(), 'S3_BEST_state_dict.pt') if ep_cnt.value % save_every == 0: torch.save(ac.state_dict(), 'S3_state_dict.pt') torch.save(optim.state_dict(), 'S3_Optim_state_dict.pt') if not validate_queue.empty(): val_reward, val_money, val_win_rate = validate_queue.get() writer.add_scalar('Validation/reward', val_reward, ep_cnt.value) writer.add_scalar('Validation/money', val_money, ep_cnt.value) writer.add_scalar('Validation/win_rate', val_win_rate, ep_cnt.value) for worker in processes: worker.join() optimizer_worker.kill()