Example #1
0
 def test_required_args(self):
     """
     Check if args.generate_config is set correctly if --generate_config
         is not specified
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[5])
     config_dict, arg_dict = parse_args(config_dict, args)
     self.assertTrue(arg_dict.get("generate_config") is None)
Example #2
0
 def test_inconsistent_training_config(self):
     """
     Check if excpetion is thrown if trainig config is specified without
         --generate_config being specified
     """
     parser, config_dict = set_args()
     with self.assertRaises(SystemExit):
         args = parser.parse_args(self.cmd_args[4])
         config_dict, arg_dict = parse_args(config_dict, args)
Example #3
0
 def test_gpu(self):
     """
     Check if --gc can parse device choice correctly
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[1])
     config_dict, arg_dict = parse_args(config_dict, args)
     self.assertTrue(config_dict.get("dataset") == "wn18")
     self.assertTrue(config_dict.get("device") == "GPU")
     self.assertTrue(arg_dict.get("output_directory") == "./output_dir")
Example #4
0
 def test_generate_config_default(self):
     """
     Check if default value of --generate_config is assigned correctly
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[0])
     config_dict, arg_dict = parse_args(config_dict, args)
     self.assertTrue(config_dict.get("dataset") == "wn18")
     self.assertTrue(config_dict.get("device") == "GPU")
     self.assertTrue(arg_dict.get("num_partitions") == 5)
Example #5
0
 def test_cpu_training_config(self):
     """
     Check if training configs can be parsed correctly
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[2])
     config_dict, arg_dict = parse_args(config_dict, args)
     self.assertTrue(config_dict.get("dataset") == "wn18")
     self.assertTrue(config_dict.get("device") == "CPU")
     self.assertTrue(arg_dict.get("output_directory") == "./output_dir")
     self.assertTrue(config_dict.get("model.embedding_size") == "400")
     self.assertTrue(config_dict.get("training.batch_size") == "51200")
     self.assertTrue(config_dict.get("training.num_epochs") == "23")