示例#1
0
    def test_10(self):
        """
        Tests driving symmetric network.
        """
        hints = {'idx':'name', 'lat':'y', 'lon':'x'}
        transit_matrix_1 = TransitMatrix('drive',
            primary_input='tests/test_data/sources.csv',
            primary_hints=hints, secondary_hints=hints)
        transit_matrix_1.process()

        assert True
示例#2
0
 def test_26(self):
     """
     Test read OTP csv.
     """
     transit_matrix = TransitMatrix('otp',
                                    primary_input='tests/test_data/sample_otp.csv')
     try:
         transit_matrix.process()
         assert False
     except AssertionError:
         return
示例#3
0
 def test_6(self):
     """
     Tests bike network and write_csv.
     """
     hints = {'idx':'name', 'lat':'y', 'lon':'x'}
     transit_matrix_1 = TransitMatrix('bike',
         primary_input='tests/test_data/sources.csv',
         secondary_input='tests/test_data/dests.csv',
         primary_hints=hints, secondary_hints=hints)
     transit_matrix_1.process()
     transit_matrix_1.write_csv(self.datapath + "test_6.csv")
     assert True
示例#4
0
 def test_19(self):
     """
     Tests UnableToParsePrimaryDataException.
     """
     transit_matrix_1 = TransitMatrix('drive', primary_input="tests/test_data/sources.csv",
                                      primary_hints= {'idx':'wrong_column_name', 'lat':'y', 'lon':'x'})
     try:
         transit_matrix_1.process()
     except UnableToParsePrimaryDataException:
         assert True
         return
     assert False
示例#5
0
 def test_17(self):
     """
     Tests PrimaryDataNotFoundException.
     """
     transit_matrix_1 = TransitMatrix('drive', primary_input="file_that_doesnt_exist.csv",
                                      primary_hints= {'idx':'name', 'lat':'y', 'lon':'x'})
     try:
         transit_matrix_1.process()
     except PrimaryDataNotFoundException:
         assert True
         return
     assert False
示例#6
0
    def test_8(self):
        """
        Tests string labels.
        """
        hints = {'idx':'name', 'lat':'y', 'lon':'x'}
        transit_matrix_1 = TransitMatrix('bike',
            primary_input='tests/test_data/sources.csv',
            secondary_input='tests/test_data/dests_a.csv',
            primary_hints=hints, secondary_hints=hints)
        transit_matrix_1.process()

        assert True
 def run_matrix_job(job):
     if 'primary_hints' not in job.orders['init_kwargs']:
         raise MissingHintsException('primary_hints')
     if job.secondary_resource is not None and 'secondary_hints' not in job.orders[
             'init_kwargs']:
         raise MissingHintsException('secondary_hints')
     job.orders['init_kwargs']['primary_input'] = job.primary_resource
     job.orders['init_kwargs']['secondary_input'] = job.secondary_resource
     matrix = TransitMatrix(**job.orders['init_kwargs'])
     matrix.process()
     output_filename = job.job_folder + 'output.csv'
     matrix.write_csv(output_filename)
示例#8
0
 def test_5(self):
     """
     Tests the use_meters flag.
     """
     hints = {'idx':'name', 'lat':'y', 'lon':'x'}
     custom_configs = Configs()
     custom_configs.use_meters = True
     transit_matrix_1 = TransitMatrix('walk',
         primary_input='tests/test_data/sources.csv',
         secondary_input='tests/test_data/dests.csv',
         primary_hints=hints, secondary_hints=hints,
         configs=custom_configs)
     transit_matrix_1.process()
示例#9
0
    def test_16(self):
        """
        Tests write_csv (symmetric).
        """
        hints = {'idx':'name', 'lat':'y', 'lon':'x'}
        transit_matrix_1 = TransitMatrix('walk',
            primary_input='tests/test_data/sources.csv',
            primary_hints=hints)
        transit_matrix_1.process()
        filename = self.datapath + 'test_16_file.csv'
        transit_matrix_1.write_csv(filename)

        assert True
示例#10
0
    def test_28(self):
        """
        Test ImproperIndecesTypeException is raised
        when indeces are not strings or ints.
        """
        hints = {'idx': 'name', 'lat': 'y', 'lon': 'x'}
        transit_matrix_1 = TransitMatrix('walk',
                                         primary_input='tests/test_data/sources_bad_indeces_type.csv',
                                         primary_hints=hints)

        try:
            transit_matrix_1.process()
            assert False
        except ImproperIndecesTypeException:
            return
