Exemple #1
0
def _train_epoch(
    train_device: torch.device,
    model: torch.jit.ScriptModule,
    ddpmodel: ModelWrapperForDDP,
    model_path: Path,
    optim: torch.optim.Optimizer,
    assembler: tube.ChannelAssembler,
    stat: utils.MultiCounter,
    epoch: int,
    optim_params: OptimParams,
    sync_period: int,
) -> None:
    global _train_epoch_waiting_time
    pre_num_add = assembler.buffer_num_add()
    pre_num_sample = assembler.buffer_num_sample()
    sync_s = 0.
    num_sync = 0
    t = time.time()
    time.sleep(_train_epoch_waiting_time)
    lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model
    for eid in range(optim_params.epoch_len):
        batch = assembler.sample(optim_params.batchsize)
        batch = utils.to_device(batch, train_device)
        loss = model.loss(lossmodel, batch["s"], batch["v"], batch["pi"],
                          batch["pi_mask"], stat)
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(),
                                             optim_params.grad_clip)
        optim.step()
        optim.zero_grad()

        if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0:
            sync_t0 = time.time()
            assembler.update_model(model.state_dict())
            sync_s += time.time() - sync_t0
            num_sync += 1

        stat["loss"].feed(loss.detach().item())
        stat["grad_norm"].feed(grad_norm)

    post_num_add = assembler.buffer_num_add()
    post_num_sample = assembler.buffer_num_sample()
    time_elapsed = time.time() - t
    delta_add = post_num_add - pre_num_add
    print("buffer add rate: %.2f / s" % (delta_add / time_elapsed))
    delta_sample = post_num_sample - pre_num_sample
    if delta_sample > 8 * delta_add:  # If the sample rate is not at least 8x the add rate, everything is fine.
        _train_epoch_waiting_time += time_elapsed
    else:
        _train_epoch_waiting_time = 0
    print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed))
    print(
        f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)"
    )

    stat.summary(epoch)
    stat.reset()
Exemple #2
0
def save_checkpoint(
    command_history: CommandHistory,
    epoch: int,
    model: torch.jit.ScriptModule,
    optim: torch.optim.Optimizer,
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    executor: ThreadPoolExecutor = None,
) -> None:
    checkpoint_dir = execution_params.checkpoint_dir
    save_uncompressed = execution_params.save_uncompressed
    checkpoint_name = f"checkpoint_{epoch}"
    checkpoint = {
        "command_history": command_history,
        "epoch": epoch,
        "model_state_dict": {
            k: v.cpu().clone()
            if isinstance(v, torch.Tensor) else copy.deepcopy(v)
            for k, v in model.state_dict().items()
        },
        "optim_state_dict": {
            k: v.cpu().clone()
            if isinstance(v, torch.Tensor) else copy.deepcopy(v)
            for k, v in optim.state_dict().items()
        },
        "game_params": game_params,
        "model_params": model_params,
        "optim_params": optim_params,
        "simulation_params": simulation_params,
        "execution_params": execution_params,
    }

    def saveit():
        nonlocal save_uncompressed
        nonlocal checkpoint
        nonlocal checkpoint_dir
        if save_uncompressed:
            torch.save(checkpoint, checkpoint_dir / f"{checkpoint_name}.pt")
        else:
            # with zipfile.ZipFile(Path(checkpoint_dir) / f"{checkpoint_name}.zip", "w", allowZip64=True) as z:
            #    with z.open(f"{checkpoint_name}.pt", "w", force_zip64=True) as f:
            #        torch.save(checkpoint, f)
            with gzip.open(checkpoint_dir / f"{checkpoint_name}.pt.gz",
                           "wb") as f:
                torch.save(checkpoint, f)

    if executor is not None:
        return executor.submit(saveit)
    else:
        saveit()
