예제 #1
0
def compute_roughness_loss(roughness, roughness_output, criterion):
    roughness = roughness.float()
    roughness_prec1 = [0]
    if args.roughness_loss == 'cross_entropy':
        roughness_labels = (
            (roughness *
             (args.num_roughness_classes - 1)).floor().long().cuda())
        roughness_loss = criterion(roughness_output, roughness_labels)
        roughness_prec1 = compute_precision(roughness_output, roughness_labels)
    elif args.roughness_loss == 'mse':
        roughness_loss = (roughness.cuda() - roughness_output).pow(2).mean()
        roughness_prec1 = roughness_loss.item()
    else:
        roughness_loss = 0

    return roughness_loss, roughness_prec1
예제 #2
0
def validate_epoch(val_loader, model, epoch, mat_criterion, subst_criterion,
                   color_criterion, subst_mat_labels, loss_variances,
                   loss_weights):
    meter_dict = defaultdict(AverageValueMeter)

    # switch to evaluate mode
    model.eval()

    pbar = tqdm(total=len(val_loader),
                desc='Validating Epoch',
                dynamic_ncols=True)

    last_end_time = time.time()
    for batch_idx, batch_dict in enumerate(val_loader):
        input_var = batch_dict['image'].cuda()
        mat_labels = batch_dict['material_label'].cuda()
        subst_labels = batch_dict['substance_label'].cuda()

        output = model.forward(input_var)

        mat_output = output['material']

        losses = {}

        if args.substance_loss is not None:
            if args.substance_loss == 'fc':
                subst_output = output['substance']
                subst_fc_prec1 = compute_precision(subst_output, subst_labels)
                meter_dict['subst_fc_prec1'].add(subst_fc_prec1)
            elif args.substance_loss == 'from_material':
                subst_output = material_to_substance_output(
                    mat_output, subst_mat_labels)
            else:
                raise ValueError('Invalid value for substance_loss')
            losses['substance'] = subst_criterion(subst_output, subst_labels)

        mat_subst_output = material_to_substance_output(
            mat_output, subst_mat_labels)

        subst_prec1 = compute_precision(mat_subst_output, subst_labels)
        meter_dict['subst_from_mat_prec1'].add(subst_prec1)

        if args.roughness_loss:
            roughness_output = output['roughness']
            losses['roughness'], roughness_prec1 = compute_roughness_loss(
                batch_dict['roughness'], roughness_output, mat_criterion)
            meter_dict['roughness_prec1'].add(roughness_prec1)

        if args.color_loss is not None:
            color_output = output['color']
            color_target = batch_dict['color_hist'].cuda()
            if isinstance(color_criterion, nn.KLDivLoss):
                color_output = F.log_softmax(color_output)
            losses['color'] = color_criterion(color_output, color_target)
        else:
            color_output = None
            color_target = None

        losses['material'] = mat_criterion(mat_output, mat_labels)
        loss = combine_losses(losses, loss_weights, loss_variances,
                              args.use_variance)

        # Add losses to meters.
        meter_dict['loss'].add(loss.item())
        for loss_name, loss_tensor in losses.items():
            meter_dict[f'loss_{loss_name}'].add(loss_tensor.item())

        # measure accuracy and record loss
        mat_prec1, mat_prec5 = compute_precision(mat_output,
                                                 mat_labels,
                                                 topk=(1, 5))
        meter_dict['mat_prec1'].add(mat_prec1)
        meter_dict['mat_prec5'].add(mat_prec5)
        meter_dict['batch_time'].add(time.time() - last_end_time)
        last_end_time = time.time()
        pbar.update()

        if batch_idx % args.show_freq == 0:
            vis.image(visualize_input(batch_dict),
                      win='validation-batch-example',
                      opts=dict(title='Validation Batch Example'))
            if 'color' in losses:
                color_output_vis = color_binner.visualize(
                    F.softmax(color_output[0].cpu().detach(), dim=0))
                color_target_vis = color_binner.visualize(color_target[0])
                color_hist_vis = np.vstack(
                    (color_output_vis, color_target_vis))
                vis.heatmap(color_hist_vis,
                            win='validation-color-hist',
                            opts=dict(title='Validation Color Histogram'))
        meters = [(k, v) for k, v in meter_dict.items()]
        meter_table = MeterTable(
            f"Epoch {epoch} Iter {batch_idx+1}/{len(val_loader)}", meters)

        vis.text(meter_table.render(),
                 win='validation-status',
                 opts=dict(title='Validation Status'))

        # measure elapsed time
        meter_dict['batch_time'].add(time.time() - last_end_time)
        last_end_time = time.time()

    mean_stat_dict = {k: v.mean for k, v in meter_dict.items()}
    return mean_stat_dict
