Exemplo n.º 1
0
def compute_time_train(model, loss_fun):
    """Computes precise model forward + backward time using dummy data."""
    # Use train mode
    model.train()
    # Generate a dummy mini-batch and copy data to GPU
    im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
    inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False)
    labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
    # Cache BatchNorm2D running stats
    bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
    bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
    # Compute precise forward backward pass time
    fw_timer, bw_timer = Timer(), Timer()
    total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
    for cur_iter in range(total_iter):
        # Reset the timers after the warmup phase
        if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
            fw_timer.reset()
            bw_timer.reset()
        # Forward
        fw_timer.tic()
        preds = model(inputs)
        loss = loss_fun(preds, labels)
        torch.cuda.synchronize()
        fw_timer.toc()
        # Backward
        bw_timer.tic()
        loss.backward()
        torch.cuda.synchronize()
        bw_timer.toc()
    # Restore BatchNorm2D running stats
    for bn, (mean, var) in zip(bns, bn_stats):
        bn.running_mean, bn.running_var = mean, var
    return fw_timer.average_time, bw_timer.average_time
Exemplo n.º 2
0
    def __init__(self, max_iter):
        self.max_iter = max_iter
        self.iter_timer = Timer()

        self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)

        self.max_miou = 0.0

        self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_samples = 0
Exemplo n.º 3
0
 def __init__(self, max_iter):
     self.max_iter = max_iter
     self.iter_timer = Timer()
     # Current minibatch errors (smoothed over a window)
     self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
     self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
     # Min errors (over the full test set)
     self.min_top1_err = 100.0
     self.min_top5_err = 100.0
     # Number of misclassified examples
     self.num_top1_mis = 0
     self.num_top5_mis = 0
     self.num_samples = 0
Exemplo n.º 4
0
    def __init__(self, epoch_iters):
        self.epoch_iters = epoch_iters
        self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
        self.iter_timer = Timer()
        self.loss = ScalarMeter(cfg.LOG_PERIOD)
        self.loss_total = 0.0
        self.lr = None

        self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)

        self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_samples = 0
Exemplo n.º 5
0
 def __init__(self, epoch_iters):
     self.epoch_iters = epoch_iters
     self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
     self.iter_timer = Timer()
     self.loss = ScalarMeter(cfg.LOG_PERIOD)
     self.loss_total = 0.0
     self.lr = None
     # Current minibatch errors (smoothed over a window)
     self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
     self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
     # Number of misclassified examples
     self.num_top1_mis = 0
     self.num_top5_mis = 0
     self.num_samples = 0
Exemplo n.º 6
0
def compute_time_eval(model):
    """Computes precise model forward test time using dummy data."""
    # Use eval mode
    model.eval()
    # Generate a dummy mini-batch and copy data to GPU
    im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE /
                                                 cfg.NUM_GPUS)
    if cfg.TASK == "jig":
        inputs = torch.rand(batch_size, cfg.JIGSAW_GRID**2,
                            cfg.MODEL.INPUT_CHANNELS, im_size,
                            im_size).cuda(non_blocking=False)
    else:
        inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size,
                             im_size).cuda(non_blocking=False)
    # Compute precise forward pass time
    timer = Timer()
    total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
    for cur_iter in range(total_iter):
        # Reset the timers after the warmup phase
        if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
            timer.reset()
        # Forward
        timer.tic()
        model(inputs)
        torch.cuda.synchronize()
        timer.toc()
    return timer.average_time
Exemplo n.º 7
0
def compute_time_loader(data_loader):
    """Computes loader time."""
    timer = Timer()
    loader.shuffle(data_loader, 0)
    data_loader_iterator = iter(data_loader)
    total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
    total_iter = min(total_iter, len(data_loader))
    for cur_iter in range(total_iter):
        if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
            timer.reset()
        timer.tic()
        next(data_loader_iterator)
        timer.toc()
    return timer.average_time
