예제 #1
0
def _train_epoch(
    train_device: torch.device,
    model: torch.jit.ScriptModule,
    ddpmodel: ModelWrapperForDDP,
    model_path: Path,
    optim: torch.optim.Optimizer,
    assembler: tube.ChannelAssembler,
    stat: utils.MultiCounter,
    epoch: int,
    optim_params: OptimParams,
    sync_period: int,
) -> None:
    global _train_epoch_waiting_time
    pre_num_add = assembler.buffer_num_add()
    pre_num_sample = assembler.buffer_num_sample()
    sync_s = 0.
    num_sync = 0
    t = time.time()
    time.sleep(_train_epoch_waiting_time)
    lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model
    for eid in range(optim_params.epoch_len):
        batch = assembler.sample(optim_params.batchsize)
        batch = utils.to_device(batch, train_device)
        loss = model.loss(lossmodel, batch["s"], batch["v"], batch["pi"],
                          batch["pi_mask"], stat)
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(),
                                             optim_params.grad_clip)
        optim.step()
        optim.zero_grad()

        if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0:
            sync_t0 = time.time()
            assembler.update_model(model.state_dict())
            sync_s += time.time() - sync_t0
            num_sync += 1

        stat["loss"].feed(loss.detach().item())
        stat["grad_norm"].feed(grad_norm)

    post_num_add = assembler.buffer_num_add()
    post_num_sample = assembler.buffer_num_sample()
    time_elapsed = time.time() - t
    delta_add = post_num_add - pre_num_add
    print("buffer add rate: %.2f / s" % (delta_add / time_elapsed))
    delta_sample = post_num_sample - pre_num_sample
    if delta_sample > 8 * delta_add:  # If the sample rate is not at least 8x the add rate, everything is fine.
        _train_epoch_waiting_time += time_elapsed
    else:
        _train_epoch_waiting_time = 0
    print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed))
    print(
        f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)"
    )

    stat.summary(epoch)
    stat.reset()
예제 #2
0
def create_optimizer(
    model: torch.jit.ScriptModule,
    optim_params: OptimParams,
    optim_state_dict: Optional[dict] = None,
) -> torch.optim.Optimizer:
    optim = torch.optim.Adam(
        model.parameters(), lr=optim_params.lr, eps=optim_params.eps
    )
    if optim_state_dict is not None:
        optim.load_state_dict(optim_state_dict)
    return optim
예제 #3
0
def create_optimizer(
    model: torch.jit.ScriptModule,
    optim_params: OptimParams,
    optim_state_dict: Optional[dict] = None,
) -> torch.optim.Optimizer:
    optim = torch.optim.Adam(model.parameters(),
                             lr=optim_params.lr,
                             eps=optim_params.eps)
    if optim_state_dict is not None and not optim_params.reset_optimizer_state:
        try:
            optim.load_state_dict(optim_state_dict)
        except ValueError:
            print("Optimizer state not compatible... skipping.")
    return optim