Ejemplo n.º 1
0
def main():
    train_path = Path(args.snapshot_dir, 'train')
    validation_path = Path(args.snapshot_dir, 'validation')
    meta_path = Path(args.snapshot_dir, 'meta.json')

    with meta_path.open('r') as f:
        meta_dict = json.load(f)

    print(f"Opening LMDB datasets.")
    train_dataset = rendering_dataset.MaterialRendDataset(
        train_path,
        meta_dict,
        shape=(500, 500),
        image_transform=transforms.train_image_transform(INPUT_SIZE),
        mask_transform=transforms.train_mask_transform(INPUT_SIZE))

    train_loader = DataLoader(train_dataset,
                              batch_size=1,
                              num_workers=1,
                              shuffle=True,
                              pin_memory=False,
                              collate_fn=rendering_dataset.collate_fn)

    print(f"Woohoo")
    for batch in train_loader:
        vis.image(visualize_input(batch))
Ejemplo n.º 2
0
def main():
    train_path = Path(args.snapshot_dir, 'train')
    validation_path = Path(args.snapshot_dir, 'validation')
    meta_path = Path(args.snapshot_dir, 'meta.json')

    with meta_path.open('r') as f:
        meta_dict = json.load(f)

    env = lmdb.open(train_path, readonly=False)

    print(f"Woohoo")
    for batch in train_loader:
        vis.image(visualize_input(batch))
Ejemplo n.º 3
0
def compute_topk(label_to_mat_id, model: RendNet3, image, seg_mask):
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)

    seg_mask = skimage.transform.resize(seg_mask, (224, 224),
                                        order=0,
                                        anti_aliasing=False,
                                        mode='reflect')
    seg_mask = seg_mask[:, :, np.newaxis].astype(dtype=np.uint8) * 255

    image_tensor = transforms.inference_image_transform(224)(image)
    mask_tensor = transforms.inference_mask_transform(224)(seg_mask)
    input_tensor = torch.cat((image_tensor, mask_tensor), dim=0).unsqueeze(0)

    input_vis = utils.visualize_input({'image': input_tensor})
    vis.image(input_vis, win='input')

    output = model.forward(input_tensor.cuda())

    topk_mat_scores, topk_mat_labels = torch.topk(F.softmax(output['material'],
                                                            dim=1),
                                                  k=10)

    topk_dict = {
        'material': [{
            'score': score,
            'id': label_to_mat_id[int(label)],
        } for score, label in zip(topk_mat_scores.squeeze().tolist(),
                                  topk_mat_labels.squeeze().tolist())]
    }

    if 'substance' in output:
        topk_subst_scores, topk_subst_labels = torch.topk(
            F.softmax(output['substance'].cpu(), dim=1),
            k=output['substance'].size(1))
        topk_dict['substance'] = \
            [
                {
                    'score': score,
                    'id': label,
                    'name': SUBSTANCES[int(label)],
                } for score, label in zip(topk_subst_scores.squeeze().tolist(),
                                          topk_subst_labels.squeeze().tolist())
            ]

    if 'roughness' in output:
        nrc = model.num_roughness_classes
        roughness_midpoints = np.linspace(1 / nrc / 2, 1 - 1 / nrc / 2, nrc)
        topk_roughness_scores, topk_roughness_labels = torch.topk(F.softmax(
            output['roughness'].cpu(), dim=1),
                                                                  k=5)
        topk_dict['roughness'] = \
            [
                {
                    'score': score,
                    'value': roughness_midpoints[int(label)],
                } for score, label in zip(topk_roughness_scores.squeeze().tolist(),
                                          topk_roughness_labels.squeeze().tolist())
            ]

    return topk_dict