示例#11
0
    def test_7(self):
        """
        Tests symmetric bike network and write_tmx
        """
        hints = {'idx':'name', 'lat':'y', 'lon':'x'}
        transit_matrix_1 = TransitMatrix('bike',
            primary_input='tests/test_data/sources.csv',
            primary_hints=hints, secondary_hints=hints)
        transit_matrix_1.process()
        filename = self.datapath + "test_7.tmx"
        transit_matrix_1.write_tmx(filename)
        transit_matrix_2 = TransitMatrix('walk',
                                         read_from_file=filename)

        assert True
示例#12
0
    def test_27(self):
        """
        Test write_csv/read_file
        """
        hints = {'idx': 'name', 'lat': 'y', 'lon': 'x'}
        transit_matrix_1 = TransitMatrix('walk',
                                         primary_input='tests/test_data/sources.csv',
                                         primary_hints=hints)

        transit_matrix_1.process()
        transit_matrix_1.matrix_interface.print_data_frame()
        filename = self.datapath + 'test_27_file.csv'
        transit_matrix_1.write_csv(filename)

        transit_matrix_2 = TransitMatrix('walk', read_from_file=filename)
        transit_matrix_2.matrix_interface.print_data_frame()
示例#13
0
    def test_24(self):
        """
        Tests write_tmx (asymmetric).
        """
        hints = {'idx':'name', 'lat':'y', 'lon':'x'}
        transit_matrix_1 = TransitMatrix('bike',
            primary_input='tests/test_data/sources.csv',
            secondary_input='tests/test_data/dests.csv',
            primary_hints=hints, secondary_hints=hints)
        transit_matrix_1.process()
        filename = self.datapath + 'test_24_file.tmx'
        transit_matrix_1.write_tmx(filename)

        transit_matrix_2 = TransitMatrix('bike', read_from_file=filename)

        assert True
示例#14
0
    def test_25(self):
        """
        Tests write file with bad filename (symmetric).
        """
        hints = {'idx':'name', 'lat':'y', 'lon':'x'}
        transit_matrix_1 = TransitMatrix('walk',
            primary_input='tests/test_data/sources.csv',
            primary_hints=hints)

        transit_matrix_1.process()
        bad_filename = self.datapath + 'test_25_file.ext'

        try:
            transit_matrix_1.write_tmx(bad_filename)
            assert False
        except WriteTMXFailedException:
            return
