示例#1
0
    def __call__(self,
                 X,
                 Y=None,
                 batch_size=512,
                 detect_convergence=True,
                 thresh=0.025,
                 n_samples=None,
                 optimize_ordering=True,
                 ordering_batches=1,
                 verbose=False,
                 bar=True):
        '''
        Estimate SAGE values.

        Args:
          X: input data.
          Y: target data. If None, model output will be used.
          batch_size: number of examples to be processed in parallel, should be
            set to a large value.
          detect_convergence: whether to stop when approximately converged.
          thresh: threshold for determining convergence
          n_samples: number of samples to take per feature.
          optimize_ordering: whether to guess an ordering of features based on
            importance. May accelerate convergence.
          ordering_batches: number of minibatches while determining ordering.
          verbose: print progress messages.
          bar: display progress bar.

        The default behavior is to detect each feature's convergence based on
        the ratio of its standard deviation to the gap between the largest and
        smallest values. Since neither value is known initially, we begin with
        estimates (upper_val, lower_val) and update them as more features are
        analyzed.

        Returns: Explanation object.
        '''
        # Determine explanation type.
        if Y is not None:
            explanation_type = 'SAGE'
        else:
            explanation_type = 'Shapley Effects'

        # Verify model.
        N, _ = X.shape
        num_features = self.imputer.num_groups
        X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn,
                                       batch_size)

        # For setting up bar.
        estimate_convergence = n_samples is None
        if estimate_convergence and verbose:
            print('Estimating convergence time')

        # Possibly force convergence detection.
        if n_samples is None:
            n_samples = 1e20
            if not detect_convergence:
                detect_convergence = True
                if verbose:
                    print('Turning convergence detection on')

        if detect_convergence:
            assert 0 < thresh < 1

        # Print message explaining parameter choices.
        if verbose:
            print('Batch size = batch * samples = {}'.format(
                batch_size * self.imputer.samples))

        # For detecting convergence.
        total = estimate_total(self.imputer, X, Y, batch_size, self.loss_fn)
        upper_val = max(total / num_features, 0)
        lower_val = 0

        # Feature ordering.
        if optimize_ordering:
            if verbose:
                print('Determining feature ordering...')
            holdout_importance = estimate_holdout_importance(
                self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches)
            if verbose:
                print('Done')
            # Use np.abs in case there are large negative contributors.
            ordering = list(np.argsort(np.abs(holdout_importance))[::-1])
        else:
            ordering = list(range(num_features))

        # Set up bar.
        n_loops = int(n_samples / batch_size)
        if bar:
            if estimate_convergence:
                bar = tqdm(total=1)
            else:
                bar = tqdm(total=n_loops * batch_size * num_features)

        # Iterated sampling.
        tracker_list = []
        for i, ind in enumerate(ordering):
            tracker = utils.ImportanceTracker()
            for it in range(n_loops):
                # Sample data.
                mb = np.random.choice(N, batch_size)
                x = X[mb]
                y = Y[mb]

                # Sample subset of features.
                S = utils.sample_subset_feature(num_features, batch_size, ind)

                # Loss with feature excluded.
                y_hat = self.imputer(x, S)
                loss_discluded = self.loss_fn(y_hat, y)

                # Loss with feature included.
                S[:, ind] = 1
                y_hat = self.imputer(x, S)
                loss_included = self.loss_fn(y_hat, y)

                # Calculate delta sample.
                tracker.update(loss_discluded - loss_included)
                if bar and (not estimate_convergence):
                    bar.update(batch_size)

                # Calculate progress.
                std = tracker.std
                gap = (max(upper_val, tracker.values.item()) -
                       min(lower_val, tracker.values.item()))
                ratio = std / gap

                # Print progress message.
                if verbose:
                    if detect_convergence:
                        print('StdDev Ratio = {:.4f} '
                              '(Converge at {:.4f})'.format(ratio, thresh))
                    else:
                        print('StdDev Ratio = {:.4f}'.format(ratio))

                # Check for convergence.
                if detect_convergence:
                    if ratio < thresh:
                        if verbose:
                            print('Detected feature convergence')

                        # Skip bar ahead.
                        if bar:
                            bar.n = np.around(
                                bar.total * (i + 1) / num_features, 4)
                            bar.refresh()
                        break

                # Update convergence estimation.
                if bar and estimate_convergence:
                    std_est = ratio * np.sqrt(it + 1)
                    n_est = (std_est / thresh)**2
                    bar.n = np.around((i + (it + 1) / n_est) / num_features, 4)
                    bar.refresh()

            if verbose:
                print('Done with feature {}'.format(i))
            tracker_list.append(tracker)

            # Adjust min max value.
            upper_val = max(upper_val, tracker.values.item())
            lower_val = min(lower_val, tracker.values.item())

        if bar:
            bar.close()

        # Extract SAGE values.
        reverse_ordering = [ordering.index(ind) for ind in range(num_features)]
        values = np.array(
            [tracker_list[ind].values.item() for ind in reverse_ordering])
        std = np.array(
            [tracker_list[ind].std.item() for ind in reverse_ordering])

        return core.Explanation(values, std, explanation_type)
