Ejemplo n.º 1
0
def compute_topk(label_to_mat_id, model: RendNet3, image, seg_mask):
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)

    seg_mask = skimage.transform.resize(seg_mask, (224, 224),
                                        order=0,
                                        anti_aliasing=False,
                                        mode='reflect')
    seg_mask = seg_mask[:, :, np.newaxis].astype(dtype=np.uint8) * 255

    image_tensor = transforms.inference_image_transform(224)(image)
    mask_tensor = transforms.inference_mask_transform(224)(seg_mask)
    input_tensor = torch.cat((image_tensor, mask_tensor), dim=0).unsqueeze(0)

    input_vis = utils.visualize_input({'image': input_tensor})
    vis.image(input_vis, win='input')

    output = model.forward(input_tensor.cuda())

    topk_mat_scores, topk_mat_labels = torch.topk(F.softmax(output['material'],
                                                            dim=1),
                                                  k=10)

    topk_dict = {
        'material': [{
            'score': score,
            'id': label_to_mat_id[int(label)],
        } for score, label in zip(topk_mat_scores.squeeze().tolist(),
                                  topk_mat_labels.squeeze().tolist())]
    }

    if 'substance' in output:
        topk_subst_scores, topk_subst_labels = torch.topk(
            F.softmax(output['substance'].cpu(), dim=1),
            k=output['substance'].size(1))
        topk_dict['substance'] = \
            [
                {
                    'score': score,
                    'id': label,
                    'name': SUBSTANCES[int(label)],
                } for score, label in zip(topk_subst_scores.squeeze().tolist(),
                                          topk_subst_labels.squeeze().tolist())
            ]

    if 'roughness' in output:
        nrc = model.num_roughness_classes
        roughness_midpoints = np.linspace(1 / nrc / 2, 1 - 1 / nrc / 2, nrc)
        topk_roughness_scores, topk_roughness_labels = torch.topk(F.softmax(
            output['roughness'].cpu(), dim=1),
                                                                  k=5)
        topk_dict['roughness'] = \
            [
                {
                    'score': score,
                    'value': roughness_midpoints[int(label)],
                } for score, label in zip(topk_roughness_scores.squeeze().tolist(),
                                          topk_roughness_labels.squeeze().tolist())
            ]

    return topk_dict
Ejemplo n.º 2
0
def main(checkpoint_path, batch_size, normalized,
         visdom_port):

    checkpoint_path = Path(checkpoint_path)
    snapshot_path = checkpoint_path.parent.parent.parent / 'snapshot.json'

    with snapshot_path.open('r') as f:
        snapshot_dict = json.load(f)
        mat_id_to_label = snapshot_dict['mat_id_to_label']
        label_to_mat_id = {int(v): int(k) for k, v in mat_id_to_label.items()}
        num_classes = len(label_to_mat_id) + 1

    print(f'Loading model checkpoint from {checkpoint_path!r}')
    checkpoint = torch.load(checkpoint_path)

    model = RendNet3(num_classes=num_classes,
                 num_roughness_classes=20,
                 num_substances=len(SUBSTANCES),
                 base_model=resnet.resnet18(pretrained=False))
    model.load_state_dict(checkpoint['state_dict'])
    model.train(False)
    model = model.cuda()

    validation_dataset = rendering_dataset.MaterialRendDataset(
        snapshot_dict,
        snapshot_dict['examples']['validation'],
        shape=(384, 384),
        image_transform=transforms.inference_image_transform(INPUT_SIZE),
        mask_transform=transforms.inference_mask_transform(INPUT_SIZE))

    validation_loader = DataLoader(
        validation_dataset, batch_size=batch_size,
        num_workers=8,
        shuffle=False,
        pin_memory=True,
        collate_fn=rendering_dataset.collate_fn)

    pred_counts = collections.defaultdict(collections.Counter)

    # switch to evaluate mode
    model.eval()

    confusion_meter = tnt.meter.ConfusionMeter(
        k=num_classes, normalized=normalized)

    pbar = tqdm(validation_loader)
    for batch_idx, batch_dict in enumerate(pbar):
        input_tensor = batch_dict['image'].cuda()
        labels = batch_dict['material_label'].cuda()

        # compute output
        output = model.forward(input_tensor)
        pbar.set_description(f"{output['material'].size()}")

        # _, pred = output['material'].topk(k=1, dim=1, largest=True, sorted=True)

        confusion_meter.add(output['material'].cpu(), labels.cpu())

    with session_scope() as sess:
        materials = sess.query(models.Material).filter_by(enabled=True).all()
        material_id_to_name = {m.id: m.name for m in materials}
        mat_by_id = {m.id: m for m in materials}

    class_names = ['background']
    class_names.extend([
        mat_by_id[label_to_mat_id[i]].name for i in range(1, num_classes)
    ])

    print(len(class_names), )

    confusion_matrix = confusion_meter.value()
    # sorted_confusion_matrix = confusion_matrix[:, inds]
    # sorted_confusion_matrix = sorted_confusion_matrix[inds, :]

    # sorted_class_names = [class_names[i] for i in inds]
    confusion_logger = VisdomLogger(
        'heatmap', opts={
            'title': 'Confusion matrix',
            'columnnames': class_names,
            'rownames': class_names,
            'xtickfont': {'size': 8},
            'ytickfont': {'size': 8},
        },
        env='brdf-classifier-confusion',
    port=visdom_port)

    confusion_logger.log(confusion_matrix)
