Esempio n. 1
0
def train_epoch(epoch, model: BaseModel, dataset, device: str,
                tracker: BaseTracker, checkpoint: ModelCheckpoint, log):
    model.train()
    tracker.reset("train")
    train_loader = dataset.train_dataloader()

    iter_data_time = time.time()
    with Ctq(train_loader) as tq_train_loader:
        for i, data in enumerate(tq_train_loader):
            data = data.to(device)  # This takes time

            model.set_input(data)
            t_data = time.time() - iter_data_time

            iter_start_time = time.time()
            model.optimize_parameters(dataset.batch_size)

            if i % 10 == 0:
                tracker.track(model)

            tq_train_loader.set_postfix(**tracker.get_metrics(),
                                        data_loading=float(t_data),
                                        iteration=float(time.time() -
                                                        iter_start_time),
                                        color=COLORS.TRAIN_COLOR)
            iter_data_time = time.time()

    metrics = tracker.publish()
    checkpoint.save_best_models_under_current_metrics(model, metrics)
    log.info("Learning rate = %f" % model.learning_rate)
def train_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device: str,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    early_break: bool,
):
    model.train()
    tracker.reset("train")
    visualizer.reset(epoch, "train")
    train_loader = dataset.train_dataloader

    iter_data_time = time.time()
    with Ctq(train_loader) as tq_train_loader:
        for i, data in enumerate(tq_train_loader):
            model.set_input(data, device)
            t_data = time.time() - iter_data_time

            iter_start_time = time.time()
            model.optimize_parameters(epoch, dataset.batch_size)
            if i % 10 == 0:
                tracker.track(model)

            tq_train_loader.set_postfix(**tracker.get_metrics(),
                                        data_loading=float(t_data),
                                        iteration=float(time.time() -
                                                        iter_start_time),
                                        color=COLORS.TRAIN_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            iter_data_time = time.time()

            if early_break:
                break

    metrics = tracker.publish(epoch)
    checkpoint.save_best_models_under_current_metrics(model, metrics)
    log.info("Learning rate = %f" % model.learning_rate)