Exemplo n.º 1
0
def validate(module, epoch, best_iou, num_classes, writer, logger):
    # Runs validation for the model on the appropriate split and returns best iou.
    # Unpack the module
    model = module.model
    device = module.device
    val_loader = module.val_loader
    loss_fn = module.loss_fn

    avg_loss = AverageMeter()
    running_score = RunningScore(num_classes)

    model.eval()
    with torch.no_grad():
        for idx, (images, labels) in tqdm(enumerate(val_loader)):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = loss_fn(input=outputs, target=labels)

            avg_loss.update(loss.data.item())

            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels.data.cpu().numpy()
            running_score.update(gt, pred)

    writer.add_scalar("Val Loss", avg_loss.average(), epoch)
    logger.info("Epoch: {} Loss: {:.4f}".format(epoch, avg_loss.average()))

    mean_iou, disp_score = running_score.get_scores()
    logger.info(disp_score)
    if mean_iou >= best_iou:
        # Saves the model if the current mean_iou is better.
        best_iou = mean_iou
        path = os.path.join(writer.file_writer.get_logdir(), "best_model.pkl")
        save_model(model=model,
                   optimizer=module.optimizer,
                   epoch=epoch,
                   best_iou=best_iou,
                   path=path)
    return best_iou
Exemplo n.º 2
0
def train_epoch(module, config, writer, logger):
    # Trains the model for a single epoch.
    batch_time = AverageMeter()
    train_loss = AverageMeter()

    # Unpacks the module
    model = module.model
    device = module.device
    train_loader = module.train_loader
    loss_fn = module.loss_fn
    optimizer = module.optimizer

    model.train()
    idx = 0
    for images, labels in train_loader:
        idx += 1
        start_tic = time.time()

        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = loss_fn(input=outputs, target=labels)
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - start_tic)
        train_loss.update(loss.data.item())

        if idx % config.training.disp_iter == 0:
            # This is the iteration to display the information.
            print_str = "Iter {:d} Loss: {:.4f} Time/Batch: {:.4f}".format(
                idx, train_loss.average(), batch_time.average())
            print(print_str)
            logger.info(print_str)