def update_config( config: Optional[Dict[Any, Any]] = {}, path: Optional[str] = None, filename: Optional[str] = "emmental-config.yaml", update_random_seed: Optional[bool] = True, ) -> None: """Update the config with the configs in root of project and its parents. Note: There are two ways to update the config: (1) uses a config dict to update to config (2) uses path and filename to load yaml file to update config Args: config: The new configuration, defaults to {}. path: The path to the config file, defaults to os.getcwd(). filename: The config file name, defaults to "emmental-config.yaml". update_random_seed: Whether update the random seed or not. """ if config != {}: Meta.config = merge(Meta.config, config, specical_keys="checkpoint_metric") logger.info("Updating Emmental config from user provided config.") if path is not None: tries = 0 current_dir = path while current_dir and tries < MAX_CONFIG_SEARCH_DEPTH: potential_path = os.path.join(current_dir, filename) if os.path.exists(potential_path): with open(potential_path, "r") as f: Meta.config = merge( Meta.config, yaml.load(f, Loader=yaml.FullLoader), specical_keys="checkpoint_metric", ) logger.info( f"Updating Emmental config from {potential_path}.") break new_dir = os.path.split(current_dir)[0] if current_dir == new_dir: logger.info("Unable to find config file. Using defaults.") break current_dir = new_dir tries += 1 if update_random_seed: set_random_seed(Meta.config["meta_config"]["seed"]) Meta.check_config()
def test_embedding_module(caplog): """Unit test of Embedding Module.""" caplog.set_level(logging.INFO) # Set random seed seed set_random_seed(1) word_counter = {"1": 1, "2": 3, "3": 1} weight_tensor = torch.FloatTensor([ [-0.4277, 0.7110, -0.3268, -0.7473, 0.3847], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [-0.2247, -0.7969, -0.4558, -0.3063, 0.4276], [2.0000, 2.0000, 2.0000, 2.0000, 2.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000], ]) emb_layer = EmbeddingModule(word_counter=word_counter, word_dim=10, max_size=10) assert emb_layer.dim == 10 # <unk> and <pad> are default tokens assert emb_layer.embeddings.weight.size() == (5, 10) emb_layer = EmbeddingModule( word_counter=word_counter, word_dim=10, embedding_file="tests/shared/embeddings.vec", fix_emb=True, ) assert emb_layer.dim == 5 assert emb_layer.embeddings.weight.size() == (5, 5) assert torch.max( torch.abs(emb_layer.embeddings.weight.data - weight_tensor)) < 1e-4 assert (torch.max( torch.abs(emb_layer(torch.LongTensor([1, 2])) - weight_tensor[1:3, :])) < 1e-4) # With threshold word_counter = {"1": 3, "2": 1, "3": 1} emb_layer = EmbeddingModule(word_counter=word_counter, word_dim=10, threshold=2) assert emb_layer.embeddings.weight.size() == (3, 10) # No word counter emb_layer = EmbeddingModule(embedding_file="tests/shared/embeddings.vec") assert emb_layer.embeddings.weight.size() == (5, 5)
def test_set_random_seed(caplog): """Unit test of setting random seed.""" caplog.set_level(logging.INFO) set_random_seed(1) set_random_seed() set_random_seed(-999999999999) set_random_seed(999999999999)
def test_round_robin_scheduler(caplog): """Unit test of round robin scheduler.""" caplog.set_level(logging.INFO) init() # Set random seed seed set_random_seed(2) task1 = "task1" x1 = np.random.rand(20, 2) y1 = torch.from_numpy(np.random.rand(20)) task2 = "task2" x2 = np.random.rand(30, 3) y2 = torch.from_numpy(np.random.rand(30)) dataloaders = [ EmmentalDataLoader( task_to_label_dict={task_name: "label"}, dataset=EmmentalDataset(name=task_name, X_dict={"feature": x}, Y_dict={"label": y}), split="train", batch_size=10, shuffle=True, ) for task_name, x, y in [(task1, x1, y1), (task2, x2, y2)] ] scheduler = RoundRobinScheduler() assert scheduler.get_num_batches(dataloaders) == 5 batch_task_names = [ batch_data[-2] for batch_data in scheduler.get_batches(dataloaders) ] assert batch_task_names == [task2, task1, task2, task2, task1] scheduler = RoundRobinScheduler(fillup=True) assert scheduler.get_num_batches(dataloaders) == 6 batch_task_names = [ batch_data[-2] for batch_data in scheduler.get_batches(dataloaders) ] assert batch_task_names == [task2, task1, task2, task2, task1, task1]
def init( log_dir: str = tempfile.gettempdir(), log_name: str = "emmental.log", use_exact_log_path: bool = False, format: str = "[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s", level: int = logging.INFO, config: Optional[Dict[Any, Any]] = {}, config_dir: Optional[str] = None, config_name: Optional[str] = "emmental-config.yaml", local_rank: int = -1, ) -> None: """Initialize the logging and configuration. Args: log_dir: The directory to store logs in, defaults to tempfile.gettempdir(). log_name: The log file name, defaults to "emmental.log". use_exact_log_path: Whether to use the exact log directory, defaults to False. format: The logging format string to use, defaults to "[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s". level: The logging level to use, defaults to logging.INFO. config: The new configuration, defaults to {}. config_dir: The path to the config file, defaults to None. config_name: The config file name, defaults to "emmental-config.yaml". local_rank: local_rank for distributed training on gpus. """ init_logging(log_dir, log_name, use_exact_log_path, format, level, local_rank) init_config() if config or config_dir is not None: Meta.update_config(config, config_dir, config_name, update_random_seed=False) set_random_seed(Meta.config["meta_config"]["seed"]) Meta.check_config()
def test_round_robin_scheduler_no_y_dict(caplog): """Unit test of round robin scheduler with no y_dict.""" caplog.set_level(logging.INFO) init() # Set random seed seed set_random_seed(2) task1 = "task1" x1 = np.random.rand(20, 2) task2 = "task2" x2 = np.random.rand(30, 3) dataloaders = [ EmmentalDataLoader( task_to_label_dict={task_name: None}, dataset=EmmentalDataset(name=task_name, X_dict={"feature": x}), split="train", batch_size=10, shuffle=True, ) for task_name, x in [(task1, x1), (task2, x2)] ] dataloaders[0].n_batches = 3 dataloaders[1].n_batches = 4 scheduler = RoundRobinScheduler() assert scheduler.get_num_batches(dataloaders) == 7 batch_y_dicts = [ batch.Y_dict for batch in scheduler.get_batches(dataloaders) ] assert batch_y_dicts == [None] * 7