Ejemplo n.º 1
0
def box_experiment(data_params={}, model_params={}, stage='predict'):
    data, model = load_params(data_params, model_params)
    data = apply_preprocess(data)
    if stage == 'predict':
        mdl = GANModel(data, model)

        convergence_counter = 0
        for epoch in range(model['p_num_epochs']):
            for _ in range(model['d_steps']):
                mdl.D.zero_grad()
                d_error = 0
                for __ in range(model['minibatch_size']):
                    d_real_data = Variable(mdl.d_sampler(1))
                    d_real_decision = mdl.D(mdl.preprocess(d_real_data).t())
                    d_real_error = mdl.criterion(d_real_decision,
                                                 Variable(torch.ones(1)))
                    # update yet

                    if np.random.rand() > 0.3:
                        d_gen_input = mdl.g_sampler(1, model['g_input_size'])
                        # detach to avoid training G
                        d_fake_data = mdl.G(d_gen_input).detach()
                        d_fake_decision = mdl.D(mdl.preprocess(d_fake_data))
                    else:
                        d_gen_input = Variable(mdl.adversarial_d_sampler(1))
                        d_fake_decision = mdl.D(
                            mdl.preprocess(d_gen_input).t())

                    d_fake_error = mdl.criterion(d_fake_decision,
                                                 Variable(torch.zeros(1)))
                    d_error += d_real_error + d_fake_error
                d_error.backward()

                mdl.d_optimizer.step()

            for _ in range(model['g_steps']):
                mdl.G.zero_grad()

                g_total_error = 0
                for __ in range(model['minibatch_size']):
                    gen_input = Variable(
                        mdl.g_sampler(1, model['g_input_size']))
                    g_fake_data = mdl.G(gen_input)
                    g_fake_decision = mdl.D(mdl.preprocess(g_fake_data))
                    g_error = mdl.criterion(g_fake_decision,
                                            Variable(torch.ones(1)))
                    g_total_error += g_error
                g_total_error.backward()

                mdl.g_optimizer.step()

            if epoch % model['print_interval'] == 0:
                print('epoch: {}, D: {}/{}, G: {}'.format(
                    epoch,
                    extract(d_real_error)[0],
                    extract(d_fake_error)[0],
                    extract(g_error)[0]))
                real_dist = stats(extract(d_real_data))
                fake_dist = stats(extract(d_fake_data))

        # final discriminator update
        for _ in range(50):
            mdl.D.zero_grad()
            d_error = 0
            for __ in range(model['minibatch_size']):
                d_real_data = Variable(mdl.d_sampler(1))
                d_real_decision = mdl.D(mdl.preprocess(d_real_data).t())
                d_real_error = mdl.criterion(d_real_decision,
                                             Variable(torch.ones(1)))

                if np.random.rand() > 0.3:
                    d_gen_input = mdl.g_sampler(1, model['g_input_size'])
                    # detach to avoid training G
                    d_fake_data = mdl.G(d_gen_input).detach()
                    d_fake_decision = mdl.D(mdl.preprocess(d_fake_data))

                else:
                    d_gen_input = Variable(mdl.adversarial_d_sampler(1))
                    d_fake_decision = mdl.D(mdl.preprocess(d_gen_input).t())
                d_fake_decision = mdl.D(mdl.preprocess(d_fake_data))
                d_fake_error = mdl.criterion(d_fake_decision,
                                             Variable(torch.zeros(1)))
                # d_fake_error.backward()
                d_error += d_real_error + d_fake_error
            d_error.backward()

            mdl.d_optimizer.step()

        save_filename = 'ganG_{}_id_{}.pth'.format(data['mode'], model['id'])
        save_path = os.path.join(model['save_dir'], save_filename)
        torch.save(mdl.G, save_path)

        save_filename = 'ganD_{}_id_{}.pth'.format(data['mode'], model['id'])
        save_path = os.path.join(model['save_dir'], save_filename)
        torch.save(mdl.D, save_path)
        generate_2d_samples(mdl, show_real=True)
        plot.show()

    elif stage == 'optimize':
        mdl = GANModel(data, model, load_model=True)
        # debug_discriminator(mdl)
        mdl.epoch = 'pre'
        generate_2d_samples(mdl, show_real=True, save_fig=True)
        # plot.show()

        optimizer(mdl)
        generate_2d_samples(mdl, show_real=True)
        # plot.show()

    elif stage == 'plot':
        mdl = GANModel(data, model, load_model=True)
        generate_2d_samples(mdl, show_real=True)
        plot.show()

    elif stage == 'compare':
        mdl = GANModel(data, model, load_model=True)
        mdl.epoch = 'pre'
        res_real = optimizer(mdl, compare='store')

        mdl = GANModel(data, model, load_model=True)
        mdl.epoch = 'pre'
        res_fake = optimizer(mdl, compare='unconstrained')

        for real, fake in zip(res_real, res_fake):
            save_filename = 'comparison_id_{}_loss_{}_epoch_{}.png'.format(
                mdl.model['id'], mdl.model['loss'], real[0])
            save_path = os.path.join(mdl.model['plots_dir'], save_filename)

            kwargs = {'alpha': 0.3, 'numticks': 5, 'save_fig': save_path}

            fig, ax = plot.plot_2d_samples(real[1], **kwargs)

            kwargs['color'] = '#be3392'
            kwargs['zorder'] = -5
            fig, ax = plot.plot_2d_samples(fake[1], fig, ax, **kwargs)

            gen_dist = mdl.d_sampler(1000).numpy()
            kwargs['color'] = '#ed7d31'
            kwargs['alpha'] = 0.1
            kwargs['zorder'] = -10
            kwargs['mode'] = mdl.data['mode']
            plot.plot_2d_samples(gen_dist, fig, ax, **kwargs)

    else:
        raise NotImplementedError('dont recognize stage {}'.format(stage))
