예제 #1
0
    def __call__(self,
                 X,
                 Y=None,
                 batch_size=512,
                 detect_convergence=True,
                 thresh=0.025,
                 n_samples=None,
                 optimize_ordering=True,
                 ordering_batches=1,
                 verbose=False,
                 bar=True):
        '''
        Estimate SAGE values.

        Args:
          X: input data.
          Y: target data. If None, model output will be used.
          batch_size: number of examples to be processed in parallel, should be
            set to a large value.
          detect_convergence: whether to stop when approximately converged.
          thresh: threshold for determining convergence
          n_samples: number of samples to take per feature.
          optimize_ordering: whether to guess an ordering of features based on
            importance. May accelerate convergence.
          ordering_batches: number of minibatches while determining ordering.
          verbose: print progress messages.
          bar: display progress bar.

        The default behavior is to detect each feature's convergence based on
        the ratio of its standard deviation to the gap between the largest and
        smallest values. Since neither value is known initially, we begin with
        estimates (upper_val, lower_val) and update them as more features are
        analyzed.

        Returns: Explanation object.
        '''
        # Determine explanation type.
        if Y is not None:
            explanation_type = 'SAGE'
        else:
            explanation_type = 'Shapley Effects'

        # Verify model.
        N, _ = X.shape
        num_features = self.imputer.num_groups
        X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn,
                                       batch_size)

        # For setting up bar.
        estimate_convergence = n_samples is None
        if estimate_convergence and verbose:
            print('Estimating convergence time')

        # Possibly force convergence detection.
        if n_samples is None:
            n_samples = 1e20
            if not detect_convergence:
                detect_convergence = True
                if verbose:
                    print('Turning convergence detection on')

        if detect_convergence:
            assert 0 < thresh < 1

        # Print message explaining parameter choices.
        if verbose:
            print('Batch size = batch * samples = {}'.format(
                batch_size * self.imputer.samples))

        # For detecting convergence.
        total = estimate_total(self.imputer, X, Y, batch_size, self.loss_fn)
        upper_val = max(total / num_features, 0)
        lower_val = 0

        # Feature ordering.
        if optimize_ordering:
            if verbose:
                print('Determining feature ordering...')
            holdout_importance = estimate_holdout_importance(
                self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches)
            if verbose:
                print('Done')
            # Use np.abs in case there are large negative contributors.
            ordering = list(np.argsort(np.abs(holdout_importance))[::-1])
        else:
            ordering = list(range(num_features))

        # Set up bar.
        n_loops = int(n_samples / batch_size)
        if bar:
            if estimate_convergence:
                bar = tqdm(total=1)
            else:
                bar = tqdm(total=n_loops * batch_size * num_features)

        # Iterated sampling.
        tracker_list = []
        for i, ind in enumerate(ordering):
            tracker = utils.ImportanceTracker()
            for it in range(n_loops):
                # Sample data.
                mb = np.random.choice(N, batch_size)
                x = X[mb]
                y = Y[mb]

                # Sample subset of features.
                S = utils.sample_subset_feature(num_features, batch_size, ind)

                # Loss with feature excluded.
                y_hat = self.imputer(x, S)
                loss_discluded = self.loss_fn(y_hat, y)

                # Loss with feature included.
                S[:, ind] = 1
                y_hat = self.imputer(x, S)
                loss_included = self.loss_fn(y_hat, y)

                # Calculate delta sample.
                tracker.update(loss_discluded - loss_included)
                if bar and (not estimate_convergence):
                    bar.update(batch_size)

                # Calculate progress.
                std = tracker.std
                gap = (max(upper_val, tracker.values.item()) -
                       min(lower_val, tracker.values.item()))
                ratio = std / gap

                # Print progress message.
                if verbose:
                    if detect_convergence:
                        print('StdDev Ratio = {:.4f} '
                              '(Converge at {:.4f})'.format(ratio, thresh))
                    else:
                        print('StdDev Ratio = {:.4f}'.format(ratio))

                # Check for convergence.
                if detect_convergence:
                    if ratio < thresh:
                        if verbose:
                            print('Detected feature convergence')

                        # Skip bar ahead.
                        if bar:
                            bar.n = np.around(
                                bar.total * (i + 1) / num_features, 4)
                            bar.refresh()
                        break

                # Update convergence estimation.
                if bar and estimate_convergence:
                    std_est = ratio * np.sqrt(it + 1)
                    n_est = (std_est / thresh)**2
                    bar.n = np.around((i + (it + 1) / n_est) / num_features, 4)
                    bar.refresh()

            if verbose:
                print('Done with feature {}'.format(i))
            tracker_list.append(tracker)

            # Adjust min max value.
            upper_val = max(upper_val, tracker.values.item())
            lower_val = min(lower_val, tracker.values.item())

        if bar:
            bar.close()

        # Extract SAGE values.
        reverse_ordering = [ordering.index(ind) for ind in range(num_features)]
        values = np.array(
            [tracker_list[ind].values.item() for ind in reverse_ordering])
        std = np.array(
            [tracker_list[ind].std.item() for ind in reverse_ordering])

        return core.Explanation(values, std, explanation_type)
