コード例 #1
0
    def test_invalid_input_experiment(self):
        r"""Raise exception when input `experiment` is invalid."""
        msg1 = (
            'Must raise `FileNotFoundError`, `TypeError` or `ValueError` when '
            'input `experiment` is invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (False, True, 0, 1, -1, 0.0, 1.0, math.nan, -math.nan,
                    math.inf, -math.inf,
                    0j, 1j, '', 'I-DO-NOT-EXIST', b'', (), [], {}, set(),
                    object(), lambda x: x, type, None, NotImplemented, ...)

        for invalid_input in examples:
            with self.assertRaises((FileNotFoundError, TypeError, ValueError),
                                   msg=msg1) as ctx_man:
                BaseConfig.load(experiment=invalid_input)

            if isinstance(ctx_man.exception, FileNotFoundError):
                file_path = os.path.join(lmp.path.DATA_PATH, invalid_input,
                                         'config.json')
                self.assertEqual(ctx_man.exception.args[0],
                                 f'File {file_path} does not exist.',
                                 msg=msg2)
            elif isinstance(ctx_man.exception, TypeError):
                self.assertEqual(ctx_man.exception.args[0],
                                 '`experiment` must be an instance of `str`.',
                                 msg=msg2)
            else:
                self.assertEqual(ctx_man.exception.args[0],
                                 '`experiment` must not be empty.',
                                 msg=msg2)
コード例 #2
0
    def test_invalid_input_tokenizer_class(self):
        r"""Raise exception when input `tokenizer_class` is invalid."""
        msg1 = ('Must raise `TypeError` or `ValueError` when input '
                '`tokenizer_class` is invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (False, True, 0, 1, -1, 0.0, 1.0, math.nan, -math.nan,
                    math.inf, -math.inf, 0j, 1j, '', b'', (), [], {}, set(),
                    object(), lambda x: x, type, None, NotImplemented, ...)

        for invalid_input in examples:
            with self.assertRaises((TypeError, ValueError),
                                   msg=msg1) as ctx_man:
                BaseConfig(dataset='test',
                           experiment='test',
                           tokenizer_class=invalid_input)

            if isinstance(ctx_man.exception, TypeError):
                self.assertEqual(
                    ctx_man.exception.args[0],
                    '`tokenizer_class` must be an instance of `str`.',
                    msg=msg2)
            else:
                self.assertEqual(ctx_man.exception.args[0],
                                 '`tokenizer_class` must not be empty.',
                                 msg=msg2)
コード例 #3
0
    def test_invalid_input_min_count(self):
        r"""Raise exception when input `min_count` is invalid."""
        msg1 = (
            'Must raise `TypeError` or `ValueError` when input `min_count` is '
            'invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (False, 0, -1, 0.0, 1.0, math.nan, -math.nan,
                    math.inf, -math.inf, 0j, 1j, '', b'', (), [], {}, set(),
                    object(), lambda x: x, type, None, NotImplemented, ...)

        for invalid_input in examples:
            with self.assertRaises((TypeError, ValueError),
                                   msg=msg1) as ctx_man:
                BaseConfig(dataset='test',
                           experiment='test',
                           min_count=invalid_input)

            if isinstance(ctx_man.exception, TypeError):
                self.assertEqual(ctx_man.exception.args[0],
                                 '`min_count` must be an instance of `int`.',
                                 msg=msg2)
            else:
                self.assertEqual(
                    ctx_man.exception.args[0],
                    '`min_count` must be bigger than or equal to `1`.',
                    msg=msg2)
コード例 #4
0
    def test_expected_return(self):
        r"""Return expected `torch.device`."""
        msg = 'Inconsistent `torch.device`.'
        examples = (torch.device('cpu'), torch.device('cuda'))

        self.assertIn(BaseConfig(dataset='test', experiment='test').device,
                      examples,
                      msg=msg)