Ejemplo n.º 2
0
    plt.ylabel('$\Delta$ test accuracy')
    plt.legend()
    plt.tight_layout()
    plt.title(
        'ResNet-50 $\Delta$ test accuracy after {idx}{suffix} pruning iteration ({sparsity:.2%} sparsity)'
        .format(
            idx=it,
            suffix=suffix_of_number(it),
            sparsity=1 - density,
        ))
    vals = plt.gca().get_yticks()
    plt.gca().set_yticklabels(['{:,.2%}'.format(x) for x in vals])

    ticks1 = np.linspace(0, 90, 11)
    int_ticks1 = [round(i) for i in ticks1]
    plt.gca().set_xticks(int_ticks1)

    vals = plt.gca().get_xticks()
    plt.gca().set_xticklabels(
        [r'${} \times {}$'.format(int(x), it) if x > 0 else '0' for x in vals])

    plt.tight_layout()

    if plot_utils.save():
        plt.savefig(
            os.path.join('results_iterative',
                         'resnet50_{}'.format(it) + '.png'))

if plot_utils.show():
    plt.show()
Ejemplo n.º 3
0
generator = keras.models.Sequential([
    keras.layers.Dense(7 * 7 * 128, input_shape=[num_features]),
    keras.layers.Reshape([7, 7, 128]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(64, (5, 5), (2, 2),
                                 padding="same",
                                 activation="selu"),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(1, (5, 5), (2, 2),
                                 padding="same",
                                 activation="tanh"),
])

noise = tf.random.normal(shape=[1, num_features])
generated_images = generator(noise, training=False)
plot_utils.show(generated_images, 1)

#Build the Discriminator Network for DCGAN

discriminator = keras.models.Sequential([
    keras.layers.Conv2D(64, (5, 5), (2, 2),
                        padding="same",
                        input_shape=[28, 28, 1]),
    keras.layers.LeakyReLU(0.2),
    keras.layers.Dropout(0.3),
    keras.layers.Conv2D(128, (5, 5), (2, 2), padding="same"),
    keras.layers.LeakyReLU(0.2),
    keras.layers.Dropout(0.3),
    keras.layers.Flatten(),
    keras.layers.Dense(1, activation='sigmoid')
])
Ejemplo n.º 4
0
def visualize_field(field, only_window=True, approximate_wavelength=None, filter_label=None, use_autostretch=False, white_point_scale=1, use_log=False, title='', output_path=None):

    use_colors = False
    was_stretched = False

    wavelength = None

    vmin = None
    vmax = None

    if only_window:
        field_values = field.get_values_inside_window()
        extent = field.grid.get_window_bounds()
    else:
        field_values = field.values
        extent = field.grid.get_bounds()

    if isinstance(field, FilteredSpectralField):
        if filter_label is None:
            use_colors = True
            use_log = False
            field_values = np.moveaxis(field_values, 0, 2)
            if not use_autostretch:
                field_values = image_utils.perform_histogram_clip(field_values, white_point_scale=white_point_scale)
                was_stretched = True
        else:
            channel_idx = field.channels[filter_label]
            field_values = field_values[channel_idx, :, :]
            wavelength = field.central_wavelengths[channel_idx]
    elif isinstance(field, SpectralField):
        wavelength_idx = 0 if approximate_wavelength is None else np.argmin(np.abs(field.wavelengths - approximate_wavelength))
        wavelength = field.wavelengths[wavelength_idx]
        field_values = field_values[wavelength_idx, :, :]

    if field.dtype == np.dtype('complex128') and not use_colors:
        phases = np.angle(field_values)
        field_values = np.abs(field_values)
        phases[field_values == 0] = np.nan
    else:
        field_values = field_values.astype('float64')

    if use_log:
        field_values = np.log10(field_values)
        log_label = 'log10 '
    else:
        log_label = ''

    if not was_stretched:
        if use_autostretch:
            field_values = image_utils.perform_autostretch(field_values)
            vmin = 0
            vmax = 1
        else:
            vmin = np.min(field_values)
            vmax = np.max(field_values)*white_point_scale

    if wavelength is None:
        wavelength = 1
    else:
        wavelength_text = '{:g} nm'.format(wavelength*1e9)
        title = wavelength_text if title == '' else '{} ({})'.format(title, wavelength_text)

    colorbar = not (use_colors or use_autostretch)

    xlabel = r'$x$'
    ylabel = r'$y$'
    phase_clabel = 'Field phase [rad]'

    if field.grid.grid_type == 'source':
        xlabel = r'$d_x$'
        ylabel = r'$d_y$'
        clabel = '{}Field amplitude [sqrt(W/m^2/m)]'.format(log_label)
    elif field.grid.grid_type == 'aperture':
        xlabel = r'$x$ [m]'
        ylabel = r'$y$ [m]'
        clabel = '{}Field amplitude [sqrt(W/m^2/m)]'.format(log_label)
        extent *= wavelength
    elif field.grid.grid_type == 'image':
        xlabel = r'$x$ [focal lengths]'
        ylabel = r'$y$ [focal lengths]'
        clabel = '{}Flux [W/m^2/m]'.format(log_label)

    fig = plot_utils.figure()

    if field.dtype == np.dtype('complex128') and not use_colors:
        left_ax = fig.add_subplot(121)
        plot_utils.plot_image(fig, left_ax, field_values, xlabel=xlabel, ylabel=ylabel, title=title, extent=extent, clabel=clabel)
        right_ax = fig.add_subplot(122)
        plot_utils.plot_image(fig, right_ax, phases, xlabel=xlabel, ylabel=ylabel, title=title, extent=extent, clabel=phase_clabel)
    else:
        ax = fig.add_subplot(111)
        plot_utils.plot_image(fig, ax, field_values, vmin=vmin, vmax=vmax, xlabel=xlabel, ylabel=ylabel, title=title, extent=extent, clabel=clabel, colorbar=colorbar)

    plot_utils.tight_layout()

    if output_path is None:
        plot_utils.show()
    else:
        plot_utils.savefig(output_path)
Ejemplo n.º 5
0
def eval_estimator(type_estimator, X, Y, simparams, cachefile=None):
    all_estimates = (list(ALL_PARAMETERS.split()) + [ALL_PARAMETERS])

    if cachefile is not None and os.path.exists(cachefile):
        with open(cachefile, "rb") as f:
            results = pickle.load(f)
        for estimates in all_estimates:
            print("Estimates", estimates)
            X1 = np.round(results[estimates]["in"], 2)
            X2 = np.round(results[estimates]["out"], 5)
            X1 = results[estimates]["in"]
            X2 = results[estimates]["out"]
            #X1 = results[estimates]["factors"]
            plot_utils.plot_error_density(
                X1, X2, title="Estimation error of {}".format(estimates))
            plot_utils.plot_error_confidence_interval2(
                X1,
                X2,
                title="Estimation error of {}".format(estimates),
                maxerror=0.5)
            plt.show()
            ####plot_utils.plot_error_confidence_interval2(X1, results[estimates]["out"], title="Estimation error of {}".format(estimates))
            plot_utils.show()
        return results

    results = {}

    for estimates in all_estimates:
        simparams2 = [(a, b, c) for (a, b, c) in simparams if b == estimates]
        idxsimparams2 = [
            i for i, (a, b, c) in enumerate(simparams) if b == estimates
        ]
        factors = np.array([c for (a, b, c) in simparams if b == estimates])
        print("Estimates", estimates)

        inErrs, outErrs, bestEstParams, performance_metrics = _test_set_parameters(
            type_estimator, X, Y, simparams2)
        #'''
        # Remove outliers
        xperformance_metrics = performance_metrics.copy()
        xfactors = factors.copy()
        xinErrs = inErrs.copy()
        xoutErrs = outErrs.copy()

        performance_metrics = performance_metrics[:, np.isfinite(outErrs)]
        inErrs = inErrs[np.isfinite(outErrs)]
        factors = factors[np.isfinite(outErrs)]
        outErrs = outErrs[np.isfinite(outErrs)]

        #inErrs = inErrs[np.abs(outErrs - outErrs.mean()) < 7 * outErrs.std()]
        #factors = factors[np.abs(outErrs - outErrs.mean()) < 7 * outErrs.std()]
        #outErrs = outErrs[np.abs(outErrs - outErrs.mean()) < 7 * outErrs.std()]
        #
        try:
            X1 = np.round(inErrs, 2)
            X2 = np.round(outErrs, 5)
            X1 = inErrs
            X2 = outErrs
            plot_utils.plot_error_density(
                X1, X2, title="Estimation error of {}".format(estimates))
            plot_utils.plot_error_confidence_interval2(
                X1,
                X2,
                title="Estimation error of {}".format(estimates),
                maxerror=0.5)
            plt.show()
            ####plot_utils.plot_error_confidence_interval2(X1, outErrs, title="Estimation error of {}".format(estimates))
            plot_utils.show()
        except Exception as e:
            print("ERROR: Plot failed!", e)
            plot_utils.close()

        results[estimates] = {
            "in": inErrs,
            "out": outErrs,
            "xin": xinErrs,
            "xout": xoutErrs,
            "estimator": bestEstParams,
            "idx": np.array(idxsimparams2),
            "factors": factors,
            "xfactors": xfactors,
            "performance": performance_metrics,
            "xperformance": xperformance_metrics,
        }

    if cachefile is not None:
        with open(cachefile, "wb") as f:
            pickle.dump(results, f)
    return results
alexnet.classifier.add_module("softmax", Softmax(dim=1))
alexnet.eval()

# ## Channelwise Regularization/Transformation Comparisons
# Regularization is the most important part for max mean activation to pick up the information that is required. In the following, we will examine the effect of different regularizers.
# ### Baseline: No Regularization
# For illustration, we optimize for some channel of the 9th feature layer in alexnet.

# In[3]:

from midnite.visualization.base import *
from plot_utils import show

show(
    PixelActivation(
        alexnet.features[:9],
        SplitSelector(ChannelSplit(), [1]),
    ).visualize())

# ### Weight Decay
# Decays the gradient during optimization, i.e. causes less relevant parts of the optimized image to vanish.

# In[4]:

show(
    PixelActivation(alexnet.features[:9],
                    SplitSelector(ChannelSplit(), [1]),
                    regularization=[WeightDecay(decay_factor=1e-3)
                                    ]).visualize())

# ### Blur Filter
Ejemplo n.º 7
0
show_normalized(img)


# ### Step 3: Predict Uncertainties

# In[8]:


from plot_utils import show, show_heatmap

with torch.no_grad():
    with midnite.device("cpu"): # GPU device if available, e.g. "cuda:0"!
        pred, pred_entropy, mutual_info = fcn_ensemble(img)

print("Max prediction:")
show(pred.argmax(dim=1))

print("Predictive entropy (total uncertainty):")
show_heatmap(pred_entropy, 1.2)

print("Mutual information (model uncertainty):")
show_heatmap(mutual_info, 1.2)


# ### Interpretation
# 
# If we overlay the original image and the mutual information, we can see on which objects the segmentation model was unsure:

# In[9]: