def perturbation_correlation(pert_fn, diff_fn, pred_fn, n_layers, inputs, n_iterations, batch_size=30, seed=((2017, 7, 10))): """ Calculates phase perturbation correlation for layers in network pred_fn: Function that returns a list of activations. Each entry in the list corresponds to the output of 1 layer in a network n_layers: Number of layers pred_fn returns activations for. inputs: Original inputs that are used for perturbation [B,X,T,1] Phase perturbations are sampled for each input individually, but applied to all X of that input n_iterations: Number of iterations of correlation computation. The higher the better batch_size: Number of inputs that are used for one forward pass. (Concatenated for all inputs) """ rng = np.random.RandomState(seed) # Get batch indeces batch_inds = get_balanced_batches( n_trials=len(inputs), rng=rng, shuffle=False, batch_size=batch_size) # Calculate layer activations and reshape orig_preds = [pred_fn(inputs[inds]) for inds in batch_inds] orig_preds_layers = [np.concatenate([orig_preds[o][l] for o in range(len(orig_preds))]) for l in range(n_layers)] # Compute FFT of inputs fft_input = np.fft.rfft(inputs, n=inputs.shape[2], axis=2) amps = np.abs(fft_input) phases = np.angle(fft_input) pert_corrs = [0]*n_layers for i in range(n_iterations): #print('Iteration%d'%i) amps_pert,phases_pert,pert_vals = pert_fn(amps,phases,rng=rng) # Compute perturbed inputs fft_pert = amps_pert*np.exp(1j*phases_pert) inputs_pert = np.fft.irfft(fft_pert, n=inputs.shape[2], axis=2).astype(np.float32) # Calculate layer activations for perturbed inputs new_preds = [pred_fn(inputs_pert[inds]) for inds in batch_inds] new_preds_layers = [np.concatenate([new_preds[o][l] for o in range(len(new_preds))]) for l in range(n_layers)] for l in range(n_layers): # Calculate correlations of original and perturbed feature map activations preds_diff = diff_fn(orig_preds_layers[l][:,:,:,0],new_preds_layers[l][:,:,:,0]) # Calculate feature map correlations with absolute phase perturbations pert_corrs_tmp = wrap_reshape_apply_fn(corr, pert_vals[:,:,:,0],preds_diff, axis_a=(0), axis_b=(0)) pert_corrs[l] += pert_corrs_tmp pert_corrs = [pert_corrs[l]/n_iterations for l in range(n_layers)] #mean over iterations return pert_corrs
def spectral_perturbation_correlation(pert_fn, diff_fn, pred_fn, n_layers, inputs, n_iterations, batch_size=30, seed=((2017, 7, 10))): """Calculates perturbation correlations for layers in network by perturbing either amplitudes or phases Parameters ---------- pert_fn : function Function that perturbs spectral phase and amplitudes of inputs diff_fn : function Function that calculates difference between original and perturbed activations pred_fn : function Function that returns a list of activations. Each entry in the list corresponds to the output of 1 layer in a network n_layers : int Number of layers pred_fn returns activations for. inputs : numpy array Original inputs that are used for perturbation [B,X,T,1] Phase perturbations are sampled for each input individually, but applied to all X of that input n_iterations : int Number of iterations of correlation computation. The higher the better batch_size : int Number of inputs that are used for one forward pass. (Concatenated for all inputs) Returns ------- pert_corrs : numpy array List of length n_layers containing average perturbation correlations over iterations L x CxFrxFi (Channels,Frequencies,Filters) """ rng = np.random.RandomState(seed) # Get batch indeces batch_inds = get_balanced_batches(n_trials=len(inputs), rng=rng, shuffle=False, batch_size=batch_size) # Calculate layer activations and reshape log.info("Compute original predictions...") orig_preds = [pred_fn(inputs[inds]) for inds in batch_inds] use_shape = [] for l in range(n_layers): tmp = list(orig_preds[0][l].shape) tmp.extend([1] * (4 - len(tmp))) tmp[0] = len(inputs) use_shape.append(tmp) orig_preds_layers = [ np.concatenate([orig_preds[o][l] for o in range(len(orig_preds))]).reshape(use_shape[l]) for l in range(n_layers) ] # Compute FFT of inputs fft_input = np.fft.rfft(inputs, n=inputs.shape[2], axis=2) amps = np.abs(fft_input) phases = np.angle(fft_input) pert_corrs = [0] * n_layers for i in range(n_iterations): log.info("Iteration {:d}...".format(i)) log.info("Sample perturbation...") amps_pert, phases_pert, pert_vals = pert_fn(amps, phases, rng=rng) # Compute perturbed inputs log.info("Compute perturbed complex inputs...") fft_pert = amps_pert * np.exp(1j * phases_pert) log.info("Compute perturbed real inputs...") inputs_pert = np.fft.irfft(fft_pert, n=inputs.shape[2], axis=2).astype(np.float32) # Calculate layer activations for perturbed inputs log.info("Compute new predictions...") new_preds = [pred_fn(inputs_pert[inds]) for inds in batch_inds] new_preds_layers = [ np.concatenate([new_preds[o][l] for o in range(len(new_preds)) ]).reshape(use_shape[l]) for l in range(n_layers) ] for l in range(n_layers): log.info("Layer {:d}...".format(l)) # Calculate difference of original and perturbed feature map activations log.info("Compute activation difference...") preds_diff = diff_fn(new_preds_layers[l][:, :, :, 0], orig_preds_layers[l][:, :, :, 0]) # Calculate feature map differences with perturbations log.info("Compute correlation...") pert_corrs_tmp = wrap_reshape_apply_fn(corr, pert_vals[:, :, :, 0], preds_diff, axis_a=(0, ), axis_b=(0)) pert_corrs[l] += pert_corrs_tmp pert_corrs = [pert_corrs[l] / n_iterations for l in range(n_layers)] #mean over iterations return pert_corrs
def compute_amplitude_prediction_correlations_voltage(pred_fn, examples, n_iterations, perturb_fn=None, batch_size=30, seed=((2017, 7, 10))): """ Changed function to calculate time-resolved voltage pertubations, and not frequency as original in compute_amplitude_prediction_correlations Perturb input amplitudes and compute correlation between amplitude perturbations and prediction changes when pushing perturbed input through the prediction function. For more details, see [EEGDeepLearning]_. Parameters ---------- pred_fn: function Function accepting an numpy input and returning prediction. examples: ndarray Numpy examples, first axis should be example axis. n_iterations: int Number of iterations to compute. perturb_fn: function, optional Function accepting amplitude array and random generator and returning perturbation. Default is Gaussian perturbation. batch_size: int, optional Batch size for computing predictions. seed: int, optional Random generator seed Returns ------- amplitude_pred_corrs: ndarray Correlations between amplitude perturbations and prediction changes for all sensors and frequency bins. References ---------- .. [EEGDeepLearning] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017). Deep learning with convolutional neural networks for EEG decoding and visualization. arXiv preprint arXiv:1703.05051. """ inds_per_batch = get_balanced_batches(n_trials=len(examples), rng=None, shuffle=False, batch_size=batch_size) log.info("Compute original predictions...") orig_preds = [ pred_fn(examples[example_inds]) for example_inds in inds_per_batch ] orig_preds_arr = np.concatenate(orig_preds) rng = RandomState(seed) fft_input = np.fft.rfft(examples, axis=2) amps = np.abs(fft_input) phases = np.angle(fft_input) amp_pred_corrs = [] for i_iteration in range(n_iterations): log.info("Iteration {:d}...".format(i_iteration)) log.info("Sample perturbation...") #modified part start perturbation = rng.randn(*examples.shape) new_in = examples + perturbation #modified part end log.info("Compute new predictions...") new_in = new_in.astype('float32') new_preds = [ pred_fn(new_in[example_inds]) for example_inds in inds_per_batch ] new_preds_arr = np.concatenate(new_preds) diff_preds = new_preds_arr - orig_preds_arr log.info("Compute correlation...") amp_pred_corr = wrap_reshape_apply_fn(corr, perturbation[:, :, :, 0], diff_preds, axis_a=(0, ), axis_b=(0)) amp_pred_corrs.append(amp_pred_corr) return amp_pred_corrs
def compute_amplitude_prediction_correlations( pred_fn, examples, n_iterations, perturb_fn=gaussian_perturbation, batch_size=30, seed=((2017, 7, 10)), original_y=None, ): """ Perturb input amplitudes and compute correlation between amplitude perturbations and prediction changes when pushing perturbed input through the prediction function. For more details, see [EEGDeepLearning]_. Parameters ---------- pred_fn: function Function accepting an numpy input and returning prediction. examples: ndarray Numpy examples, first axis should be example axis. n_iterations: int Number of iterations to compute. perturb_fn: function, optional Function accepting amplitude array and random generator and returning perturbation. Default is Gaussian perturbation. batch_size: int, optional Batch size for computing predictions. seed: int, optional Random generator seed Returns ------- amplitude_pred_corrs: ndarray Correlations between amplitude perturbations and prediction changes for all sensors and frequency bins. References ---------- .. [EEGDeepLearning] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017). Deep learning with convolutional neural networks for EEG decoding and visualization. arXiv preprint arXiv:1703.05051. """ inds_per_batch = get_balanced_batches(n_trials=len(examples), rng=None, shuffle=False, batch_size=batch_size) log.info("Compute original predictions...") orig_preds = [ pred_fn(examples[example_inds]) for example_inds in inds_per_batch ] orig_preds_arr = np.concatenate(orig_preds) if original_y is not None: orig_pred_labels = np.argmax(orig_preds_arr, axis=1) orig_accuracy = np.mean(orig_pred_labels == original_y) log.info("Original accuracy: {:.2f}...".format(orig_accuracy)) rng = RandomState(seed) fft_input = np.fft.rfft(examples, axis=2).astype(np.complex64) amps = np.abs(fft_input).astype(np.float32) phases = np.angle(fft_input).astype(np.float32) del fft_input amp_pred_corrs = [] new_accuracies = [] for i_iteration in range(n_iterations): log.info("Iteration {:d}...".format(i_iteration)) log.info("Sample perturbation...") perturbation = perturb_fn(amps, rng).astype(np.float32) log.info("Compute new amplitudes...") # do not allow perturbation to make amplitudes go below # zero perturbation = np.maximum(-amps, perturbation) new_amps = amps + perturbation new_amps = new_amps.astype(np.float32) log.info("Compute new complex inputs...") new_complex = _amplitude_phase_to_complex(new_amps, phases).astype(np.complex64) log.info("Compute new real inputs...") new_in = np.fft.irfft(new_complex, axis=2).astype(np.float32) del new_complex, new_amps log.info("Compute new predictions...") new_preds = [ pred_fn(new_in[example_inds]) for example_inds in inds_per_batch ] new_preds_arr = np.concatenate(new_preds) if original_y is not None: new_pred_labels = np.argmax(new_preds_arr, axis=1) new_accuracy = np.mean(new_pred_labels == original_y) log.info("New accuracy: {:.2f}...".format(new_accuracy)) new_accuracies.append(new_accuracy) diff_preds = new_preds_arr - orig_preds_arr log.info("Compute correlation...") amp_pred_corr = wrap_reshape_apply_fn(corr, perturbation[:, :, :, 0], diff_preds, axis_a=(0, ), axis_b=(0)) print("max corr", np.max(amp_pred_corr)) print("min corr", np.min(amp_pred_corr)) amp_pred_corrs.append(amp_pred_corr) if original_y is not None: return amp_pred_corrs, orig_accuracy, new_accuracies else: return amp_pred_corrs
def compute_amplitude_prediction_correlations_batchwise( pred_fn, examples, n_iterations, perturb_fn=gaussian_perturbation, batch_size=30, seed=((2017, 7, 10)), original_y=None, ): """ Perturb input amplitudes and compute correlation between amplitude perturbations and prediction changes when pushing perturbed input through the prediction function. For more details, see [EEGDeepLearning]_. Parameters ---------- pred_fn: function Function accepting an numpy input and returning prediction. examples: ndarray Numpy examples, first axis should be example axis. n_iterations: int Number of iterations to compute. perturb_fn: function, optional Function accepting amplitude array and random generator and returning perturbation. Default is Gaussian perturbation. batch_size: int, optional Batch size for computing predictions. seed: int, optional Random generator seed Returns ------- amplitude_pred_corrs: ndarray Correlations between amplitude perturbations and prediction changes for all sensors and frequency bins. References ---------- .. [EEGDeepLearning] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017). Deep learning with convolutional neural networks for EEG decoding and visualization. arXiv preprint arXiv:1703.05051. """ inds_per_batch = get_balanced_batches(n_trials=len(examples), rng=None, shuffle=False, batch_size=batch_size) log.info("Compute original predictions...") orig_preds = [ pred_fn(examples[example_inds]) for example_inds in inds_per_batch ] orig_preds_arr = np.concatenate(orig_preds) if original_y is not None: orig_pred_labels = np.argmax(orig_preds_arr, axis=1) orig_accuracy = np.mean(orig_pred_labels == original_y) log.info("Original accuracy: {:.2f}...".format(orig_accuracy)) amp_pred_corrs = [] new_accuracies = [] rng = RandomState(seed) for i_iteration in range(n_iterations): log.info("Iteration {:d}...".format(i_iteration)) size_so_far = 0 mean_perturb_so_far = None mean_pred_diff_so_far = None var_perturb_so_far = None var_pred_diff_so_far = None covariance_so_far = None all_new_pred_labels = [] for example_inds in inds_per_batch: this_orig_preds = orig_preds_arr[example_inds] this_examples = examples[example_inds] fft_input = np.fft.rfft(this_examples, axis=2).astype(np.complex64) amps = np.abs(fft_input).astype(np.float32) phases = np.angle(fft_input).astype(np.float32) #log.info("Sample perturbation...") perturbation = perturb_fn(amps, rng).astype(np.float32) #log.info("Compute new amplitudes...") # do not allow perturbation to make amplitudes go below # zero perturbation = np.maximum(-amps, perturbation) new_amps = amps + perturbation new_amps = new_amps.astype(np.float32) #log.info("Compute new complex inputs...") new_complex = _amplitude_phase_to_complex(new_amps, phases).astype( np.complex64) #log.info("Compute new real inputs...") new_in = np.fft.irfft(new_complex, axis=2).astype(np.float32) #log.info("Compute new predictions...") new_preds_arr = pred_fn(new_in) if original_y is not None: new_pred_labels = np.argmax(new_preds_arr, axis=1) all_new_pred_labels.append(new_pred_labels) diff_preds = new_preds_arr - this_orig_preds this_amp_pred_cov = wrap_reshape_apply_fn(cov, perturbation[:, :, :, 0], diff_preds, axis_a=(0, ), axis_b=(0)) var_perturb = np.var(perturbation, axis=0, ddof=1) var_pred_diff = np.var(diff_preds, axis=0, ddof=1) mean_perturb = np.mean(perturbation, axis=0) mean_diff_pred = np.mean(diff_preds) if mean_perturb_so_far is None: mean_perturb_so_far = mean_perturb mean_pred_diff_so_far = mean_diff_pred covariance_so_far = this_amp_pred_cov var_perturb_so_far = var_perturb var_pred_diff_so_far = var_pred_diff else: covariance_so_far = combine_covs(covariance_so_far, size_so_far, mean_perturb_so_far, mean_pred_diff_so_far, this_amp_pred_cov, len(example_inds), mean_perturb, mean_diff_pred) var_perturb_so_far = combine_vars( var_perturb_so_far, size_so_far, mean_perturb_so_far, var_perturb, len(example_inds), mean_perturb, ) var_pred_diff_so_far = combine_vars( var_pred_diff_so_far, size_so_far, mean_pred_diff_so_far, var_pred_diff, len(example_inds), mean_diff_pred, ) next_size = size_so_far + len(example_inds) mean_perturb_so_far = ( (mean_perturb_so_far * size_so_far / float(next_size)) + (mean_perturb * len(example_inds) / float(next_size))) mean_pred_diff_so_far = ( (mean_pred_diff_so_far * size_so_far / float(next_size)) + (mean_diff_pred * len(example_inds) / float(next_size))) size_so_far += len(example_inds) all_new_pred_labels = np.concatenate(all_new_pred_labels) new_accuracy = np.mean(all_new_pred_labels == original_y) assert len(original_y) == len(all_new_pred_labels) log.info("New accuracy: {:.2f}...".format(new_accuracy)) new_accuracies.append(new_accuracy) divisor = np.outer(np.sqrt(var_perturb_so_far), np.sqrt(var_pred_diff_so_far)).reshape( (var_perturb_so_far.shape + var_pred_diff_so_far.shape)).squeeze() this_amp_pred_corr = covariance_so_far / divisor amp_pred_corrs.append(this_amp_pred_corr) if original_y is not None: return amp_pred_corrs, orig_accuracy, new_accuracies else: return amp_pred_corrs