Exemplo n.º 1
0
    def test_optim_config_parse_arg_by_target(self):
        basic_optim_config = {
            '_target_': 'nemo.core.config.NovogradParams',
            'params': {
                'weight_decay': 0.001,
                'betas': [0.8, 0.5]
            },
        }
        basic_optim_config = omegaconf.OmegaConf.create(basic_optim_config)
        parsed_params = optim.parse_optimizer_args('novograd',
                                                   basic_optim_config)
        assert parsed_params['weight_decay'] == basic_optim_config['params'][
            'weight_decay']
        assert parsed_params['betas'][0] == basic_optim_config['params'][
            'betas'][0]
        assert parsed_params['betas'][1] == basic_optim_config['params'][
            'betas'][1]

        dict_config = omegaconf.OmegaConf.create(basic_optim_config)
        parsed_params = optim.parse_optimizer_args('novograd', dict_config)
        assert parsed_params['weight_decay'] == dict_config['params'][
            'weight_decay']
        assert parsed_params['betas'][0] == dict_config['params']['betas'][0]
        assert parsed_params['betas'][1] == dict_config['params']['betas'][1]

        # Names are ignored when passing class path
        # This will be captured during optimizer instantiation
        output_config = optim.parse_optimizer_args('sgd', dict_config)
        sgd_config = vars(config.SGDParams())
        novograd_config = vars(config.NovogradParams())

        assert set(output_config.keys()) != set(sgd_config.keys())
        assert set(output_config.keys()) == set(novograd_config)
Exemplo n.º 2
0
    def test_optim_config_parse_bypass(self):
        basic_optim_config = {'weight_decay': 0.001, 'betas': [0.8, 0.5]}
        parsed_params = optim.parse_optimizer_args('novograd', basic_optim_config)
        assert parsed_params['weight_decay'] == basic_optim_config['weight_decay']
        assert parsed_params['betas'][0] == basic_optim_config['betas'][0]
        assert parsed_params['betas'][1] == basic_optim_config['betas'][1]

        dict_config = omegaconf.OmegaConf.create(basic_optim_config)
        parsed_params = optim.parse_optimizer_args('novograd', dict_config)
        assert parsed_params['weight_decay'] == dict_config['weight_decay']
        assert parsed_params['betas'][0] == dict_config['betas'][0]
        assert parsed_params['betas'][1] == dict_config['betas'][1]
Exemplo n.º 3
0
    def test_optim_config_parse_arg_by_name(self):
        basic_optim_config = {'name': 'auto', 'weight_decay': 0.001, 'betas': [0.8, 0.5]}
        parsed_params = optim.parse_optimizer_args('novograd', basic_optim_config)
        assert parsed_params['weight_decay'] == basic_optim_config['weight_decay']
        assert parsed_params['betas'][0] == basic_optim_config['betas'][0]
        assert parsed_params['betas'][1] == basic_optim_config['betas'][1]

        dict_config = omegaconf.OmegaConf.create(basic_optim_config)
        parsed_params = optim.parse_optimizer_args('novograd', dict_config)
        assert parsed_params['weight_decay'] == dict_config['weight_decay']
        assert parsed_params['betas'][0] == dict_config['betas'][0]
        assert parsed_params['betas'][1] == dict_config['betas'][1]

        with pytest.raises(omegaconf.errors.ConfigKeyError):
            optim.parse_optimizer_args('sgd', dict_config)