Exemplo n.º 8
0
def sweep_setup():
    """Samples cfgs for the sweep."""
    setup_cfg = sweep_cfg.SETUP
    # Create output directories
    sweep_dir = os.path.join(sweep_cfg.ROOT_DIR, sweep_cfg.NAME)
    cfgs_dir = os.path.join(sweep_dir, "cfgs")
    logs_dir = os.path.join(sweep_dir, "logs")
    print("Sweep directory is: {}".format(sweep_dir))
    assert not os.path.exists(logs_dir), "Sweep already started: " + sweep_dir
    if os.path.exists(logs_dir) or os.path.exists(cfgs_dir):
        print("Overwriting sweep which has not yet launched")
    os.makedirs(sweep_dir, exist_ok=True)
    os.makedirs(cfgs_dir, exist_ok=True)
    # Dump the original sweep_cfg
    sweep_cfg_file = os.path.join(sweep_dir, "sweep_cfg.yaml")
    os.system("cp {} {}".format(sweep_cfg.SWEEP_CFG_FILE, sweep_cfg_file))
    # Create worker pool for sampling and saving configs
    n_proc, chunk = sweep_cfg.NUM_PROC, setup_cfg.CHUNK_SIZE
    process_pool = multiprocessing.Pool(n_proc)
    # Fix random number generator seed and generate per chunk seeds
    np.random.seed(setup_cfg.RNG_SEED)
    n_chunks = int(np.ceil(setup_cfg.NUM_SAMPLES / chunk))
    chunk_seeds = np.random.choice(1000000, size=n_chunks, replace=False)
    # Sample configs in chunks using multiple workers each with a unique seed
    info_str = "Number configs sampled: {}, configs kept: {} [t={:.2f}s]"
    n_samples, n_cfgs, i, cfgs, timer = 0, 0, 0, {}, Timer()
    while n_samples < setup_cfg.NUM_SAMPLES and n_cfgs < setup_cfg.NUM_CONFIGS:
        timer.tic()
        seeds = chunk_seeds[i * n_proc:i * n_proc + n_proc]
        cfgs_all = process_pool.map(sample_cfgs, seeds)
        cfgs = dict(cfgs, **{k: v for d in cfgs_all for k, v in d.items()})
        n_samples, n_cfgs, i = n_samples + chunk * n_proc, len(cfgs), i + 1
        timer.toc()
        print(info_str.format(n_samples, n_cfgs, timer.total_time))
    # Randomize cfgs order and subsample if oversampled
    keys, cfgs = list(cfgs.keys()), list(cfgs.values())
    n_cfgs = min(n_cfgs, setup_cfg.NUM_CONFIGS)
    ids = np.random.choice(len(cfgs), n_cfgs, replace=False)
    keys, cfgs = [keys[i] for i in ids], [cfgs[i] for i in ids]
    # Save the cfgs and a cfgs_summary
    timer.tic()
    cfg_names = ["{:06}.yaml".format(i) for i in range(n_cfgs)]
    cfgs_summary = {cfg_name: key for cfg_name, key in zip(cfg_names, keys)}
    with open(os.path.join(sweep_dir, "cfgs_summary.yaml"), "w") as f:
        yaml.dump(cfgs_summary, f, width=float("inf"))
    cfg_files = [os.path.join(cfgs_dir, cfg_name) for cfg_name in cfg_names]
    process_pool.starmap(dump_cfg, zip(cfg_files, cfgs))
    timer.toc()
    print(info_str.format(n_samples, n_cfgs, timer.total_time))
Exemplo n.º 9
0
def compute_time_eval(model, im_size, batch_size):
    """Computes precise model forward test time using dummy data."""
    # Use eval mode
    model.eval()
    # Generate a dummy mini-batch and copy data to GPU
    inputs = torch.zeros(batch_size, 3, im_size,
                         im_size).cuda(non_blocking=False)
    # Compute precise forward pass time
    timer = Timer()
    total_iter = cfg.PREC_TIME.NUM_ITER + 100 + cfg.PREC_TIME.WARMUP_ITER + 1000

    # Run.
    for cur_iter in range(total_iter):
        # Reset the timers after the warmup phase
        if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
            timer.reset()
        # Forward
        timer.tic()
        model(inputs)
        torch.cuda.synchronize()
        timer.toc()
    return timer.average_time