コード例 #5
0
    def test_invalid_json(self):
        r"""Raise `JSONDecodeError` when configuration is invalid."""
        msg = (
            'Must raise `JSONDecodeError` when configuration is not in JSON '
            'format.')

        test_path = os.path.join(self.__class__.test_dir, 'config.json')

        try:
            # Create test file.
            with open(test_path, 'w', encoding='utf-8') as output_file:
                output_file.write('Invalid JSON format.')

            with self.assertRaises(json.JSONDecodeError, msg=msg):
                BaseConfig.load(experiment=self.__class__.experiment)
        finally:
            # Clean up test file.
            os.remove(test_path)
コード例 #6
0
    def test_invalid_input_is_uncased(self):
        r"""Raise `TypeError` when input `is_uncased` is invalid."""
        msg1 = 'Must raise `TypeError` when input `is_uncased` is invalid.'
        msg2 = 'Inconsistent error message.'
        examples = (0, 1, -1, 0.0, 1.0, math.nan, -math.nan,
                    math.inf, -math.inf, 0j, 1j, '', b'', (), [], {}, set(),
                    object(), lambda x: x, type, None, NotImplemented, ...)

        for invalid_input in examples:
            with self.assertRaises(TypeError, msg=msg1) as ctx_man:
                BaseConfig(dataset='test',
                           experiment='test',
                           is_uncased=invalid_input)

            self.assertEqual(ctx_man.exception.args[0],
                             '`is_uncased` must be an instance of `bool`.',
                             msg=msg2)
コード例 #7
0
    def test_invalid_input_dropout(self):
        r"""Raise exception when input `dropout` is invalid."""
        msg1 = (
            'Must raise `TypeError` or `ValueError` when input `dropout` is '
            'invalid.')
        msg2 = 'Inconsistent error message.'
        examples = (-1, -1.0, 1.1, math.nan, -math.nan,
                    math.inf, -math.inf, 0j, 1j, '', b'', (), [], {}, set(),
                    object(), lambda x: x, type, None, NotImplemented, ...)

        for invalid_input in examples:
            with self.assertRaises((TypeError, ValueError),
                                   msg=msg1) as ctx_man:
                BaseConfig(dataset='test', dropout=invalid_input)

            if isinstance(ctx_man.exception, TypeError):
                self.assertEqual(ctx_man.exception.args[0],
                                 '`dropout` must be an instance of `float`.',
                                 msg=msg2)
            else:
                self.assertEqual(ctx_man.exception.args[0],
                                 '`dropout` must range from `0.0` to `1.0`.',
                                 msg=msg2)
コード例 #8
0
    def test_load_result(self):
        r"""Load result must be consistent."""
        msg = 'Inconsistent load result.'
        examples = (
            {
                'batch_size': 111,
                'checkpoint_step': 222,
                'd_emb': 333,
                'd_hid': 444,
                'dataset': 'hello',
                'dropout': 0.42069,
                'epoch': 555,
                'experiment': 'world',
                'is_uncased': True,
                'learning_rate': 0.69420,
                'max_norm': 6.9,
                'max_seq_len': 666,
                'min_count': 777,
                'model_class': 'HELLO',
                'num_linear_layers': 888,
                'num_rnn_layers': 999,
                'optimizer_class': 'WORLD',
                'seed': 101010,
                'tokenizer_class': 'hello world',
            },
            {
                'batch_size': 101010,
                'checkpoint_step': 999,
                'd_emb': 888,
                'd_hid': 777,
                'dataset': 'world',
                'dropout': 0.69420,
                'epoch': 666,
                'experiment': 'hello',
                'is_uncased': True,
                'learning_rate': 0.42069,
                'max_norm': 4.20,
                'max_seq_len': 555,
                'min_count': 444,
                'model_class': 'hello world',
                'num_linear_layers': 333,
                'num_rnn_layers': 222,
                'optimizer_class': 'WORLD',
                'seed': 111,
                'tokenizer_class': 'HELLO',
            },
        )

        for attributes in examples:
            test_path = os.path.join(self.__class__.test_dir, 'config.json')

            try:
                # Create test file.
                with open(test_path, 'w', encoding='utf-8') as output_file:
                    json.dump(attributes, output_file)

                config = BaseConfig.load(experiment=self.__class__.experiment)
                self.assertIsInstance(config, BaseConfig)

                for attr_key, attr_value in attributes.items():
                    self.assertTrue(hasattr(config, attr_key), msg=msg)
                    self.assertIsInstance(getattr(config, attr_key),
                                          type(attr_value),
                                          msg=msg)
                    self.assertEqual(getattr(config, attr_key),
                                     attr_value,
                                     msg=msg)
            finally:
                # Clean up test file.
                os.remove(test_path)