Ejemplo n.º 3
0
def compute_topk(label_to_mat_id, model: RendNet3, image, seg_mask, *,
                 color_binner, minc_substance, mat_by_id):
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)

    seg_mask = skimage.transform.resize(seg_mask, (224, 224),
                                        order=0,
                                        anti_aliasing=False,
                                        mode='constant')
    seg_mask = seg_mask[:, :, np.newaxis].astype(dtype=np.uint8) * 255

    image_tensor = transforms.inference_image_transform(input_size=224,
                                                        output_size=224,
                                                        pad=0,
                                                        to_pil=True)(image)
    mask_tensor = transforms.inference_mask_transform(input_size=224,
                                                      output_size=224,
                                                      pad=0)(seg_mask)
    input_tensor = (torch.cat((image_tensor, mask_tensor),
                              dim=0).unsqueeze(0).cuda())

    input_vis = utils.visualize_input({'image': input_tensor})
    vis.image(input_vis, win='input')

    output = model.forward(input_tensor)

    if 'color' in output:
        color_output = output['color']
        color_hist_vis = color_binner.visualize(
            F.softmax(color_output[0].cpu().detach(), dim=0))
        vis.heatmap(color_hist_vis,
                    win='color-hist',
                    opts=dict(title='Color Histogram'))

    topk_mat_scores, topk_mat_labels = torch.topk(F.softmax(output['material'],
                                                            dim=1),
                                                  k=output['material'].size(1))

    topk_dict = {'material': list()}

    for score, label in zip(topk_mat_scores.squeeze().tolist(),
                            topk_mat_labels.squeeze().tolist()):
        if int(label) == 0:
            continue
        mat_id = int(label_to_mat_id[int(label)])
        material = mat_by_id[mat_id]
        topk_dict['material'].append({
            'score': score,
            'id': mat_id,
            'pred_substance': material.substance,
            'minc_substance': minc_substance,
        })

    if 'substance' in output:
        topk_subst_scores, topk_subst_labels = torch.topk(
            F.softmax(output['substance'].cpu(), dim=1),
            k=output['substance'].size(1))
        topk_dict['substance'] = \
            [
                {
                    'score': score,
                    'id': label,
                    'name': SUBSTANCES[int(label)],
                } for score, label in zip(topk_subst_scores.squeeze().tolist(),
                                          topk_subst_labels.squeeze().tolist())
            ]

    if 'roughness' in output:
        nrc = model.num_roughness_classes
        roughness_midpoints = np.linspace(1 / nrc / 2, 1 - 1 / nrc / 2, nrc)
        topk_roughness_scores, topk_roughness_labels = torch.topk(F.softmax(
            output['roughness'].cpu(), dim=1),
                                                                  k=5)
        topk_dict['roughness'] = \
            [
                {
                    'score': score,
                    'value': roughness_midpoints[int(label)],
                } for score, label in zip(topk_roughness_scores.squeeze().tolist(),
                                          topk_roughness_labels.squeeze().tolist())
            ]

    return topk_dict