예제 #3
0
def train_epoch(train_loader, model, epoch, mat_criterion, subst_criterion,
                color_criterion, optimizer, subst_mat_labels, loss_variances,
                loss_weights):
    meter_dict = defaultdict(AverageValueMeter)
    pbar = tqdm(total=len(train_loader),
                desc='Training Epoch',
                dynamic_ncols=True)

    last_end_time = time.time()
    for batch_idx, batch_dict in enumerate(train_loader):
        meter_dict['data_time'].add(time.time() - last_end_time)

        input_var = batch_dict['image'].cuda()
        mat_labels = batch_dict['material_label'].cuda()
        subst_labels = batch_dict['substance_label'].cuda()

        batch_start_time = time.time()
        output = model.forward(input_var)

        mat_output = output['material']

        losses = {}

        if args.substance_loss is not None:
            if args.substance_loss == 'fc':
                subst_output = output['substance']
                subst_fc_prec1 = compute_precision(subst_output, subst_labels)
                meter_dict['subst_fc_prec1'].add(subst_fc_prec1)
            elif args.substance_loss == 'from_material':
                subst_output = material_to_substance_output(
                    mat_output, subst_mat_labels)
            else:
                raise ValueError('Invalid value for substance_loss')
            losses['substance'] = subst_criterion(subst_output, subst_labels)

        mat_subst_output = material_to_substance_output(
            mat_output, subst_mat_labels)

        mat_subst_prec1 = compute_precision(mat_subst_output, subst_labels)
        meter_dict['subst_from_mat_prec1'].add(mat_subst_prec1)

        if args.roughness_loss is not None:
            roughness_output = output['roughness']
            losses['roughness'], roughness_prec1 = compute_roughness_loss(
                batch_dict['roughness'], roughness_output, mat_criterion)
            meter_dict['roughness_prec1'].add(roughness_prec1[0])

        if args.color_loss is not None:
            color_output = output['color']
            color_target = batch_dict['color_hist'].cuda()
            if isinstance(color_criterion, nn.KLDivLoss):
                color_output = F.log_softmax(color_output)
            losses['color'] = color_criterion(color_output, color_target)
        else:
            color_output = None
            color_target = None

        losses['material'] = mat_criterion(mat_output, mat_labels)

        loss = combine_losses(losses, loss_weights, loss_variances,
                              args.use_variance)

        # Add losses to meters.
        meter_dict['loss'].add(loss.item())
        for loss_name, loss_tensor in losses.items():
            meter_dict[f'loss_{loss_name}'].add(loss_tensor.item())

        mat_prec1, mat_prec5 = compute_precision(mat_output,
                                                 mat_labels,
                                                 topk=(1, 5))
        meter_dict['mat_prec1'].add(mat_prec1)
        meter_dict['mat_prec5'].add(mat_prec5)
        for loss_name, var_tensor in loss_variances.items():
            meter_dict[f'{loss_name}_variance(s_hat)'].add(var_tensor.item())
            meter_dict[f'{loss_name}_variance(exp[-s_hat])'].add(
                torch.exp(-var_tensor).item())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        meter_dict['batch_time'].add(time.time() - batch_start_time)
        last_end_time = time.time()
        pbar.update()

        if batch_idx % args.show_freq == 0:
            # Visualize batch input.
            vis.image(visualize_input(batch_dict),
                      win='train-batch-example',
                      opts=dict(title='Train Batch Example'))
            # Visualize color prediction.
            if 'color' in losses:
                color_output_vis = color_binner.visualize(
                    F.softmax(color_output[0].cpu().detach(), dim=0))
                color_target_vis = color_binner.visualize(color_target[0])
                color_hist_vis = np.vstack(
                    (color_output_vis, color_target_vis))
                vis.heatmap(color_hist_vis,
                            win='train-color-hist',
                            opts=dict(title='Training Color Histogram'))
            # Show meters.
            meters = [(k, v) for k, v in meter_dict.items()]
            meter_table = MeterTable(
                f"Epoch {epoch} Iter {batch_idx+1}/{len(train_loader)}",
                meters)

            vis.text(meter_table.render(),
                     win='train-status',
                     opts=dict(title='Training Status'))

    mean_stat_dict = {k: v.mean for k, v in meter_dict.items()}

    return mean_stat_dict
예제 #4
0
def compute_substance_loss(subst_label, subst_output, criterion):
    subst_labels = subst_label.cuda()
    subst_loss = criterion(subst_output, subst_labels)
    subst_prec1 = compute_precision(subst_output, subst_labels)

    return subst_loss, subst_prec1
