def perturbation_correlation(pert_fn, diff_fn, pred_fn, n_layers, inputs, n_iterations,
                                                  batch_size=30,
                                                  seed=((2017, 7, 10))):
    """
    Calculates phase perturbation correlation for layers in network
    
    pred_fn: Function that returns a list of activations.
             Each entry in the list corresponds to the output of 1 layer in a network
    n_layers: Number of layers pred_fn returns activations for.
    inputs: Original inputs that are used for perturbation [B,X,T,1]
            Phase perturbations are sampled for each input individually, but applied to all X of that input
    n_iterations: Number of iterations of correlation computation. The higher the better
    batch_size: Number of inputs that are used for one forward pass. (Concatenated for all inputs)
    """
    rng = np.random.RandomState(seed)
    
    # Get batch indeces
    batch_inds = get_balanced_batches(
        n_trials=len(inputs), rng=rng, shuffle=False, batch_size=batch_size)
    
    # Calculate layer activations and reshape
    orig_preds = [pred_fn(inputs[inds])
                  for inds in batch_inds]
    orig_preds_layers = [np.concatenate([orig_preds[o][l] for o in range(len(orig_preds))])
                        for l in range(n_layers)]
    
    # Compute FFT of inputs
    fft_input = np.fft.rfft(inputs, n=inputs.shape[2], axis=2)
    amps = np.abs(fft_input)
    phases = np.angle(fft_input)
    
    pert_corrs = [0]*n_layers
    for i in range(n_iterations):
        #print('Iteration%d'%i)
        
        amps_pert,phases_pert,pert_vals = pert_fn(amps,phases,rng=rng)
        
        # Compute perturbed inputs
        fft_pert = amps_pert*np.exp(1j*phases_pert)
        inputs_pert = np.fft.irfft(fft_pert, n=inputs.shape[2], axis=2).astype(np.float32)
        
        # Calculate layer activations for perturbed inputs
        new_preds = [pred_fn(inputs_pert[inds])
                     for inds in batch_inds]
        new_preds_layers = [np.concatenate([new_preds[o][l] for o in range(len(new_preds))])
                        for l in range(n_layers)]
        
        for l in range(n_layers):
            # Calculate correlations of original and perturbed feature map activations
            preds_diff = diff_fn(orig_preds_layers[l][:,:,:,0],new_preds_layers[l][:,:,:,0])
            
            # Calculate feature map correlations with absolute phase perturbations
            pert_corrs_tmp = wrap_reshape_apply_fn(corr,
                                                   pert_vals[:,:,:,0],preds_diff,
                                                   axis_a=(0), axis_b=(0))
            pert_corrs[l] += pert_corrs_tmp
            
    pert_corrs = [pert_corrs[l]/n_iterations for l in range(n_layers)] #mean over iterations
    return pert_corrs
