def client_loop( model_manager: polygames.ModelManager, start_time: float, context: tube.Context, execution_params: ExecutionParams, ) -> None: model_manager.start() max_time = execution_params.max_time while max_time is None or time.time() < start_time + max_time: time.sleep(60) print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str())
def client_loop( assembler: tube.ChannelAssembler, start_time: float, context: tube.Context, execution_params: ExecutionParams, ) -> None: assembler.start() max_time = execution_params.max_time while max_time is None or time.time() < start_time + max_time: time.sleep(60) print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str())
def train_model( command_history: utils.CommandHistory, start_time: float, model: torch.jit.ScriptModule, device: torch.device, ddpmodel, optim: torch.optim.Optimizer, context: tube.Context, model_manager: polygames.ModelManager, get_train_reward: Callable[[], List[int]], game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, epoch: int = 0, ) -> None: info = zutils.get_game_info(game_params) c, h, w = info["feature_size"][:3] rc, rh, rw = info["raw_feature_size"][:3] c_prime, h_prime, w_prime = info["action_size"][:3] predicts = (2 if game_params.predict_end_state else 0) + game_params.predict_n_states batchsizes = { "s": [c, h, w], "v": [3 if getattr(model, "logit_value", False) else 1], "pred_v": [1], "pi": [c_prime, h_prime, w_prime], "pi_mask": [c_prime, h_prime, w_prime] } if game_params.player == "forward": batchsizes["action_pi"] = [c_prime, h_prime, w_prime] if predicts > 0: batchsizes["predict_pi"] = [rc * predicts, rh, rw] batchsizes["predict_pi_mask"] = [rc * predicts, rh, rw] if getattr(model, "rnn_state_shape", None) is not None: batchsizes["rnn_state_mask"] = [1] if execution_params.rnn_seqlen > 0: for k, v in batchsizes.items(): batchsizes[k] = [execution_params.rnn_seqlen, *v] if getattr(model, "rnn_state_shape", None) is not None: batchsizes["rnn_initial_state"] = model.rnn_state_shape rank = 0 if ddpmodel: rank = torch.distributed.get_rank() executor = ThreadPoolExecutor(max_workers=1) savefuture = None stat = utils.MultiCounter(execution_params.checkpoint_dir) max_time = execution_params.max_time init_epoch = epoch while max_time is None or time.time() < start_time + max_time: if epoch - init_epoch >= optim_params.num_epoch: break epoch += 1 if rank == 0 and epoch % execution_params.saving_period == 0: model_manager.add_tournament_model("e%d" % (epoch), model.state_dict()) savestart = time.time() if savefuture is not None: savefuture.result() savefuture = utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, executor=executor) print("checkpoint saved in %gs" % (time.time() - savestart)) _train_epoch( model=model, device=device, ddpmodel=ddpmodel, batchsizes=batchsizes, optim=optim, model_manager=model_manager, stat=stat, epoch=epoch, optim_params=optim_params, sync_period=simulation_params.sync_period, ) # resource usage stats print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str()) # train result print( ">>>train: epoch: %d, %s" % (epoch, utils.Result(get_train_reward()).log()), flush=True, ) if savefuture is not None: savefuture.result() # checkpoint last state utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, )
def train_model( command_history: utils.CommandHistory, start_time: float, train_device: torch.device, model: torch.jit.ScriptModule, model_path: Path, ddpmodel, optim: torch.optim.Optimizer, context: tube.Context, assembler: tube.ChannelAssembler, get_train_reward: Callable[[], List[int]], game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, epoch: int = 0, ) -> None: stat = utils.MultiCounter(execution_params.checkpoint_dir) max_time = execution_params.max_time init_epoch = epoch while max_time is None or time.time() < start_time + max_time: if epoch - init_epoch >= optim_params.num_epoch: break epoch += 1 if not (epoch - init_epoch) % execution_params.saving_period: assembler.add_tournament_model("e%d" % (epoch), model.state_dict()) utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, assembler=assembler, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, ) _train_epoch( train_device=train_device, model=model, ddpmodel=ddpmodel, model_path=model_path, optim=optim, assembler=assembler, stat=stat, epoch=epoch, optim_params=optim_params, sync_period=simulation_params.sync_period, ) # resource usage stats print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str()) # train result print( ">>>train: epoch: %d, %s" % (epoch, utils.Result(get_train_reward()).log()), flush=True, ) # checkpoint last state utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, assembler=assembler, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, )