Ejemplo n.º 4
0
def main():
    snapshot_dir = args.snapshot_dir
    checkpoint_dir = args.checkpoint_dir / args.model_name

    train_path = Path(snapshot_dir, 'train')
    validation_path = Path(snapshot_dir, 'validation')
    meta_path = Path(snapshot_dir, 'meta.json')

    print(f' * train_path = {train_path!s}\n'
          f' * validation_path = {validation_path!s}\n'
          f' * meta_path = {meta_path!s}\n'
          f' * checkpoint_dir = {checkpoint_dir!s}')

    with meta_path.open('r') as f:
        meta_dict = json.load(f)

    with session_scope() as sess:
        materials = sess.query(models.Material).all()
        mat_by_id = {m.id: m for m in materials}

    mat_id_to_label = {
        int(k): v
        for k, v in meta_dict['mat_id_to_label'].items()
    }
    subst_mat_labels = defaultdict(list)
    for mat_id, mat_label in mat_id_to_label.items():
        material = mat_by_id[mat_id]
        subst_mat_labels[material.substance].append(mat_label)
    subst_mat_labels = {
        SUBSTANCES.index(k): torch.LongTensor(v).cuda()
        for k, v in subst_mat_labels.items()
    }
    mat_label_to_subst_label = {
        label: mat_by_id[mat_id].substance
        for mat_id, label in mat_id_to_label.items()
    }

    num_classes = max(mat_id_to_label.values()) + 1

    if args.roughness_loss:
        num_roughness_classes = (args.num_roughness_classes if
                                 args.roughness_loss == 'cross_entropy' else 1)
        output_roughness = True
    else:
        num_roughness_classes = 0
        output_roughness = False

    output_substance = (True if args.substance_loss == 'fc' else False)

    model_params = dict(
        num_classes=num_classes,
        num_substances=len(SUBSTANCES),
        num_roughness_classes=num_roughness_classes,
        output_roughness=output_roughness,
        output_substance=output_substance,
        output_color=color_binner is not None,
        num_color_bins=color_binner.size if color_binner else 0,
    )

    base_model_fn = {
        'resnet18': resnet.resnet18,
        'resnet34': resnet.resnet34,
        'resnet50': resnet.resnet50,
    }[args.base_model]

    if args.from_scratch:
        base_model = base_model_fn(pretrained=False)
    else:
        base_model = base_model_fn(pretrained=True)

    model = RendNet3(
        **model_params,
        base_model=base_model,
    ).train().cuda()

    train_cum_stats = defaultdict(list)
    val_cum_stats = defaultdict(list)
    if args.resume:
        print(f" * Loading weights from {args.resume!s}")
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'], strict=False)

        stats_dict_path = checkpoint_dir / 'model_train_stats.json'
        if stats_dict_path.exists():
            with stats_dict_path.open('r') as f:
                stats_dict = json.load(f)
                train_cum_stats = stats_dict['train']
                val_cum_stats = stats_dict['validation']

    print(' * Loading datasets')

    train_dataset = rendering_dataset.MaterialRendDataset(
        train_path,
        meta_dict,
        color_binner=color_binner,
        shape=SHAPE,
        lmdb_name=snapshot_dir.name,
        image_transform=transforms.train_image_transform(INPUT_SIZE, pad=0),
        mask_transform=transforms.train_mask_transform(INPUT_SIZE, pad=0),
        mask_noise_p=args.mask_noise_p)

    validation_dataset = rendering_dataset.MaterialRendDataset(
        validation_path,
        meta_dict,
        color_binner=color_binner,
        shape=SHAPE,
        lmdb_name=snapshot_dir.name,
        image_transform=transforms.inference_image_transform(
            INPUT_SIZE, INPUT_SIZE),
        mask_transform=transforms.inference_mask_transform(
            INPUT_SIZE, INPUT_SIZE))

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        pin_memory=True,
    )

    validation_loader = DataLoader(
        validation_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=False,
    )

    loss_weights = {
        'material': args.material_loss_weight,
        'roughness': args.roughness_loss_weight,
        'substance': args.substance_loss_weight,
        'color': args.color_loss_weight,
    }

    model_params = {
        'init_lr':
        args.init_lr,
        'lr_decay_epochs':
        args.lr_decay_epochs,
        'lr_decay_frac':
        args.lr_decay_frac,
        'momentum':
        args.momentum,
        'weight_decay':
        args.weight_decay,
        'batch_size':
        args.batch_size,
        'from_scratch':
        args.from_scratch,
        'base_model':
        args.base_model,
        'resumed_from':
        str(args.resume),
        'mask_noise_p':
        args.mask_noise_p,
        'use_substance_loss':
        args.substance_loss is not None,
        'substance_loss':
        args.substance_loss,
        'roughness_loss':
        args.roughness_loss,
        'material_variance_init':
        args.material_variance_init,
        'substance_variance_init':
        args.substance_variance_init,
        'color_variance_init':
        args.color_variance_init,
        'color_loss':
        args.color_loss,
        'num_classes':
        num_classes,
        'num_roughness_classes':
        (args.num_roughness_classes if args.roughness_loss else None),
        'use_variance':
        args.use_variance,
        'loss_weights':
        loss_weights,
        'model_params':
        model_params,
    }

    if color_binner:
        model_params = {
            **model_params,
            'color_hist_name': color_binner.name,
            'color_hist_shape': color_binner.shape,
            'color_hist_space': color_binner.space,
            'color_hist_sigma': color_binner.sigma,
        }

    vis.text(f'<pre>{json.dumps(model_params, indent=2)}</pre>', win='params')

    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    model_params_path = (checkpoint_dir / 'model_params.json')
    if not args.dry_run and not model_params_path.exists():
        with model_params_path.open('w') as f:
            print(f' * Saving model params to {model_params_path!s}')
            json.dump(model_params, f, indent=2)

    mat_criterion = nn.CrossEntropyLoss().cuda()
    if args.substance_loss == 'from_material':
        subst_criterion = nn.NLLLoss().cuda()
    else:
        subst_criterion = nn.CrossEntropyLoss().cuda()

    if args.color_loss == 'cross_entropy':
        color_criterion = nn.BCEWithLogitsLoss().cuda()
    elif args.color_loss == 'kl_divergence':
        color_criterion = nn.KLDivLoss().cuda()
    else:
        color_criterion = None

    loss_variances = {
        'material':
        torch.tensor([args.material_variance_init], requires_grad=True),
        'substance':
        torch.tensor([args.substance_variance_init], requires_grad=True),
        'color':
        torch.tensor([args.color_variance_init], requires_grad=True),
    }

    optimizer = optim.SGD([*model.parameters(), *loss_variances.values()],
                          args.init_lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    best_prec1 = 0
    pbar = tqdm(total=args.epochs, dynamic_ncols=True)
    for epoch in range(args.start_epoch, args.epochs):
        pbar.set_description(f'Epoch {epoch}')
        decay_learning_rate(optimizer, epoch, args.init_lr,
                            args.lr_decay_epochs, args.lr_decay_frac)

        train_stats = train_epoch(
            train_loader,
            model,
            epoch,
            mat_criterion,
            subst_criterion,
            color_criterion,
            optimizer,
            subst_mat_labels,
            loss_variances=loss_variances,
            loss_weights=loss_weights,
        )

        # evaluate on validation set
        val_stats = validate_epoch(
            validation_loader,
            model,
            epoch,
            mat_criterion,
            subst_criterion,
            color_criterion,
            subst_mat_labels,
            loss_variances=loss_variances,
            loss_weights=loss_weights,
        )

        # remember best prec@1 and save checkpoint
        is_best = val_stats['mat_prec1'] > best_prec1
        best_prec1 = max(val_stats['mat_prec1'], best_prec1)

        for stat_name, stat_val in train_stats.items():
            train_cum_stats[stat_name].append(stat_val)

        for stat_name, stat_val in val_stats.items():
            val_cum_stats[stat_name].append(stat_val)

        stats_dict = {
            'epoch': epoch + 1,
            'train': train_cum_stats,
            'validation': val_cum_stats,
        }

        if not args.dry_run:
            with (checkpoint_dir / 'model_train_stats.json').open('w') as f:
                json.dump(stats_dict, f)

            if is_best:
                with (checkpoint_dir / 'model_best_stats.json').open('w') as f:
                    json.dump(
                        {
                            'epoch': epoch + 1,
                            'train': {
                                k: v
                                for k, v in train_stats.items()
                                if not math.isnan(v)
                            },
                            'validation': {
                                k: v
                                for k, v in val_stats.items()
                                if not math.isnan(v)
                            },
                        }, f)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_name': args.model_name,
                    'state_dict': model.state_dict(),
                    'loss_variances': loss_variances,
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                    'params': model_params,
                },
                checkpoint_dir=checkpoint_dir,
                filename=f'model.{args.model_name}.epoch{epoch:03d}.pth.tar',
                is_best=is_best)

        # Plot graphs.
        for stat_name, train_stat_vals in train_cum_stats.items():
            if stat_name not in val_cum_stats:
                continue
            val_stat_vals = val_cum_stats[stat_name]
            vis.line(Y=np.column_stack((train_stat_vals, val_stat_vals)),
                     X=np.array(list(range(0, len(train_stat_vals)))),
                     win=f'plot-{stat_name}',
                     name=stat_name,
                     opts={
                         'legend': ['training', 'validation'],
                         'title': stat_name,
                     })