Ejemplo n.º 2
0
def spectral_perturbation_correlation(pert_fn,
                                      diff_fn,
                                      pred_fn,
                                      n_layers,
                                      inputs,
                                      n_iterations,
                                      batch_size=30,
                                      seed=((2017, 7, 10))):
    """Calculates perturbation correlations for layers in network by perturbing either amplitudes or phases

    Parameters
    ----------
    pert_fn : function
        Function that perturbs spectral phase and amplitudes of inputs
    diff_fn : function
        Function that calculates difference between original and perturbed activations
    pred_fn : function
        Function that returns a list of activations.
        Each entry in the list corresponds to the output of 1 layer in a network
    n_layers : int
        Number of layers pred_fn returns activations for.
    inputs : numpy array
        Original inputs that are used for perturbation [B,X,T,1]
        Phase perturbations are sampled for each input individually, but applied to all X of that input
    n_iterations : int
        Number of iterations of correlation computation. The higher the better
    batch_size : int
        Number of inputs that are used for one forward pass. (Concatenated for all inputs)

    Returns
    -------
    pert_corrs : numpy array
        List of length n_layers containing average perturbation correlations over iterations
        L  x  CxFrxFi (Channels,Frequencies,Filters)
    """
    rng = np.random.RandomState(seed)

    # Get batch indeces
    batch_inds = get_balanced_batches(n_trials=len(inputs),
                                      rng=rng,
                                      shuffle=False,
                                      batch_size=batch_size)
    # Calculate layer activations and reshape
    log.info("Compute original predictions...")
    orig_preds = [pred_fn(inputs[inds]) for inds in batch_inds]
    use_shape = []
    for l in range(n_layers):
        tmp = list(orig_preds[0][l].shape)
        tmp.extend([1] * (4 - len(tmp)))
        tmp[0] = len(inputs)
        use_shape.append(tmp)
    orig_preds_layers = [
        np.concatenate([orig_preds[o][l]
                        for o in range(len(orig_preds))]).reshape(use_shape[l])
        for l in range(n_layers)
    ]

    # Compute FFT of inputs
    fft_input = np.fft.rfft(inputs, n=inputs.shape[2], axis=2)
    amps = np.abs(fft_input)
    phases = np.angle(fft_input)

    pert_corrs = [0] * n_layers
    for i in range(n_iterations):
        log.info("Iteration {:d}...".format(i))
        log.info("Sample perturbation...")
        amps_pert, phases_pert, pert_vals = pert_fn(amps, phases, rng=rng)

        # Compute perturbed inputs
        log.info("Compute perturbed complex inputs...")
        fft_pert = amps_pert * np.exp(1j * phases_pert)
        log.info("Compute perturbed real inputs...")
        inputs_pert = np.fft.irfft(fft_pert, n=inputs.shape[2],
                                   axis=2).astype(np.float32)

        # Calculate layer activations for perturbed inputs
        log.info("Compute new predictions...")
        new_preds = [pred_fn(inputs_pert[inds]) for inds in batch_inds]
        new_preds_layers = [
            np.concatenate([new_preds[o][l] for o in range(len(new_preds))
                            ]).reshape(use_shape[l]) for l in range(n_layers)
        ]

        for l in range(n_layers):
            log.info("Layer {:d}...".format(l))
            # Calculate difference of original and perturbed feature map activations
            log.info("Compute activation difference...")
            preds_diff = diff_fn(new_preds_layers[l][:, :, :, 0],
                                 orig_preds_layers[l][:, :, :, 0])

            # Calculate feature map differences with perturbations
            log.info("Compute correlation...")
            pert_corrs_tmp = wrap_reshape_apply_fn(corr,
                                                   pert_vals[:, :, :, 0],
                                                   preds_diff,
                                                   axis_a=(0, ),
                                                   axis_b=(0))
            pert_corrs[l] += pert_corrs_tmp

    pert_corrs = [pert_corrs[l] / n_iterations
                  for l in range(n_layers)]  #mean over iterations
    return pert_corrs