示例#2
0
    def __call__(self,
                 xy,
                 batch_size,
                 n_permutations=None,
                 detect_convergence=None,
                 convergence_threshold=0.01,
                 verbose=False,
                 bar=False):
        '''
        Estimate SAGE values.

        Args:
          xy: tuple of np.ndarrays for input and output.
          batch_size: number of examples to be processed at once. You should use
            as large of a batch size as possible without exceeding available
            memory.
          n_samples: number of permutations. If not specified, samples
            are taken until the estimates converge.
          detect_convergence: whether to detect convergence of SAGE estimates.
          convergence_threshold: confidence interval threshold for determining
            convergence. Represents portion of estimated sum of SAGE values.
          verbose: whether to print progress messages.
          bar: whether to display progress bar.

        Returns: SAGEValues object.
        '''
        X, Y = xy
        N, input_size = X.shape

        # Verify model.
        X, Y = utils.verify_model_data(self.model, X, Y, self.loss_fn,
                                       batch_size * self.imputer.samples)

        # For detecting cnovergence.
        total = estimate_total(self.model, xy,
                               batch_size * self.imputer.samples, self.loss_fn)
        if n_permutations is None:
            # Turn convergence detectio on.
            if detect_convergence is None:
                detect_convergence = True
            elif not detect_convergence:
                detect_convergence = True
                print('Turning convergence detection on')

            # Turn bar off.
            if bar:
                bar = False
                print('Turning bar off')

            # Set n_samples to an extremely large number.
            n_permutations = 1e20

        if detect_convergence:
            assert 0 < convergence_threshold < 1

        # Print message explaining parameter choices.
        if verbose:
            print('{} permutations, batch size (batch x samples) = {}'.format(
                n_permutations, batch_size * self.imputer.samples))

        # For updating scores.
        tracker = utils.ImportanceTracker()

        # Permutation sampling.
        n_loops = int(n_permutations / batch_size)
        if bar:
            bar = tqdm(total=n_loops * batch_size * input_size)
        for _ in range(n_loops):
            # Sample data.
            mb = np.random.choice(N, batch_size)
            x = X[mb]
            y = Y[mb]

            # Sample permutations.
            S = np.zeros((batch_size, input_size))
            permutations = np.tile(np.arange(input_size), (batch_size, 1))
            for i in range(batch_size):
                np.random.shuffle(permutations[i])

            # Make prediction with missing features.
            y_hat = self.model(self.imputer(x, S))
            y_hat = np.mean(y_hat.reshape(-1, self.imputer.samples,
                                          *y_hat.shape[1:]),
                            axis=1)
            prev_loss = self.loss_fn(y_hat, y)

            # Setup.
            arange = np.arange(batch_size)
            scores = np.zeros((batch_size, input_size))

            for i in range(input_size):
                # Add next feature.
                inds = permutations[:, i]
                S[arange, inds] = 1.0

                # Make prediction with missing features.
                y_hat = self.model(self.imputer(x, S))
                y_hat = np.mean(y_hat.reshape(-1, self.imputer.samples,
                                              *y_hat.shape[1:]),
                                axis=1)
                loss = self.loss_fn(y_hat, y)

                # Calculate delta sample.
                scores[arange, inds] = prev_loss - loss
                prev_loss = loss
                if bar:
                    bar.update(batch_size)

            # Update tracker.
            tracker.update(scores)

            # Check for convergence.
            conf = np.max(tracker.std)
            if verbose:
                print('Conf = {:.4f}, Total = {:.4f}'.format(conf, total))
            if detect_convergence:
                if (conf / total) < convergence_threshold:
                    if verbose:
                        print('Stopping early')
                    break

        return utils.SAGEValues(tracker.values, tracker.std)
