Exemple #1
0
def write_video_summary(vid_dataset, model, model_input, gt, model_output, writer, total_steps, prefix='train_'):
    resolution = vid_dataset.shape
    frames = [0, 60, 120, 200]
    Nslice = 10
    with torch.no_grad():
        coords = [dataio.get_mgrid((1, resolution[1], resolution[2]), dim=3)[None,...].cpu() for f in frames]
        for idx, f in enumerate(frames):
            coords[idx][..., 0] = (f / (resolution[0] - 1) - 0.5) * 2
        coords = torch.cat(coords, dim=0)

        output = torch.zeros(coords.shape)
        split = int(coords.shape[1] / Nslice)
        for i in range(Nslice):
            pred = model({'coords':coords[:, i*split:(i+1)*split, :]})['model_out']
            output[:, i*split:(i+1)*split, :] =  pred.cpu()

    pred_vid = output.view(len(frames), resolution[1], resolution[2], 3) / 2 + 0.5
    pred_vid = torch.clamp(pred_vid, 0, 1)
    gt_vid = torch.from_numpy(vid_dataset.vid[frames, :, :, :])
    psnr = 10*torch.log10(1 / torch.mean((gt_vid - pred_vid)**2))

    pred_vid = pred_vid.permute(0, 3, 1, 2)
    gt_vid = gt_vid.permute(0, 3, 1, 2)

    output_vs_gt = torch.cat((gt_vid, pred_vid), dim=-2)
    writer.add_image(prefix + 'output_vs_gt', make_grid(output_vs_gt, scale_each=False, normalize=True),
                     global_step=total_steps)
    min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps)
    min_max_summary(prefix + 'pred_vid', pred_vid, writer, total_steps)
    writer.add_scalar(prefix + "psnr", psnr, total_steps)
def val_fn(model, ckpt_dir, epoch):
    # Time values at which the function needs to be plotted
    times = [0., 0.5 * (opt.tMax - 0.1), (opt.tMax - 0.1)]
    num_times = len(times)

    # Theta slices to be plotted
    thetas = [-math.pi, -0.5 * math.pi, 0., 0.5 * math.pi, math.pi]
    num_thetas = len(thetas)

    # Create a figure
    fig = plt.figure(figsize=(5 * num_times, 5 * num_thetas))

    # Get the meshgrid in the (x, y) coordinate
    sidelen = 200
    mgrid_coords = dataio.get_mgrid(sidelen)

    # Start plotting the results
    for i in range(num_times):
        time_coords = torch.ones(mgrid_coords.shape[0], 1) * times[i]

        for j in range(num_thetas):
            theta_coords = torch.ones(mgrid_coords.shape[0], 1) * thetas[j]
            theta_coords = theta_coords / (opt.angle_alpha * math.pi)
            coords = torch.cat((time_coords, mgrid_coords, theta_coords),
                               dim=1)
            if torch.cuda.is_available():
                model_in = {'coords': coords.cuda()}
            else:
                model_in = {'coords': coords.cpu()}
            model_out = model(model_in)['model_out']

            # Detatch model ouput and reshape
            model_out = model_out.detach().cpu().numpy()
            model_out = model_out.reshape((sidelen, sidelen))

            # Unnormalize the value function
            norm_to = 0.02
            mean = 0.25
            var = 0.5
            model_out = (model_out * var / norm_to) + mean

            # Plot the zero level sets
            model_out = (model_out <= 0.001) * 1.

            # Plot the actual data
            ax = fig.add_subplot(num_times, num_thetas,
                                 (j + 1) + i * num_thetas)
            ax.set_title('t = %0.2f, theta = %0.2f' % (times[i], thetas[j]))
            s = ax.imshow(model_out.T,
                          cmap='bwr',
                          origin='lower',
                          extent=(-1., 1., -1., 1.))
            fig.colorbar(s)

    fig.savefig(
        os.path.join(ckpt_dir, 'BRS_validation_plot_epoch_%04d.png' % epoch))
