Пример #1
0
def eval_epoch(model, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'eval'))

    metrics = {
        'loss': utils.Mean(),
        'dice': utils.Mean(),
        'fps': utils.Mean(),
    }

    model.eval()
    t1 = time.time()
    with torch.no_grad():
        for images, masks, _ in tqdm(data_loader,
                                     desc='epoch {} evaluation'.format(epoch)):
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            logits = model(images)

            loss = compute_loss(input=logits, target=masks)
            metrics['loss'].update(loss.data.cpu().numpy())

            metric = compute_metric(input=logits, target=masks)
            for k in metric:
                metrics[k].update(metric[k].data.cpu().numpy())

            t2 = time.time()
            metrics['fps'].update(1 / ((t2 - t1) / images.size(0)))
            t1 = t2

        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        print('[FOLD {}][EPOCH {}][EVAL] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)

        images = images[:32]
        masks = mask_to_image(masks[:32], num_classes=NUM_CLASSES)
        preds = mask_to_image(logits[:32].argmax(1, keepdim=True),
                              num_classes=NUM_CLASSES)

        writer.add_image('images',
                         torchvision.utils.make_grid(images,
                                                     nrow=compute_nrow(images),
                                                     normalize=True),
                         global_step=epoch)
        writer.add_image('masks',
                         torchvision.utils.make_grid(masks,
                                                     nrow=compute_nrow(masks),
                                                     normalize=True),
                         global_step=epoch)
        writer.add_image('preds',
                         torchvision.utils.make_grid(preds,
                                                     nrow=compute_nrow(preds),
                                                     normalize=True),
                         global_step=epoch)

        return metrics
Пример #2
0
def eval_epoch(model, data_loader, epoch):
    writer = SummaryWriter(os.path.join(args.experiment_path, 'eval'))

    metrics = {
        'loss': utils.Mean(),
        'dice': utils.Mean(),
        'iou': utils.Mean(),
    }

    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(data_loader,
                                   desc='epoch {} evaluation'.format(epoch)):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            logits = model(images)

            loss = compute_loss(input=logits, target=labels)
            metrics['loss'].update(loss.data.cpu().numpy())

            metric = compute_metric(input=logits, target=labels)
            for k in metric:
                metrics[k].update(metric[k].data.cpu().numpy())

        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        masks_true = draw_masks(labels)
        masks_pred = draw_masks(logits.argmax(1, keepdim=True))

        print('[EPOCH {}][EVAL] {}'.format(
            epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_image('images',
                         torchvision.utils.make_grid(
                             images,
                             nrow=math.ceil(math.sqrt(images.size(0))),
                             normalize=True),
                         global_step=epoch)
        writer.add_image('masks_true',
                         torchvision.utils.make_grid(
                             masks_true,
                             nrow=math.ceil(math.sqrt(masks_true.size(0))),
                             normalize=False),
                         global_step=epoch)
        writer.add_image('masks_pred',
                         torchvision.utils.make_grid(
                             masks_pred,
                             nrow=math.ceil(math.sqrt(masks_pred.size(0))),
                             normalize=False),
                         global_step=epoch)

        return metrics
Пример #3
0
def train_epoch(model, optimizer, scheduler, data_loader, fold, epoch):
    writer = SummaryWriter(os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train'))

    metrics = {
        'loss': utils.Mean(),
    }

    model.train()
    for images, labels, ids in tqdm(data_loader, desc='epoch {} train'.format(epoch)):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        logits = model(images)

        loss = compute_loss(input=logits, target=labels, smoothing=config.label_smooth)
        metrics['loss'].update(loss.data.cpu().numpy())

        lr, beta = scheduler.get_lr()
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        scheduler.step()

        if args.debug:
            break

    with torch.no_grad():
        loss = metrics['loss'].compute_and_reset()

        print('[FOLD {}][EPOCH {}][TRAIN] loss: {:.4f}'.format(fold, epoch, loss))
        writer.add_scalar('loss', loss, global_step=epoch)
        writer.add_scalar('learning_rate', lr, global_step=epoch)
        writer.add_scalar('beta', beta, global_step=epoch)
        writer.add_image('image', torchvision.utils.make_grid(images[:32], normalize=True), global_step=epoch)
Пример #4
0
def eval_epoch(model, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'eval'))

    metrics = {
        'loss': utils.Mean(),
    }

    model.eval()
    with torch.no_grad():
        fold_labels = []
        fold_logits = []
        fold_exps = []

        for batch in tqdm(data_loader,
                          desc='epoch {} evaluation'.format(epoch)):
            batch = batch.to(DEVICE)
            exps = [EXPS[i] for i in batch.exps.data.cpu().numpy()]
            logits = model(batch)

            loss = compute_loss(input=logits, target=batch.y)
            metrics['loss'].update(loss.data.cpu().numpy())

            fold_labels.append(batch.y)
            fold_logits.append(logits)
            fold_exps.extend(exps)

        fold_labels = torch.cat(fold_labels, 0)
        fold_logits = torch.cat(fold_logits, 0)

        if epoch % 10 == 0:
            temp, metric, fig = find_temp_global(input=fold_logits,
                                                 target=fold_labels,
                                                 exps=fold_exps)
            writer.add_scalar('temp', temp, global_step=epoch)
            writer.add_scalar('metric_final', metric, global_step=epoch)
            writer.add_figure('temps', fig, global_step=epoch)

        temp = 1.  # use default temp
        fold_preds = assign_classes(probs=to_prob(fold_logits,
                                                  temp).data.cpu().numpy(),
                                    exps=fold_exps)
        fold_preds = torch.tensor(fold_preds).to(fold_logits.device)
        metric = compute_metric(input=fold_preds,
                                target=fold_labels,
                                exps=fold_exps)

        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        for k in metric:
            metrics[k] = metric[k].mean().data.cpu().numpy()
        print('[FOLD {}][EPOCH {}][EVAL] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)

        return metrics
Пример #5
0
def train_epoch(model, optimizer, scheduler, data_loader, epoch):
    writer = SummaryWriter(os.path.join(args.experiment_path, 'train'))

    metrics = {
        'loss': utils.Mean(),
    }

    model.train()
    for images, labels in tqdm(data_loader,
                               desc='epoch {} train'.format(epoch)):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        logits = model(images)

        loss = compute_loss(input=logits, target=labels)
        metrics['loss'].update(loss.data.cpu().numpy())

        lr, _ = scheduler.get_lr()
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        scheduler.step()

    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        masks_true = draw_masks(labels)
        masks_pred = draw_masks(logits.argmax(1, keepdim=True))

        print('[EPOCH {}][TRAIN] {}'.format(
            epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_scalar('learning_rate', lr, global_step=epoch)
        writer.add_image('images',
                         torchvision.utils.make_grid(
                             images,
                             nrow=math.ceil(math.sqrt(images.size(0))),
                             normalize=True),
                         global_step=epoch)
        writer.add_image('masks_true',
                         torchvision.utils.make_grid(
                             masks_true,
                             nrow=math.ceil(math.sqrt(masks_true.size(0))),
                             normalize=False),
                         global_step=epoch)
        writer.add_image('masks_pred',
                         torchvision.utils.make_grid(
                             masks_pred,
                             nrow=math.ceil(math.sqrt(masks_pred.size(0))),
                             normalize=False),
                         global_step=epoch)
Пример #6
0
def eval_epoch(model, data_loader, fold, epoch):
    writer = SummaryWriter(os.path.join(args.experiment_path, 'fold{}'.format(fold), 'eval'))

    metrics = {
        'loss': utils.Mean(),
    }

    model.eval()
    with torch.no_grad():
        fold_labels = []
        fold_logits = []
        fold_exps = []

        for images, feats, exps, labels, _ in tqdm(data_loader, desc='epoch {} evaluation'.format(epoch)):
            images, feats, labels = images.to(DEVICE), feats.to(DEVICE), labels.to(DEVICE)
            logits = model(images, feats)

            loss = compute_loss(
                input=logits, target=labels, weight=np.linspace(1 / len(logits), 1., config.epochs)[epoch - 1].item())
            metrics['loss'].update(loss.data.cpu().numpy())
            *_, logits = logits

            fold_labels.append(labels)
            fold_logits.append(logits)
            fold_exps.extend(exps)

        fold_labels = torch.cat(fold_labels, 0)
        fold_logits = torch.cat(fold_logits, 0)

        if epoch % 10 == 0:
            temp, metric, fig = find_temp_global(input=fold_logits, target=fold_labels, exps=fold_exps)
            writer.add_scalar('temp', temp, global_step=epoch)
            writer.add_scalar('metric_final', metric, global_step=epoch)
            writer.add_figure('temps', fig, global_step=epoch)
        temp = 1.  # use default temp
        fold_preds = assign_classes(probs=(fold_logits * temp).softmax(1).data.cpu().numpy(), exps=fold_exps)
        fold_preds = torch.tensor(fold_preds).to(fold_logits.device)
        metric = compute_metric(input=fold_preds, target=fold_labels)

        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        for k in metric:
            metrics[k] = metric[k].mean().data.cpu().numpy()
        images = images_to_rgb(images)[:16]
        print('[FOLD {}][EPOCH {}][EVAL] {}'.format(
            fold, epoch, ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_image('images', torchvision.utils.make_grid(
            images, nrow=math.ceil(math.sqrt(images.size(0))), normalize=True), global_step=epoch)

        return metrics
Пример #7
0
def train_epoch(model, optimizer, scheduler, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train'))

    metrics = {
        'loss': utils.Mean(),
    }

    update_transforms(
        round(224 + (config.crop_size - 224) *
              np.linspace(0, 1, config.epochs)[epoch - 1].item()))
    model.train()
    optimizer.zero_grad()
    for i, (images, feats, labels, ids) in enumerate(
            tqdm(data_loader, desc='epoch {} train'.format(epoch)), 1):
        images, feats, labels = images.to(DEVICE), feats.to(DEVICE), labels.to(
            DEVICE)
        logits = model(images, feats, labels)

        loss = compute_loss(input=logits,
                            target=labels,
                            weight=np.linspace(1., 0.8,
                                               config.epochs)[epoch - 1])
        logits, _ = logits
        metrics['loss'].update(loss.data.cpu().numpy())

        lr = scheduler.get_lr()
        (loss.mean() / config.opt.acc_steps).backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        images = images_to_rgb(images)[:16]
        print('[FOLD {}][EPOCH {}][TRAIN] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_scalar('learning_rate', lr, global_step=epoch)
        writer.add_image('images',
                         torchvision.utils.make_grid(
                             images,
                             nrow=math.ceil(math.sqrt(images.size(0))),
                             normalize=True),
                         global_step=epoch)
Пример #8
0
def eval_epoch(model, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'eval'))

    metrics = {
        'loss': utils.Mean(),
    }

    predictions = []
    targets = []
    model.eval()
    with torch.no_grad():
        for sigs, labels, ids in tqdm(
                data_loader, desc='epoch {} evaluation'.format(epoch)):
            sigs, labels = sigs.to(DEVICE), labels.to(DEVICE)
            logits, images, weights = model(sigs)

            targets.append(labels)
            predictions.append(logits)

            loss = compute_loss(input=logits, target=labels)
            metrics['loss'].update(loss.data.cpu().numpy())

            if args.debug:
                break

        loss = metrics['loss'].compute_and_reset()

        predictions = torch.cat(predictions, 0)
        targets = torch.cat(targets, 0)
        score = compute_score(input=predictions, target=targets)

        print('[FOLD {}][EPOCH {}][EVAL] loss: {:.4f}, score: {:.4f}'.format(
            fold, epoch, loss, score))
        writer.add_scalar('loss', loss, global_step=epoch)
        writer.add_scalar('score', score, global_step=epoch)
        writer.add_image('image',
                         torchvision.utils.make_grid(images[:32],
                                                     nrow=get_nrow(
                                                         images[:32]),
                                                     normalize=True),
                         global_step=epoch)
        writer.add_histogram('distribution', images[:32], global_step=epoch)
        writer.add_image('weights',
                         torchvision.utils.make_grid(weights[:32],
                                                     nrow=get_nrow(
                                                         weights[:32])),
                         global_step=epoch)

        return score
Пример #9
0
def eval_epoch(model, data_loader, fold, epoch):
    writer = SummaryWriter(os.path.join(args.experiment_path, 'fold{}'.format(fold), 'eval'))

    metrics = {
        'loss': utils.Mean(),
    }

    model.eval()
    with torch.no_grad():
        predictions = []
        targets = []

        for images, labels, ids in tqdm(data_loader, desc='epoch {} evaluation'.format(epoch)):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            logits = model(images)

            targets.append(labels)
            predictions.append(logits)

            loss = compute_loss(input=logits, target=labels, smoothing=config.label_smooth)
            metrics['loss'].update(loss.data.cpu().numpy())

            if args.debug:
                break

        loss = metrics['loss'].compute_and_reset()

        predictions = torch.cat(predictions, 0)
        targets = torch.cat(targets, 0)
        threshold, score, fig = find_threshold_global(input=predictions, target=targets)

        scores = compute_score(input=predictions, target=targets, threshold=threshold)
        indices = scores.argsort()[:32]
        failure = [data_loader.dataset[i.item()][0] for i in indices]
        failure = torch.stack(failure, 0).to(DEVICE)
        failure = draw_errors(
            failure,
            true=(output_to_logits(predictions[indices]).sigmoid() > threshold).float(),
            pred=targets[indices])

        print('[FOLD {}][EPOCH {}][EVAL] loss: {:.4f}, score: {:.4f}'.format(fold, epoch, loss, score))
        writer.add_scalar('loss', loss, global_step=epoch)
        writer.add_scalar('score', score, global_step=epoch)
        writer.add_image('image', torchvision.utils.make_grid(images[:32], normalize=True), global_step=epoch)
        writer.add_image('failure', torchvision.utils.make_grid(failure, normalize=True), global_step=epoch)
        writer.add_figure('thresholds', fig, global_step=epoch)

        return score
Пример #10
0
def train_epoch(model, optimizer, scheduler, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train'))

    metrics = {
        'loss': utils.Mean(),
    }

    update_transforms(np.linspace(0, 1, config.epochs)[epoch - 1].item())
    model.train()
    optimizer.zero_grad()
    for i, batch in enumerate(
            tqdm(data_loader, desc='epoch {} train'.format(epoch)), 1):
        batch = batch.to(DEVICE)
        logits = model(batch)

        loss = compute_loss(input=logits, target=batch.y)
        metrics['loss'].update(loss.data.cpu().numpy())

        lr = scheduler.get_lr()
        (loss.mean() / config.opt.acc_steps).backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        print('[FOLD {}][EPOCH {}][TRAIN] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_scalar('learning_rate', lr, global_step=epoch)
Пример #11
0
def train_epoch(model, optimizer, scheduler, data_loader, unsup_data_loader,
                fold, epoch):
    assert len(data_loader) <= len(unsup_data_loader), (len(data_loader),
                                                        len(unsup_data_loader))

    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train'))

    metrics = {
        'loss': utils.Mean(),
    }

    update_transforms(np.linspace(0, 1, config.epochs)[epoch - 1].item())
    data = zip(data_loader, unsup_data_loader)
    total = min(len(data_loader), len(unsup_data_loader))
    model.train()
    optimizer.zero_grad()
    for i, ((images_s, _, labels_s, _), (images_u, _, _)) \
            in enumerate(tqdm(data, desc='epoch {} train'.format(epoch), total=total), 1):
        images_s, labels_s, images_u = images_s.to(DEVICE), labels_s.to(
            DEVICE), images_u.to(DEVICE)
        labels_s = utils.one_hot(labels_s, NUM_CLASSES)

        with torch.no_grad():
            b, n, c, h, w = images_u.size()
            images_u = images_u.view(b * n, c, h, w)
            logits_u = model(images_u, None, True)
            logits_u = logits_u.view(b, n, NUM_CLASSES)
            labels_u = logits_u.softmax(2).mean(1, keepdim=True)
            labels_u = labels_u.repeat(1, n, 1).view(b * n, NUM_CLASSES)
            labels_u = dist_sharpen(labels_u, temp=SHARPEN_TEMP)

        assert images_s.size() == images_u.size()
        assert labels_s.size() == labels_u.size()

        images, labels = torch.cat([images_s, images_u],
                                   0), torch.cat([labels_s, labels_u], 0)
        images, labels = mixup(images, labels)
        assert images.size(0) == config.batch_size * 2
        logits = model(images, None, True)

        loss = compute_loss(input=logits, target=labels, unsup=True)
        metrics['loss'].update(loss.data.cpu().numpy())
        labels = labels.argmax(1)

        lr = scheduler.get_lr()
        (loss.mean() / config.opt.acc_steps).backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        images = images_to_rgb(images)[:16]
        print('[FOLD {}][EPOCH {}][TRAIN] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_scalar('learning_rate', lr, global_step=epoch)
        writer.add_image('images',
                         torchvision.utils.make_grid(
                             images,
                             nrow=math.ceil(math.sqrt(images.size(0))),
                             normalize=True),
                         global_step=epoch)
Пример #12
0
def train_epoch(model, optimizer, scheduler, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train'))

    metrics = {
        'loss': utils.Mean(),
    }

    if epoch >= config.finetune_epoch:
        for ds in data_loader.dataset.datasets:
            ds.transform = T.Compose([
                LoadSignal(config.model.sample_rate),
                RandomCrop(config.aug.crop.size * config.model.sample_rate),
                ToTensor(),
            ])

    model.train()
    for sigs, labels, ids in tqdm(data_loader,
                                  desc='epoch {} train'.format(epoch)):
        if config.mixup is not None and epoch < config.finetune_epoch:
            if np.random.random() > (epoch / config.finetune_epoch):
                sigs, labels, ids = mixup(sigs,
                                          labels,
                                          ids,
                                          alpha=config.mixup)

        sigs, labels = sigs.to(DEVICE), labels.to(DEVICE)
        logits, images, weights = model(sigs,
                                        spec_aug=config.aug.spec_aug
                                        and epoch < config.finetune_epoch)

        loss = compute_loss(input=logits, target=labels)
        metrics['loss'].update(loss.data.cpu().numpy())

        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        scheduler.step()

        if args.debug:
            break

    with torch.no_grad():
        loss = metrics['loss'].compute_and_reset()

        print('[FOLD {}][EPOCH {}][TRAIN] loss: {:.4f}'.format(
            fold, epoch, loss))
        writer.add_scalar('loss', loss, global_step=epoch)
        lr, beta = scheduler.get_lr()
        writer.add_scalar('learning_rate', lr, global_step=epoch)
        writer.add_scalar('beta', beta, global_step=epoch)
        writer.add_image('image',
                         torchvision.utils.make_grid(images[:32],
                                                     nrow=get_nrow(
                                                         images[:32]),
                                                     normalize=True),
                         global_step=epoch)
        writer.add_histogram('distribution', images[:32], global_step=epoch)
        writer.add_image('weights',
                         torchvision.utils.make_grid(weights[:32],
                                                     nrow=get_nrow(
                                                         weights[:32])),
                         global_step=epoch)
Пример #13
0
def train_epoch(model, optimizer, scheduler, data_loader, fold, epoch):
    writer = SummaryWriter(
        os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train'))

    metrics = {
        'loss': utils.Mean(),
        'fps': utils.Mean(),
    }

    update_transforms(np.linspace(0, 1, config.epochs)[epoch - 1].item())
    model.train()
    optimizer.zero_grad()
    t1 = time.time()
    for i, (images, masks, ids) in enumerate(
            tqdm(data_loader, desc='epoch {} train'.format(epoch)), 1):
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        logits = model(images)

        loss = compute_loss(input=logits, target=masks)
        metrics['loss'].update(loss.data.cpu().numpy())

        lr = scheduler.get_lr()
        (loss.mean() / config.opt.acc_steps).backward()

        # with amp.scale_loss((loss.mean() / config.opt.acc_steps), optimizer) as scaled_loss:
        #     scaled_loss.backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

        t2 = time.time()
        metrics['fps'].update(1 / ((t2 - t1) / images.size(0)))
        t1 = t2

    with torch.no_grad():
        metrics = {k: metrics[k].compute_and_reset() for k in metrics}
        print('[FOLD {}][EPOCH {}][TRAIN] {}'.format(
            fold, epoch,
            ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics)))
        for k in metrics:
            writer.add_scalar(k, metrics[k], global_step=epoch)
        writer.add_scalar('learning_rate', lr, global_step=epoch)

        images = images[:32]
        masks = mask_to_image(masks[:32], num_classes=NUM_CLASSES)
        preds = mask_to_image(logits[:32].argmax(1, keepdim=True),
                              num_classes=NUM_CLASSES)

        writer.add_image('images',
                         torchvision.utils.make_grid(images,
                                                     nrow=compute_nrow(images),
                                                     normalize=True),
                         global_step=epoch)
        writer.add_image('masks',
                         torchvision.utils.make_grid(masks,
                                                     nrow=compute_nrow(masks),
                                                     normalize=True),
                         global_step=epoch)
        writer.add_image('preds',
                         torchvision.utils.make_grid(preds,
                                                     nrow=compute_nrow(preds),
                                                     normalize=True),
                         global_step=epoch)