示例#1
0
 def test_invalid_config_opt(self):
     """
     Check if exception is thrown with invalid config opt
     """
     parser, config_dict = set_args()
     with self.assertRaises(SystemExit):
         args = parser.parse_args(self.cmd_args[11])
示例#2
0
 def test_exclusive_args(self):
     """
     Check if exception is thrown when dataset and stats are both specified
     """
     parser, config_dict = set_args()
     with self.assertRaises(SystemExit):
         args = parser.parse_args(self.cmd_args[3])
示例#3
0
 def test_missing_dataset(self):
     """
     Check if exception is thrown when dataset name is missing
     """
     parser, config_dict = set_args()
     with self.assertRaises(SystemExit):
         args = parser.parse_args(self.cmd_args[1])
示例#4
0
 def test_empty_arg(self):
     """
     Check if exception is thrown when no arg is given
     """
     parser, config_dict = set_args()
     with self.assertRaises(SystemExit):
         args = parser.parse_args(self.cmd_args[8])
示例#5
0
 def test_missing_arg(self):
     """
     Check if exception is thrown when missing required arg
     """
     parser, config_dict = set_args()
     with self.assertRaises(SystemExit):
         args = parser.parse_args(self.cmd_args[7])
示例#6
0
 def test_incomplete_stats(self):
     """
     Check if exception is thrown when incomplete stats is given
     """
     parser, config_dict = set_args()
     with self.assertRaises(SystemExit):
         args = parser.parse_args(self.cmd_args[6])
示例#7
0
 def test_multi_gpu_opt(self):
     """
     Check if multi_gpu can be parsed by -dev option
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[4])
     config_dict = parse_args(args)
     self.assertTrue(config_dict.get("general.device") == "GPU")
示例#8
0
 def test_unrecognized_dataset(self):
     """
     Check if exception is thrown with incorrect dataset name given
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[2])
     with self.assertRaises(RuntimeError):
         config_dict = parse_args(args)
示例#9
0
 def test_missing_mode_arg(self):
     """
     Check if exception is thrown when -d and -s are not specified
     """
     parser, config_dict = set_args()
     with self.assertRaises(RuntimeError):
         args = parser.parse_args(self.cmd_args[13])
         config_dict = parse_args(args)
示例#10
0
 def test_config_opt_parsing(self):
     """
     Check if config opt can be parsed correctly
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[10])
     config_dict = parse_args(args)
     self.assertTrue(config_dict.get("training.number_of_chunks") == "32")
     self.assertTrue(config_dict.get("dataset") == "live_journal")
     self.assertTrue(config_dict.get("model.embedding_size") == "128")
示例#11
0
 def test_data_path(self):
     """
     Check if data path is set correctly if --data_directory is specified
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[12])
     config_dict = parse_args(args)
     self.assertTrue(str(config_dict.get("data_directory")) ==
                     "./data_dir")
     self.assertTrue(str(config_dict.get("output_directory")) ==
                     "./output_dir")
示例#12
0
 def test_stats_parsing(self):
     """
     Check if stats opt can be parsed correctly
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[5])
     config_dict = parse_args(args)
     self.assertTrue(str(config_dict.get("num_nodes")) == "14")
     self.assertTrue(str(config_dict.get("num_relations")) == "2")
     self.assertTrue(str(config_dict.get("num_train")) == "14")
     self.assertTrue(str(config_dict.get("num_valid")) == "7")
     self.assertTrue(str(config_dict.get("num_test")) == "5")
示例#13
0
 def test_device_default(self):
     """
     Check if default values of -dev and other config opts
         are assigned correctly
     """
     parser, config_dict = set_args()
     args = parser.parse_args(self.cmd_args[0])
     config_dict = parse_args(args)
     self.assertTrue(config_dict.get("device") == "GPU")
     self.assertTrue(config_dict.get("dataset") == "wn18")
     self.assertTrue(config_dict.get("model.embedding_size") == "128")
     self.assertTrue(config_dict.get("general.random_seed") is None)
     self.assertTrue(str(config_dict.get("num_train")) == "141442")
     self.assertTrue(str(config_dict.get("num_nodes")) == "40943")
     self.assertTrue(str(config_dict.get("num_relations")) == "18")
     self.assertTrue(str(config_dict.get("num_valid")) == "5000")
     self.assertTrue(str(config_dict.get("num_test")) == "5000")
     self.assertTrue(config_dict.get("data_directory") is None)
     self.assertTrue(config_dict.get("output_directory") == "./output_dir")