Exemple #3
0
def save_checkpoint(
    command_history: CommandHistory,
    epoch: int,
    model: torch.jit.ScriptModule,
    optim: torch.optim.Optimizer,
    assembler: tube.ChannelAssembler,
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
) -> None:
    checkpoint_dir = execution_params.checkpoint_dir
    save_uncompressed = execution_params.save_uncompressed
    do_not_save_replay_buffer = execution_params.do_not_save_replay_buffer
    checkpoint_name = f"checkpoint_{epoch}"
    checkpoint = {
        "command_history": command_history,
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optim_state_dict": optim.state_dict(),
        "game_params": game_params,
        "model_params": model_params,
        "optim_params": optim_params,
        "simulation_params": simulation_params,
        "execution_params": execution_params,
    }
    if not do_not_save_replay_buffer:
        checkpoint.update({"replay_buffer": assembler.buffer})

    if save_uncompressed:
        torch.save(checkpoint, checkpoint_dir / f"{checkpoint_name}.pt")
    else:
        # with zipfile.ZipFile(Path(checkpoint_dir) / f"{checkpoint_name}.zip", "w", allowZip64=True) as z:
        #    with z.open(f"{checkpoint_name}.pt", "w", force_zip64=True) as f:
        #        torch.save(checkpoint, f)
        with gzip.open(checkpoint_dir / f"temp_{checkpoint_name}.pt.gz",
                       "wb") as f:
            torch.save(checkpoint, f)
        os.rename(checkpoint_dir / f"temp_{checkpoint_name}.pt.gz",
                  checkpoint_dir / f"{checkpoint_name}.pt.gz")
Exemple #4
0
def train_model(
    command_history: utils.CommandHistory,
    start_time: float,
    model: torch.jit.ScriptModule,
    device: torch.device,
    ddpmodel,
    optim: torch.optim.Optimizer,
    context: tube.Context,
    model_manager: polygames.ModelManager,
    get_train_reward: Callable[[], List[int]],
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    epoch: int = 0,
) -> None:

    info = zutils.get_game_info(game_params)
    c, h, w = info["feature_size"][:3]
    rc, rh, rw = info["raw_feature_size"][:3]
    c_prime, h_prime, w_prime = info["action_size"][:3]

    predicts = (2 if game_params.predict_end_state else
                0) + game_params.predict_n_states

    batchsizes = {
        "s": [c, h, w],
        "v": [3 if getattr(model, "logit_value", False) else 1],
        "pred_v": [1],
        "pi": [c_prime, h_prime, w_prime],
        "pi_mask": [c_prime, h_prime, w_prime]
    }

    if game_params.player == "forward":
        batchsizes["action_pi"] = [c_prime, h_prime, w_prime]

    if predicts > 0:
        batchsizes["predict_pi"] = [rc * predicts, rh, rw]
        batchsizes["predict_pi_mask"] = [rc * predicts, rh, rw]

    if getattr(model, "rnn_state_shape", None) is not None:
        batchsizes["rnn_state_mask"] = [1]

    if execution_params.rnn_seqlen > 0:
        for k, v in batchsizes.items():
            batchsizes[k] = [execution_params.rnn_seqlen, *v]

    if getattr(model, "rnn_state_shape", None) is not None:
        batchsizes["rnn_initial_state"] = model.rnn_state_shape

    rank = 0
    if ddpmodel:
        rank = torch.distributed.get_rank()

    executor = ThreadPoolExecutor(max_workers=1)
    savefuture = None

    stat = utils.MultiCounter(execution_params.checkpoint_dir)
    max_time = execution_params.max_time
    init_epoch = epoch
    while max_time is None or time.time() < start_time + max_time:
        if epoch - init_epoch >= optim_params.num_epoch:
            break
        epoch += 1
        if rank == 0 and epoch % execution_params.saving_period == 0:
            model_manager.add_tournament_model("e%d" % (epoch),
                                               model.state_dict())
            savestart = time.time()
            if savefuture is not None:
                savefuture.result()
            savefuture = utils.save_checkpoint(
                command_history=command_history,
                epoch=epoch,
                model=model,
                optim=optim,
                game_params=game_params,
                model_params=model_params,
                optim_params=optim_params,
                simulation_params=simulation_params,
                execution_params=execution_params,
                executor=executor)
            print("checkpoint saved in %gs" % (time.time() - savestart))
        _train_epoch(
            model=model,
            device=device,
            ddpmodel=ddpmodel,
            batchsizes=batchsizes,
            optim=optim,
            model_manager=model_manager,
            stat=stat,
            epoch=epoch,
            optim_params=optim_params,
            sync_period=simulation_params.sync_period,
        )
        # resource usage stats
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
        # train result
        print(
            ">>>train: epoch: %d, %s" %
            (epoch, utils.Result(get_train_reward()).log()),
            flush=True,
        )
    if savefuture is not None:
        savefuture.result()
    # checkpoint last state
    utils.save_checkpoint(
        command_history=command_history,
        epoch=epoch,
        model=model,
        optim=optim,
        game_params=game_params,
        model_params=model_params,
        optim_params=optim_params,
        simulation_params=simulation_params,
        execution_params=execution_params,
    )