示例#3
0
    def __call__(self,
                 X,
                 Y=None,
                 batch_size=512,
                 detect_convergence=True,
                 thresh=0.025,
                 n_permutations=None,
                 min_coalition=0.0,
                 max_coalition=1.0,
                 verbose=False,
                 bar=True):
        '''
        Estimate SAGE values.

        Args:
          X: input data.
          Y: target data. If None, model output will be used.
          batch_size: number of examples to be processed in parallel, should be
            set to a large value.
          detect_convergence: whether to stop when approximately converged.
          thresh: threshold for determining convergence.
          n_permutations: number of permutations to unroll.
          min_coalition: minimum coalition size (int or float).
          max_coalition: maximum coalition size (int or float).
          verbose: print progress messages.
          bar: display progress bar.

        The default behavior is to detect convergence based on the width of the
        SAGE values' confidence intervals. Convergence is defined by the ratio
        of the maximum standard deviation to the gap between the largest and
        smallest values.

        Returns: Explanation object.
        '''
        # Determine explanation type.
        if Y is not None:
            explanation_type = 'SAGE'
        else:
            explanation_type = 'Shapley Effects'

        # Verify model.
        N, _ = X.shape
        num_features = self.imputer.num_groups
        X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn,
                                       batch_size)

        # Determine min/max coalition sizes.
        if isinstance(min_coalition, float):
            min_coalition = int(min_coalition * num_features)
        if isinstance(max_coalition, float):
            max_coalition = int(max_coalition * num_features)
        assert min_coalition >= 0
        assert max_coalition <= num_features
        assert min_coalition < max_coalition
        if min_coalition > 0 or max_coalition < num_features:
            relaxed = True
            explanation_type = 'Relaxed ' + explanation_type
        else:
            relaxed = False
            sample_counts = None

        # Possibly force convergence detection.
        if n_permutations is None:
            n_permutations = 1e20
            if not detect_convergence:
                detect_convergence = True
                if verbose:
                    print('Turning convergence detection on')

        if detect_convergence:
            assert 0 < thresh < 1

        # Set up bar.
        n_loops = int(n_permutations / batch_size)
        if bar:
            if detect_convergence:
                bar = tqdm(total=1)
            else:
                bar = tqdm(total=n_loops * batch_size * num_features)

        # Setup.
        arange = np.arange(batch_size)
        scores = np.zeros((batch_size, num_features))
        S = np.zeros((batch_size, num_features), dtype=bool)
        permutations = np.tile(np.arange(num_features), (batch_size, 1))
        tracker = utils.ImportanceTracker()

        # Permutation sampling.
        for it in range(n_loops):
            # Sample data.
            mb = np.random.choice(N, batch_size)
            x = X[mb]
            y = Y[mb]

            # Sample permutations.
            S[:] = 0
            for i in range(batch_size):
                np.random.shuffle(permutations[i])

            # Calculate sample counts.
            if relaxed:
                scores[:] = 0
                sample_counts = np.zeros(num_features, dtype=int)
                for i in range(batch_size):
                    sample_counts[permutations[
                        i, min_coalition:max_coalition]] = (sample_counts[
                            permutations[i, min_coalition:max_coalition]] + 1)

            # Add necessary features to minimum coalition.
            for i in range(min_coalition):
                # Add next feature.
                inds = permutations[:, i]
                S[arange, inds] = 1

            # Make prediction with minimum coalition.
            y_hat = self.imputer(x, S)
            prev_loss = self.loss_fn(y_hat, y)

            # Add all remaining features.
            for i in range(min_coalition, max_coalition):
                # Add next feature.
                inds = permutations[:, i]
                S[arange, inds] = 1

                # Make prediction with missing features.
                y_hat = self.imputer(x, S)
                loss = self.loss_fn(y_hat, y)

                # Calculate delta sample.
                scores[arange, inds] = prev_loss - loss
                prev_loss = loss

                # Update bar (if not detecting convergence).
                if bar and (not detect_convergence):
                    bar.update(batch_size)

            # Update tracker.
            tracker.update(scores, sample_counts)

            # Calculate progress.
            std = np.max(tracker.std)
            gap = max(tracker.values.max() - tracker.values.min(), 1e-12)
            ratio = std / gap

            # Print progress message.
            if verbose:
                if detect_convergence:
                    print(f'StdDev Ratio = {ratio:.4f} '
                          f'(Converge at {thresh:.4f})')
                else:
                    print(f'StdDev Ratio = {ratio:.4f}')

            # Check for convergence.
            if detect_convergence:
                if ratio < thresh:
                    if verbose:
                        print('Detected convergence')

                    # Skip bar ahead.
                    if bar:
                        bar.n = bar.total
                        bar.refresh()
                    break

            # Update convergence estimation.
            if bar and detect_convergence:
                N_est = (it + 1) * (ratio / thresh)**2
                bar.n = np.around((it + 1) / N_est, 4)
                bar.refresh()

        if bar:
            bar.close()

        return core.Explanation(tracker.values, tracker.std, explanation_type)
