def __call__(self, X, Y=None, batch_size=512, detect_convergence=True, thresh=0.025, n_samples=None, optimize_ordering=True, ordering_batches=1, verbose=False, bar=True): ''' Estimate SAGE values. Args: X: input data. Y: target data. If None, model output will be used. batch_size: number of examples to be processed in parallel, should be set to a large value. detect_convergence: whether to stop when approximately converged. thresh: threshold for determining convergence n_samples: number of samples to take per feature. optimize_ordering: whether to guess an ordering of features based on importance. May accelerate convergence. ordering_batches: number of minibatches while determining ordering. verbose: print progress messages. bar: display progress bar. The default behavior is to detect each feature's convergence based on the ratio of its standard deviation to the gap between the largest and smallest values. Since neither value is known initially, we begin with estimates (upper_val, lower_val) and update them as more features are analyzed. Returns: Explanation object. ''' # Determine explanation type. if Y is not None: explanation_type = 'SAGE' else: explanation_type = 'Shapley Effects' # Verify model. N, _ = X.shape num_features = self.imputer.num_groups X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn, batch_size) # For setting up bar. estimate_convergence = n_samples is None if estimate_convergence and verbose: print('Estimating convergence time') # Possibly force convergence detection. if n_samples is None: n_samples = 1e20 if not detect_convergence: detect_convergence = True if verbose: print('Turning convergence detection on') if detect_convergence: assert 0 < thresh < 1 # Print message explaining parameter choices. if verbose: print('Batch size = batch * samples = {}'.format( batch_size * self.imputer.samples)) # For detecting convergence. total = estimate_total(self.imputer, X, Y, batch_size, self.loss_fn) upper_val = max(total / num_features, 0) lower_val = 0 # Feature ordering. if optimize_ordering: if verbose: print('Determining feature ordering...') holdout_importance = estimate_holdout_importance( self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches) if verbose: print('Done') # Use np.abs in case there are large negative contributors. ordering = list(np.argsort(np.abs(holdout_importance))[::-1]) else: ordering = list(range(num_features)) # Set up bar. n_loops = int(n_samples / batch_size) if bar: if estimate_convergence: bar = tqdm(total=1) else: bar = tqdm(total=n_loops * batch_size * num_features) # Iterated sampling. tracker_list = [] for i, ind in enumerate(ordering): tracker = utils.ImportanceTracker() for it in range(n_loops): # Sample data. mb = np.random.choice(N, batch_size) x = X[mb] y = Y[mb] # Sample subset of features. S = utils.sample_subset_feature(num_features, batch_size, ind) # Loss with feature excluded. y_hat = self.imputer(x, S) loss_discluded = self.loss_fn(y_hat, y) # Loss with feature included. S[:, ind] = 1 y_hat = self.imputer(x, S) loss_included = self.loss_fn(y_hat, y) # Calculate delta sample. tracker.update(loss_discluded - loss_included) if bar and (not estimate_convergence): bar.update(batch_size) # Calculate progress. std = tracker.std gap = (max(upper_val, tracker.values.item()) - min(lower_val, tracker.values.item())) ratio = std / gap # Print progress message. if verbose: if detect_convergence: print('StdDev Ratio = {:.4f} ' '(Converge at {:.4f})'.format(ratio, thresh)) else: print('StdDev Ratio = {:.4f}'.format(ratio)) # Check for convergence. if detect_convergence: if ratio < thresh: if verbose: print('Detected feature convergence') # Skip bar ahead. if bar: bar.n = np.around( bar.total * (i + 1) / num_features, 4) bar.refresh() break # Update convergence estimation. if bar and estimate_convergence: std_est = ratio * np.sqrt(it + 1) n_est = (std_est / thresh)**2 bar.n = np.around((i + (it + 1) / n_est) / num_features, 4) bar.refresh() if verbose: print('Done with feature {}'.format(i)) tracker_list.append(tracker) # Adjust min max value. upper_val = max(upper_val, tracker.values.item()) lower_val = min(lower_val, tracker.values.item()) if bar: bar.close() # Extract SAGE values. reverse_ordering = [ordering.index(ind) for ind in range(num_features)] values = np.array( [tracker_list[ind].values.item() for ind in reverse_ordering]) std = np.array( [tracker_list[ind].std.item() for ind in reverse_ordering]) return core.Explanation(values, std, explanation_type)
def __call__(self, X, Y=None, batch_size=512, detect_convergence=True, thresh=0.01, n_samples=None, verbose=False, bar=True, check_every=5): ''' Estimate SAGE values by fitting regression model (like KernelSHAP). Args: X: input data. Y: target data. If None, model output will be used. batch_size: number of examples to be processed in parallel, should be set to a large value. detect_convergence: whether to stop when approximately converged. thresh: threshold for determining convergence. n_samples: number of permutations to unroll. verbose: print progress messages. bar: display progress bar. check_every: number of batches between progress/convergence checks. The default behavior is to detect convergence based on the width of the SAGE values' confidence intervals. Convergence is defined by the ratio of the maximum standard deviation to the gap between the largest and smallest values. Returns: Explanation object. ''' # Determine explanation type. if Y is not None: explanation_type = 'SAGE' else: explanation_type = 'Shapley Effects' # Verify model. N, _ = X.shape num_features = self.imputer.num_groups X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn, batch_size) # For setting up bar. estimate_convergence = n_samples is None if estimate_convergence and verbose: print('Estimating convergence time') # Possibly force convergence detection. if n_samples is None: n_samples = 1e20 if not detect_convergence: detect_convergence = True if verbose: print('Turning convergence detection on') if detect_convergence: assert 0 < thresh < 1 # Print message explaining parameter choices. if verbose: print('Batch size = batch * samples = {}'.format( batch_size * self.imputer.samples)) # Weighting kernel (probability of each subset size). weights = np.arange(1, num_features) weights = 1 / (weights * (num_features - weights)) weights = weights / np.sum(weights) # Estimate v({}) and v(D) for constraints. v0, v1 = estimate_constraints(self.imputer, X, Y, batch_size, self.loss_fn) # Exact form for A. p_coaccur = ((np.sum((np.arange(2, num_features) - 1) / (num_features - np.arange(2, num_features)))) / (num_features * (num_features - 1) * np.sum(1 / (np.arange(1, num_features) * (num_features - np.arange(1, num_features)))))) A = np.eye(num_features) * 0.5 + (1 - np.eye(num_features)) * p_coaccur # Set up bar. n_loops = int(n_samples / batch_size) if bar: if estimate_convergence: bar = tqdm(total=1) else: bar = tqdm(total=n_loops * batch_size) # Setup. n = 0 b = 0 b_sum_squares = 0 # Sample subsets. for it in range(n_loops): # Sample data. mb = np.random.choice(N, batch_size) x = X[mb] y = Y[mb] # Sample subsets. S = np.zeros((batch_size, num_features), dtype=bool) num_included = np.random.choice( num_features - 1, size=batch_size, p=weights) + 1 for row, num in zip(S, num_included): inds = np.random.choice(num_features, size=num, replace=False) row[inds] = 1 # Make predictions. y_hat = self.imputer(x, S) loss = -self.loss_fn(y_hat, y) - v0 b_temp1 = S.astype(float) * loss[:, np.newaxis] # Invert subset for variance reduction. S = np.logical_not(S) # Make predictions. y_hat = self.imputer(x, S) loss = -self.loss_fn(y_hat, y) - v0 b_temp2 = S.astype(float) * loss[:, np.newaxis] # Covariance estimate (Welford's algorithm). n += batch_size b_temp = 0.5 * (b_temp1 + b_temp2) b_diff = b_temp - b b += np.sum(b_diff, axis=0) / n b_diff2 = b_temp - b b_sum_squares += np.sum(np.matmul(np.expand_dims(b_diff, 2), np.expand_dims(b_diff2, 1)), axis=0) if bar and (not estimate_convergence): bar.update(batch_size) if (it + 1) % check_every == 0: # Calculate progress. values, std = calculate_result(A, b, v0, v1, b_sum_squares, n) gap = values.max() - values.min() ratio = np.max(std) / gap # Print progress message. if verbose: if detect_convergence: print('StdDev Ratio = {:.4f} (Converge at {:.4f})'. format(ratio, thresh)) else: print('StdDev Ratio = {:.4f}'.format(ratio)) # Check for convergence. if detect_convergence: if ratio < thresh: if verbose: print('Detected convergence') # Skip bar ahead. if bar: bar.n = bar.total bar.refresh() break # Update convergence estimation. if bar and estimate_convergence: std_est = ratio * np.sqrt(it + 1) n_est = (std_est / thresh)**2 bar.n = np.around((it + 1) / n_est, 4) bar.refresh() # Calculate SAGE values. values, std = calculate_result(A, b, v0, v1, b_sum_squares, n) return core.Explanation(np.squeeze(values), std, explanation_type)
def __call__(self, xy, batch_size, n_permutations=None, detect_convergence=None, convergence_threshold=0.01, verbose=False, bar=False): ''' Estimate SAGE values. Args: xy: tuple of np.ndarrays for input and output. batch_size: number of examples to be processed at once. You should use as large of a batch size as possible without exceeding available memory. n_samples: number of permutations. If not specified, samples are taken until the estimates converge. detect_convergence: whether to detect convergence of SAGE estimates. convergence_threshold: confidence interval threshold for determining convergence. Represents portion of estimated sum of SAGE values. verbose: whether to print progress messages. bar: whether to display progress bar. Returns: SAGEValues object. ''' X, Y = xy N, input_size = X.shape # Verify model. X, Y = utils.verify_model_data(self.model, X, Y, self.loss_fn, batch_size * self.imputer.samples) # For detecting cnovergence. total = estimate_total(self.model, xy, batch_size * self.imputer.samples, self.loss_fn) if n_permutations is None: # Turn convergence detectio on. if detect_convergence is None: detect_convergence = True elif not detect_convergence: detect_convergence = True print('Turning convergence detection on') # Turn bar off. if bar: bar = False print('Turning bar off') # Set n_samples to an extremely large number. n_permutations = 1e20 if detect_convergence: assert 0 < convergence_threshold < 1 # Print message explaining parameter choices. if verbose: print('{} permutations, batch size (batch x samples) = {}'.format( n_permutations, batch_size * self.imputer.samples)) # For updating scores. tracker = utils.ImportanceTracker() # Permutation sampling. n_loops = int(n_permutations / batch_size) if bar: bar = tqdm(total=n_loops * batch_size * input_size) for _ in range(n_loops): # Sample data. mb = np.random.choice(N, batch_size) x = X[mb] y = Y[mb] # Sample permutations. S = np.zeros((batch_size, input_size)) permutations = np.tile(np.arange(input_size), (batch_size, 1)) for i in range(batch_size): np.random.shuffle(permutations[i]) # Make prediction with missing features. y_hat = self.model(self.imputer(x, S)) y_hat = np.mean(y_hat.reshape(-1, self.imputer.samples, *y_hat.shape[1:]), axis=1) prev_loss = self.loss_fn(y_hat, y) # Setup. arange = np.arange(batch_size) scores = np.zeros((batch_size, input_size)) for i in range(input_size): # Add next feature. inds = permutations[:, i] S[arange, inds] = 1.0 # Make prediction with missing features. y_hat = self.model(self.imputer(x, S)) y_hat = np.mean(y_hat.reshape(-1, self.imputer.samples, *y_hat.shape[1:]), axis=1) loss = self.loss_fn(y_hat, y) # Calculate delta sample. scores[arange, inds] = prev_loss - loss prev_loss = loss if bar: bar.update(batch_size) # Update tracker. tracker.update(scores) # Check for convergence. conf = np.max(tracker.std) if verbose: print('Conf = {:.4f}, Total = {:.4f}'.format(conf, total)) if detect_convergence: if (conf / total) < convergence_threshold: if verbose: print('Stopping early') break return utils.SAGEValues(tracker.values, tracker.std)
def __call__(self, xy, batch_size, n_samples=None, detect_convergence=False, convergence_threshold=0.01, verbose=False, bar=False): ''' Estimate SAGE values. Args: xy: tuple of np.ndarrays for input and output. batch_size: number of examples to be processed at once. You should use as large of a batch size as possible without exceeding available memory. n_samples: number of samples for each feature. If not specified, samples are taken until the estimates converge. detect_convergence: whether to detect convergence of SAGE estimates. convergence_threshold: confidence interval threshold for determining convergence. Represents portion of estimated sum of SAGE values. verbose: whether to print progress messages. bar: whether to display progress bar. Returns: SAGEValues object. ''' X, Y = xy N, input_size = X.shape # Verify model. X, Y = utils.verify_model_data(self.model, X, Y, self.loss_fn, batch_size * self.imputer.samples) # For detecting cnovergence. total = estimate_total(self.model, xy, batch_size * self.imputer.samples, self.loss_fn) if n_samples is None: # Turn convergence detectio on. if detect_convergence is None: detect_convergence = True elif not detect_convergence: detect_convergence = True print('Turning convergence detection on') # Turn bar off. if bar: bar = False print('Turning bar off') # Set n_samples to an extremely large number. n_samples = 1e20 if detect_convergence: assert 0 < convergence_threshold < 1 if verbose: print('{} samples/feat, batch size (batch x samples) = {}'.format( n_samples, batch_size * self.imputer.samples)) # For updating scores. tracker_list = [] # Iterated sampling. n_loops = int(n_samples / batch_size) if bar: bar = tqdm(total=n_loops * batch_size * input_size) for ind in range(input_size): tracker = utils.ImportanceTracker() for _ in range(n_loops): # Sample data. mb = np.random.choice(N, batch_size) x = X[mb] y = Y[mb] # Sample subset of features. S = utils.sample_subset_feature(input_size, batch_size, ind) # Loss with feature excluded. y_hat = self.model(self.imputer(x, S)) y_hat = np.mean(y_hat.reshape(-1, self.imputer.samples, *y_hat.shape[1:]), axis=1) loss_discluded = self.loss_fn(y_hat, y) # Loss with feature included. S[:, ind] = 1.0 y_hat = self.model(self.imputer(x, S)) y_hat = np.mean(y_hat.reshape(-1, self.imputer.samples, *y_hat.shape[1:]), axis=1) loss_included = self.loss_fn(y_hat, y) # Calculate delta sample. tracker.update(loss_discluded - loss_included) if bar: bar.update(batch_size) # Check for convergence. conf = tracker.std if verbose: print('Imp = {:.4f}, Conf = {:.4f}, Total = {:.4f}'.format( tracker.values, conf, total)) if detect_convergence: if (conf / total) < convergence_threshold: if verbose: print('Stopping feature early') break if verbose: print('Done with feature {}'.format(ind)) tracker_list.append(tracker) return utils.SAGEValues( np.array([tracker.values.item() for tracker in tracker_list]), np.array([tracker.std.item() for tracker in tracker_list]))
def __call__(self, X, Y=None, batch_size=512, detect_convergence=True, thresh=0.025, n_permutations=None, min_coalition=0.0, max_coalition=1.0, verbose=False, bar=True): ''' Estimate SAGE values. Args: X: input data. Y: target data. If None, model output will be used. batch_size: number of examples to be processed in parallel, should be set to a large value. detect_convergence: whether to stop when approximately converged. thresh: threshold for determining convergence. n_permutations: number of permutations to unroll. min_coalition: minimum coalition size (int or float). max_coalition: maximum coalition size (int or float). verbose: print progress messages. bar: display progress bar. The default behavior is to detect convergence based on the width of the SAGE values' confidence intervals. Convergence is defined by the ratio of the maximum standard deviation to the gap between the largest and smallest values. Returns: Explanation object. ''' # Determine explanation type. if Y is not None: explanation_type = 'SAGE' else: explanation_type = 'Shapley Effects' # Verify model. N, _ = X.shape num_features = self.imputer.num_groups X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn, batch_size) # Determine min/max coalition sizes. if isinstance(min_coalition, float): min_coalition = int(min_coalition * num_features) if isinstance(max_coalition, float): max_coalition = int(max_coalition * num_features) assert min_coalition >= 0 assert max_coalition <= num_features assert min_coalition < max_coalition if min_coalition > 0 or max_coalition < num_features: relaxed = True explanation_type = 'Relaxed ' + explanation_type else: relaxed = False sample_counts = None # Possibly force convergence detection. if n_permutations is None: n_permutations = 1e20 if not detect_convergence: detect_convergence = True if verbose: print('Turning convergence detection on') if detect_convergence: assert 0 < thresh < 1 # Set up bar. n_loops = int(n_permutations / batch_size) if bar: if detect_convergence: bar = tqdm(total=1) else: bar = tqdm(total=n_loops * batch_size * num_features) # Setup. arange = np.arange(batch_size) scores = np.zeros((batch_size, num_features)) S = np.zeros((batch_size, num_features), dtype=bool) permutations = np.tile(np.arange(num_features), (batch_size, 1)) tracker = utils.ImportanceTracker() # Permutation sampling. for it in range(n_loops): # Sample data. mb = np.random.choice(N, batch_size) x = X[mb] y = Y[mb] # Sample permutations. S[:] = 0 for i in range(batch_size): np.random.shuffle(permutations[i]) # Calculate sample counts. if relaxed: scores[:] = 0 sample_counts = np.zeros(num_features, dtype=int) for i in range(batch_size): sample_counts[permutations[ i, min_coalition:max_coalition]] = (sample_counts[ permutations[i, min_coalition:max_coalition]] + 1) # Add necessary features to minimum coalition. for i in range(min_coalition): # Add next feature. inds = permutations[:, i] S[arange, inds] = 1 # Make prediction with minimum coalition. y_hat = self.imputer(x, S) prev_loss = self.loss_fn(y_hat, y) # Add all remaining features. for i in range(min_coalition, max_coalition): # Add next feature. inds = permutations[:, i] S[arange, inds] = 1 # Make prediction with missing features. y_hat = self.imputer(x, S) loss = self.loss_fn(y_hat, y) # Calculate delta sample. scores[arange, inds] = prev_loss - loss prev_loss = loss # Update bar (if not detecting convergence). if bar and (not detect_convergence): bar.update(batch_size) # Update tracker. tracker.update(scores, sample_counts) # Calculate progress. std = np.max(tracker.std) gap = max(tracker.values.max() - tracker.values.min(), 1e-12) ratio = std / gap # Print progress message. if verbose: if detect_convergence: print(f'StdDev Ratio = {ratio:.4f} ' f'(Converge at {thresh:.4f})') else: print(f'StdDev Ratio = {ratio:.4f}') # Check for convergence. if detect_convergence: if ratio < thresh: if verbose: print('Detected convergence') # Skip bar ahead. if bar: bar.n = bar.total bar.refresh() break # Update convergence estimation. if bar and detect_convergence: N_est = (it + 1) * (ratio / thresh)**2 bar.n = np.around((it + 1) / N_est, 4) bar.refresh() if bar: bar.close() return core.Explanation(tracker.values, tracker.std, explanation_type)
def __call__(self, X, Y=None, batch_size=512, sign_confidence=0.99, narrow_thresh=0.025, optimize_ordering=True, ordering_batches=1, verbose=False, bar=True): ''' Estimate SAGE values. Args: X: input data. Y: target data. If None, model output will be used. batch_size: number of examples to be processed in parallel, should be set to a large value. sign_confidence: confidence level on sign. narrow_thresh: threshold for detecting that the standard deviation is small enough optimize_ordering: whether to guess an ordering of features based on importance. May accelerate convergence. ordering_batches: number of minibatches while determining ordering. verbose: print progress messages. bar: display progress bar. Convergence for each SAGE value is detected when one of two conditions holds: (1) the sign is known with high confidence (given by sign_confidence), or (2) the standard deviation of the Gaussian confidence interval is sufficiently narrow (given by narrow_thresh). Returns: Explanation object. ''' # Determine explanation type. if Y is not None: explanation_type = 'SAGE' else: explanation_type = 'Shapley Effects' # Verify model. N, _ = X.shape num_features = self.imputer.num_groups X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn, batch_size) # Verify thresholds. assert 0 < narrow_thresh < 1 assert 0.9 <= sign_confidence < 1 sign_thresh = 1 / norm.ppf(sign_confidence) # For detecting convergence. total = estimate_total(self.imputer, X, Y, batch_size, self.loss_fn) upper_val = max(total / num_features, 0) lower_val = min(total / num_features, 0) # Feature ordering. if optimize_ordering: if verbose: print('Determining feature ordering...') holdout_importance = estimate_holdout_importance( self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches) if verbose: print('Done') # Use np.abs in case there are large negative contributors. ordering = list(np.argsort(np.abs(holdout_importance))[::-1]) else: ordering = list(range(num_features)) # Set up bar. if bar: bar = tqdm(total=1) # Iterated sampling. tracker_list = [] for i, ind in enumerate(ordering): tracker = utils.ImportanceTracker() it = 0 converged = False while not converged: # Sample data. mb = np.random.choice(N, batch_size) x = X[mb] y = Y[mb] # Sample subset of features. S = utils.sample_subset_feature(num_features, batch_size, ind) # Loss with feature excluded. y_hat = self.imputer(x, S) loss_discluded = self.loss_fn(y_hat, y) # Loss with feature included. S[:, ind] = 1 y_hat = self.imputer(x, S) loss_included = self.loss_fn(y_hat, y) # Calculate delta sample. tracker.update(loss_discluded - loss_included) # Calculate progress. val = tracker.values.item() std = tracker.std.item() gap = max(max(upper_val, val) - min(lower_val, val), 1e-12) converged_sign = (std / max(np.abs(val), 1e-12)) < sign_thresh converged_narrow = (std / gap) < narrow_thresh # Print progress message. if verbose: print('Sign Ratio = {:.4f} (Converge at {:.4f}), ' 'Narrow Ratio = {:.4f} (Converge at {:.4f})'.format( std / np.abs(val), sign_thresh, std / gap, narrow_thresh)) # Check for convergence. converged = converged_sign or converged_narrow if converged: if verbose: print('Detected feature convergence') # Skip bar ahead. if bar: bar.n = np.around(bar.total * (i + 1) / num_features, 4) bar.refresh() # Update convergence estimation. elif bar: N_sign = (it + 1) * ((std / np.abs(val)) / sign_thresh)**2 N_narrow = (it + 1) * ((std / gap) / narrow_thresh)**2 N_est = min(N_sign, N_narrow) bar.n = np.around((i + (it + 1) / N_est) / num_features, 4) bar.refresh() # Increment iteration variable. it += 1 if verbose: print('Done with feature {}'.format(i)) tracker_list.append(tracker) # Adjust min max value. upper_val = max(upper_val, tracker.values.item()) lower_val = min(lower_val, tracker.values.item()) if bar: bar.close() # Extract SAGE values. reverse_ordering = [ordering.index(ind) for ind in range(num_features)] values = np.array( [tracker_list[ind].values.item() for ind in reverse_ordering]) std = np.array( [tracker_list[ind].std.item() for ind in reverse_ordering]) return core.Explanation(values, std, explanation_type)