Exemple #5
0
def _train_epoch(
    model: torch.jit.ScriptModule,
    device: torch.device,
    ddpmodel: ModelWrapperForDDP,
    batchsizes,
    optim: torch.optim.Optimizer,
    model_manager: polygames.ModelManager,
    stat: utils.MultiCounter,
    epoch: int,
    optim_params: OptimParams,
    sync_period: int,
) -> None:
    global _pre_num_add
    global _pre_num_sample
    global _running_add_rate
    global _running_sample_rate
    global _last_train_time
    global _remote_replay_buffer_inited
    if _pre_num_add is None:
        pre_num_add = model_manager.buffer_num_add()
        pre_num_sample = model_manager.buffer_num_sample()
    else:
        pre_num_add = _pre_num_add
        pre_num_sample = _pre_num_sample
    sync_s = 0.
    num_sync = 0

    train_start_time = time.time()

    if pre_num_sample > 0:
        print("sample/add ratio ", float(pre_num_sample) / pre_num_add)

    if _last_train_time == 0:
        _last_train_time = time.time()

    batchsize = optim_params.batchsize

    lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model

    lossmodel.train()

    world_size = 0
    rank = 0
    if ddpmodel is not None:
        print("DDP is active")
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        print("World size %d, rank %d. Waiting for all processes" %
              (world_size, rank))
        torch.distributed.barrier()
        print("Synchronizing model")
        for p in ddpmodel.parameters():
            torch.distributed.broadcast(p.data, 0)
        for p in ddpmodel.buffers():
            torch.distributed.broadcast(p.data, 0)
        print("Synchronized, start training")

    has_predict = False
    cpubatch = {}
    for k, v in batchsizes.items():
        sizes = v.copy()
        sizes.insert(0, batchsize)
        cpubatch[k] = torch.empty(sizes)
        if k == "predict_pi":
            has_predict = True

    for eid in range(optim_params.epoch_len):
        while _running_add_rate * 1.25 < _running_sample_rate:
            print("add rate insufficient, waiting")
            time.sleep(5)
            t = time.time()
            time_elapsed = t - _last_train_time
            _last_train_time = t
            alpha = pow(0.99, time_elapsed)
            post_num_add = model_manager.buffer_num_add()
            post_num_sample = model_manager.buffer_num_sample()
            delta_add = post_num_add - pre_num_add
            delta_sample = post_num_sample - pre_num_sample
            _running_add_rate = _running_add_rate * alpha + (
                delta_add / time_elapsed) * (1 - alpha)
            _running_sample_rate = _running_sample_rate * alpha + (
                delta_sample / time_elapsed) * (1 - alpha)
            pre_num_add = post_num_add
            pre_num_sample = post_num_sample
            print("running add rate: %.2f / s" % (_running_add_rate))
            print("running sample rate: %.2f / s" % (_running_sample_rate))
            print("current add rate: %.2f / s" % (delta_add / time_elapsed))
            print("current sample rate: %.2f / s" %
                  (delta_sample / time_elapsed))

        if world_size > 0:
            batchlist = None
            if rank == 0:
                batchlist = {}
                for k in cpubatch.keys():
                    batchlist[k] = []
                for i in range(world_size):
                    for k, v in model_manager.sample(batchsize).items():
                        batchlist[k].append(v)
            for k, v in cpubatch.items():
                torch.distributed.scatter(v,
                                          batchlist[k] if rank == 0 else None)
            batch = utils.to_device(cpubatch, device)
        else:
            batch = model_manager.sample(batchsize)
            batch = utils.to_device(batch, device)
        for k, v in batch.items():
            batch[k] = v.detach()
        loss, v_err, pi_err, predict_err = model_loss.mcts_loss(
            model, lossmodel, batch)
        loss.backward()

        grad_norm = nn.utils.clip_grad_norm_(lossmodel.parameters(),
                                             optim_params.grad_clip)
        optim.step()
        optim.zero_grad()

        stat["v_err"].feed(v_err.item())
        stat["pi_err"].feed(pi_err.item())
        if has_predict:
            stat["predict_err"].feed(predict_err.item())
        stat["loss"].feed(loss.item())
        stat["grad_norm"].feed(grad_norm)

        if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0:
            sync_t0 = time.time()
            model_manager.update_model(model.state_dict())
            sync_s += time.time() - sync_t0
            num_sync += 1

        t = time.time()
        time_elapsed = t - _last_train_time
        _last_train_time = t
        alpha = pow(0.99, time_elapsed)
        post_num_add = model_manager.buffer_num_add()
        post_num_sample = model_manager.buffer_num_sample()
        delta_add = post_num_add - pre_num_add
        delta_sample = post_num_sample - pre_num_sample
        _running_add_rate = _running_add_rate * alpha + (
            delta_add / time_elapsed) * (1 - alpha)
        _running_sample_rate = _running_sample_rate * alpha + (
            delta_sample / time_elapsed) * (1 - alpha)
        pre_num_add = post_num_add
        pre_num_sample = post_num_sample

    total_time_elapsed = time.time() - train_start_time

    print("running add rate: %.2f / s" % (_running_add_rate))
    print("running sample rate: %.2f / s" % (_running_sample_rate))
    print("current add rate: %.2f / s" % (delta_add / time_elapsed))
    print("current sample rate: %.2f / s" % (delta_sample / time_elapsed))
    print(
        f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / total_time_elapsed)}% of train time)"
    )

    _pre_num_add = pre_num_add
    _pre_num_sample = pre_num_sample

    stat.summary(epoch)
    stat.reset()