Ejemplo n.º 4
0
def validate_epoch(val_loader, model, epoch, mat_criterion, subst_criterion,
                   color_criterion, subst_mat_labels, loss_variances,
                   loss_weights):
    meter_dict = defaultdict(AverageValueMeter)

    # switch to evaluate mode
    model.eval()

    pbar = tqdm(total=len(val_loader),
                desc='Validating Epoch',
                dynamic_ncols=True)

    last_end_time = time.time()
    for batch_idx, batch_dict in enumerate(val_loader):
        input_var = batch_dict['image'].cuda()
        mat_labels = batch_dict['material_label'].cuda()
        subst_labels = batch_dict['substance_label'].cuda()

        output = model.forward(input_var)

        mat_output = output['material']

        losses = {}

        if args.substance_loss is not None:
            if args.substance_loss == 'fc':
                subst_output = output['substance']
                subst_fc_prec1 = compute_precision(subst_output, subst_labels)
                meter_dict['subst_fc_prec1'].add(subst_fc_prec1)
            elif args.substance_loss == 'from_material':
                subst_output = material_to_substance_output(
                    mat_output, subst_mat_labels)
            else:
                raise ValueError('Invalid value for substance_loss')
            losses['substance'] = subst_criterion(subst_output, subst_labels)

        mat_subst_output = material_to_substance_output(
            mat_output, subst_mat_labels)

        subst_prec1 = compute_precision(mat_subst_output, subst_labels)
        meter_dict['subst_from_mat_prec1'].add(subst_prec1)

        if args.roughness_loss:
            roughness_output = output['roughness']
            losses['roughness'], roughness_prec1 = compute_roughness_loss(
                batch_dict['roughness'], roughness_output, mat_criterion)
            meter_dict['roughness_prec1'].add(roughness_prec1)

        if args.color_loss is not None:
            color_output = output['color']
            color_target = batch_dict['color_hist'].cuda()
            if isinstance(color_criterion, nn.KLDivLoss):
                color_output = F.log_softmax(color_output)
            losses['color'] = color_criterion(color_output, color_target)
        else:
            color_output = None
            color_target = None

        losses['material'] = mat_criterion(mat_output, mat_labels)
        loss = combine_losses(losses, loss_weights, loss_variances,
                              args.use_variance)

        # Add losses to meters.
        meter_dict['loss'].add(loss.item())
        for loss_name, loss_tensor in losses.items():
            meter_dict[f'loss_{loss_name}'].add(loss_tensor.item())

        # measure accuracy and record loss
        mat_prec1, mat_prec5 = compute_precision(mat_output,
                                                 mat_labels,
                                                 topk=(1, 5))
        meter_dict['mat_prec1'].add(mat_prec1)
        meter_dict['mat_prec5'].add(mat_prec5)
        meter_dict['batch_time'].add(time.time() - last_end_time)
        last_end_time = time.time()
        pbar.update()

        if batch_idx % args.show_freq == 0:
            vis.image(visualize_input(batch_dict),
                      win='validation-batch-example',
                      opts=dict(title='Validation Batch Example'))
            if 'color' in losses:
                color_output_vis = color_binner.visualize(
                    F.softmax(color_output[0].cpu().detach(), dim=0))
                color_target_vis = color_binner.visualize(color_target[0])
                color_hist_vis = np.vstack(
                    (color_output_vis, color_target_vis))
                vis.heatmap(color_hist_vis,
                            win='validation-color-hist',
                            opts=dict(title='Validation Color Histogram'))
        meters = [(k, v) for k, v in meter_dict.items()]
        meter_table = MeterTable(
            f"Epoch {epoch} Iter {batch_idx+1}/{len(val_loader)}", meters)

        vis.text(meter_table.render(),
                 win='validation-status',
                 opts=dict(title='Validation Status'))

        # measure elapsed time
        meter_dict['batch_time'].add(time.time() - last_end_time)
        last_end_time = time.time()

    mean_stat_dict = {k: v.mean for k, v in meter_dict.items()}
    return mean_stat_dict