コード例 #9
0
    def test_save_result(self):
        r"""Save result must be consistent."""
        msg1 = 'Must save as `config.json`.'
        msg2 = 'Inconsistent save result.'
        examples = (
            {
                'batch_size': 111,
                'checkpoint_step': 222,
                'd_emb': 333,
                'd_hid': 444,
                'dataset': 'hello',
                'dropout': 0.42069,
                'epoch': 555,
                'experiment': self.__class__.experiment,
                'is_uncased': True,
                'learning_rate': 0.69420,
                'max_norm': 6.9,
                'max_seq_len': 666,
                'min_count': 777,
                'model_class': 'HELLO',
                'num_linear_layers': 888,
                'num_rnn_layers': 999,
                'optimizer_class': 'WORLD',
                'seed': 101010,
                'tokenizer_class': 'hello world',
            },
            {
                'batch_size': 101010,
                'checkpoint_step': 999,
                'd_emb': 888,
                'd_hid': 777,
                'dataset': 'world',
                'dropout': 0.69420,
                'epoch': 666,
                'experiment': self.__class__.experiment,
                'is_uncased': True,
                'learning_rate': 0.42069,
                'max_norm': 4.20,
                'max_seq_len': 555,
                'min_count': 444,
                'model_class': 'hello world',
                'num_linear_layers': 333,
                'num_rnn_layers': 222,
                'optimizer_class': 'WORLD',
                'seed': 111,
                'tokenizer_class': 'HELLO',
            },
        )

        for ans_attributes in examples:
            test_path = os.path.join(self.__class__.test_dir, 'config.json')

            try:
                # Create test file.
                BaseConfig(**ans_attributes).save()
                self.assertTrue(os.path.exists(test_path), msg=msg1)

                with open(test_path, 'r') as input_file:
                    attributes = json.load(input_file)

                for attr_key, attr_value in attributes.items():
                    self.assertIn(attr_key, ans_attributes, msg=msg2)
                    self.assertIsInstance(ans_attributes[attr_key],
                                          type(attr_value),
                                          msg=msg2)
                    self.assertEqual(ans_attributes[attr_key],
                                     attr_value,
                                     msg=msg2)
            finally:
                # Clean up test file.
                os.remove(test_path)
