Beispiel #1
0
def visualize_filters():
    # Create models
    conv_net = ConvClassificationModel()

    # Load model
    conv_net.load(
        torch.load('models\\pre_trained_models\\mnist_digit_conv.model'))

    # Plot convolutional filters
    filters = [
        conv_net.conv1.weight.detach().numpy(),
        conv_net.conv2.weight.detach().numpy()
    ]
    plot_filter_layer(filters)
def visualize_adv_affects_accuracy():
     # Hyper parameters
    seed = 2
    batch_size = 1
    data_name = 'MNIST'
    attack_name = 'OnePixel'

    # Download MNIST data set
    test_set = get_data(data_name, False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size, shuffle=True)

    # Create model
    conv_net = ConvClassificationModel()

    # Load model
    conv_net.load(torch.load('models\\pre_trained_models\\mnist_digit_conv.model'))

    # Get base accuracy
    set_seed(seed)
    base_conv_acc = conv_net.eval_model(test_loader)

    # Evaluate noise
    loss_func = torch.nn.NLLLoss()
    if attack_name == 'FGSM':
        axis_title = 'Epsilon'
        conv_x, conv_acc = evaluate_fgsm_attack(conv_net, test_loader, loss_func, seed)
    elif attack_name == 'OnePixel':
        axis_title = 'Number of Pixels'
        conv_x, conv_acc = evaluate_onepixel_attack(conv_net, test_loader, loss_func, seed)
    else:
        raise NotImplementedError('Unknown attack type: {0}. Available attacks; [FGSM, OnePixel]'.format(attack_name))

    # Insert base accuracy
    conv_x.insert(0,0)
    conv_acc.insert(0, base_conv_acc)

    # Plot results
    data = [(conv_x, conv_acc)]
    plot_accuracy(data, attack_name, axis_title)
Beispiel #3
0
def visualize_noisy_affects_accuracy():
    # Hyper parameters
    seed = 2
    batch_size = 64
    data_name = 'MNIST'

    # Download MNIST data set
    test_set = get_data(data_name, False)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size,
                                              shuffle=True)

    conv_net = ConvClassificationModel()

    # Load model
    conv_net.load(
        torch.load('models\\pre_trained_models\\mnist_digit_conv.model'))

    # Get base accuracy
    set_seed(seed)
    base_conv_acc = conv_net.eval_model(test_loader)

    data = []
    for noise_type in ['snp', 'gaussian_gray']:
        # Evaluate noise
        conv_percent, conv_acc = evaluate_noise(conv_net, 'MNIST', noise_type,
                                                batch_size, seed)

        # Insert base accuracy
        conv_percent.insert(0, 0)
        conv_acc.insert(0, base_conv_acc)

        # Plot results
        data.append((conv_percent, conv_acc))

    plot_accuracy(data, ['Salt-And-Pepper Noise', 'Gaussian Noise'])
def train_mnist_digit_models():
    # Hyper parameters
    seed = 2
    num_epochs = 5
    learning_rate = 0.01
    momentum = 0.5
    batch_size = 64

    # Download MNIST data set
    train_set = get_data('MNIST', True)

    # MNIST digit dataset values
    input_size = np.prod(train_set.data.shape[1:])
    output_size = len(train_set.classes)

    # Use torch data loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size,
                                               shuffle=True)

    # Train and save models
    set_seed(seed)
    print('Training Non-Convolutional Classification Network...')
    nonconv_net = NonConvClassificationModel(input_size, output_size)
    nonconv_net = run_training(nonconv_net, learning_rate, momentum,
                               num_epochs, train_loader, test_loader)
    torch.save(nonconv_net.save(),
               'models\\pre_trained_models\\mnist_digit_nonconv.model')

    set_seed(seed)
    print('Training Convolutional Classification Network...')
    conv_net = ConvClassificationModel()
    conv_net = run_training(conv_net, learning_rate, momentum, num_epochs,
                            train_loader, test_loader)
    torch.save(conv_net.save(),
               'models\\pre_trained_models\\mnist_digit_conv.model')
