Example #1
0
def train_one_epoch(
    model,
    data_queue,
    opt,
    tot_steps,
    rank,
    epoch_id,
    world_size,
    enable_sublinear=False,
):
    sublinear_cfg = jit.SublinearMemoryConfig() if enable_sublinear else None

    @jit.trace(symbolic=True, sublinear_memory_config=sublinear_cfg)
    def propagate():
        loss_dict = model(model.inputs)
        opt.backward(loss_dict["total_loss"])
        losses = list(loss_dict.values())
        return losses

    meter = AverageMeter(record_len=model.cfg.num_losses)
    time_meter = AverageMeter(record_len=2)
    log_interval = model.cfg.log_interval
    for step in range(tot_steps):
        adjust_learning_rate(opt, epoch_id, step, model, world_size)

        data_tik = time.time()
        mini_batch = next(data_queue)
        data_tok = time.time()

        model.inputs["image"].set_value(mini_batch["data"])
        model.inputs["gt_boxes"].set_value(mini_batch["gt_boxes"])
        model.inputs["im_info"].set_value(mini_batch["im_info"])

        tik = time.time()
        opt.zero_grad()
        loss_list = propagate()
        opt.step()
        tok = time.time()

        time_meter.update([tok - tik, data_tok - data_tik])

        if rank == 0:
            info_str = "e%d, %d/%d, lr:%f, "
            loss_str = ", ".join(
                ["{}:%f".format(loss) for loss in model.cfg.losses_keys])
            time_str = ", train_time:%.3fs, data_time:%.3fs"
            log_info_str = info_str + loss_str + time_str
            meter.update([loss.numpy() for loss in loss_list])
            if step % log_interval == 0:
                average_loss = meter.average()
                logger.info(log_info_str, epoch_id, step, tot_steps,
                            opt.param_groups[0]["lr"], *average_loss,
                            *time_meter.average())
                meter.reset()
                time_meter.reset()
Example #2
0
def train_one_epoch(model, data_queue, opt, gm, epoch, args):
    def train_func(image, im_info, gt_boxes):
        with gm:
            loss_dict = model(image=image, im_info=im_info, gt_boxes=gt_boxes)
            gm.backward(loss_dict["total_loss"])
            loss_list = list(loss_dict.values())
        opt.step().clear_grad()
        return loss_list

    meter = AverageMeter(record_len=model.cfg.num_losses)
    time_meter = AverageMeter(record_len=2)
    log_interval = model.cfg.log_interval
    tot_step = model.cfg.nr_images_epoch // (args.batch_size * dist.get_world_size())
    for step in range(tot_step):
        adjust_learning_rate(opt, epoch, step, model.cfg, args)

        data_tik = time.time()
        mini_batch = next(data_queue)
        data_tok = time.time()

        tik = time.time()
        loss_list = train_func(
            image=mge.tensor(mini_batch["data"]),
            im_info=mge.tensor(mini_batch["im_info"]),
            gt_boxes=mge.tensor(mini_batch["gt_boxes"])
        )
        tok = time.time()

        time_meter.update([tok - tik, data_tok - data_tik])

        if dist.get_rank() == 0:
            info_str = "e%d, %d/%d, lr:%f, "
            loss_str = ", ".join(
                ["{}:%f".format(loss) for loss in model.cfg.losses_keys]
            )
            time_str = ", train_time:%.3fs, data_time:%.3fs"
            log_info_str = info_str + loss_str + time_str
            meter.update([loss.numpy() for loss in loss_list])
            if step % log_interval == 0:
                logger.info(
                    log_info_str,
                    epoch,
                    step,
                    tot_step,
                    opt.param_groups[0]["lr"],
                    *meter.average(),
                    *time_meter.average()
                )
                meter.reset()
                time_meter.reset()