예제 #1
0
class BaseLendingEnv(core.FairnessEnv):
    """Base loan decision environment.

  In each step, the agent decides whether to accept or reject an
  application.

  The base class is abstract.
  """

    metadata = {'render.modes': ['human']}
    default_param_builder = lending_params.Params
    group_membership_var = 'group'
    _cash_updater = _CashUpdater()
    _parameter_updater = core.NoUpdate()
    _applicant_updater = _ApplicantSampler()

    def __init__(self, params=None):
        params = (self.default_param_builder() if params is None else params
                  )  # type: lending_params.Params

        # The action space of the agent is Accept/Reject.
        self.action_space = spaces.Discrete(2)

        # Bank's cash is a scalar and cannot be negative.
        bank_cash_space = spaces.Box(low=0,
                                     high=params.max_cash,
                                     shape=(),
                                     dtype=np.float32)

        # Two-dimensional observation space describes each loan applicant.
        loan_applicant_space = spaces.Box(
            params.min_observation,
            params.max_observation,
            dtype=np.float32,
            shape=(params.applicant_distribution.dim, ))

        group_space = spaces.MultiBinary(params.num_groups)

        self.observable_state_vars = {
            'bank_cash': bank_cash_space,
            'applicant_features': loan_applicant_space,
            'group': group_space
        }

        super(BaseLendingEnv, self).__init__(params)
        self._state_init()

    def _state_init(self, rng=None):
        self.state = State(
            # Copy in case state.params get mutated, initial_params stays pristine.
            params=copy.deepcopy(self.initial_params),
            rng=rng or np.random.RandomState(),
            bank_cash=self.initial_params.bank_starting_cash)
        self._applicant_updater.update(self.state, None)

    def reset(self):
        """Resets the environment."""
        self._state_init(self.state.rng)
        return super(BaseLendingEnv, self).reset()

    def _is_done(self):
        """Returns True if the bank cash is less than loan_amount."""
        return self.state.bank_cash < self.state.params.loan_amount

    def _step_impl(self, state, action):
        """Run one timestep of the environment's dynamics.

    In a single step, the agent decides whether to accept or reject an
    application.

    The potential payoffs of rejected application are always 0.
    If an application is accepted, the payoffs are:
      -loan_amount if the applicant defaults.
      +loan_amount*interest_rate if the applicant successfully pays back.

    Args:
      state: A `State` object containing the current state.
      action: An action in `action_space`.

    Returns:
      A `State` object containing the updated state.
    """

        self._cash_updater.update(self.state, action)
        self._parameter_updater.update(self.state, action)
        self._applicant_updater.update(self.state, action)
        return self.state

    def render(self, mode='human'):
        """Renders the history and current state using matplotlib.

    Args:
      mode: string indicating the rendering mode. The only supported mode is
        `human`.
    """

        if mode == 'human':
            if self.state.params.applicant_distribution.dim != 2:
                raise NotImplementedError(
                    'Cannot render if applicant features are not exactly 2 dimensional. '
                    'Got %d dimensional applicant features.' %
                    self.state.params.applicant_distribution.dim)

            plt.figure(figsize=(12, 4))
            plt.subplot(1, 2, 1)
            plt.xlim(-2, 2)
            plt.ylim(-2, 2)
            plt.title('Applicant Features')
            plt.xticks([], [])
            plt.yticks([], [])
            for state, action in self.history:
                if action == 1:
                    x, y = state.applicant_features
                    color = 'r' if state.will_default else 'b'
                    plt.plot([x], [y],
                             _MARKERS[state.group_id] + color,
                             markersize=12)
            plt.xlabel('Feature 1')
            plt.ylabel('Feature 2')

            x, y = self.state.applicant_features

            plt.plot([x], [y],
                     _MARKERS[self.state.group_id] + 'k',
                     markersize=15)

            plt.subplot(1, 2, 2)
            plt.title('Cash')
            plt.plot([state.bank_cash
                      for state, _ in self.history] + [self.state.bank_cash])
            plt.ylabel('# loans available')
            plt.xlabel('Time')
            plt.tight_layout()
        else:
            super(BaseLendingEnv,
                  self).render(mode)  # Raises NotImplementedError
예제 #2
0
 def test_noop_state_updater_does_nothing(self):
     env = test_util.DummyEnv()
     state = env._get_state()
     before = copy.deepcopy(state)
     core.NoUpdate().update(state, env.action_space.sample())
     self.assertEqual(state, before)