예제 #2
0
    def __call__(self,
                 X,
                 Y=None,
                 batch_size=512,
                 detect_convergence=True,
                 thresh=0.025,
                 n_permutations=None,
                 min_coalition=0.0,
                 max_coalition=1.0,
                 verbose=False,
                 bar=True):
        '''
        Estimate SAGE values.

        Args:
          X: input data.
          Y: target data. If None, model output will be used.
          batch_size: number of examples to be processed in parallel, should be
            set to a large value.
          detect_convergence: whether to stop when approximately converged.
          thresh: threshold for determining convergence.
          n_permutations: number of permutations to unroll.
          min_coalition: minimum coalition size (int or float).
          max_coalition: maximum coalition size (int or float).
          verbose: print progress messages.
          bar: display progress bar.

        The default behavior is to detect convergence based on the width of the
        SAGE values' confidence intervals. Convergence is defined by the ratio
        of the maximum standard deviation to the gap between the largest and
        smallest values.

        Returns: Explanation object.
        '''
        # Determine explanation type.
        if Y is not None:
            explanation_type = 'SAGE'
        else:
            explanation_type = 'Shapley Effects'

        # Verify model.
        N, _ = X.shape
        num_features = self.imputer.num_groups
        X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn,
                                       batch_size)

        # Determine min/max coalition sizes.
        if isinstance(min_coalition, float):
            min_coalition = int(min_coalition * num_features)
        if isinstance(max_coalition, float):
            max_coalition = int(max_coalition * num_features)
        assert min_coalition >= 0
        assert max_coalition <= num_features
        assert min_coalition < max_coalition
        if min_coalition > 0 or max_coalition < num_features:
            relaxed = True
            explanation_type = 'Relaxed ' + explanation_type
        else:
            relaxed = False
            sample_counts = None

        # Possibly force convergence detection.
        if n_permutations is None:
            n_permutations = 1e20
            if not detect_convergence:
                detect_convergence = True
                if verbose:
                    print('Turning convergence detection on')

        if detect_convergence:
            assert 0 < thresh < 1

        # Set up bar.
        n_loops = int(n_permutations / batch_size)
        if bar:
            if detect_convergence:
                bar = tqdm(total=1)
            else:
                bar = tqdm(total=n_loops * batch_size * num_features)

        # Setup.
        arange = np.arange(batch_size)
        scores = np.zeros((batch_size, num_features))
        S = np.zeros((batch_size, num_features), dtype=bool)
        permutations = np.tile(np.arange(num_features), (batch_size, 1))
        tracker = utils.ImportanceTracker()

        # Permutation sampling.
        for it in range(n_loops):
            # Sample data.
            mb = np.random.choice(N, batch_size)
            x = X[mb]
            y = Y[mb]

            # Sample permutations.
            S[:] = 0
            for i in range(batch_size):
                np.random.shuffle(permutations[i])

            # Calculate sample counts.
            if relaxed:
                scores[:] = 0
                sample_counts = np.zeros(num_features, dtype=int)
                for i in range(batch_size):
                    sample_counts[permutations[
                        i, min_coalition:max_coalition]] = (sample_counts[
                            permutations[i, min_coalition:max_coalition]] + 1)

            # Add necessary features to minimum coalition.
            for i in range(min_coalition):
                # Add next feature.
                inds = permutations[:, i]
                S[arange, inds] = 1

            # Make prediction with minimum coalition.
            y_hat = self.imputer(x, S)
            prev_loss = self.loss_fn(y_hat, y)

            # Add all remaining features.
            for i in range(min_coalition, max_coalition):
                # Add next feature.
                inds = permutations[:, i]
                S[arange, inds] = 1

                # Make prediction with missing features.
                y_hat = self.imputer(x, S)
                loss = self.loss_fn(y_hat, y)

                # Calculate delta sample.
                scores[arange, inds] = prev_loss - loss
                prev_loss = loss

                # Update bar (if not detecting convergence).
                if bar and (not detect_convergence):
                    bar.update(batch_size)

            # Update tracker.
            tracker.update(scores, sample_counts)

            # Calculate progress.
            std = np.max(tracker.std)
            gap = max(tracker.values.max() - tracker.values.min(), 1e-12)
            ratio = std / gap

            # Print progress message.
            if verbose:
                if detect_convergence:
                    print(f'StdDev Ratio = {ratio:.4f} '
                          f'(Converge at {thresh:.4f})')
                else:
                    print(f'StdDev Ratio = {ratio:.4f}')

            # Check for convergence.
            if detect_convergence:
                if ratio < thresh:
                    if verbose:
                        print('Detected convergence')

                    # Skip bar ahead.
                    if bar:
                        bar.n = bar.total
                        bar.refresh()
                    break

            # Update convergence estimation.
            if bar and detect_convergence:
                N_est = (it + 1) * (ratio / thresh)**2
                bar.n = np.around((it + 1) / N_est, 4)
                bar.refresh()

        if bar:
            bar.close()

        return core.Explanation(tracker.values, tracker.std, explanation_type)