コード例 #10
0
    def test_instance_attributes(self):
        r"""Declare required instance attributes."""
        msg1 = 'Missing instance attribute `{}`.'
        msg2 = 'Instance attribute `{}` must be an instance of `{}`.'
        msg3 = 'Instance attribute `{}` must be `{}`.'

        examples = (
            (
                ('batch_size', 111),
                ('checkpoint_step', 222),
                ('d_emb', 333),
                ('d_hid', 444),
                ('dataset', 'hello'),
                ('dropout', 0.42069),
                ('epoch', 555),
                ('experiment', 'world'),
                ('is_uncased', True),
                ('learning_rate', 0.69420),
                ('max_norm', 6.9),
                ('max_seq_len', 666),
                ('min_count', 777),
                ('model_class', 'HELLO'),
                ('num_linear_layers', 888),
                ('num_rnn_layers', 999),
                ('optimizer_class', 'WORLD'),
                ('seed', 101010),
                ('tokenizer_class', 'hello world'),
            ),
            (
                ('batch_size', 101010),
                ('checkpoint_step', 999),
                ('d_emb', 888),
                ('d_hid', 777),
                ('dataset', 'world'),
                ('dropout', 0.69420),
                ('epoch', 666),
                ('experiment', 'hello'),
                ('is_uncased', True),
                ('learning_rate', 0.42069),
                ('max_norm', 4.20),
                ('max_seq_len', 555),
                ('min_count', 444),
                ('model_class', 'hello world'),
                ('num_linear_layers', 333),
                ('num_rnn_layers', 222),
                ('optimizer_class', 'WORLD'),
                ('seed', 111),
                ('tokenizer_class', 'HELLO'),
            ),
        )

        for parameters in examples:
            pos = []
            kwargs = {}
            for attr, attr_val in parameters:
                pos.append(attr_val)
                kwargs[attr] = attr_val

            # Construct using positional and keyword arguments.
            configs = [
                BaseConfig(*pos),
                BaseConfig(**kwargs),
            ]

            for config in configs:
                for attr, attr_val in parameters:
                    self.assertTrue(hasattr(config, attr),
                                    msg=msg1.format(attr))
                    self.assertIsInstance(getattr(config, attr),
                                          type(attr_val),
                                          msg=msg2.format(
                                              attr,
                                              type(attr_val).__name__))

                    self.assertEqual(getattr(config, attr),
                                     attr_val,
                                     msg=msg3.format(attr, attr_val))
コード例 #11
0
    def test_yield_value(self):
        r"""Is an iterable which yield attributes in order."""
        msg = 'Must be an iterable which yield attributes in order.'
        examples = (
            {
                'batch_size': 111,
                'checkpoint_step': 222,
                'd_emb': 333,
                'd_hid': 444,
                'dataset': 'hello',
                'dropout': 0.42069,
                'epoch': 555,
                'experiment': 'world',
                'is_uncased': True,
                'learning_rate': 0.69420,
                'max_norm': 6.9,
                'max_seq_len': 666,
                'min_count': 777,
                'model_class': 'HELLO',
                'num_linear_layers': 888,
                'num_rnn_layers': 999,
                'optimizer_class': 'WORLD',
                'seed': 101010,
                'tokenizer_class': 'hello world',
            },
            {
                'batch_size': 101010,
                'checkpoint_step': 999,
                'd_emb': 888,
                'd_hid': 777,
                'dataset': 'world',
                'dropout': 0.69420,
                'epoch': 666,
                'experiment': 'hello',
                'is_uncased': True,
                'learning_rate': 0.42069,
                'max_norm': 4.20,
                'max_seq_len': 555,
                'min_count': 444,
                'model_class': 'hello world',
                'num_linear_layers': 333,
                'num_rnn_layers': 222,
                'optimizer_class': 'WORLD',
                'seed': 111,
                'tokenizer_class': 'HELLO',
            },
        )

        for ans_attributes in examples:
            config = BaseConfig(**ans_attributes)

            self.assertIsInstance(config, Iterable, msg=msg)

            for attr_key, attr_value in config:
                self.assertIn(attr_key, ans_attributes, msg=msg)
                self.assertTrue(hasattr(config, attr_key), msg=msg)
                self.assertIsInstance(getattr(config, attr_key),
                                      type(ans_attributes[attr_key]),
                                      msg=msg)
                self.assertIsInstance(getattr(config, attr_key),
                                      type(attr_value),
                                      msg=msg)
                self.assertEqual(getattr(config, attr_key),
                                 ans_attributes[attr_key],
                                 msg=msg)
                self.assertEqual(getattr(config, attr_key),
                                 attr_value,
                                 msg=msg)