Ejemplo n.º 5
0
def train_epoch(train_loader, model, epoch, mat_criterion, subst_criterion,
                color_criterion, optimizer, subst_mat_labels, loss_variances,
                loss_weights):
    meter_dict = defaultdict(AverageValueMeter)
    pbar = tqdm(total=len(train_loader),
                desc='Training Epoch',
                dynamic_ncols=True)

    last_end_time = time.time()
    for batch_idx, batch_dict in enumerate(train_loader):
        meter_dict['data_time'].add(time.time() - last_end_time)

        input_var = batch_dict['image'].cuda()
        mat_labels = batch_dict['material_label'].cuda()
        subst_labels = batch_dict['substance_label'].cuda()

        batch_start_time = time.time()
        output = model.forward(input_var)

        mat_output = output['material']

        losses = {}

        if args.substance_loss is not None:
            if args.substance_loss == 'fc':
                subst_output = output['substance']
                subst_fc_prec1 = compute_precision(subst_output, subst_labels)
                meter_dict['subst_fc_prec1'].add(subst_fc_prec1)
            elif args.substance_loss == 'from_material':
                subst_output = material_to_substance_output(
                    mat_output, subst_mat_labels)
            else:
                raise ValueError('Invalid value for substance_loss')
            losses['substance'] = subst_criterion(subst_output, subst_labels)

        mat_subst_output = material_to_substance_output(
            mat_output, subst_mat_labels)

        mat_subst_prec1 = compute_precision(mat_subst_output, subst_labels)
        meter_dict['subst_from_mat_prec1'].add(mat_subst_prec1)

        if args.roughness_loss is not None:
            roughness_output = output['roughness']
            losses['roughness'], roughness_prec1 = compute_roughness_loss(
                batch_dict['roughness'], roughness_output, mat_criterion)
            meter_dict['roughness_prec1'].add(roughness_prec1[0])

        if args.color_loss is not None:
            color_output = output['color']
            color_target = batch_dict['color_hist'].cuda()
            if isinstance(color_criterion, nn.KLDivLoss):
                color_output = F.log_softmax(color_output)
            losses['color'] = color_criterion(color_output, color_target)
        else:
            color_output = None
            color_target = None

        losses['material'] = mat_criterion(mat_output, mat_labels)

        loss = combine_losses(losses, loss_weights, loss_variances,
                              args.use_variance)

        # Add losses to meters.
        meter_dict['loss'].add(loss.item())
        for loss_name, loss_tensor in losses.items():
            meter_dict[f'loss_{loss_name}'].add(loss_tensor.item())

        mat_prec1, mat_prec5 = compute_precision(mat_output,
                                                 mat_labels,
                                                 topk=(1, 5))
        meter_dict['mat_prec1'].add(mat_prec1)
        meter_dict['mat_prec5'].add(mat_prec5)
        for loss_name, var_tensor in loss_variances.items():
            meter_dict[f'{loss_name}_variance(s_hat)'].add(var_tensor.item())
            meter_dict[f'{loss_name}_variance(exp[-s_hat])'].add(
                torch.exp(-var_tensor).item())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        meter_dict['batch_time'].add(time.time() - batch_start_time)
        last_end_time = time.time()
        pbar.update()

        if batch_idx % args.show_freq == 0:
            # Visualize batch input.
            vis.image(visualize_input(batch_dict),
                      win='train-batch-example',
                      opts=dict(title='Train Batch Example'))
            # Visualize color prediction.
            if 'color' in losses:
                color_output_vis = color_binner.visualize(
                    F.softmax(color_output[0].cpu().detach(), dim=0))
                color_target_vis = color_binner.visualize(color_target[0])
                color_hist_vis = np.vstack(
                    (color_output_vis, color_target_vis))
                vis.heatmap(color_hist_vis,
                            win='train-color-hist',
                            opts=dict(title='Training Color Histogram'))
            # Show meters.
            meters = [(k, v) for k, v in meter_dict.items()]
            meter_table = MeterTable(
                f"Epoch {epoch} Iter {batch_idx+1}/{len(train_loader)}",
                meters)

            vis.text(meter_table.render(),
                     win='train-status',
                     opts=dict(title='Training Status'))

    mean_stat_dict = {k: v.mean for k, v in meter_dict.items()}

    return mean_stat_dict