예제 #3
0
    def __call__(self,
                 X,
                 Y=None,
                 batch_size=512,
                 detect_convergence=True,
                 thresh=0.01,
                 n_samples=None,
                 verbose=False,
                 bar=True,
                 check_every=5):
        '''
        Estimate SAGE values by fitting regression model (like KernelSHAP).

        Args:
          X: input data.
          Y: target data. If None, model output will be used.
          batch_size: number of examples to be processed in parallel, should be
            set to a large value.
          detect_convergence: whether to stop when approximately converged.
          thresh: threshold for determining convergence.
          n_samples: number of permutations to unroll.
          verbose: print progress messages.
          bar: display progress bar.
          check_every: number of batches between progress/convergence checks.

        The default behavior is to detect convergence based on the width of the
        SAGE values' confidence intervals. Convergence is defined by the ratio
        of the maximum standard deviation to the gap between the largest and
        smallest values.

        Returns: Explanation object.
        '''
        # Determine explanation type.
        if Y is not None:
            explanation_type = 'SAGE'
        else:
            explanation_type = 'Shapley Effects'

        # Verify model.
        N, _ = X.shape
        num_features = self.imputer.num_groups
        X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn,
                                       batch_size)

        # For setting up bar.
        estimate_convergence = n_samples is None
        if estimate_convergence and verbose:
            print('Estimating convergence time')

        # Possibly force convergence detection.
        if n_samples is None:
            n_samples = 1e20
            if not detect_convergence:
                detect_convergence = True
                if verbose:
                    print('Turning convergence detection on')

        if detect_convergence:
            assert 0 < thresh < 1

        # Print message explaining parameter choices.
        if verbose:
            print('Batch size = batch * samples = {}'.format(
                batch_size * self.imputer.samples))

        # Weighting kernel (probability of each subset size).
        weights = np.arange(1, num_features)
        weights = 1 / (weights * (num_features - weights))
        weights = weights / np.sum(weights)

        # Estimate v({}) and v(D) for constraints.
        v0, v1 = estimate_constraints(self.imputer, X, Y, batch_size,
                                      self.loss_fn)

        # Exact form for A.
        p_coaccur = ((np.sum((np.arange(2, num_features) - 1) /
                             (num_features - np.arange(2, num_features)))) /
                     (num_features * (num_features - 1) *
                      np.sum(1 /
                             (np.arange(1, num_features) *
                              (num_features - np.arange(1, num_features))))))
        A = np.eye(num_features) * 0.5 + (1 - np.eye(num_features)) * p_coaccur

        # Set up bar.
        n_loops = int(n_samples / batch_size)
        if bar:
            if estimate_convergence:
                bar = tqdm(total=1)
            else:
                bar = tqdm(total=n_loops * batch_size)

        # Setup.
        n = 0
        b = 0
        b_sum_squares = 0

        # Sample subsets.
        for it in range(n_loops):
            # Sample data.
            mb = np.random.choice(N, batch_size)
            x = X[mb]
            y = Y[mb]

            # Sample subsets.
            S = np.zeros((batch_size, num_features), dtype=bool)
            num_included = np.random.choice(
                num_features - 1, size=batch_size, p=weights) + 1
            for row, num in zip(S, num_included):
                inds = np.random.choice(num_features, size=num, replace=False)
                row[inds] = 1

            # Make predictions.
            y_hat = self.imputer(x, S)
            loss = -self.loss_fn(y_hat, y) - v0
            b_temp1 = S.astype(float) * loss[:, np.newaxis]

            # Invert subset for variance reduction.
            S = np.logical_not(S)

            # Make predictions.
            y_hat = self.imputer(x, S)
            loss = -self.loss_fn(y_hat, y) - v0
            b_temp2 = S.astype(float) * loss[:, np.newaxis]

            # Covariance estimate (Welford's algorithm).
            n += batch_size
            b_temp = 0.5 * (b_temp1 + b_temp2)
            b_diff = b_temp - b
            b += np.sum(b_diff, axis=0) / n
            b_diff2 = b_temp - b
            b_sum_squares += np.sum(np.matmul(np.expand_dims(b_diff, 2),
                                              np.expand_dims(b_diff2, 1)),
                                    axis=0)
            if bar and (not estimate_convergence):
                bar.update(batch_size)

            if (it + 1) % check_every == 0:
                # Calculate progress.
                values, std = calculate_result(A, b, v0, v1, b_sum_squares, n)
                gap = values.max() - values.min()
                ratio = np.max(std) / gap

                # Print progress message.
                if verbose:
                    if detect_convergence:
                        print('StdDev Ratio = {:.4f} (Converge at {:.4f})'.
                              format(ratio, thresh))
                    else:
                        print('StdDev Ratio = {:.4f}'.format(ratio))

                # Check for convergence.
                if detect_convergence:
                    if ratio < thresh:
                        if verbose:
                            print('Detected convergence')

                        # Skip bar ahead.
                        if bar:
                            bar.n = bar.total
                            bar.refresh()
                        break

                # Update convergence estimation.
                if bar and estimate_convergence:
                    std_est = ratio * np.sqrt(it + 1)
                    n_est = (std_est / thresh)**2
                    bar.n = np.around((it + 1) / n_est, 4)
                    bar.refresh()

        # Calculate SAGE values.
        values, std = calculate_result(A, b, v0, v1, b_sum_squares, n)

        return core.Explanation(np.squeeze(values), std, explanation_type)