示例#4
0
    def __call__(self,
                 xy,
                 batch_size,
                 n_samples=None,
                 detect_convergence=False,
                 convergence_threshold=0.01,
                 verbose=False,
                 bar=False):
        '''
        Estimate SAGE values.

        Args:
          xy: tuple of np.ndarrays for input and output.
          batch_size: number of examples to be processed at once. You should use
            as large of a batch size as possible without exceeding available
            memory.
          n_samples: number of samples for each feature. If not specified,
            samples are taken until the estimates converge.
          detect_convergence: whether to detect convergence of SAGE estimates.
          convergence_threshold: confidence interval threshold for determining
            convergence. Represents portion of estimated sum of SAGE values.
          verbose: whether to print progress messages.
          bar: whether to display progress bar.

        Returns: SAGEValues object.
        '''
        X, Y = xy
        N, input_size = X.shape

        # Verify model.
        X, Y = utils.verify_model_data(self.model, X, Y, self.loss_fn,
                                       batch_size * self.imputer.samples)

        # For detecting cnovergence.
        total = estimate_total(self.model, xy,
                               batch_size * self.imputer.samples, self.loss_fn)
        if n_samples is None:
            # Turn convergence detectio on.
            if detect_convergence is None:
                detect_convergence = True
            elif not detect_convergence:
                detect_convergence = True
                print('Turning convergence detection on')

            # Turn bar off.
            if bar:
                bar = False
                print('Turning bar off')

            # Set n_samples to an extremely large number.
            n_samples = 1e20

        if detect_convergence:
            assert 0 < convergence_threshold < 1

        if verbose:
            print('{} samples/feat, batch size (batch x samples) = {}'.format(
                n_samples, batch_size * self.imputer.samples))

        # For updating scores.
        tracker_list = []

        # Iterated sampling.
        n_loops = int(n_samples / batch_size)
        if bar:
            bar = tqdm(total=n_loops * batch_size * input_size)
        for ind in range(input_size):
            tracker = utils.ImportanceTracker()
            for _ in range(n_loops):
                # Sample data.
                mb = np.random.choice(N, batch_size)
                x = X[mb]
                y = Y[mb]

                # Sample subset of features.
                S = utils.sample_subset_feature(input_size, batch_size, ind)

                # Loss with feature excluded.
                y_hat = self.model(self.imputer(x, S))
                y_hat = np.mean(y_hat.reshape(-1, self.imputer.samples,
                                              *y_hat.shape[1:]),
                                axis=1)
                loss_discluded = self.loss_fn(y_hat, y)

                # Loss with feature included.
                S[:, ind] = 1.0
                y_hat = self.model(self.imputer(x, S))
                y_hat = np.mean(y_hat.reshape(-1, self.imputer.samples,
                                              *y_hat.shape[1:]),
                                axis=1)
                loss_included = self.loss_fn(y_hat, y)

                # Calculate delta sample.
                tracker.update(loss_discluded - loss_included)
                if bar:
                    bar.update(batch_size)

                # Check for convergence.
                conf = tracker.std
                if verbose:
                    print('Imp = {:.4f}, Conf = {:.4f}, Total = {:.4f}'.format(
                        tracker.values, conf, total))
                if detect_convergence:
                    if (conf / total) < convergence_threshold:
                        if verbose:
                            print('Stopping feature early')
                        break

            if verbose:
                print('Done with feature {}'.format(ind))
            tracker_list.append(tracker)

        return utils.SAGEValues(
            np.array([tracker.values.item() for tracker in tracker_list]),
            np.array([tracker.std.item() for tracker in tracker_list]))