Ejemplo n.º 5
0
def main():
    checkpoint_dir = args.checkpoint_dir / args.model_name
    split_path = args.opensurfaces_dir / args.split_file

    with split_path.open('r') as f:
        split_dict = json.load(f)

    print(f' * opensurface_dir = {args.opensurfaces_dir!s}\n'
          f' * split_path = {split_path!s}\n'
          f' * checkpoint_dir = {checkpoint_dir!s}')

    print(' * Loading datasets')

    output_substance = (True if args.substance_loss == 'fc' else False)
    output_color = color_binner is not None

    train_dataset = dataset.OpenSurfacesDataset(
        base_dir=args.opensurfaces_dir,
        color_binner=color_binner,
        photo_ids=split_dict['train'],
        image_transform=transforms.train_image_transform(
            INPUT_SIZE,
            crop_scales=CROP_RANGE,
            max_rotation=MAX_ROTATION,
            max_brightness_jitter=0.2,
            max_contrast_jitter=0.2,
            max_saturation_jitter=0.2,
            max_hue_jitter=0.1,
        ),
        mask_transform=transforms.train_mask_transform(
            INPUT_SIZE, crop_scales=CROP_RANGE, max_rotation=MAX_ROTATION),
        cropped_image_transform=transforms.train_image_transform(
            INPUT_SIZE,
            pad=200,
            crop_scales=CROPPED_CROP_RANGE,
            max_rotation=CROPPED_MAX_ROTATION,
            max_brightness_jitter=0.2,
            max_contrast_jitter=0.2,
            max_saturation_jitter=0.2,
            max_hue_jitter=0.1,
        ),
        cropped_mask_transform=transforms.train_mask_transform(
            INPUT_SIZE,
            pad=200,
            crop_scales=CROPPED_CROP_RANGE,
            max_rotation=CROPPED_MAX_ROTATION),
        use_cropped=args.use_cropped,
        p_cropped=args.p_cropped,
    )

    validation_dataset = dataset.OpenSurfacesDataset(
        base_dir=args.opensurfaces_dir,
        color_binner=color_binner,
        photo_ids=split_dict['validation'],
        image_transform=transforms.inference_image_transform(
            INPUT_SIZE, INPUT_SIZE),
        mask_transform=transforms.inference_mask_transform(
            INPUT_SIZE, INPUT_SIZE),
        cropped_image_transform=transforms.inference_image_transform(
            INPUT_SIZE, INPUT_SIZE),
        cropped_mask_transform=transforms.inference_mask_transform(
            INPUT_SIZE, INPUT_SIZE),
        use_cropped=args.use_cropped,
        p_cropped=args.p_cropped,
    )

    model_params = dict(
        num_substances=len(SUBSTANCES),
        output_material=False,
        output_roughness=False,
        output_substance=output_substance,
        output_color=output_color,
        num_color_bins=color_binner.size if color_binner else 0,
    )

    base_model_fn = {
        'resnet18': resnet.resnet18,
        'resnet34': resnet.resnet34,
        'resnet50': resnet.resnet50,
    }[args.base_model]

    if args.from_scratch:
        base_model = base_model_fn(pretrained=False)
    else:
        base_model = base_model_fn(pretrained=True)

    model = RendNet3(**model_params, base_model=base_model).train().cuda()

    train_cum_stats = defaultdict(list)
    val_cum_stats = defaultdict(list)
    if args.resume:
        print(f" * Loading weights from {args.resume!s}")
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'], strict=False)

        stats_dict_path = checkpoint_dir / 'model_train_stats.json'
        if stats_dict_path.exists():
            with stats_dict_path.open('r') as f:
                stats_dict = json.load(f)
                train_cum_stats = stats_dict['train']
                val_cum_stats = stats_dict['validation']

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        pin_memory=True,
    )

    validation_loader = DataLoader(
        validation_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=False,
    )

    loss_weights = {
        # 'material': args.material_loss_weight,
        # 'roughness': args.roughness_loss_weight,
        'substance': args.substance_loss_weight,
        'color': args.color_loss_weight,
    }

    model_params = {
        'init_lr': args.init_lr,
        'lr_decay_epochs': args.lr_decay_epochs,
        'lr_decay_frac': args.lr_decay_frac,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay,
        'batch_size': args.batch_size,
        'from_scratch': args.from_scratch,
        'base_model': args.base_model,
        'resumed_from': str(args.resume),
        'use_substance_loss': args.substance_loss is not None,
        'substance_loss': args.substance_loss,
        'color_loss': args.color_loss,
        'color_hist_name': args.color_hist_name,
        'use_variance': args.use_variance,
        'loss_weights': loss_weights,
        'use_cropped': args.use_cropped,
        'p_cropped': args.p_cropped,
        'model_params': model_params,
    }

    if color_binner:
        model_params = {
            **model_params,
            'color_hist_name': color_binner.name,
            'color_hist_shape': color_binner.shape,
            'color_hist_space': color_binner.space,
            'color_hist_sigma': color_binner.sigma,
        }

    vis.text(f'<pre>{json.dumps(model_params, indent=2)}</pre>', win='params')

    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    model_params_path = (checkpoint_dir / 'model_params.json')
    if not model_params_path.exists():
        with model_params_path.open('w') as f:
            print(f' * Saving model params to {model_params_path!s}')
            json.dump(model_params, f, indent=2)

    subst_criterion = nn.CrossEntropyLoss().cuda()

    if args.color_loss == 'cross_entropy':
        color_criterion = nn.BCEWithLogitsLoss().cuda()
    elif args.color_loss == 'kl_divergence':
        color_criterion = nn.KLDivLoss().cuda()
    else:
        color_criterion = None
    loss_variances = {
        'substance':
        torch.tensor([args.substance_variance_init], requires_grad=True),
        'color':
        torch.tensor([args.color_variance_init], requires_grad=True),
    }

    optimizer = optim.SGD([*model.parameters(), *loss_variances.values()],
                          args.init_lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if 'subst_fc_prec1' in val_cum_stats:
        best_prec1 = max(val_cum_stats['subst_fc_prec1'])
    else:
        best_prec1 = 0

    pbar = tqdm(total=args.epochs)
    for epoch in range(args.start_epoch, args.epochs):
        pbar.set_description(f'Epoch {epoch}')
        decay_learning_rate(optimizer, epoch, args.init_lr,
                            args.lr_decay_epochs, args.lr_decay_frac)

        train_stats = train_epoch(
            train_loader,
            model,
            epoch,
            subst_criterion,
            color_criterion,
            optimizer,
            loss_variances=loss_variances,
            loss_weights=loss_weights,
        )

        # evaluate on validation set
        val_stats = validate_epoch(
            validation_loader,
            model,
            epoch,
            subst_criterion,
            color_criterion,
            loss_variances=loss_variances,
            loss_weights=loss_weights,
        )

        # remember best prec@1 and save checkpoint
        is_best = val_stats['subst_fc_prec1'] > best_prec1
        best_prec1 = max(val_stats['subst_fc_prec1'], best_prec1)

        for stat_name, stat_val in train_stats.items():
            train_cum_stats[stat_name].append(stat_val)

        for stat_name, stat_val in val_stats.items():
            val_cum_stats[stat_name].append(stat_val)

        stats_dict = {
            'epoch': epoch + 1,
            'train': train_cum_stats,
            'validation': val_cum_stats,
        }

        with (checkpoint_dir / 'model_train_stats.json').open('w') as f:
            json.dump(stats_dict, f)

        if is_best:
            with (checkpoint_dir / 'model_best_stats.json').open('w') as f:
                json.dump(
                    {
                        'epoch': epoch + 1,
                        'train': {
                            k: v
                            for k, v in train_stats.items()
                            if not math.isnan(v)
                        },
                        'validation': {
                            k: v
                            for k, v in val_stats.items() if not math.isnan(v)
                        },
                    }, f)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model_name': args.model_name,
                'state_dict': model.state_dict(),
                'loss_variances': loss_variances,
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
                'params': model_params,
            },
            checkpoint_dir=checkpoint_dir,
            filename=f'model.{args.model_name}.epoch{epoch:03d}.pth.tar',
            is_best=is_best)

        # Plot graphs.
        for stat_name, train_stat_vals in train_cum_stats.items():
            if stat_name not in val_cum_stats:
                continue
            tqdm.write(f"Plotting {stat_name}")
            val_stat_vals = val_cum_stats[stat_name]
            vis.line(Y=np.column_stack((train_stat_vals, val_stat_vals)),
                     X=np.array(list(range(0, len(train_stat_vals)))),
                     win=f'plot-{stat_name}',
                     name=stat_name,
                     opts={
                         'legend': ['training', 'validation'],
                         'title': stat_name,
                     })