def search_for_filter_match():
    # Parameters
    seed = 2
    data_name = 'MNIST'
    noise_type = 'snp'
    noisy_image_index = 118

    # Create models
    set_seed(seed)
    conv_net = ConvClassificationModel()

    # Load models
    conv_net.load(
        torch.load('models\\pre_trained_models\\mnist_digit_conv.model'))

    # Get test data set
    test_set = get_data(data_name, False)
    clean_data, original_label = test_set[noisy_image_index]
    clean_data = clean_data.detach().numpy()[0, :, :]

    # Get second feature representation for index
    clean_data_torch = torch.from_numpy(clean_data)
    clean_data_torch = clean_data_torch.view(1, 1,
                                             *clean_data_torch.shape).float()
    clean_label = conv_net.get_label(clean_data_torch).detach().numpy()[0][0]
    clean_second_feature_set = get_plots(clean_data_torch,
                                         conv_net)[(2, 'Filter')][0]

    # Read noisy image file
    filename = 'data\\{}\\noisy_data_np\\{}_{}.npy'.format(
        data_name, noise_type, noisy_image_index)
    noisy_data = np.flipud(np.load(filename)).copy()

    # Get second feature representation for index
    noisy_data_torch = torch.from_numpy(noisy_data)
    noisy_data_torch = noisy_data_torch.view(1, 1,
                                             *noisy_data_torch.shape).float()
    noisy_label = conv_net.get_label(noisy_data_torch).detach().numpy()[0][0]
    noisy_second_feature_set = get_plots(noisy_data_torch,
                                         conv_net)[(2, 'Filter')][0]

    # Get training data set
    train_set = get_data(data_name, True)

    # Search for similar feature set
    sim_feature_clean = ()
    sim_feature_noisy = ()
    max_sim_clean = -float('inf')
    max_sim_noisy = -float('inf')
    for i in range(len(train_set)):
        train_data, train_label = train_set[i]

        if train_label == noisy_label or train_label == clean_label:
            # Compare second feature set
            train_data_torch = train_data.view(1, *train_data.shape).float()
            train_second_feature_set = get_plots(train_data_torch,
                                                 conv_net)[(2, 'Filter')][0]

            simularity = 0.0
            num_features = train_second_feature_set.shape[0]

            if train_label == noisy_label:
                for j in range(num_features):
                    # Calculate similarity
                    simularity += calculate_similarity(
                        train_second_feature_set[j, :, :],
                        noisy_second_feature_set[j, :, :])
                simularity /= float(num_features)

                if simularity > max_sim_noisy:
                    sim_feature_noisy = (simularity, i, train_data,
                                         train_second_feature_set)
                    max_sim_noisy = simularity
            elif train_label == clean_label:
                for j in range(num_features):
                    # Calculate similarity
                    simularity += calculate_similarity(
                        train_second_feature_set[j, :, :],
                        clean_second_feature_set[j, :, :])
                simularity /= float(num_features)

                if simularity > max_sim_clean:
                    sim_feature_clean = (simularity, i, train_data,
                                         train_second_feature_set)
                    max_sim_clean = simularity

    filter_fig = make_subplots(
        rows=num_features,
        cols=4,
        subplot_titles=[
            'Similar Clean<br>Class Training<br>Image', 'Clean Image',
            'Noisy Image', 'Similar Noisy<br>Class Training<br>Image'
        ] + [''] * ((num_features - 1) * 4))

    for i in range(num_features):
        noisy_plot = np.flipud(noisy_second_feature_set[i, :, :])
        clean_plot = np.flipud(clean_second_feature_set[i, :, :])
        sim_clean_plot = np.flipud(sim_feature_clean[3][i, :, :])
        sim_noisy_plot = np.flipud(sim_feature_noisy[3][i, :, :])

        axis_num = (i * 4) + 1

        filter_fig.add_trace(go.Heatmap(z=sim_clean_plot,
                                        type='heatmap',
                                        coloraxis='coloraxis',
                                        showscale=False),
                             row=i + 1,
                             col=1)
        filter_fig.update_xaxes(showgrid=False,
                                showticklabels=False,
                                zeroline=False,
                                scaleanchor='y{}'.format(axis_num),
                                row=i + 1,
                                col=1)
        filter_fig.update_yaxes(showgrid=False,
                                showticklabels=False,
                                zeroline=False,
                                row=i + 1,
                                col=1)

        filter_fig.add_trace(go.Heatmap(z=clean_plot,
                                        type='heatmap',
                                        coloraxis='coloraxis',
                                        showscale=False),
                             row=i + 1,
                             col=2)
        filter_fig.update_xaxes(showgrid=False,
                                showticklabels=False,
                                zeroline=False,
                                scaleanchor='y{}'.format(axis_num + 1),
                                row=i + 1,
                                col=2)
        filter_fig.update_yaxes(showgrid=False,
                                showticklabels=False,
                                zeroline=False,
                                row=i + 1,
                                col=2)

        filter_fig.add_trace(go.Heatmap(z=noisy_plot,
                                        type='heatmap',
                                        coloraxis='coloraxis',
                                        showscale=False),
                             row=i + 1,
                             col=3)
        filter_fig.update_xaxes(showgrid=False,
                                showticklabels=False,
                                zeroline=False,
                                scaleanchor='y{}'.format(axis_num + 2),
                                row=i + 1,
                                col=3)
        filter_fig.update_yaxes(showgrid=False,
                                showticklabels=False,
                                zeroline=False,
                                row=i + 1,
                                col=3)

        filter_fig.add_trace(go.Heatmap(z=sim_noisy_plot,
                                        type='heatmap',
                                        coloraxis='coloraxis',
                                        showscale=False),
                             row=i + 1,
                             col=4)
        filter_fig.update_xaxes(showgrid=False,
                                showticklabels=False,
                                zeroline=False,
                                scaleanchor='y{}'.format(axis_num + 3),
                                row=i + 1,
                                col=4)
        filter_fig.update_yaxes(showgrid=False,
                                showticklabels=False,
                                zeroline=False,
                                row=i + 1,
                                col=4)

    filter_fig.update_layout(autosize=False,
                             height=300 * num_features,
                             coloraxis={'colorscale': 'Gray'})

    filter_fig.show()

    raw_fig = make_subplots(rows=2,
                            cols=2,
                            subplot_titles=[
                                'Clean Image', 'Noisy Image',
                                'Similar Clean<br>Class Training<br>Image',
                                'Similar Noisy<br>Class Training<br>Image'
                            ])

    raw_fig.add_trace(go.Heatmap(z=np.flipud(clean_data),
                                 type='heatmap',
                                 coloraxis='coloraxis',
                                 showscale=False),
                      row=1,
                      col=1)
    raw_fig.update_xaxes(showgrid=False,
                         showticklabels=False,
                         zeroline=False,
                         scaleanchor='y1',
                         row=1,
                         col=1)
    raw_fig.update_yaxes(showgrid=False,
                         showticklabels=False,
                         zeroline=False,
                         row=1,
                         col=1)

    raw_fig.add_trace(go.Heatmap(z=np.flipud(noisy_data),
                                 type='heatmap',
                                 coloraxis='coloraxis',
                                 showscale=False),
                      row=1,
                      col=2)
    raw_fig.update_xaxes(showgrid=False,
                         showticklabels=False,
                         zeroline=False,
                         scaleanchor='y2',
                         row=1,
                         col=2)
    raw_fig.update_yaxes(showgrid=False,
                         showticklabels=False,
                         zeroline=False,
                         row=1,
                         col=2)

    raw_fig.add_trace(go.Heatmap(z=np.flipud(
        sim_feature_clean[2].detach().numpy()[0, :, :]),
                                 type='heatmap',
                                 coloraxis='coloraxis',
                                 showscale=False),
                      row=2,
                      col=1)
    raw_fig.update_xaxes(showgrid=False,
                         showticklabels=False,
                         zeroline=False,
                         scaleanchor='y3',
                         row=2,
                         col=1)
    raw_fig.update_yaxes(showgrid=False,
                         showticklabels=False,
                         zeroline=False,
                         row=2,
                         col=1)

    raw_fig.add_trace(go.Heatmap(z=np.flipud(
        sim_feature_noisy[2].detach().numpy()[0, :, :]),
                                 type='heatmap',
                                 coloraxis='coloraxis',
                                 showscale=False),
                      row=2,
                      col=2)
    raw_fig.update_xaxes(showgrid=False,
                         showticklabels=False,
                         zeroline=False,
                         scaleanchor='y4',
                         row=2,
                         col=2)
    raw_fig.update_yaxes(showgrid=False,
                         showticklabels=False,
                         zeroline=False,
                         row=2,
                         col=2)

    raw_fig.update_layout(autosize=False,
                          coloraxis={
                              'colorscale': 'Gray',
                              'showscale': False
                          })

    raw_fig.show()
    print('Done')
