Пример #1
0
    def update_config(
        config: Optional[Dict[Any, Any]] = {},
        path: Optional[str] = None,
        filename: Optional[str] = "emmental-config.yaml",
        update_random_seed: Optional[bool] = True,
    ) -> None:
        """Update the config with the configs in root of project and its parents.

        Note: There are two ways to update the config:
            (1) uses a config dict to update to config
            (2) uses path and filename to load yaml file to update config

        Args:
          config: The new configuration, defaults to {}.
          path: The path to the config file, defaults to os.getcwd().
          filename: The config file name, defaults to "emmental-config.yaml".
          update_random_seed: Whether update the random seed or not.
        """
        if config != {}:
            Meta.config = merge(Meta.config,
                                config,
                                specical_keys="checkpoint_metric")
            logger.info("Updating Emmental config from user provided config.")

        if path is not None:
            tries = 0
            current_dir = path
            while current_dir and tries < MAX_CONFIG_SEARCH_DEPTH:
                potential_path = os.path.join(current_dir, filename)
                if os.path.exists(potential_path):
                    with open(potential_path, "r") as f:
                        Meta.config = merge(
                            Meta.config,
                            yaml.load(f, Loader=yaml.FullLoader),
                            specical_keys="checkpoint_metric",
                        )
                    logger.info(
                        f"Updating Emmental config from {potential_path}.")
                    break

                new_dir = os.path.split(current_dir)[0]
                if current_dir == new_dir:
                    logger.info("Unable to find config file. Using defaults.")
                    break
                current_dir = new_dir
                tries += 1
        if update_random_seed:
            set_random_seed(Meta.config["meta_config"]["seed"])

        Meta.check_config()
Пример #2
0
def test_embedding_module(caplog):
    """Unit test of Embedding Module."""
    caplog.set_level(logging.INFO)

    # Set random seed seed
    set_random_seed(1)

    word_counter = {"1": 1, "2": 3, "3": 1}
    weight_tensor = torch.FloatTensor([
        [-0.4277, 0.7110, -0.3268, -0.7473, 0.3847],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [-0.2247, -0.7969, -0.4558, -0.3063, 0.4276],
        [2.0000, 2.0000, 2.0000, 2.0000, 2.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
    ])

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=10,
                                max_size=10)

    assert emb_layer.dim == 10
    # <unk> and <pad> are default tokens
    assert emb_layer.embeddings.weight.size() == (5, 10)

    emb_layer = EmbeddingModule(
        word_counter=word_counter,
        word_dim=10,
        embedding_file="tests/shared/embeddings.vec",
        fix_emb=True,
    )

    assert emb_layer.dim == 5
    assert emb_layer.embeddings.weight.size() == (5, 5)
    assert torch.max(
        torch.abs(emb_layer.embeddings.weight.data - weight_tensor)) < 1e-4

    assert (torch.max(
        torch.abs(emb_layer(torch.LongTensor([1, 2])) - weight_tensor[1:3, :]))
            < 1e-4)

    # With threshold
    word_counter = {"1": 3, "2": 1, "3": 1}
    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=10,
                                threshold=2)
    assert emb_layer.embeddings.weight.size() == (3, 10)

    # No word counter
    emb_layer = EmbeddingModule(embedding_file="tests/shared/embeddings.vec")
    assert emb_layer.embeddings.weight.size() == (5, 5)
Пример #3
0
def test_set_random_seed(caplog):
    """Unit test of setting random seed."""
    caplog.set_level(logging.INFO)

    set_random_seed(1)
    set_random_seed()
    set_random_seed(-999999999999)
    set_random_seed(999999999999)
Пример #4
0
def test_round_robin_scheduler(caplog):
    """Unit test of round robin scheduler."""
    caplog.set_level(logging.INFO)

    init()

    # Set random seed seed
    set_random_seed(2)

    task1 = "task1"
    x1 = np.random.rand(20, 2)
    y1 = torch.from_numpy(np.random.rand(20))

    task2 = "task2"
    x2 = np.random.rand(30, 3)
    y2 = torch.from_numpy(np.random.rand(30))

    dataloaders = [
        EmmentalDataLoader(
            task_to_label_dict={task_name: "label"},
            dataset=EmmentalDataset(name=task_name,
                                    X_dict={"feature": x},
                                    Y_dict={"label": y}),
            split="train",
            batch_size=10,
            shuffle=True,
        ) for task_name, x, y in [(task1, x1, y1), (task2, x2, y2)]
    ]

    scheduler = RoundRobinScheduler()

    assert scheduler.get_num_batches(dataloaders) == 5

    batch_task_names = [
        batch_data[-2] for batch_data in scheduler.get_batches(dataloaders)
    ]

    assert batch_task_names == [task2, task1, task2, task2, task1]

    scheduler = RoundRobinScheduler(fillup=True)

    assert scheduler.get_num_batches(dataloaders) == 6

    batch_task_names = [
        batch_data[-2] for batch_data in scheduler.get_batches(dataloaders)
    ]

    assert batch_task_names == [task2, task1, task2, task2, task1, task1]
Пример #5
0
def init(
    log_dir: str = tempfile.gettempdir(),
    log_name: str = "emmental.log",
    use_exact_log_path: bool = False,
    format:
    str = "[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
    level: int = logging.INFO,
    config: Optional[Dict[Any, Any]] = {},
    config_dir: Optional[str] = None,
    config_name: Optional[str] = "emmental-config.yaml",
    local_rank: int = -1,
) -> None:
    """Initialize the logging and configuration.

    Args:
      log_dir: The directory to store logs in, defaults to tempfile.gettempdir().
      log_name: The log file name, defaults to "emmental.log".
      use_exact_log_path: Whether to use the exact log directory, defaults to False.
      format: The logging format string to use,
        defaults to "[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s".
      level: The logging level to use, defaults to logging.INFO.
      config: The new configuration, defaults to {}.
      config_dir: The path to the config file, defaults to None.
      config_name: The config file name, defaults to "emmental-config.yaml".
      local_rank: local_rank for distributed training on gpus.
    """
    init_logging(log_dir, log_name, use_exact_log_path, format, level,
                 local_rank)
    init_config()
    if config or config_dir is not None:
        Meta.update_config(config,
                           config_dir,
                           config_name,
                           update_random_seed=False)

    set_random_seed(Meta.config["meta_config"]["seed"])
    Meta.check_config()
Пример #6
0
def test_round_robin_scheduler_no_y_dict(caplog):
    """Unit test of round robin scheduler with no y_dict."""
    caplog.set_level(logging.INFO)

    init()

    # Set random seed seed
    set_random_seed(2)

    task1 = "task1"
    x1 = np.random.rand(20, 2)

    task2 = "task2"
    x2 = np.random.rand(30, 3)

    dataloaders = [
        EmmentalDataLoader(
            task_to_label_dict={task_name: None},
            dataset=EmmentalDataset(name=task_name, X_dict={"feature": x}),
            split="train",
            batch_size=10,
            shuffle=True,
        ) for task_name, x in [(task1, x1), (task2, x2)]
    ]

    dataloaders[0].n_batches = 3
    dataloaders[1].n_batches = 4

    scheduler = RoundRobinScheduler()

    assert scheduler.get_num_batches(dataloaders) == 7

    batch_y_dicts = [
        batch.Y_dict for batch in scheduler.get_batches(dataloaders)
    ]

    assert batch_y_dicts == [None] * 7