예제 #5
0
def validate_epoch(val_loader, model, epoch, subst_criterion, color_criterion,
                   loss_variances, loss_weights):
    meter_dict = defaultdict(AverageValueMeter)

    # switch to evaluate mode
    model.eval()

    pbar = tqdm(total=len(val_loader), desc='Validating Epoch')

    last_end_time = time.time()
    for batch_idx, batch_dict in enumerate(val_loader):
        input_var = batch_dict['image'].cuda()
        subst_labels = batch_dict['substance_label'].cuda()

        output = model.forward(input_var)

        losses = {}

        subst_output = output['substance']
        subst_fc_prec1 = compute_precision(subst_output, subst_labels)
        meter_dict['subst_fc_prec1'].add(subst_fc_prec1)
        losses['substance'] = subst_criterion(subst_output, subst_labels)

        if args.color_loss is not None:
            color_output = output['color']
            color_target = batch_dict['color_hist'].cuda()
            if isinstance(color_criterion, nn.KLDivLoss):
                color_output = F.log_softmax(color_output)
            losses['color'] = color_criterion(color_output, color_target)
        else:
            color_output = None
            color_target = None

        loss = combine_losses(losses, loss_weights, loss_variances,
                              args.use_variance)

        # Add losses to meters.
        meter_dict['loss'].add(loss.item())
        for loss_name, loss_tensor in losses.items():
            meter_dict[f'loss_{loss_name}'].add(loss_tensor.item())

        meter_dict['batch_time'].add(time.time() - last_end_time)
        last_end_time = time.time()
        pbar.update()

        if batch_idx % args.show_freq == 0:
            vis.image(visualize_input(batch_dict),
                      win='validation-batch-example',
                      opts=dict(title='Validation Batch Example'))
            if 'color' in losses:
                visualize_color_hist_output(
                    color_output[0],
                    color_target[0],
                    'validation-color-hist',
                    'Validation Color Histogram',
                    color_hist_shape=color_binner.shape,
                    color_hist_space=color_binner.space,
                )
        meters = [(k, v) for k, v in meter_dict.items()]
        meter_table = MeterTable(
            f"Epoch {epoch} Iter {batch_idx+1}/{len(val_loader)}", meters)

        vis.text(meter_table.render(),
                 win='validation-status',
                 opts=dict(title='Validation Status'))

        # measure elapsed time
        meter_dict['batch_time'].add(time.time() - last_end_time)
        last_end_time = time.time()

    mean_stat_dict = {k: v.mean for k, v in meter_dict.items()}
    return mean_stat_dict
예제 #6
0
def train_epoch(train_loader, model, epoch, subst_criterion, color_criterion,
                optimizer, loss_variances, loss_weights):
    meter_dict = defaultdict(AverageValueMeter)
    pbar = tqdm(total=len(train_loader), desc='Training Epoch')

    last_end_time = time.time()
    for batch_idx, batch_dict in enumerate(train_loader):
        meter_dict['data_time'].add(time.time() - last_end_time)

        input_var = batch_dict['image'].cuda()
        subst_labels = batch_dict['substance_label'].cuda()

        batch_start_time = time.time()
        output = model.forward(input_var)

        losses = {}

        subst_output = output['substance']
        subst_fc_prec1 = compute_precision(subst_output, subst_labels)
        meter_dict['subst_fc_prec1'].add(subst_fc_prec1)
        losses['substance'] = subst_criterion(subst_output, subst_labels)

        if args.color_loss is not None:
            color_output = output['color']
            color_target = batch_dict['color_hist'].cuda()
            if isinstance(color_criterion, nn.KLDivLoss):
                color_output = F.log_softmax(color_output)
            losses['color'] = color_criterion(color_output, color_target)
        else:
            color_output = None
            color_target = None

        loss = combine_losses(losses, loss_weights, loss_variances,
                              args.use_variance)

        # Add losses to meters.
        meter_dict['loss'].add(loss.item())
        for loss_name, loss_tensor in losses.items():
            meter_dict[f'loss_{loss_name}'].add(loss_tensor.item())

        for loss_name, var_tensor in loss_variances.items():
            meter_dict[f'{loss_name}_variance(s_hat)'].add(var_tensor.item())
            meter_dict[f'{loss_name}_variance(exp[-s_hat])'].add(
                torch.exp(-var_tensor).item())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        meter_dict['batch_time'].add(time.time() - batch_start_time)
        last_end_time = time.time()
        pbar.update()

        if batch_idx % args.show_freq == 0:
            # Visualize batch input.
            vis.image(visualize_input(batch_dict),
                      win='train-batch-example',
                      opts=dict(title='Train Batch Example'))
            # Visualize color prediction.
            if 'color' in losses:
                visualize_color_hist_output(
                    color_output[0],
                    color_target[0],
                    win='train-color-hist',
                    title='Training Color Histogram',
                    color_hist_shape=color_binner.shape,
                    color_hist_space=color_binner.space,
                )
            # Show meters.
            meters = [(k, v) for k, v in meter_dict.items()]
            meter_table = MeterTable(
                f"Epoch {epoch} Iter {batch_idx+1}/{len(train_loader)}",
                meters)

            vis.text(meter_table.render(),
                     win='train-status',
                     opts=dict(title='Training Status'))

    mean_stat_dict = {k: v.mean for k, v in meter_dict.items()}

    return mean_stat_dict