Beispiel #6
0
def visualize_adv_affects_filter():
    # Parameters
    seed = 2
    data_name = 'MNIST'
    show_all = False

    # Set seed
    set_seed(seed)

    # Save indices
    save_indices = [20]

    # Creat attack
    attack = FGSM(0.1)

    # Create models
    conv_net = ConvClassificationModel()

    # Load models
    conv_net.load(torch.load('models\\pre_trained_models\\mnist_digit_conv.model'))

    # Download MNIST data set
    loss_func = torch.nn.NLLLoss()
    test_set = get_data(data_name, False)
    test_set_n, test_set, test_set_labels = attack.get_modified_data(conv_net, test_set, loss_func)
    
    # Pre-compute values
    raw_results = {}
    filter_results = {}
    for i in range(len(test_set)):
        clean_data_torch = test_set[i]
        original_label = test_set_labels[i]
        clean_label = conv_net.get_label(clean_data_torch).detach().numpy()[0][0]
        clean_p_dist = np.round(np.exp(conv_net.classify(clean_data_torch).detach().numpy()[0]), decimals=2)
        clean_data_np = np.flipud(clean_data_torch.detach().numpy()[0,0,:,:])

        adv_data_torch = test_set_n[i]
        adv_label = conv_net.get_label(adv_data_torch).detach().numpy()[0][0]
        adv_p_dist = np.round(np.exp(conv_net.classify(adv_data_torch).detach().numpy()[0]), decimals=2)
        adv_data_np = np.flipud(adv_data_torch.detach().numpy()[0,0,:,:])

        label_mismatch = original_label == clean_label and original_label != adv_label

        if show_all or label_mismatch:
            raw_results[i] = (clean_data_np, adv_data_np, original_label, clean_label, adv_label, clean_p_dist,adv_p_dist)
            filter_results[i] = (get_plots(clean_data_torch, conv_net), get_plots(adv_data_torch, conv_net))

            if i in save_indices:
                # Dump output
                filename = 'data\\{}\\adv_data_np\\{}_{}_adv'.format(data_name, 'FGSM', i)
                np.save(filename, adv_data_np)

                filename = 'data\\{}\\adv_data_np\\{}_{}_clean'.format(data_name, 'FGSM', i)
                np.save(filename, clean_data_np)

    # Create dash app
    external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
    app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

    app.layout = html.Div([
        html.Div([
            html.Div([
                html.Label('Test Set Image Index:'),
                dcc.Dropdown(
                    id='data_set-index',
                    options=[{'label': i, 'value': i} for i in raw_results.keys()],
                    value=list(raw_results.keys())[0]
                )
            ],
            style={'width': '33%', 'display': 'inline-block', 'vertical-align': 'top'}),
            html.Div([
                html.Label('Network Layer:'),
                dcc.Dropdown(
                    id='filter-index',
                    options=[{'label': i, 'value': i} for i in  [1,2]],
                    value=1
                ),
            ],
            style={'width': '33%', 'display': 'inline-block', 'vertical-align': 'top'}),
            html.Div([
                html.Label('Network Layer Type:'),
                dcc.RadioItems(
                    id='display-type',
                    options=[{'label': i, 'value': i} for i in ['Filter', 'Activation']],
                    value='Filter',
                    labelStyle={'display': 'inline-block'}
                )
            ],
            style={'width': '33%', 'display': 'inline-block', 'vertical-align': 'top'})
        ]),
        html.Div([
            html.Div([
                html.Div(id='original-label'),
                html.Div(id='clean-label'),
                html.Div(id='adv-label'),
                dcc.Graph(id='raw-graph')
            ],
            style={'width': '35%', 'display': 'inline-block', 'vertical-align': 'top'}),

            html.Div([
                dcc.Graph(id='filter-graph')
            ],
            style={'width': '60%', 'display': 'inline-block', 'vertical-align': 'top'})
        ])
    ])

    @app.callback(
    Output('filter-graph', 'figure'),
    [Input('data_set-index', 'value'),
    Input('display-type', 'value'),
    Input('filter-index', 'value')])
    def update_graph(selected_index, display_type, filter_index):
        clean_plot_set, adv_plot_set = filter_results[selected_index]

        clean_plots, labels = clean_plot_set[(filter_index, display_type)]
        adv_plots, _ = adv_plot_set[(filter_index, display_type)]

        num_rows = clean_plots.shape[0]
        fig = make_subplots(
            rows=num_rows, 
            cols=3,
            subplot_titles=['{} on {}'.format(t, l) for l in labels for t in ['Clean Image', 'Adversarial Image', 'Difference']])

        for i in range(num_rows):
            clean_plot = np.flipud(clean_plots[i,:,:])
            adv_plot = np.flipud(adv_plots[i,:,:])
            diff_plot = clean_plot - adv_plot

            axis_num = (i*3) + 1

            fig.add_trace(
                go.Heatmap(
                    z=clean_plot,
                    type='heatmap', 
                    coloraxis='coloraxis',
                    showscale=False
                ),
                row=i+1,
                col=1
            )
            fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False, scaleanchor='y{}'.format(axis_num), row=i+1, col=1)
            fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False, row=i+1, col=1)

            fig.add_trace(
                go.Heatmap(
                    z=adv_plot,
                    type='heatmap', 
                    coloraxis='coloraxis',
                    showscale=False
                ),
                row=i+1,
                col=2
            )
            fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False, scaleanchor='y{}'.format(axis_num+1), row=i+1, col=2)
            fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False, row=i+1, col=2)

            fig.add_trace(
                go.Heatmap(
                    z=diff_plot,
                    type='heatmap', 
                    coloraxis='coloraxis',
                    showscale=False
                ),
                row=i+1,
                col=3
            )
            fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False, scaleanchor='y{}'.format(axis_num+2), row=i+1, col=3)
            fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False, row=i+1, col=3)

        fig.update_layout(
            autosize=False,
            height=300*num_rows,
            coloraxis={
                'colorscale': 'Gray'
            }
        )

        return fig

    @app.callback(
    [Output('raw-graph', 'figure'),
    Output('original-label', 'children'),
    Output('clean-label', 'children'),
    Output('adv-label', 'children')],
    [Input('data_set-index', 'value')])
    def update_graph(selected_index):
        clean_data_np, adv_data_np, original_label, clean_label, adv_label, clean_p_dist, adv_p_dist = raw_results[selected_index]

        fig = make_subplots(
            rows=1, 
            cols=2,
            horizontal_spacing=0.1,
            vertical_spacing=0.1,
            subplot_titles=['Clean Image', 'Adversarial Image'])

        fig.add_trace(
            go.Heatmap(
                z=clean_data_np,
                type='heatmap', 
                coloraxis='coloraxis',
                showscale=False
            ),
            row=1,
            col=1
        )
        fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False, scaleanchor='y1', row=1, col=1)
        fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False, row=1, col=1)

        fig.add_trace(
            go.Heatmap(
                z=adv_data_np,
                type='heatmap', 
                coloraxis='coloraxis',
                showscale=False
            ),
            row=1,
            col=2
        )
        fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False, scaleanchor='y2', row=1, col=2)
        fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False, row=1, col=2)

        fig.update_layout(
            autosize=False,
            coloraxis={
                'colorscale': 'Gray',
                'showscale': False
            }
        )

        clean_desc = 'Clean Label: {}, Clean Label Probabilities: {}'.format(clean_label, clean_p_dist)
        adv_desc = 'Adversarial Label: {}, Adversarial Label Probabilities: {}'.format(adv_label, adv_p_dist)

        return fig, 'Original Label: {}'.format(original_label), clean_desc, adv_desc

    app.run_server(port=8051, debug=True)