Ejemplo n.º 3
0
def compute_amplitude_prediction_correlations_voltage(pred_fn,
                                                      examples,
                                                      n_iterations,
                                                      perturb_fn=None,
                                                      batch_size=30,
                                                      seed=((2017, 7, 10))):
    """
    Changed function to calculate time-resolved voltage pertubations, and not frequency as original in compute_amplitude_prediction_correlations

    Perturb input amplitudes and compute correlation between amplitude
    perturbations and prediction changes when pushing perturbed input through
    the prediction function.    
    For more details, see [EEGDeepLearning]_.
    Parameters
    ----------
    pred_fn: function
    Function accepting an numpy input and returning prediction.
    examples: ndarray
    Numpy examples, first axis should be example axis.
    n_iterations: int
    Number of iterations to compute.
    perturb_fn: function, optional
    Function accepting amplitude array and random generator and returning
    perturbation. Default is Gaussian perturbation.
    batch_size: int, optional
    Batch size for computing predictions.
    seed: int, optional
    Random generator seed
    Returns
    -------
    amplitude_pred_corrs: ndarray
    Correlations between amplitude perturbations and prediction changes
    for all sensors and frequency bins.
    References
    ----------
    .. [EEGDeepLearning] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
    Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017).
    Deep learning with convolutional neural networks for EEG decoding and
    visualization.
    arXiv preprint arXiv:1703.05051.
    """
    inds_per_batch = get_balanced_batches(n_trials=len(examples),
                                          rng=None,
                                          shuffle=False,
                                          batch_size=batch_size)
    log.info("Compute original predictions...")
    orig_preds = [
        pred_fn(examples[example_inds]) for example_inds in inds_per_batch
    ]
    orig_preds_arr = np.concatenate(orig_preds)
    rng = RandomState(seed)
    fft_input = np.fft.rfft(examples, axis=2)
    amps = np.abs(fft_input)
    phases = np.angle(fft_input)

    amp_pred_corrs = []
    for i_iteration in range(n_iterations):
        log.info("Iteration {:d}...".format(i_iteration))
        log.info("Sample perturbation...")
        #modified part start
        perturbation = rng.randn(*examples.shape)
        new_in = examples + perturbation
        #modified part end
        log.info("Compute new predictions...")
        new_in = new_in.astype('float32')
        new_preds = [
            pred_fn(new_in[example_inds]) for example_inds in inds_per_batch
        ]

        new_preds_arr = np.concatenate(new_preds)

        diff_preds = new_preds_arr - orig_preds_arr

        log.info("Compute correlation...")
        amp_pred_corr = wrap_reshape_apply_fn(corr,
                                              perturbation[:, :, :, 0],
                                              diff_preds,
                                              axis_a=(0, ),
                                              axis_b=(0))
        amp_pred_corrs.append(amp_pred_corr)
    return amp_pred_corrs
def compute_amplitude_prediction_correlations(
        pred_fn,
        examples,
        n_iterations,
        perturb_fn=gaussian_perturbation,
        batch_size=30,
        seed=((2017, 7, 10)),
        original_y=None,
):
    """
    Perturb input amplitudes and compute correlation between amplitude
    perturbations and prediction changes when pushing perturbed input through
    the prediction function.

    For more details, see [EEGDeepLearning]_.

    Parameters
    ----------
    pred_fn: function
        Function accepting an numpy input and returning prediction.
    examples: ndarray
        Numpy examples, first axis should be example axis.
    n_iterations: int
        Number of iterations to compute.
    perturb_fn: function, optional
        Function accepting amplitude array and random generator and returning
        perturbation. Default is Gaussian perturbation.
    batch_size: int, optional
        Batch size for computing predictions.
    seed: int, optional
        Random generator seed

    Returns
    -------
    amplitude_pred_corrs: ndarray
        Correlations between amplitude perturbations and prediction changes
        for all sensors and frequency bins.

    References
    ----------

    .. [EEGDeepLearning] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
       Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017).
       Deep learning with convolutional neural networks for EEG decoding and
       visualization.
       arXiv preprint arXiv:1703.05051.
    """
    inds_per_batch = get_balanced_batches(n_trials=len(examples),
                                          rng=None,
                                          shuffle=False,
                                          batch_size=batch_size)
    log.info("Compute original predictions...")
    orig_preds = [
        pred_fn(examples[example_inds]) for example_inds in inds_per_batch
    ]
    orig_preds_arr = np.concatenate(orig_preds)
    if original_y is not None:
        orig_pred_labels = np.argmax(orig_preds_arr, axis=1)
        orig_accuracy = np.mean(orig_pred_labels == original_y)
        log.info("Original accuracy: {:.2f}...".format(orig_accuracy))
    rng = RandomState(seed)
    fft_input = np.fft.rfft(examples, axis=2).astype(np.complex64)
    amps = np.abs(fft_input).astype(np.float32)
    phases = np.angle(fft_input).astype(np.float32)
    del fft_input

    amp_pred_corrs = []
    new_accuracies = []
    for i_iteration in range(n_iterations):
        log.info("Iteration {:d}...".format(i_iteration))
        log.info("Sample perturbation...")
        perturbation = perturb_fn(amps, rng).astype(np.float32)
        log.info("Compute new amplitudes...")
        # do not allow perturbation to make amplitudes go below
        # zero
        perturbation = np.maximum(-amps, perturbation)
        new_amps = amps + perturbation
        new_amps = new_amps.astype(np.float32)
        log.info("Compute new  complex inputs...")
        new_complex = _amplitude_phase_to_complex(new_amps,
                                                  phases).astype(np.complex64)
        log.info("Compute new real inputs...")
        new_in = np.fft.irfft(new_complex, axis=2).astype(np.float32)
        del new_complex, new_amps
        log.info("Compute new predictions...")
        new_preds = [
            pred_fn(new_in[example_inds]) for example_inds in inds_per_batch
        ]

        new_preds_arr = np.concatenate(new_preds)
        if original_y is not None:
            new_pred_labels = np.argmax(new_preds_arr, axis=1)
            new_accuracy = np.mean(new_pred_labels == original_y)
            log.info("New accuracy: {:.2f}...".format(new_accuracy))
            new_accuracies.append(new_accuracy)
        diff_preds = new_preds_arr - orig_preds_arr

        log.info("Compute correlation...")
        amp_pred_corr = wrap_reshape_apply_fn(corr,
                                              perturbation[:, :, :, 0],
                                              diff_preds,
                                              axis_a=(0, ),
                                              axis_b=(0))
        print("max corr", np.max(amp_pred_corr))
        print("min corr", np.min(amp_pred_corr))
        amp_pred_corrs.append(amp_pred_corr)
    if original_y is not None:
        return amp_pred_corrs, orig_accuracy, new_accuracies
    else:
        return amp_pred_corrs