Exemple #6
0
def train_model(
    command_history: utils.CommandHistory,
    start_time: float,
    train_device: torch.device,
    model: torch.jit.ScriptModule,
    model_path: Path,
    ddpmodel,
    optim: torch.optim.Optimizer,
    context: tube.Context,
    assembler: tube.ChannelAssembler,
    get_train_reward: Callable[[], List[int]],
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    epoch: int = 0,
) -> None:
    stat = utils.MultiCounter(execution_params.checkpoint_dir)
    max_time = execution_params.max_time
    init_epoch = epoch
    while max_time is None or time.time() < start_time + max_time:
        if epoch - init_epoch >= optim_params.num_epoch:
            break
        epoch += 1
        if not (epoch - init_epoch) % execution_params.saving_period:
            assembler.add_tournament_model("e%d" % (epoch), model.state_dict())
            utils.save_checkpoint(
                command_history=command_history,
                epoch=epoch,
                model=model,
                optim=optim,
                assembler=assembler,
                game_params=game_params,
                model_params=model_params,
                optim_params=optim_params,
                simulation_params=simulation_params,
                execution_params=execution_params,
            )
        _train_epoch(
            train_device=train_device,
            model=model,
            ddpmodel=ddpmodel,
            model_path=model_path,
            optim=optim,
            assembler=assembler,
            stat=stat,
            epoch=epoch,
            optim_params=optim_params,
            sync_period=simulation_params.sync_period,
        )
        # resource usage stats
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
        # train result
        print(
            ">>>train: epoch: %d, %s" %
            (epoch, utils.Result(get_train_reward()).log()),
            flush=True,
        )
    # checkpoint last state
    utils.save_checkpoint(
        command_history=command_history,
        epoch=epoch,
        model=model,
        optim=optim,
        assembler=assembler,
        game_params=game_params,
        model_params=model_params,
        optim_params=optim_params,
        simulation_params=simulation_params,
        execution_params=execution_params,
    )