Exemplo n.º 10
0
def compute_fw_test_time(model, inputs):
    """Computes forward test time (no grad, eval mode)."""
    # Use eval mode
    model.eval()
    # Warm up the caches
    for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER):
        model(inputs)
    # Make sure warmup kernels completed
    torch.cuda.synchronize()
    # Compute precise forward pass time
    timer = Timer()
    for _cur_iter in range(cfg.PREC_TIME.NUM_ITER):
        timer.tic()
        model(inputs)
        torch.cuda.synchronize()
        timer.toc()
    # Make sure forward kernels completed
    torch.cuda.synchronize()
    return timer.average_time
Exemplo n.º 11
0
class TrainMeter(object):
    """Measures training stats."""

    def __init__(self, epoch_iters):
        self.epoch_iters = epoch_iters
        self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
        self.iter_timer = Timer()
        self.loss = ScalarMeter(cfg.LOG_PERIOD)
        self.loss_total = 0.0
        self.lr = None
        # Current minibatch errors (smoothed over a window)
        self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
        self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
        # Number of misclassified examples
        self.num_top1_mis = 0
        self.num_top5_mis = 0
        self.num_samples = 0

    def reset(self, timer=False):
        if timer:
            self.iter_timer.reset()
        self.loss.reset()
        self.loss_total = 0.0
        self.lr = None
        self.mb_top1_err.reset()
        self.mb_top5_err.reset()
        self.num_top1_mis = 0
        self.num_top5_mis = 0
        self.num_samples = 0

    def iter_tic(self):
        self.iter_timer.tic()

    def iter_toc(self):
        self.iter_timer.toc()

    def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
        # Current minibatch stats
        self.mb_top1_err.add_value(top1_err)
        self.mb_top5_err.add_value(top5_err)
        self.loss.add_value(loss)
        self.lr = lr
        # Aggregate stats
        self.num_top1_mis += top1_err * mb_size
        self.num_top5_mis += top5_err * mb_size
        self.loss_total += loss * mb_size
        self.num_samples += mb_size

    def get_iter_stats(self, cur_epoch, cur_iter):
        cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
        eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
        mem_usage = gpu_mem_usage()
        stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
            "time_avg": self.iter_timer.average_time,
            "time_diff": self.iter_timer.diff,
            "eta": time_string(eta_sec),
            "top1_err": self.mb_top1_err.get_win_median(),
            "top5_err": self.mb_top5_err.get_win_median(),
            "loss": self.loss.get_win_median(),
            "lr": self.lr,
            "mem": int(np.ceil(mem_usage)),
        }
        return stats

    def log_iter_stats(self, cur_epoch, cur_iter):
        if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
            return
        stats = self.get_iter_stats(cur_epoch, cur_iter)
        logger.info(logging.dump_log_data(stats, "train_iter"))

    def get_epoch_stats(self, cur_epoch):
        cur_iter_total = (cur_epoch + 1) * self.epoch_iters
        eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
        mem_usage = gpu_mem_usage()
        top1_err = self.num_top1_mis / self.num_samples
        top5_err = self.num_top5_mis / self.num_samples
        avg_loss = self.loss_total / self.num_samples
        stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "time_avg": self.iter_timer.average_time,
            "eta": time_string(eta_sec),
            "top1_err": top1_err,
            "top5_err": top5_err,
            "loss": avg_loss,
            "lr": self.lr,
            "mem": int(np.ceil(mem_usage)),
        }
        return stats

    def log_epoch_stats(self, cur_epoch):
        stats = self.get_epoch_stats(cur_epoch)
        logger.info(logging.dump_log_data(stats, "train_epoch"))