示例#15
0
class ModelData:
    """
    Common resources for spatial_access.Models.
    """
    def __init__(self,
                 network_type,
                 sources_filename,
                 destinations_filename,
                 source_column_names=None,
                 dest_column_names=None,
                 configs=None,
                 debug=False):
        """
        Args:
            network_type: string, one of {'walk', 'bike', 'drive', 'otp'}.
            sources_filename: string, csv filename.
            destinations_filename: string, csv filename.
            source_column_names: dictionary, map column names to expected values.
            dest_column_names: dictionary, map column names to expected values.
            configs: defaults to None, else pass in an instance of Configs to override
                default values for p2p.
            debug: boolean, enable to see more detailed logging output.
        """
        self.network_type = network_type
        self.transit_matrix = None
        self.dests = None
        self.sources = None
        self.model_results = None
        self.aggregated_results = None
        self.all_categories = {}
        self.focus_categories = {}

        self.configs = configs

        self._aggregation_args = {}
        self._is_source = True
        self._is_aggregatable = True
        self._requires_user_aggregation_type = False
        self._result_column_names = None

        self.sources_filename = sources_filename
        self.destinations_filename = destinations_filename

        # column_names and file_hints are similar, both map intended_name->actual_data_name
        # the difference is column names should be complete/contain all needed fields
        self.source_column_names = source_column_names
        self.dest_column_names = dest_column_names

        # hints are partial/potentially incomplete, and supplied
        # by p2p.TransitMatrix
        self._source_file_hints = None
        self._dest_file_hints = None

        self.sources_in_range = {}
        self.dests_in_range = {}

        # initialize logger
        self.debug = debug
        self.logger = None
        self.set_logging(debug)

    def write_transit_matrix_to_csv(self, filename=None):
        """
        Args:
            filename: string (or none, in which case a filename will
            be automatically generated).
        Write transit matrix to csv.
        """
        self.transit_matrix.write_csv(filename)

    def write_transit_matrix_to_tmx(self, filename=None):
        """
        Args:
            filename: string (or none, in which case a filename will
            be automatically generated)
        Write transit matrix to tmx.
        """
        self.transit_matrix.write_tmx(filename)

    @staticmethod
    def _get_output_filename(keyword, extension='csv', file_path='data/'):
        """
        Args:
            keyword: string such as "model_results" or "aggregated_results"
                to build the filename.
            extension: file type extension (no ".")
            file_path: subdirectory.
        Returns: string of unused filename.
        """
        if file_path is None:
            file_path = "data/"
        if not os.path.exists(file_path):
            os.makedirs(file_path)
        filename = os.path.join(file_path,
                                '{}_0.{}'.format(keyword, extension))
        counter = 1
        while os.path.isfile(filename):
            filename = os.path.join(
                file_path, '{}_{}.{}'.format(keyword, counter, extension))
            counter += 1

        return filename

    def get_population(self, source_id):
        """
        Args:
            source_id: string/int
        Returns: the population at a source point.
        """
        return self.sources.loc[source_id, 'population']

    def get_capacity(self, dest_id):
        """
        Args:
            dest_id: string/int
        Returns: the capacity value at a dest point.
        """
        return self.dests.loc[dest_id, 'capacity']

    def get_category(self, dest_id):
        """
        Args:
            dest_id: string/int
        Returns: the category value at a dest point.
        """
        return self.dests.loc[dest_id, 'category']

    def get_all_dest_ids(self):
        """
        Returns: all ids of destination data frame.
        """
        return list(self.dests.index)

    def get_all_source_ids(self):
        """
        Returns: all ids of source data frame.
        """
        return list(self.sources.index)

    def get_ids_for_category(self, category):
        """
        Given category, return an array of all indeces
        which match. If category is all_categories, return all indeces.
        """
        return list(self.dests[self.dests['category'] == category].index)

    def set_logging(self, debug):
        """
        Args:
            debug: set to true for more detailed logging
                output.
        """

        if debug:
            logging.basicConfig(level=logging.DEBUG)
        else:
            logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)

    def load_transit_matrix(self, read_from_file=None):
        """
        Load the transit matrix (and sources/dests).
        Args:
            read_from_file: filename of a tmx or csv file to load.
                This allows the user to bypass computing the
                transit matrix from scratch. If read_from_file is
                None, the user will be directed to compute the
                transit matrix from given source/dest data.
        Raises:
            SourceDataNotFoundException: Cannot find source data.
            DestDataNotFoundException: Cannot find dest data.
        """
        if read_from_file:
            self.transit_matrix = TransitMatrix(self.network_type,
                                                read_from_file=read_from_file,
                                                debug=self.debug)
        else:
            self.transit_matrix = TransitMatrix(
                self.network_type,
                primary_input=self.sources_filename,
                secondary_input=self.destinations_filename,
                primary_hints=self.source_column_names,
                secondary_hints=self.dest_column_names,
                configs=None,
                debug=self.debug)
            try:
                self.transit_matrix.process()
            except PrimaryDataNotFoundException:
                raise SourceDataNotFoundException()
            except SecondaryDataNotFoundException:
                raise DestDataNotFoundException()

            # borrow hints for use in load_sources() and load_dests() if not user supplied
            if self._source_file_hints is None:
                self._source_file_hints = self.transit_matrix.primary_hints
            if self._dest_file_hints is None:
                self._dest_file_hints = self.transit_matrix.secondary_hints

        self.reload_sources()
        self.reload_dests()

    def reload_sources(self, filename=None):
        """
        Load the source points for the model (from csv).
        For each point, the table should contain:
        -unique identifier (integer or string)
        -latitude & longitude
        -population (integer) [only for some models]

        Args:
            filename: string
        Raises:
            SourceDataNotFoundException: Cannot find source
                data.
            SourceDataNotParsableException: Provided source_column_names
                do not correspond to column names.
        """

        if filename:
            self.sources_filename = filename
        try:
            self.sources = pd.read_csv(self.sources_filename)
        except FileNotFoundError:
            raise SourceDataNotFoundException()

        if self.source_column_names is None:
            # extract the column names from the table
            population = ''
            idx = ''
            lat = ''
            lon = ''

            if self._source_file_hints is not None:
                if 'idx' in self._source_file_hints:
                    idx = self._source_file_hints['idx']
                if 'population' in self._source_file_hints:
                    population = self._source_file_hints['population']
                if 'lat' in self._source_file_hints:
                    lat = self._source_file_hints['lat']
                if 'lon' in self._source_file_hints:
                    lon = self._source_file_hints['lon']

            # extract the column names from the table for whichever fields
            # were not gleaned from self.source_file_hints
            source_data_columns = self.sources.columns.values
            print('The variables in your data set are:')
            for var in source_data_columns:
                print('> ', var)
            while idx not in source_data_columns:
                idx = input('Enter the unique index variable: ')
            print(
                'If you have no population variable, write "skip" (no quotations)'
            )
            while population not in source_data_columns and population != 'skip':
                population = input('Enter the population variable: ')
            while lat not in source_data_columns:
                lat = input('Enter the latitude variable: ')
            while lon not in source_data_columns:
                lon = input('Enter the longitude variable: ')

            # store the col names for later use
            self.source_column_names = {
                'lat': lat,
                'lon': lon,
                'idx': idx,
                'population': population
            }

        try:
            # insert filler values for the population column if
            # user does not want to include it. need it for coverage
            if self.source_column_names['population'] == 'skip':
                self.sources['population'] = 1

            # rename columns, clean the data frame
            rename_cols = {
                self.source_column_names['population']: 'population',
                self.source_column_names['lat']: 'lat',
                self.source_column_names['lon']: 'lon'
            }
            self.sources.set_index(self.source_column_names['idx'],
                                   inplace=True)
            self.sources.rename(columns=rename_cols, inplace=True)
        except KeyError:
            raise SourceDataNotParsableException()

        # drop unused columns
        columns_to_keep = list(rename_cols.values())
        self.sources = self.sources[columns_to_keep]

    def reload_dests(self, filename=None):
        """
        Load the dest points for the model (from csv).
        For each point, the table should contain:
        -unique identifier (integer or string)
        -latitude & longitude
        -category (string/int) [only for some models]
        -capacity (numeric) [only for some models]

        Args:
            filename: string
        Raises:
            DestDataNotFoundException: Cannot find dest
                data.
            DestDataNotParsableException: Provided dest_column_names
                do not correspond to column names.
        """

        if filename:
            self.destinations_filename = filename

        try:
            self.dests = pd.read_csv(self.destinations_filename)
        except FileNotFoundError:
            raise DestDataNotFoundException()

        if self.dest_column_names is None:
            # extract the column names from the table
            category = ''
            capacity = ''
            idx = ''
            lat = ''
            lon = ''

            if self._dest_file_hints is not None:
                if 'idx' in self._dest_file_hints:
                    idx = self._dest_file_hints['idx']
                if 'category' in self._dest_file_hints:
                    category = self._dest_file_hints['category']
                if 'capacity' in self._dest_file_hints:
                    capacity = self._dest_file_hints['capacity']
                if 'lat' in self._dest_file_hints:
                    lat = self._dest_file_hints['lat']
                if 'lon' in self._dest_file_hints:
                    lon = self._dest_file_hints['lon']

            # extract the column names from the table for whichever fields
            # were not gleaned from self.dest_file_hints
            dest_data_columns = self.dests.columns.values
            print('The variables in your data set are:')
            for var in dest_data_columns:
                print('> ', var)
            while idx not in dest_data_columns:
                idx = input('Enter the unique index variable: ')
            print(
                'If you have no capacity variable, write "skip" (no quotations)'
            )
            while capacity not in dest_data_columns and capacity != 'skip':
                capacity = input('Enter the capacity variable: ')
            print(
                'If you have no category variable, write "skip" (no quotations)'
            )
            while category not in dest_data_columns and category != 'skip':
                category = input('Enter the category variable: ')
            while lat not in dest_data_columns:
                lat = input('Enter the latitude variable: ')
            while lon not in dest_data_columns:
                lon = input('Enter the longitude variable: ')
            self.dest_column_names = {
                'lat': lat,
                'lon': lon,
                'idx': idx,
                'category': category,
                'capacity': capacity
            }

        try:
            # insert filler values for the capacity column if
            # user does not want to include it.
            if self.dest_column_names['capacity'] == 'skip':
                self.dests['capacity'] = 1

            # insert filler values for the category column if
            # user does not want to include it.
            if self.dest_column_names['category'] == 'skip':
                self.dests['category'] = 1

            # rename columns, clean the data frame
            rename_cols = {
                self.dest_column_names['lat']: 'lat',
                self.dest_column_names['lon']: 'lon'
            }
            if self.dest_column_names['capacity'] != 'skip':
                rename_cols[self.dest_column_names['capacity']] = 'capacity'
            if self.dest_column_names['category'] != 'skip':
                rename_cols[self.dest_column_names['category']] = 'category'

            self.dests.set_index(self.dest_column_names['idx'], inplace=True)
            self.dests.rename(columns=rename_cols, inplace=True)

        except KeyError:
            raise DestDataNotParsableException()

        # drop unused columns
        columns_to_keep = list(rename_cols.values())
        self.dests = self.dests[columns_to_keep]
        self.all_categories = set(self.dests['category'])

    def get_dests_in_range_of_source(self, source_id):
        """
        Args:
            source_id: string/int
        Returns: a list of dest ids in range of the source.
        """
        return self.dests_in_range[source_id]

    def get_sources_in_range_of_dest(self, dest_id):
        """
        Args:
            dest_id: string/int
        Returns: a list of source ids in range of the dest.
        """
        return self.sources_in_range[dest_id]

    def calculate_dests_in_range(self, upper_threshold):
        """
        Args:
            upper_threshold: numeric, upper threshold of what
                points are considered to be in range.
        """
        self.dests_in_range = self.transit_matrix.matrix_interface.get_dests_in_range(
            upper_threshold)

    def calculate_sources_in_range(self, upper_threshold):
        """
        Args:
            upper_threshold: numeric, upper threshold of what
                points are considered to be in range.
        """
        self.sources_in_range = self.transit_matrix.matrix_interface.get_sources_in_range(
            upper_threshold)

    def get_values_by_source(self, source_id, sort=False):
        """
        Args:
            source_id: string/int
            sort: boolean, set to true for return value
                to be sorted in nondecreasing order.
        Returns: list of (dest_id, value) pairs, with the option
            to sort in increasing order by value.
        """
        return self.transit_matrix.matrix_interface.get_values_by_source(
            source_id, sort)

    def get_values_by_dest(self, dest_id, sort=False):
        """
        Args:
            dest_id: string/int
            sort: boolean, set to true for return value
                to be sorted in nondecreasing order.
        Returns: a list of (source_id, value) pairs, with the option
            to sort in increasing order by value.
        """
        return self.transit_matrix.matrix_interface.get_values_by_dest(
            dest_id, sort)

    def get_population_in_range(self, dest_id):
        """
        Args:
            dest_id: string/int
        Returns: the population within the capacity range for the given
            destination id.
        """
        cumulative_population = 0
        for source_id in self.get_sources_in_range_of_dest(dest_id):
            source_population = self.get_population(source_id)
            if source_population > 0:
                cumulative_population += source_population

        return cumulative_population

    def _map_categories_to_sp_matrix(self):
        """
        Map all categories-> associated dest_ids.
        """
        for dest_id in self.get_all_dest_ids():
            associated_category = self.get_category(dest_id)
            self._add_to_category_map(dest_id, associated_category)

    def _add_to_category_map(self, dest_id, category):
        """
        Args:
            dest_id: string/int
            category: string
        Map the dest_id to the category in the
            transit matrix.
        """
        self.transit_matrix.matrix_interface.add_to_category_map(
            dest_id, category)

    def time_to_nearest_dest(self, source_id, category):
        """
        Args:
            source_id: string/int
            category: string
        Returns: the time to nearest destination for source_id
            of type category. If category is 'all_categories', return
            the time to nearest destination of any type.
        """
        if category == 'all_categories':
            return self.transit_matrix.matrix_interface.time_to_nearest_dest(
                source_id, None)
        else:
            return self.transit_matrix.matrix_interface.time_to_nearest_dest(
                source_id, category)

    def count_dests_in_range_by_categories(self, source_id, category,
                                           upper_threshold):
        """
        Args:
            source_id: int/string
            category: string
            upper_threshold: numeric, upper limit of what is
                considered to be 'in range'.
        Returns: the count of destinations in range
            of the source id per category
        """
        if category == 'all_categories':
            return self.transit_matrix.matrix_interface.count_dests_in_range(
                source_id, upper_threshold, None)
        else:
            return self.transit_matrix.matrix_interface.count_dests_in_range(
                source_id, upper_threshold, category)

        # TODO: optimize this method
    def count_sum_in_range_by_categories(self, source_id, category):
        """
        Args:
            source_id: int/string
            category: string
        Returns: the count of destinations in range
            of the source id per category
        """
        running_sum = 0
        for dest_id in self.get_dests_in_range_of_source(source_id):
            if self.get_category(
                    dest_id) == category or category == 'all_categories':
                running_sum += self.get_capacity(dest_id)
        return running_sum

    def _print_data_frame(self):
        """
        Print the transit matrix.
        Don't call this for anything other
        than trivially small matrices.
        """
        self.transit_matrix.matrix_interface.print_data_frame()

    def _spatial_join_community_index(
            self,
            dataframe,
            shapefile='data/chicago_boundaries/chicago_boundaries.shp',
            spatial_index='community',
            projection='epsg:4326'):
        """
        Join the dataframe with location data from shapefile.
        Args:
            dataframe: pandas dataframe with unique id.
            shapefile: shapefile containing geometry.
            spatial_index: column names of aggregation area in shapefile.
            projection: defaults to 'epsg:4326'
        Returns: dataframe.
        Raises:
            ShapefileNotFoundException: Shapefile not found.
            SpatialIndexNotMatchedException: spatial_index not found in shapefile.
        """
        geometry = [
            Point(xy) for xy in zip(dataframe['lon'], dataframe['lat'])
        ]
        crs = {'init': projection}
        geo_original = gpd.GeoDataFrame(dataframe, crs=crs, geometry=geometry)
        try:
            boundaries_gdf = gpd.read_file(shapefile)
        except FileNotFoundError:
            raise ShapefileNotFoundException(
                'shapefile not found: {}'.format(shapefile))

        geo_result = gpd.sjoin(boundaries_gdf,
                               geo_original,
                               how='right',
                               op='intersects')

        dataframe_columns = list(dataframe.columns)

        geo_result.rename(columns={spatial_index: 'spatial_index'},
                          inplace=True)
        dataframe_columns.append('spatial_index')
        dataframe_columns.append('geometry')
        try:
            geo_result = geo_result[dataframe_columns]
        except KeyError:
            raise SpatialIndexNotMatchedException(
                'Unable to match spatial_index:{}'.format(spatial_index))
        if len(geo_result) != len(dataframe):
            self.logger.warning(
                'Length of joined dataframe ({}) != length of input dataframe ({})'
                .format(len(geo_result), len(dataframe)))
        return geo_result

    def _rejoin_results_with_coordinates(self, model_results, is_source):
        """
        Args:
            model_results: dataframe
            is_source: boolean (tells where to draw lat/long data from)
        Returns: deep copty of dataframe with lat/long data.
        """
        model_results = model_results.copy(deep=True)
        if is_source:
            model_results['lat'] = self.sources['lat']
            model_results['lon'] = self.sources['lon']
        else:
            model_results['lat'] = self.dests['lat']
            model_results['lon'] = self.dests['lon']
        return model_results

    def _build_aggregate(self, data_frame, aggregation_args, shapefile,
                         spatial_index, projection):
        """
        Private method invoked to aggregate dataframe on spatial area.
        Args:
            data_frame: dataframe
            aggregation_args: dictionary mapping each column name to the method by which that
                column should be aggregated, e.g. mean, sum, etc...
            shapefile: filename of shapefile
            spatial_index: index of geospatial area in shapefile
            projection: defaults to 'epsg:4326'

        Returns: aggregated data frame.

        """
        if 'lat' not in data_frame.columns or 'lon' not in data_frame.columns or 'spatial_index' not in data_frame.columns:
            data_frame = self._spatial_join_community_index(
                dataframe=data_frame,
                shapefile=shapefile,
                spatial_index=spatial_index,
                projection=projection)
        aggregated_data_frame = data_frame.groupby('spatial_index').agg(
            aggregation_args)
        return aggregated_data_frame

    def set_focus_categories(self, categories):
        """
        Set the categories that the model should perform computations for.
        Args:
            categories: list of categories.
        Raises:
            UnrecognizedCategoriesException: User passes categories not
                found in the dest data.
        """
        if categories is None:
            self.focus_categories = self.all_categories
        else:
            self.focus_categories = categories
            unrecognized_categories = set(categories) - self.all_categories
            if len(unrecognized_categories) > 0:
                raise UnrecognizedCategoriesException(','.join(
                    [category for category in unrecognized_categories]))

    def aggregate(self,
                  aggregation_type=None,
                  shapefile='data/chicago_boundaries/chicago_boundaries.shp',
                  spatial_index='community',
                  projection='epsg:4326'):
        """

        Args:
            aggregation_type: string, required for models with multiple possiblities
                for aggregating.
            shapefile: filename of shapefile
            spatial_index: index of geospatial area in shapefile
            projection: defaults to 'epsg:4326'

        Returns: aggregated data frame.

        Raises:
            ModelNotCalculatedException: If the model has not yet been
                calculated.
            UnexpectedAggregationTypeException: If the user passes an
                unexpected aggregation type, or no aggregation type
                when one is expected, or an aggregation type when none
                is expected.
        """
        if self.model_results is None:
            raise ModelNotCalculatedException()
        if not self._is_aggregatable:
            raise ModelNotAggregatableException()
        if self._requires_user_aggregation_type:
            if aggregation_type is None:
                raise UnexpectedAggregationTypeException(aggregation_type)
            else:
                if aggregation_type not in {'min', 'max', 'mean'}:
                    raise UnexpectedAggregationTypeException(aggregation_type)
                else:
                    self._aggregation_args = {}
                    for column in self.model_results.columns:
                        self._aggregation_args[column] = aggregation_type

        else:
            if aggregation_type is not None:
                raise UnexpectedAggregationTypeException(aggregation_type)

        results_with_coordinates = self._rejoin_results_with_coordinates(
            self.model_results, self._is_source)

        self.aggregated_results = self._build_aggregate(
            data_frame=results_with_coordinates,
            aggregation_args=self._aggregation_args,
            shapefile=shapefile,
            spatial_index=spatial_index,
            projection=projection)
        return self.aggregated_results

    def write_aggregated_results(self, filename=None, output_type='csv'):
        """
        Args:
            filename: file to write results. If none is given, a valid
                filename will be automatically generated.
            output_type: 'csv' or 'json'.

        Raises:
            AggregateOutputTypeNotExpectedException:

        """
        if filename is not None:
            output_type = filename.split('.')[1]
        else:
            filename = self._get_output_filename(keyword='aggregate',
                                                 extension=output_type,
                                                 file_path='data/')

        if self.aggregated_results is None:
            raise ModelNotAggregatedException()

        if output_type == 'csv':
            self.aggregated_results.to_csv(filename)
        elif output_type == 'json':
            output = {}
            for row in self.aggregated_results.itertuples():
                output[row[0]] = {}
                for i, column in enumerate(self.aggregated_results.columns):
                    output[row[0]][column] = row[i + 1]
            with open(filename, 'w') as file:
                json.dump(output, file)
        else:
            raise AggregateOutputTypeNotExpectedException(output_type)

    def write_results(self, filename=None):
        """
        Write results to csv.
        Args:
            filename: file to write results. If none is given, a valid
            filename will be automatically generated.
        Raises:
            ModelNotCalculatedException: if model has not been calculated.
        """
        if self.model_results is None:
            raise ModelNotCalculatedException()
        if filename is None:
            filename = self._get_output_filename(keyword='model',
                                                 extension='csv',
                                                 file_path='data/')
        self.model_results.to_csv(filename)

    @staticmethod
    def _join_aggregated_data_with_boundaries(
            aggregated_results,
            spatial_index,
            shapefile='data/chicago_boundaries/chicago_boundaries.shp'):
        """
        Args:
            aggregated_results: dataframe
            shapefile: filename of shapefile
            spatial_index: index of geospatial area in shapefile

        Returns: dataframe.

        Raises:
            ShapefileNotFoundException: shapefile not found.
        """
        try:
            boundaries_gdf = gpd.read_file(shapefile)
        except FileNotFoundError:
            raise ShapefileNotFoundException(
                'shapefile not found: {}'.format(shapefile))
        columns_to_keep = list(aggregated_results.columns)
        columns_to_keep.append('geometry')
        columns_to_keep.append(spatial_index)

        results = boundaries_gdf.merge(aggregated_results,
                                       left_on=spatial_index,
                                       right_index=True,
                                       how='outer')
        results.fillna(value=0, inplace=True)
        return results[columns_to_keep]

    def plot_cdf(self,
                 plot_type=None,
                 xlabel="xlabel",
                 ylabel="ylabel",
                 title="title",
                 bins=100,
                 is_density=False,
                 filename=None):
        """
        Args:
            plot_type: If the model has multiple possibilities to plot, specify which
                one.
            xlabel: xlabel for figure.
            ylabel: ylabel for figure.
            title: title for figure.
            bins: integer, number of bins.
            is_density: boolean, true for density plot.
            filename: filename to write to.

        Raises:
            ModelNotAggregatedException: Model is not aggregated.
            UnexpectedPlotColumnException: User passes unexpected plot type.
            TooManyCategoriesToPlotException: Too many categories to plot.
        """
        if self.aggregated_results is None:
            raise ModelNotAggregatedException()

        if self._is_source:
            cdf_eligible = self.model_results[self.sources['population'] > 0]
        else:
            cdf_eligible = self.model_results

        if isinstance(self._result_column_names, str):
            if plot_type is not None:
                raise UnexpectedPlotColumnException(plot_type)
            plot_type = self._result_column_names
        else:
            if plot_type is None:
                raise UnexpectedPlotColumnException(plot_type)

        # initialize block parameters
        mpl.pyplot.close()
        mpl.pyplot.rcParams['axes.facecolor'] = '#cfcfd1'
        fig, ax = mpl.pyplot.subplots(figsize=(8, 4))
        ax.grid(zorder=0)

        available_colors = [
            'black', 'magenta', 'lime', 'red', 'black', 'orange', 'grey',
            'yellow', 'brown', 'teal'
        ]
        color_keys = []
        for column in cdf_eligible.columns:
            if plot_type not in column:
                continue
            x = cdf_eligible[column]
            try:
                color = available_colors.pop(0)
            except IndexError:
                raise TooManyCategoriesToPlotException()
            patch = mpatches.Patch(color=color, label=column)
            color_keys.append(patch)
            n, bins, blah = ax.hist(x,
                                    bins,
                                    density=is_density,
                                    histtype='step',
                                    cumulative=True,
                                    label=column,
                                    color=color,
                                    zorder=3)
        ax.legend(loc='right', handles=color_keys)

        ax.set_title(title)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        if filename is None:
            filename = self._get_output_filename(keyword='figure',
                                                 extension='png',
                                                 file_path='figures/')
        mpl.pyplot.savefig(filename, dpi=400)
        self.logger.info('Plot was saved to: {}'.format(filename))

    def plot_choropleth(
            self,
            column,
            include_destinations=True,
            title='Title',
            color_map='Greens',
            shapefile='data/chicago_boundaries/chicago_boundaries.shp',
            spatial_index='community',
            filename=None):
        """
        Args:
            column: Which column to plot.
            include_destinations: boolean, will plot circles for destinations if true.
            title: Figure title.
            color_map: See https://matplotlib.org/tutorials/colors/colormaps.html
            shapefile: filename of shapefile
            spatial_index: index of geospatial area in shapefile
            filename: file to write figure to.

        Raises:
            ModelNotAggregatedException: Model is not aggregated.
            UnexpectedPlotColumnException: User passes unexpected column.
        """
        if self.aggregated_results is None:
            raise ModelNotAggregatedException()
        if include_destinations:
            categories = self.focus_categories
        else:
            categories = None
        results_with_geometry = self._join_aggregated_data_with_boundaries(
            aggregated_results=self.aggregated_results,
            spatial_index=spatial_index,
            shapefile=shapefile)
        if column not in results_with_geometry.columns:
            raise UnexpectedPlotColumnException(
                'Did not expect column argument: {}'.format(column))

        mpl.pyplot.close()

        mpl.pyplot.rcParams['axes.facecolor'] = '#cfcfd1'

        results_with_geometry.plot(column=column,
                                   cmap=color_map,
                                   edgecolor='black',
                                   linewidth=0.1)

        # add a scatter plot of the vendors over the chloropleth
        if categories is not None:
            available_colors = [
                'magenta', 'lime', 'red', 'black', 'orange', 'grey', 'yellow',
                'brown', 'teal'
            ]
            # if we have too many categories of vendors, limit to using black dots
            if len(categories) > len(available_colors):
                monochrome = True
            else:
                monochrome = False
            color_keys = []
            max_dest_capacity = max(self.dests['capacity'])
            for category in categories:
                if monochrome:
                    color = 'black'
                else:
                    color = available_colors.pop(0)
                    patch = mpatches.Patch(color=color, label=category)
                    color_keys.append(patch)
                dest_subset = self.dests.loc[self.dests['category'] ==
                                             category]
                mpl.pyplot.scatter(
                    y=dest_subset['lat'],
                    x=dest_subset['lon'],
                    color=color,
                    marker='o',
                    s=50 * (dest_subset['capacity'] / max_dest_capacity),
                    label=category)
                if not monochrome:
                    mpl.pyplot.legend(loc='best', handles=color_keys)

        mpl.pyplot.title(title)
        if filename is None:
            filename = self._get_output_filename(keyword='figure',
                                                 extension='png',
                                                 file_path='figures/')
        mpl.pyplot.savefig(filename, dpi=400)

        self.logger.info('Figure was saved to: {}'.format(filename))