コード例 #1
0
 def testGetSampledUciData(self, mock_read_csv):
   mock_df = mock.MagicMock()
   mock_df.sample = mock.MagicMock()
   mock_read_csv.return_value = mock_df
   sample = 0.45  # Arbitrary number.
   utils.get_uci_min_diff_datasets(sample=sample)
   mock_read_csv.assert_called_once()
   mock_df.sample.assert_called_once_with(
       frac=sample, replace=True, random_state=1)
コード例 #2
0
    def testGetMinDiffDatasetsDefaults(self, mock_get_uci_data):
        _ = utils.get_uci_min_diff_datasets()

        mock_get_uci_data.assert_has_calls(
            [mock.call(split='train', sample=None),
             mock.call(split='train')],
            any_order=True)
コード例 #3
0
    def testUciDataCached(self, mock_read_csv):
        utils.get_uci_min_diff_datasets()
        mock_read_csv.assert_called_once()

        # pandas.read_csv should not be called again for the same split ('train')
        mock_read_csv.reset_mock()
        utils.get_uci_min_diff_datasets()
        mock_read_csv.assert_not_called()

        # pandas.read_csv should be called again for a different split
        mock_read_csv.reset_mock()
        utils.get_uci_min_diff_datasets(split='test')
        mock_read_csv.assert_called_once()

        # pandas.read_csv should not be called again (both splits have been cached)
        mock_read_csv.reset_mock()
        utils.get_uci_min_diff_datasets(split='train')
        utils.get_uci_min_diff_datasets(split='test')
        mock_read_csv.assert_not_called()
コード例 #4
0
    def testGetMinDiffDatasets(self, mock_get_uci_data, mock_df_to_dataset):
        mock_md_ds1 = mock.MagicMock()
        mock_md_ds1.batch.return_value = 'batched_d1'
        mock_md_ds2 = mock.MagicMock()
        mock_md_ds2.batch.return_value = 'batched_d2'
        mock_df_to_dataset.side_effect = ['og_ds', mock_md_ds1, mock_md_ds2]

        sample = 0.56  # Arbitrary sample value.
        split = 'fake_split'
        original_batch_size = 19  # Arbitrary size.
        min_diff_batch_size = 23  # Different arbitrary size.
        res = utils.get_uci_min_diff_datasets(
            split=split,
            sample=sample,
            original_batch_size=original_batch_size,
            min_diff_batch_size=min_diff_batch_size)

        # Assert outputs come from the right place.
        expected_res = ('og_ds', 'batched_d1', 'batched_d2')
        self.assertEqual(res, expected_res)

        # Assert proper calls have been made.
        self.assertEqual(mock_get_uci_data.call_count, 2)
        mock_get_uci_data.assert_has_calls(
            [mock.call(split=split, sample=sample),
             mock.call(split=split)],
            any_order=True)

        self.assertEqual(mock_df_to_dataset.call_count, 3)
        mock_df_to_dataset.assert_has_calls([
            mock.call(mock.ANY, shuffle=True, batch_size=original_batch_size),
            mock.call(mock.ANY, shuffle=True),
            mock.call(mock.ANY, shuffle=True),
        ])
        mock_md_ds1.batch.assert_called_once_with(min_diff_batch_size,
                                                  drop_remainder=True)
        mock_md_ds2.batch.assert_called_once_with(min_diff_batch_size,
                                                  drop_remainder=True)
コード例 #5
0
 def testGetUciDataWithBadSplitRaisesError(self):
     with self.assertRaisesRegex(
             ValueError, 'split must be.*train.*test.*given.*bad_split'):
         utils.get_uci_min_diff_datasets('bad_split')
コード例 #6
0
 def testGetTestUciData(self, mock_read_csv):
     utils.get_uci_min_diff_datasets(split='test')
     mock_read_csv.assert_called_once_with(
         utils._UCI_DATA_URL_TEMPLATE.format('test'),
         names=utils._UCI_COLUMN_NAMES,
         header=None)
コード例 #7
0
 def testGetTrainUciDataAsDefault(self, mock_read_csv):
     _ = utils.get_uci_min_diff_datasets()
     mock_read_csv.assert_called_once_with(
         utils._UCI_DATA_URL_TEMPLATE.format('data'),
         names=utils._UCI_COLUMN_NAMES,
         header=None)