Exemplo n.º 12
0
class TestMeter(object):
    """Measures testing stats."""

    def __init__(self, max_iter):
        self.max_iter = max_iter
        self.iter_timer = Timer()
        # Current minibatch errors (smoothed over a window)
        self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
        self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
        # Min errors (over the full test set)
        self.min_top1_err = 100.0
        self.min_top5_err = 100.0
        # Number of misclassified examples
        self.num_top1_mis = 0
        self.num_top5_mis = 0
        self.num_samples = 0

    def reset(self, min_errs=False):
        if min_errs:
            self.min_top1_err = 100.0
            self.min_top5_err = 100.0
        self.iter_timer.reset()
        self.mb_top1_err.reset()
        self.mb_top5_err.reset()
        self.num_top1_mis = 0
        self.num_top5_mis = 0
        self.num_samples = 0

    def iter_tic(self):
        self.iter_timer.tic()

    def iter_toc(self):
        self.iter_timer.toc()

    def update_stats(self, top1_err, top5_err, mb_size):
        self.mb_top1_err.add_value(top1_err)
        self.mb_top5_err.add_value(top5_err)
        self.num_top1_mis += top1_err * mb_size
        self.num_top5_mis += top5_err * mb_size
        self.num_samples += mb_size

    def get_iter_stats(self, cur_epoch, cur_iter):
        mem_usage = gpu_mem_usage()
        iter_stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "iter": "{}/{}".format(cur_iter + 1, self.max_iter),
            "time_avg": self.iter_timer.average_time,
            "time_diff": self.iter_timer.diff,
            "top1_err": self.mb_top1_err.get_win_median(),
            "top5_err": self.mb_top5_err.get_win_median(),
            "mem": int(np.ceil(mem_usage)),
        }
        return iter_stats

    def log_iter_stats(self, cur_epoch, cur_iter):
        if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
            return
        stats = self.get_iter_stats(cur_epoch, cur_iter)
        logger.info(logging.dump_log_data(stats, "test_iter"))

    def get_epoch_stats(self, cur_epoch):
        top1_err = self.num_top1_mis / self.num_samples
        top5_err = self.num_top5_mis / self.num_samples
        self.min_top1_err = min(self.min_top1_err, top1_err)
        self.min_top5_err = min(self.min_top5_err, top5_err)
        mem_usage = gpu_mem_usage()
        stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "time_avg": self.iter_timer.average_time,
            "top1_err": top1_err,
            "top5_err": top5_err,
            "min_top1_err": self.min_top1_err,
            "min_top5_err": self.min_top5_err,
            "mem": int(np.ceil(mem_usage)),
        }
        return stats

    def log_epoch_stats(self, cur_epoch):
        stats = self.get_epoch_stats(cur_epoch)
        logger.info(logging.dump_log_data(stats, "test_epoch"))
Exemplo n.º 13
0
def compute_fw_bw_time(model, loss_fun, inputs, labels):
    """Computes forward backward time."""
    # Use train mode
    model.train()
    # Warm up the caches
    for _cur_iter in range(cfg.PREC_TIME.WARMUP_ITER):
        preds = model(inputs)
        loss = loss_fun(preds, labels)
        loss.backward()
    # Make sure warmup kernels completed
    torch.cuda.synchronize()
    # Compute precise forward backward pass time
    fw_timer = Timer()
    bw_timer = Timer()
    for _cur_iter in range(cfg.PREC_TIME.NUM_ITER):
        # Forward
        fw_timer.tic()
        preds = model(inputs)
        loss = loss_fun(preds, labels)
        torch.cuda.synchronize()
        fw_timer.toc()
        # Backward
        bw_timer.tic()
        loss.backward()
        torch.cuda.synchronize()
        bw_timer.toc()
    # Make sure forward backward kernels completed
    torch.cuda.synchronize()
    return fw_timer.average_time, bw_timer.average_time