def write_sdf_summary(model,
                      model_input,
                      gt,
                      model_output,
                      writer,
                      total_steps,
                      prefix='train_'):
    slice_coords_2d = dataio.get_mgrid(512)

    with torch.no_grad():
        yz_slice_coords = torch.cat(
            (torch.zeros_like(slice_coords_2d[:, :1]), slice_coords_2d),
            dim=-1)
        yz_slice_model_input = {'coords': yz_slice_coords.cuda()[None, ...]}

        yz_model_out = model(yz_slice_model_input)
        sdf_values = yz_model_out['model_out']
        sdf_values = dataio.lin2img(sdf_values).squeeze().cpu().numpy()
        fig = make_contour_plot(sdf_values)
        writer.add_figure(prefix + 'yz_sdf_slice',
                          fig,
                          global_step=total_steps)

        xz_slice_coords = torch.cat(
            (slice_coords_2d[:, :1], torch.zeros_like(
                slice_coords_2d[:, :1]), slice_coords_2d[:, -1:]),
            dim=-1)
        xz_slice_model_input = {'coords': xz_slice_coords.cuda()[None, ...]}

        xz_model_out = model(xz_slice_model_input)
        sdf_values = xz_model_out['model_out']
        sdf_values = dataio.lin2img(sdf_values).squeeze().cpu().numpy()
        fig = make_contour_plot(sdf_values)
        writer.add_figure(prefix + 'xz_sdf_slice',
                          fig,
                          global_step=total_steps)

        xy_slice_coords = torch.cat(
            (slice_coords_2d[:, :2],
             -0.75 * torch.ones_like(slice_coords_2d[:, :1])),
            dim=-1)
        xy_slice_model_input = {'coords': xy_slice_coords.cuda()[None, ...]}

        xy_model_out = model(xy_slice_model_input)
        sdf_values = xy_model_out['model_out']
        sdf_values = dataio.lin2img(sdf_values).squeeze().cpu().numpy()
        fig = make_contour_plot(sdf_values)
        writer.add_figure(prefix + 'xy_sdf_slice',
                          fig,
                          global_step=total_steps)

        min_max_summary(prefix + 'model_out_min_max',
                        model_output['model_out'], writer, total_steps)
        min_max_summary(prefix + 'coords', model_input['coords'], writer,
                        total_steps)