Ejemplo n.º 6
0
def compute_topk(label_to_mat_id, model: RendNet3, image, seg_mask, *,
                 color_binner, minc_substance, mat_by_id):
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)

    seg_mask = skimage.transform.resize(seg_mask, (224, 224),
                                        order=0,
                                        anti_aliasing=False,
                                        mode='constant')
    seg_mask = seg_mask[:, :, np.newaxis].astype(dtype=np.uint8) * 255

    image_tensor = transforms.inference_image_transform(input_size=224,
                                                        output_size=224,
                                                        pad=0,
                                                        to_pil=True)(image)
    mask_tensor = transforms.inference_mask_transform(input_size=224,
                                                      output_size=224,
                                                      pad=0)(seg_mask)
    input_tensor = (torch.cat((image_tensor, mask_tensor),
                              dim=0).unsqueeze(0).cuda())

    input_vis = utils.visualize_input({'image': input_tensor})
    vis.image(input_vis, win='input')

    output = model.forward(input_tensor)

    if 'color' in output:
        color_output = output['color']
        color_hist_vis = color_binner.visualize(
            F.softmax(color_output[0].cpu().detach(), dim=0))
        vis.heatmap(color_hist_vis,
                    win='color-hist',
                    opts=dict(title='Color Histogram'))

    topk_mat_scores, topk_mat_labels = torch.topk(F.softmax(output['material'],
                                                            dim=1),
                                                  k=output['material'].size(1))

    topk_dict = {'material': list()}

    for score, label in zip(topk_mat_scores.squeeze().tolist(),
                            topk_mat_labels.squeeze().tolist()):
        if int(label) == 0:
            continue
        mat_id = int(label_to_mat_id[int(label)])
        material = mat_by_id[mat_id]
        topk_dict['material'].append({
            'score': score,
            'id': mat_id,
            'pred_substance': material.substance,
            'minc_substance': minc_substance,
        })

    if 'substance' in output:
        topk_subst_scores, topk_subst_labels = torch.topk(
            F.softmax(output['substance'].cpu(), dim=1),
            k=output['substance'].size(1))
        topk_dict['substance'] = \
            [
                {
                    'score': score,
                    'id': label,
                    'name': SUBSTANCES[int(label)],
                } for score, label in zip(topk_subst_scores.squeeze().tolist(),
                                          topk_subst_labels.squeeze().tolist())
            ]

    if 'roughness' in output:
        nrc = model.num_roughness_classes
        roughness_midpoints = np.linspace(1 / nrc / 2, 1 - 1 / nrc / 2, nrc)
        topk_roughness_scores, topk_roughness_labels = torch.topk(F.softmax(
            output['roughness'].cpu(), dim=1),
                                                                  k=5)
        topk_dict['roughness'] = \
            [
                {
                    'score': score,
                    'value': roughness_midpoints[int(label)],
                } for score, label in zip(topk_roughness_scores.squeeze().tolist(),
                                          topk_roughness_labels.squeeze().tolist())
            ]

    return topk_dict
Ejemplo n.º 7
0
def validate_epoch(val_loader, model, epoch, subst_criterion, color_criterion,
                   loss_variances, loss_weights):
    meter_dict = defaultdict(AverageValueMeter)

    # switch to evaluate mode
    model.eval()

    pbar = tqdm(total=len(val_loader), desc='Validating Epoch')

    last_end_time = time.time()
    for batch_idx, batch_dict in enumerate(val_loader):
        input_var = batch_dict['image'].cuda()
        subst_labels = batch_dict['substance_label'].cuda()

        output = model.forward(input_var)

        losses = {}

        subst_output = output['substance']
        subst_fc_prec1 = compute_precision(subst_output, subst_labels)
        meter_dict['subst_fc_prec1'].add(subst_fc_prec1)
        losses['substance'] = subst_criterion(subst_output, subst_labels)

        if args.color_loss is not None:
            color_output = output['color']
            color_target = batch_dict['color_hist'].cuda()
            if isinstance(color_criterion, nn.KLDivLoss):
                color_output = F.log_softmax(color_output)
            losses['color'] = color_criterion(color_output, color_target)
        else:
            color_output = None
            color_target = None

        loss = combine_losses(losses, loss_weights, loss_variances,
                              args.use_variance)

        # Add losses to meters.
        meter_dict['loss'].add(loss.item())
        for loss_name, loss_tensor in losses.items():
            meter_dict[f'loss_{loss_name}'].add(loss_tensor.item())

        meter_dict['batch_time'].add(time.time() - last_end_time)
        last_end_time = time.time()
        pbar.update()

        if batch_idx % args.show_freq == 0:
            vis.image(visualize_input(batch_dict),
                      win='validation-batch-example',
                      opts=dict(title='Validation Batch Example'))
            if 'color' in losses:
                visualize_color_hist_output(
                    color_output[0],
                    color_target[0],
                    'validation-color-hist',
                    'Validation Color Histogram',
                    color_hist_shape=color_binner.shape,
                    color_hist_space=color_binner.space,
                )
        meters = [(k, v) for k, v in meter_dict.items()]
        meter_table = MeterTable(
            f"Epoch {epoch} Iter {batch_idx+1}/{len(val_loader)}", meters)

        vis.text(meter_table.render(),
                 win='validation-status',
                 opts=dict(title='Validation Status'))

        # measure elapsed time
        meter_dict['batch_time'].add(time.time() - last_end_time)
        last_end_time = time.time()

    mean_stat_dict = {k: v.mean for k, v in meter_dict.items()}
    return mean_stat_dict