def compute_amplitude_prediction_correlations_batchwise(
        pred_fn,
        examples,
        n_iterations,
        perturb_fn=gaussian_perturbation,
        batch_size=30,
        seed=((2017, 7, 10)),
        original_y=None,
):
    """
    Perturb input amplitudes and compute correlation between amplitude
    perturbations and prediction changes when pushing perturbed input through
    the prediction function.

    For more details, see [EEGDeepLearning]_.

    Parameters
    ----------
    pred_fn: function
        Function accepting an numpy input and returning prediction.
    examples: ndarray
        Numpy examples, first axis should be example axis.
    n_iterations: int
        Number of iterations to compute.
    perturb_fn: function, optional
        Function accepting amplitude array and random generator and returning
        perturbation. Default is Gaussian perturbation.
    batch_size: int, optional
        Batch size for computing predictions.
    seed: int, optional
        Random generator seed

    Returns
    -------
    amplitude_pred_corrs: ndarray
        Correlations between amplitude perturbations and prediction changes
        for all sensors and frequency bins.

    References
    ----------

    .. [EEGDeepLearning] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
       Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017).
       Deep learning with convolutional neural networks for EEG decoding and
       visualization.
       arXiv preprint arXiv:1703.05051.
    """
    inds_per_batch = get_balanced_batches(n_trials=len(examples),
                                          rng=None,
                                          shuffle=False,
                                          batch_size=batch_size)
    log.info("Compute original predictions...")
    orig_preds = [
        pred_fn(examples[example_inds]) for example_inds in inds_per_batch
    ]
    orig_preds_arr = np.concatenate(orig_preds)
    if original_y is not None:
        orig_pred_labels = np.argmax(orig_preds_arr, axis=1)
        orig_accuracy = np.mean(orig_pred_labels == original_y)
        log.info("Original accuracy: {:.2f}...".format(orig_accuracy))
    amp_pred_corrs = []
    new_accuracies = []
    rng = RandomState(seed)
    for i_iteration in range(n_iterations):
        log.info("Iteration {:d}...".format(i_iteration))
        size_so_far = 0
        mean_perturb_so_far = None
        mean_pred_diff_so_far = None
        var_perturb_so_far = None
        var_pred_diff_so_far = None
        covariance_so_far = None
        all_new_pred_labels = []
        for example_inds in inds_per_batch:
            this_orig_preds = orig_preds_arr[example_inds]
            this_examples = examples[example_inds]
            fft_input = np.fft.rfft(this_examples, axis=2).astype(np.complex64)
            amps = np.abs(fft_input).astype(np.float32)
            phases = np.angle(fft_input).astype(np.float32)
            #log.info("Sample perturbation...")
            perturbation = perturb_fn(amps, rng).astype(np.float32)
            #log.info("Compute new amplitudes...")
            # do not allow perturbation to make amplitudes go below
            # zero
            perturbation = np.maximum(-amps, perturbation)
            new_amps = amps + perturbation
            new_amps = new_amps.astype(np.float32)
            #log.info("Compute new  complex inputs...")
            new_complex = _amplitude_phase_to_complex(new_amps, phases).astype(
                np.complex64)
            #log.info("Compute new real inputs...")
            new_in = np.fft.irfft(new_complex, axis=2).astype(np.float32)
            #log.info("Compute new predictions...")
            new_preds_arr = pred_fn(new_in)
            if original_y is not None:
                new_pred_labels = np.argmax(new_preds_arr, axis=1)
                all_new_pred_labels.append(new_pred_labels)
            diff_preds = new_preds_arr - this_orig_preds
            this_amp_pred_cov = wrap_reshape_apply_fn(cov,
                                                      perturbation[:, :, :, 0],
                                                      diff_preds,
                                                      axis_a=(0, ),
                                                      axis_b=(0))
            var_perturb = np.var(perturbation, axis=0, ddof=1)
            var_pred_diff = np.var(diff_preds, axis=0, ddof=1)
            mean_perturb = np.mean(perturbation, axis=0)
            mean_diff_pred = np.mean(diff_preds)
            if mean_perturb_so_far is None:
                mean_perturb_so_far = mean_perturb
                mean_pred_diff_so_far = mean_diff_pred
                covariance_so_far = this_amp_pred_cov
                var_perturb_so_far = var_perturb
                var_pred_diff_so_far = var_pred_diff
            else:
                covariance_so_far = combine_covs(covariance_so_far,
                                                 size_so_far,
                                                 mean_perturb_so_far,
                                                 mean_pred_diff_so_far,
                                                 this_amp_pred_cov,
                                                 len(example_inds),
                                                 mean_perturb, mean_diff_pred)
                var_perturb_so_far = combine_vars(
                    var_perturb_so_far,
                    size_so_far,
                    mean_perturb_so_far,
                    var_perturb,
                    len(example_inds),
                    mean_perturb,
                )
                var_pred_diff_so_far = combine_vars(
                    var_pred_diff_so_far,
                    size_so_far,
                    mean_pred_diff_so_far,
                    var_pred_diff,
                    len(example_inds),
                    mean_diff_pred,
                )
                next_size = size_so_far + len(example_inds)
                mean_perturb_so_far = (
                    (mean_perturb_so_far * size_so_far / float(next_size)) +
                    (mean_perturb * len(example_inds) / float(next_size)))
                mean_pred_diff_so_far = (
                    (mean_pred_diff_so_far * size_so_far / float(next_size)) +
                    (mean_diff_pred * len(example_inds) / float(next_size)))

            size_so_far += len(example_inds)

        all_new_pred_labels = np.concatenate(all_new_pred_labels)
        new_accuracy = np.mean(all_new_pred_labels == original_y)
        assert len(original_y) == len(all_new_pred_labels)
        log.info("New accuracy: {:.2f}...".format(new_accuracy))
        new_accuracies.append(new_accuracy)
        divisor = np.outer(np.sqrt(var_perturb_so_far),
                           np.sqrt(var_pred_diff_so_far)).reshape(
                               (var_perturb_so_far.shape +
                                var_pred_diff_so_far.shape)).squeeze()
        this_amp_pred_corr = covariance_so_far / divisor
        amp_pred_corrs.append(this_amp_pred_corr)
    if original_y is not None:
        return amp_pred_corrs, orig_accuracy, new_accuracies
    else:
        return amp_pred_corrs