Exemple #4
0
def write_wave_summary(model, model_input, gt, model_output, writer, total_steps, prefix='train_'):

    sl = 256
    def scale_percentile(pred, min_perc=1, max_perc=99):
        min = np.percentile(pred.cpu().numpy(),1)
        max = np.percentile(pred.cpu().numpy(),99)
        pred = torch.clamp(pred, min, max)
        return (pred - min) / (max-min)

    with torch.no_grad():
        frames = [0.0, 0.05, 0.1, 0.15, 0.25]
        coords = [dataio.get_mgrid((1, sl, sl), dim=3)[None,...].cpu() for f in frames]
        for idx, f in enumerate(frames):
            coords[idx][..., 0] = f
        coords = torch.cat(coords, dim=0)

        Nslice = 10
        output = torch.zeros(coords.shape[0], coords.shape[1], 1)
        split = int(coords.shape[1] / Nslice)
        for i in range(Nslice):
            pred = model({'coords':coords[:, i*split:(i+1)*split, :]})['model_out']
            output[:, i*split:(i+1)*split, :] =  pred.cpu()

    min_max_summary(prefix + 'pred', pred, writer, total_steps)
    pred = output.view(len(frames), 1, sl, sl)

    plt.switch_backend('agg')
    fig = plt.figure()
    plt.subplot(2,2,1)
    data = pred[0, :, sl//2, :].numpy().squeeze()
    plt.plot(np.linspace(-1, 1, sl), data)
    plt.ylim([-0.01, 0.02])

    plt.subplot(2,2,2)
    data = pred[1, :, sl//2, :].numpy().squeeze()
    plt.plot(np.linspace(-1, 1, sl), data)
    plt.ylim([-0.01, 0.02])

    plt.subplot(2,2,3)
    data = pred[2, :, sl//2, :].numpy().squeeze()
    plt.plot(np.linspace(-1, 1, sl), data)
    plt.ylim([-0.01, 0.02])

    plt.subplot(2,2,4)
    data = pred[3, :, sl//2, :].numpy().squeeze()
    plt.plot(np.linspace(-1, 1, sl), data)
    plt.ylim([-0.01, 0.02])

    writer.add_figure(prefix + 'center_slice', fig, global_step=total_steps)

    pred = torch.clamp(pred, -0.002, 0.002)
    writer.add_image(prefix + 'pred_img', make_grid(pred, scale_each=False, normalize=True),
                     global_step=total_steps)
Exemple #5
0
def val_fn_BRS_posspace(model, model_1P):
    # Create a figure
    fig = plt.figure(figsize=(5 * num_slices, 5))
    fig_error = plt.figure(figsize=(5 * num_slices, 5))
    fig_valfunc = plt.figure(figsize=(5 * num_slices, 5))

    # Get the meshgrid in the (x, y) coordinate
    sidelen = 200
    mgrid_coords = dataio.get_mgrid(sidelen)

    # Time coordinates
    time_coords = torch.ones(mgrid_coords.shape[0], 1) * time

    # Start plotting the results
    for i in range(num_slices):
        coords = torch.cat((time_coords, mgrid_coords), dim=1)
        pairwise_coords = {}

        # Setup the X-Y coordinates
        for j in range(2):
            evader_key = '%i' % (j + 1) + 'E'
            # X-Y coordinates of the evaders for the full game
            xcoords = torch.ones(mgrid_coords.shape[0],
                                 1) * poss[evader_key][i][0]
            ycoords = torch.ones(mgrid_coords.shape[0],
                                 1) * poss[evader_key][i][1]
            coords = torch.cat((coords, xcoords, ycoords), dim=1)

            # X-Y coordinates of the evaders for the pairwise game
            pairwise_coords[evader_key] = torch.cat(
                (time_coords, xcoords, ycoords, mgrid_coords), dim=1)

        # Setup the theta coordinates
        coords_ego_theta = ego_vehicle_theta[i] * torch.ones(
            mgrid_coords.shape[0], 1) / (math.pi * angle_alpha)
        coords = torch.cat((coords, coords_ego_theta), dim=1)
        for j in range(2):
            evader_key = '%i' % (j + 1) + 'E'
            # Theta coordinates of the evaders for the full game
            tcoords = torch.ones(
                mgrid_coords.shape[0],
                1) * thetas[evader_key][i] / (math.pi * angle_alpha)
            coords = torch.cat((coords, tcoords), dim=1)

            # Theta coordinates of the evaders for the pairwise game
            pairwise_coords[evader_key] = torch.cat(
                (pairwise_coords[evader_key], tcoords, coords_ego_theta),
                dim=1)

        model_in = {'coords': coords[:, None, :].cuda()}
        model_out = model(model_in)

        # Detatch model ouput and reshape
        model_out = model_out['model_out'].detach().cpu().numpy()
        model_out = model_out.reshape((sidelen, sidelen))

        # Unnormalize the value function
        norm_to = 0.02
        mean = 0.25
        var = 0.5
        model_out = (model_out * var / norm_to) + mean

        # Plot the zero level sets
        valfunc = model_out * 1.
        model_out = (model_out <= level) * 1.

        # Plot the actual data and small aircrafts
        ax = fig.add_subplot(1, num_slices, i + 1)
        ax_valfunc = fig_valfunc.add_subplot(1, num_slices, i + 1)
        ax_error = fig_error.add_subplot(1, num_slices, i + 1)
        aircraft_size = 0.2
        sA = {}
        for j in range(2):
            evader_key = '%i' % (j + 1) + 'E'
            aircraft_image = scipy.ndimage.rotate(
                plt.imread('resources/ego_aircraft.png'),
                180.0 * thetas[evader_key][i] / math.pi)
            sA[evader_key] = ax.imshow(
                aircraft_image,
                extent=(poss[evader_key][i][0] - aircraft_size,
                        poss[evader_key][i][0] + aircraft_size,
                        poss[evader_key][i][1] - aircraft_size,
                        poss[evader_key][i][1] + aircraft_size))
            ax.plot(poss[evader_key][i][0], poss[evader_key][i][1], "o")
        s = ax.imshow(model_out.T,
                      cmap='bwr_r',
                      alpha=0.5,
                      origin='lower',
                      vmin=-1.,
                      vmax=1.,
                      extent=(-1., 1., -1., 1.))
        sV1 = ax_valfunc.imshow(valfunc.T,
                                cmap='bwr_r',
                                alpha=0.8,
                                origin='lower',
                                vmin=-0.2,
                                vmax=0.2,
                                extent=(-1., 1., -1., 1.))
        sV2 = ax_valfunc.contour(valfunc.T,
                                 cmap='bwr_r',
                                 alpha=0.5,
                                 origin='lower',
                                 vmin=-0.2,
                                 vmax=0.2,
                                 levels=30,
                                 extent=(-1., 1., -1., 1.))
        plt.clabel(sV2, levels=30, colors='k')
        fig_valfunc.colorbar(sV1)

        # Compute and plot pairwise collision sets
        sP = {}
        model_out_pairwise_sofar = None
        valfunc_pairwise = None
        for j in range(2):
            evader_key = '%i' % (j + 1) + 'E'
            model_in_pairwise = {'coords': pairwise_coords[evader_key].cuda()}
            model_out_pairwise = model_1P(
                model_in_pairwise)['model_out'].detach().cpu().numpy()
            model_out_pairwise = model_out_pairwise.reshape((sidelen, sidelen))
            norm_to_pairwise = 0.02
            mean_pairwise = 0.25
            var_pairwise = 0.5
            model_out_pairwise = (model_out_pairwise * var_pairwise /
                                  norm_to_pairwise) + mean_pairwise

            if model_out_pairwise_sofar is None:
                model_out_pairwise_sofar = (model_out_pairwise <= level) * 1.
                valfunc_pairwise = model_out_pairwise * 1.
            else:
                model_out_pairwise_sofar = np.clip(
                    (model_out_pairwise <= level) * 1. +
                    model_out_pairwise_sofar, 0., 1.)
                valfunc_pairwise = np.minimum(valfunc_pairwise,
                                              model_out_pairwise * 1.0)

        s2 = ax.imshow(model_out_pairwise_sofar.T,
                       cmap='seismic',
                       alpha=0.5,
                       origin='lower',
                       vmin=-1.,
                       vmax=1.,
                       extent=(-1., 1., -1., 1.))

        # Error plot
        error = np.clip(model_out - model_out_pairwise_sofar, 0., 1.)
        ax_error.imshow(error.T,
                        cmap='bwr',
                        origin='lower',
                        vmin=-1.,
                        vmax=1.,
                        extent=(-1., 1., -1., 1.))

        ## Append the value functions
        val_functions['pairwise'].append(valfunc_pairwise)
        val_functions['full'].append(valfunc)

    return fig, fig_error, fig_valfunc, val_functions
def write_helmholtz_summary(model,
                            model_input,
                            gt,
                            model_output,
                            writer,
                            total_steps,
                            prefix='train_'):
    sl = 256
    coords = dataio.get_mgrid(sl)[None, ...].cuda()

    def scale_percentile(pred, min_perc=1, max_perc=99):
        min = np.percentile(pred.cpu().numpy(), 1)
        max = np.percentile(pred.cpu().numpy(), 99)
        pred = torch.clamp(pred, min, max)
        return (pred - min) / (max - min)

    with torch.no_grad():
        if 'coords_sub' in model_input:
            summary_model_input = {
                'coords':
                coords.repeat(min(2, model_input['coords_sub'].shape[0]), 1, 1)
            }
            summary_model_input['coords_sub'] = model_input['coords_sub'][:2,
                                                                          ...]
            summary_model_input['img_sub'] = model_input['img_sub'][:2, ...]
            pred = model(summary_model_input)['model_out']
        else:
            pred = model({'coords': coords})['model_out']

        if 'pretrain' in gt:
            gt['squared_slowness_grid'] = pred[..., -1, None].clone() + 1.
            if torch.all(gt['pretrain'] == -1):
                gt['squared_slowness_grid'] = torch.clamp(
                    pred[..., -1, None].clone(), min=-0.999) + 1.
                gt['squared_slowness_grid'] = torch.where(
                    (torch.abs(coords[..., 0, None]) > 0.75) |
                    (torch.abs(coords[..., 1, None]) > 0.75),
                    torch.ones_like(gt['squared_slowness_grid']),
                    gt['squared_slowness_grid'])
            pred = pred[..., :-1]

        pred = dataio.lin2img(pred)

        pred_cmpl = pred[..., 0::2, :, :].cpu().numpy(
        ) + 1j * pred[..., 1::2, :, :].cpu().numpy()
        pred_angle = torch.from_numpy(np.angle(pred_cmpl))
        pred_mag = torch.from_numpy(np.abs(pred_cmpl))

        min_max_summary(prefix + 'coords', model_input['coords'], writer,
                        total_steps)
        min_max_summary(prefix + 'pred_real', pred[..., 0::2, :, :], writer,
                        total_steps)
        min_max_summary(
            prefix + 'pred_abs',
            torch.sqrt(pred[..., 0::2, :, :]**2 + pred[..., 1::2, :, :]**2),
            writer, total_steps)
        min_max_summary(prefix + 'squared_slowness',
                        gt['squared_slowness_grid'], writer, total_steps)

        pred = scale_percentile(pred)
        pred_angle = scale_percentile(pred_angle)
        pred_mag = scale_percentile(pred_mag)

        pred = pred.permute(1, 0, 2, 3)
        pred_mag = pred_mag.permute(1, 0, 2, 3)
        pred_angle = pred_angle.permute(1, 0, 2, 3)

    writer.add_image(prefix + 'pred_real',
                     make_grid(pred[0::2, :, :, :],
                               scale_each=False,
                               normalize=True),
                     global_step=total_steps)
    writer.add_image(prefix + 'pred_imaginary',
                     make_grid(pred[1::2, :, :, :],
                               scale_each=False,
                               normalize=True),
                     global_step=total_steps)
    writer.add_image(prefix + 'pred_angle',
                     make_grid(pred_angle, scale_each=False, normalize=True),
                     global_step=total_steps)
    writer.add_image(prefix + 'pred_mag',
                     make_grid(pred_mag, scale_each=False, normalize=True),
                     global_step=total_steps)

    if 'gt' in gt:
        gt_field = dataio.lin2img(gt['gt'])
        gt_field_cmpl = gt_field[..., 0, :, :].cpu().numpy(
        ) + 1j * gt_field[..., 1, :, :].cpu().numpy()
        gt_angle = torch.from_numpy(np.angle(gt_field_cmpl))
        gt_mag = torch.from_numpy(np.abs(gt_field_cmpl))

        gt_field = scale_percentile(gt_field)
        gt_angle = scale_percentile(gt_angle)
        gt_mag = scale_percentile(gt_mag)

        writer.add_image(prefix + 'gt_real',
                         make_grid(gt_field[..., 0, :, :],
                                   scale_each=False,
                                   normalize=True),
                         global_step=total_steps)
        writer.add_image(prefix + 'gt_imaginary',
                         make_grid(gt_field[..., 1, :, :],
                                   scale_each=False,
                                   normalize=True),
                         global_step=total_steps)
        writer.add_image(prefix + 'gt_angle',
                         make_grid(gt_angle, scale_each=False, normalize=True),
                         global_step=total_steps)
        writer.add_image(prefix + 'gt_mag',
                         make_grid(gt_mag, scale_each=False, normalize=True),
                         global_step=total_steps)
        min_max_summary(prefix + 'gt_real', gt_field[..., 0, :, :], writer,
                        total_steps)

    velocity = torch.sqrt(1 / dataio.lin2img(gt['squared_slowness_grid']))[:1]
    min_max_summary(prefix + 'velocity', velocity[..., 0, :, :], writer,
                    total_steps)
    velocity = scale_percentile(velocity)
    writer.add_image(prefix + 'velocity',
                     make_grid(velocity[..., 0, :, :],
                               scale_each=False,
                               normalize=True),
                     global_step=total_steps)

    if 'squared_slowness_grid' in gt:
        writer.add_image(prefix + 'squared_slowness',
                         make_grid(dataio.lin2img(
                             gt['squared_slowness_grid'])[:2, :1],
                                   scale_each=False,
                                   normalize=True),
                         global_step=total_steps)

    if 'img_sub' in model_input:
        writer.add_image(prefix + 'img',
                         make_grid(dataio.lin2img(
                             model_input['img_sub'])[:2, :1],
                                   scale_each=False,
                                   normalize=True),
                         global_step=total_steps)

    if isinstance(model, meta_modules.NeuralProcessImplicit2DHypernetBVP):
        hypernet_activation_summary(model, model_input, gt, model_output,
                                    writer, total_steps, prefix)
Exemple #7
0
def plot_sds_with_gradients(gt_contour, gt_normals, mgrid_sds, mgrid_grads):
    '''Plot signed distances with gradients.

    gt_contour: np array of shape (-1, 2) with x,y coordinates of gt contour.
    mgrid_sds: signed distances evaluated on square meshgrid.
    mgrid_grads: grads evaluated on square meshgrid.
    '''
    # Images are square, but flattened - compute the sidelength.
    num_pixels = mgrid_grads.shape[1]
    sidelen = int(np.sqrt(num_pixels))

    mgrid = dataio.get_mgrid(sidelen).detach().cpu().numpy()
    x = np.linspace(-1, 1, sidelen)
    y = np.linspace(-1, 1, sidelen)

    mgrid_grads = mgrid_grads[0, ...].detach().cpu().numpy()

    mgrid_grads_mag = np.linalg.norm(mgrid_grads, axis=-1)

    gt_contour = gt_contour[0, ...].detach().cpu().numpy()
    gt_normals = gt_normals[0, ...].detach().cpu().numpy()

    mgrid_sds = mgrid_sds[0, ...].detach().cpu().numpy()
    mgrid_sds = mgrid_sds.reshape(sidelen, sidelen)

    fig, (axa, axb, axc) = plt.subplots(nrows=1, ncols=3, figsize=(30, 10))
    axa.cla(), axb.cla(), axc.cla()

    # PLOT A: Ground truth mesh
    axa.set_xlim([mgrid.min(),
                  mgrid.max()]), axa.set_ylim([mgrid.min(),
                                               mgrid.max()])
    axa.plot(gt_contour[..., 1], gt_contour[..., 0] * -1.)
    q = axa.quiver(gt_contour[..., 1],
                   gt_contour[..., 0] * -1.,
                   gt_normals[..., 1],
                   gt_normals[..., 0] * -1.,
                   scale=25.)
    axa.set_title('Ground truth level set')

    # PLOT A: Predicted SDs
    axb.imshow(mgrid_sds)
    axb.contour(mgrid_sds, levels=[0], colors='k', linestyles='-')
    axb.set_title('Predicted Signed Distance')

    # PLOT B: GRADIENT DIRECTIONS
    grad_subsample = 1
    quiver_coords = mgrid.reshape(
        sidelen, sidelen,
        2)[::grad_subsample, ::grad_subsample, :].reshape(-1, 2)
    quiver_mgrid_grads = mgrid_grads.reshape(
        sidelen, sidelen,
        2)[::grad_subsample, ::grad_subsample, :].reshape(-1, 2)

    axc.set_xlim([mgrid.min(),
                  mgrid.max()]), axc.set_ylim([mgrid.min(),
                                               mgrid.max()])
    axc.set_xlabel('x'), axc.set_ylabel('y')
    axc.set_xticks([mgrid.min(), 0., mgrid.max()
                    ]), axc.set_yticks([mgrid.min(), 0.,
                                        mgrid.max()])

    q = axc.quiver(quiver_coords[..., 1], quiver_coords[..., 0] * -1,
                   quiver_mgrid_grads[..., 1], quiver_mgrid_grads[..., 0] * -1)

    axc.set_title('Orientations of Gradients')

    plt.show()
Exemple #8
0
    in_features=img_dataset_test.img_channels,
    out_features=img_dataset_test.img_channels,
    image_resolution=image_resolution,
    partial_conv=opt.partial_conv)
model.cuda()
model.eval()

root_path = os.path.join(opt.logging_root, opt.experiment_name)
utils.cond_mkdir(root_path)

# Load checkpoint
model.load_state_dict(torch.load(opt.checkpoint_path))

# First experiment: Upsample training image
model_input = {
    'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(),
    'img_sparse': coord_dataset_train[0][0]['img_sparse'].unsqueeze(0).cuda()
}
model_output = model(model_input)

out_img = dataio.lin2img(model_output['model_out'],
                         image_resolution).squeeze().permute(
                             1, 2, 0).detach().cpu().numpy()
out_img += 1
out_img /= 2.
out_img = np.clip(out_img, 0., 1.)

imageio.imwrite(os.path.join(root_path, 'upsampled_train.png'), out_img)

# Second experiment: sample larger range
model_input = {