示例#5
0
    def __call__(self,
                 X,
                 Y=None,
                 batch_size=512,
                 sign_confidence=0.99,
                 narrow_thresh=0.025,
                 optimize_ordering=True,
                 ordering_batches=1,
                 verbose=False,
                 bar=True):
        '''
        Estimate SAGE values.

        Args:
          X: input data.
          Y: target data. If None, model output will be used.
          batch_size: number of examples to be processed in parallel, should be
            set to a large value.
          sign_confidence: confidence level on sign.
          narrow_thresh: threshold for detecting that the standard deviation is
            small enough
          optimize_ordering: whether to guess an ordering of features based on
            importance. May accelerate convergence.
          ordering_batches: number of minibatches while determining ordering.
          verbose: print progress messages.
          bar: display progress bar.

        Convergence for each SAGE value is detected when one of two conditions
        holds: (1) the sign is known with high confidence (given by
        sign_confidence), or (2) the standard deviation of the Gaussian
        confidence interval is sufficiently narrow (given by narrow_thresh).

        Returns: Explanation object.
        '''
        # Determine explanation type.
        if Y is not None:
            explanation_type = 'SAGE'
        else:
            explanation_type = 'Shapley Effects'

        # Verify model.
        N, _ = X.shape
        num_features = self.imputer.num_groups
        X, Y = utils.verify_model_data(self.imputer, X, Y, self.loss_fn,
                                       batch_size)

        # Verify thresholds.
        assert 0 < narrow_thresh < 1
        assert 0.9 <= sign_confidence < 1
        sign_thresh = 1 / norm.ppf(sign_confidence)

        # For detecting convergence.
        total = estimate_total(self.imputer, X, Y, batch_size, self.loss_fn)
        upper_val = max(total / num_features, 0)
        lower_val = min(total / num_features, 0)

        # Feature ordering.
        if optimize_ordering:
            if verbose:
                print('Determining feature ordering...')
            holdout_importance = estimate_holdout_importance(
                self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches)
            if verbose:
                print('Done')
            # Use np.abs in case there are large negative contributors.
            ordering = list(np.argsort(np.abs(holdout_importance))[::-1])
        else:
            ordering = list(range(num_features))

        # Set up bar.
        if bar:
            bar = tqdm(total=1)

        # Iterated sampling.
        tracker_list = []
        for i, ind in enumerate(ordering):
            tracker = utils.ImportanceTracker()
            it = 0
            converged = False
            while not converged:
                # Sample data.
                mb = np.random.choice(N, batch_size)
                x = X[mb]
                y = Y[mb]

                # Sample subset of features.
                S = utils.sample_subset_feature(num_features, batch_size, ind)

                # Loss with feature excluded.
                y_hat = self.imputer(x, S)
                loss_discluded = self.loss_fn(y_hat, y)

                # Loss with feature included.
                S[:, ind] = 1
                y_hat = self.imputer(x, S)
                loss_included = self.loss_fn(y_hat, y)

                # Calculate delta sample.
                tracker.update(loss_discluded - loss_included)

                # Calculate progress.
                val = tracker.values.item()
                std = tracker.std.item()
                gap = max(max(upper_val, val) - min(lower_val, val), 1e-12)
                converged_sign = (std / max(np.abs(val), 1e-12)) < sign_thresh
                converged_narrow = (std / gap) < narrow_thresh

                # Print progress message.
                if verbose:
                    print('Sign Ratio = {:.4f} (Converge at {:.4f}), '
                          'Narrow Ratio = {:.4f} (Converge at {:.4f})'.format(
                              std / np.abs(val), sign_thresh, std / gap,
                              narrow_thresh))

                # Check for convergence.
                converged = converged_sign or converged_narrow
                if converged:
                    if verbose:
                        print('Detected feature convergence')

                    # Skip bar ahead.
                    if bar:
                        bar.n = np.around(bar.total * (i + 1) / num_features,
                                          4)
                        bar.refresh()

                # Update convergence estimation.
                elif bar:
                    N_sign = (it + 1) * ((std / np.abs(val)) / sign_thresh)**2
                    N_narrow = (it + 1) * ((std / gap) / narrow_thresh)**2
                    N_est = min(N_sign, N_narrow)
                    bar.n = np.around((i + (it + 1) / N_est) / num_features, 4)
                    bar.refresh()

                # Increment iteration variable.
                it += 1

            if verbose:
                print('Done with feature {}'.format(i))
            tracker_list.append(tracker)

            # Adjust min max value.
            upper_val = max(upper_val, tracker.values.item())
            lower_val = min(lower_val, tracker.values.item())

        if bar:
            bar.close()

        # Extract SAGE values.
        reverse_ordering = [ordering.index(ind) for ind in range(num_features)]
        values = np.array(
            [tracker_list[ind].values.item() for ind in reverse_ordering])
        std = np.array(
            [tracker_list[ind].std.item() for ind in reverse_ordering])

        return core.Explanation(values, std, explanation_type)