Ejemplo n.º 8
0
def train_epoch(train_loader, model, epoch, subst_criterion, color_criterion,
                optimizer, loss_variances, loss_weights):
    meter_dict = defaultdict(AverageValueMeter)
    pbar = tqdm(total=len(train_loader), desc='Training Epoch')

    last_end_time = time.time()
    for batch_idx, batch_dict in enumerate(train_loader):
        meter_dict['data_time'].add(time.time() - last_end_time)

        input_var = batch_dict['image'].cuda()
        subst_labels = batch_dict['substance_label'].cuda()

        batch_start_time = time.time()
        output = model.forward(input_var)

        losses = {}

        subst_output = output['substance']
        subst_fc_prec1 = compute_precision(subst_output, subst_labels)
        meter_dict['subst_fc_prec1'].add(subst_fc_prec1)
        losses['substance'] = subst_criterion(subst_output, subst_labels)

        if args.color_loss is not None:
            color_output = output['color']
            color_target = batch_dict['color_hist'].cuda()
            if isinstance(color_criterion, nn.KLDivLoss):
                color_output = F.log_softmax(color_output)
            losses['color'] = color_criterion(color_output, color_target)
        else:
            color_output = None
            color_target = None

        loss = combine_losses(losses, loss_weights, loss_variances,
                              args.use_variance)

        # Add losses to meters.
        meter_dict['loss'].add(loss.item())
        for loss_name, loss_tensor in losses.items():
            meter_dict[f'loss_{loss_name}'].add(loss_tensor.item())

        for loss_name, var_tensor in loss_variances.items():
            meter_dict[f'{loss_name}_variance(s_hat)'].add(var_tensor.item())
            meter_dict[f'{loss_name}_variance(exp[-s_hat])'].add(
                torch.exp(-var_tensor).item())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        meter_dict['batch_time'].add(time.time() - batch_start_time)
        last_end_time = time.time()
        pbar.update()

        if batch_idx % args.show_freq == 0:
            # Visualize batch input.
            vis.image(visualize_input(batch_dict),
                      win='train-batch-example',
                      opts=dict(title='Train Batch Example'))
            # Visualize color prediction.
            if 'color' in losses:
                visualize_color_hist_output(
                    color_output[0],
                    color_target[0],
                    win='train-color-hist',
                    title='Training Color Histogram',
                    color_hist_shape=color_binner.shape,
                    color_hist_space=color_binner.space,
                )
            # Show meters.
            meters = [(k, v) for k, v in meter_dict.items()]
            meter_table = MeterTable(
                f"Epoch {epoch} Iter {batch_idx+1}/{len(train_loader)}",
                meters)

            vis.text(meter_table.render(),
                     win='train-status',
                     opts=dict(title='Training Status'))

    mean_stat_dict = {k: v.mean for k, v in meter_dict.items()}

    return mean_stat_dict