Exemplo n.º 14
0
class TestMeterIoU(object):
    """Measures testing stats."""
    def __init__(self, max_iter):
        self.max_iter = max_iter
        self.iter_timer = Timer()

        self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)

        self.max_miou = 0.0

        self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_samples = 0

    def reset(self, min_errs=False):
        if min_errs:
            self.max_miou = 0.0
        self.iter_timer.reset()
        self.mb_miou.reset()
        self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_samples = 0

    def iter_tic(self):
        self.iter_timer.tic()

    def iter_toc(self):
        self.iter_timer.toc()

    def update_stats(self, inter, union, mb_size):
        self.mb_miou.add_value((inter / (union + 1e-10)).mean())
        self.num_inter += inter * mb_size
        self.num_union += union * mb_size
        self.num_samples += mb_size

    def get_iter_stats(self, cur_epoch, cur_iter):
        mem_usage = gpu_mem_usage()
        iter_stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "iter": "{}/{}".format(cur_iter + 1, self.max_iter),
            "time_avg": self.iter_timer.average_time,
            "time_diff": self.iter_timer.diff,
            "miou": self.mb_miou.get_win_median(),
            "mem": int(np.ceil(mem_usage)),
        }
        return iter_stats

    def log_iter_stats(self, cur_epoch, cur_iter):
        if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
            return
        stats = self.get_iter_stats(cur_epoch, cur_iter)
        logger.info(logging.dump_log_data(stats, "test_iter"))

    def get_epoch_stats(self, cur_epoch):
        miou = (self.num_inter / (self.num_union + 1e-10)).mean()
        self.max_miou = max(self.max_miou, miou)
        mem_usage = gpu_mem_usage()
        stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "time_avg": self.iter_timer.average_time,
            "miou": miou,
            "max_miou": self.max_miou,
            "mem": int(np.ceil(mem_usage)),
        }
        return stats

    def log_epoch_stats(self, cur_epoch):
        stats = self.get_epoch_stats(cur_epoch)
        logger.info(logging.dump_log_data(stats, "test_epoch"))
Exemplo n.º 15
0
class TrainMeterIoU(object):
    """Measures training stats."""
    def __init__(self, epoch_iters):
        self.epoch_iters = epoch_iters
        self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
        self.iter_timer = Timer()
        self.loss = ScalarMeter(cfg.LOG_PERIOD)
        self.loss_total = 0.0
        self.lr = None

        self.mb_miou = ScalarMeter(cfg.LOG_PERIOD)

        self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_samples = 0

    def reset(self, timer=False):
        if timer:
            self.iter_timer.reset()
        self.loss.reset()
        self.loss_total = 0.0
        self.lr = None
        self.mb_miou.reset()
        self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES)
        self.num_samples = 0

    def iter_tic(self):
        self.iter_timer.tic()

    def iter_toc(self):
        self.iter_timer.toc()

    def update_stats(self, inter, union, loss, lr, mb_size):
        # Current minibatch stats
        self.mb_miou.add_value((inter / (union + 1e-10)).mean())
        self.loss.add_value(loss)
        self.lr = lr
        # Aggregate stats
        self.num_inter += inter * mb_size
        self.num_union += union * mb_size
        self.loss_total += loss * mb_size
        self.num_samples += mb_size

    def get_iter_stats(self, cur_epoch, cur_iter):
        cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
        eta_sec = self.iter_timer.average_time * (self.max_iter -
                                                  cur_iter_total)
        mem_usage = gpu_mem_usage()
        stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
            "time_avg": self.iter_timer.average_time,
            "time_diff": self.iter_timer.diff,
            "eta": time_string(eta_sec),
            "miou": self.mb_miou.get_win_median(),
            "loss": self.loss.get_win_median(),
            "lr": self.lr,
            "mem": int(np.ceil(mem_usage)),
        }
        return stats

    def log_iter_stats(self, cur_epoch, cur_iter):
        if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
            return
        stats = self.get_iter_stats(cur_epoch, cur_iter)
        logger.info(logging.dump_log_data(stats, "train_iter"))

    def get_epoch_stats(self, cur_epoch):
        cur_iter_total = (cur_epoch + 1) * self.epoch_iters
        eta_sec = self.iter_timer.average_time * (self.max_iter -
                                                  cur_iter_total)
        mem_usage = gpu_mem_usage()
        miou = (self.num_inter / (self.num_union + 1e-10)).mean()
        avg_loss = self.loss_total / self.num_samples
        stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "time_avg": self.iter_timer.average_time,
            "eta": time_string(eta_sec),
            "miou": miou,
            "loss": avg_loss,
            "lr": self.lr,
            "mem": int(np.ceil(mem_usage)),
        }
        return stats

    def log_epoch_stats(self, cur_epoch):
        stats = self.get_epoch_stats(cur_epoch)
        logger.info(logging.dump_log_data(stats, "train_epoch"))