예제 #4
0
    def __call__(self,
                 X,
                 Y=None,
                 batch_size=512,
                 sign_confidence=0.99,
                 narrow_thresh=0.025,
                 optimize_ordering=True,
                 ordering_batches=1,
                 verbose=False,
                 bar=True):
        '''
        Estimate SAGE values.

        Args:
          X: input data.
          Y: target data. If None, model output will be used.
          batch_size: number of examples to be processed in parallel, should be
            set to a large value.
          sign_confidence: confidence level on sign.
          narrow_thresh: threshold for detecting that the standard deviation is
            small enough
          optimize_ordering: whether to guess an ordering of features based on
            importance. May accelerate convergence.
          ordering_batches: number of minibatches while determining ordering.
          verbose: print progress messages.
          bar: display progress bar.

        Convergence for each SAGE value is detected when one of two conditions
        holds: (1) the sign is known with high confidence (given by
        sign_confidence), or (2) the standard deviation of the Gaussian
        confidence interval is sufficiently narrow (given by narrow_thresh).

        Returns: Explanation object.
        '''
        # Determine explanation type.
        if Y is not None:
            explanation_type = 'SAGE'
        else:
            explanation_type = 'Shapley Effects'

        # Verify model.
        N, _ = X.shape
        num_features = self.imputer.num_groups
        X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn,
                                       batch_size)

        # Verify thresholds.
        assert 0 < narrow_thresh < 1
        assert 0.9 <= sign_confidence < 1
        sign_thresh = 1 / norm.ppf(sign_confidence)

        # For detecting convergence.
        total = estimate_total(self.imputer, X, Y, batch_size, self.loss_fn)
        upper_val = max(total / num_features, 0)
        lower_val = min(total / num_features, 0)

        # Feature ordering.
        if optimize_ordering:
            if verbose:
                print('Determining feature ordering...')
            holdout_importance = estimate_holdout_importance(
                self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches)
            if verbose:
                print('Done')
            # Use np.abs in case there are large negative contributors.
            ordering = list(np.argsort(np.abs(holdout_importance))[::-1])
        else:
            ordering = list(range(num_features))

        # Set up bar.
        if bar:
            bar = tqdm(total=1)

        # Iterated sampling.
        tracker_list = []
        for i, ind in enumerate(ordering):
            tracker = utils.ImportanceTracker()
            it = 0
            converged = False
            while not converged:
                # Sample data.
                mb = np.random.choice(N, batch_size)
                x = X[mb]
                y = Y[mb]

                # Sample subset of features.
                S = utils.sample_subset_feature(num_features, batch_size, ind)

                # Loss with feature excluded.
                y_hat = self.imputer(x, S)
                loss_discluded = self.loss_fn(y_hat, y)

                # Loss with feature included.
                S[:, ind] = 1
                y_hat = self.imputer(x, S)
                loss_included = self.loss_fn(y_hat, y)

                # Calculate delta sample.
                tracker.update(loss_discluded - loss_included)

                # Calculate progress.
                val = tracker.values.item()
                std = tracker.std.item()
                gap = max(max(upper_val, val) - min(lower_val, val), 1e-12)
                converged_sign = (std / max(np.abs(val), 1e-12)) < sign_thresh
                converged_narrow = (std / gap) < narrow_thresh

                # Print progress message.
                if verbose:
                    print('Sign Ratio = {:.4f} (Converge at {:.4f}), '
                          'Narrow Ratio = {:.4f} (Converge at {:.4f})'.format(
                              std / np.abs(val), sign_thresh, std / gap,
                              narrow_thresh))

                # Check for convergence.
                converged = converged_sign or converged_narrow
                if converged:
                    if verbose:
                        print('Detected feature convergence')

                    # Skip bar ahead.
                    if bar:
                        bar.n = np.around(bar.total * (i + 1) / num_features,
                                          4)
                        bar.refresh()

                # Update convergence estimation.
                elif bar:
                    N_sign = (it + 1) * ((std / np.abs(val)) / sign_thresh)**2
                    N_narrow = (it + 1) * ((std / gap) / narrow_thresh)**2
                    N_est = min(N_sign, N_narrow)
                    bar.n = np.around((i + (it + 1) / N_est) / num_features, 4)
                    bar.refresh()

                # Increment iteration variable.
                it += 1

            if verbose:
                print('Done with feature {}'.format(i))
            tracker_list.append(tracker)

            # Adjust min max value.
            upper_val = max(upper_val, tracker.values.item())
            lower_val = min(lower_val, tracker.values.item())

        if bar:
            bar.close()

        # Extract SAGE values.
        reverse_ordering = [ordering.index(ind) for ind in range(num_features)]
        values = np.array(
            [tracker_list[ind].values.item() for ind in reverse_ordering])
        std = np.array(
            [tracker_list[ind].std.item() for ind in reverse_ordering])

        return core.Explanation(values, std, explanation_type)