示例#1
0
def emmental_collate_fn(batch):

    X_batch = defaultdict(list)
    Y_batch = defaultdict(list)

    for x_dict, y_dict in batch:
        for field_name, value in x_dict.items():
            X_batch[field_name].append(value)
        for label_name, value in y_dict.items():
            Y_batch[label_name].append(value)

    for field_name, values in X_batch.items():
        # Only merge list of tensors
        if isinstance(values[0], torch.Tensor):
            X_batch[field_name] = list_to_tensor(
                values,
                min_len=Meta.config["data_config"]["min_data_len"],
                max_len=Meta.config["data_config"]["max_data_len"],
            )

    for label_name, values in Y_batch.items():
        Y_batch[label_name] = list_to_tensor(
            values,
            min_len=Meta.config["data_config"]["min_data_len"],
            max_len=Meta.config["data_config"]["max_data_len"],
        )

    return dict(X_batch), dict(Y_batch)
示例#2
0
def test_list_to_tensor(caplog):
    """Unit test of list to tensor"""

    caplog.set_level(logging.INFO)

    # list of 1-D tensor with the different length
    batch = [torch.Tensor([1, 2]), torch.Tensor([3]), torch.Tensor([4, 5, 6])]

    padded_batch, mask_batch = list_to_tensor(batch)

    assert torch.equal(padded_batch,
                       torch.Tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]))
    assert torch.equal(
        mask_batch, mask_batch.new_tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]]))

    # list of 1-D tensor with the same length
    batch = [
        torch.Tensor([1, 2, 3]),
        torch.Tensor([4, 5, 6]),
        torch.Tensor([7, 8, 9])
    ]

    padded_batch, mask_batch = list_to_tensor(batch)

    assert torch.equal(padded_batch,
                       torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
    assert torch.equal(
        mask_batch, mask_batch.new_tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]))

    # list of 2-D tensor with the same size
    batch = [
        torch.Tensor([[1, 2, 3], [1, 2, 3]]),
        torch.Tensor([[4, 5, 6], [4, 5, 6]]),
        torch.Tensor([[7, 8, 9], [7, 8, 9]]),
    ]

    padded_batch, mask_batch = list_to_tensor(batch)

    assert torch.equal(
        padded_batch,
        torch.Tensor([[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6]],
                      [[7, 8, 9], [7, 8, 9]]]),
    )
    assert mask_batch is None

    # list of tensor with the different size
    batch = [
        torch.Tensor([[1, 2], [2, 3]]),
        torch.Tensor([4, 5, 6]),
        torch.Tensor([7, 8, 9, 0]),
    ]

    padded_batch, mask_batch = list_to_tensor(batch)

    assert torch.equal(
        padded_batch, torch.Tensor([[1, 2, 2, 3], [4, 5, 6, 0], [7, 8, 9, 0]]))
    assert torch.equal(
        mask_batch,
        mask_batch.new_tensor([[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 1]]))
示例#3
0
def emmental_collate_fn(
    batch: List[Tuple[Dict[str, Any], Dict[str, Tensor]]]
) -> Tuple[Dict[str, Any], Dict[str, Tensor]]:
    r"""Collate function.

    Args:
      batch(Tuple[Dict[str, Any], Dict[str, Tensor]]): The batch to collate.

    Returns:
      Tuple[Dict[str, Any], Dict[str, Tensor]]: The collated batch.

    """
    X_batch: defaultdict = defaultdict(list)
    Y_batch: defaultdict = defaultdict(list)

    for x_dict, y_dict in batch:
        for field_name, value in x_dict.items():
            if isinstance(value, list):
                X_batch[field_name] += value
            else:
                X_batch[field_name].append(value)
        for label_name, value in y_dict.items():
            if isinstance(value, list):
                Y_batch[label_name] += value
            else:
                Y_batch[label_name].append(value)

    field_names = copy.deepcopy(list(X_batch.keys()))

    for field_name in field_names:
        values = X_batch[field_name]
        # Only merge list of tensors
        if isinstance(values[0], Tensor):
            item_tensor, item_mask_tensor = list_to_tensor(
                values,
                min_len=Meta.config["data_config"]["min_data_len"],
                max_len=Meta.config["data_config"]["max_data_len"],
            )
            X_batch[field_name] = item_tensor
            if item_mask_tensor is not None:
                X_batch[f"{field_name}_mask"] = item_mask_tensor

    for label_name, values in Y_batch.items():
        Y_batch[label_name] = list_to_tensor(
            values,
            min_len=Meta.config["data_config"]["min_data_len"],
            max_len=Meta.config["data_config"]["max_data_len"],
        )[0]

    return dict(X_batch), dict(Y_batch)
示例#4
0
 def wrapped_f(dataset):
     X_dict = defaultdict(list)
     Y_dict = defaultdict(list)
     examples = []
     for x_dict, y_dict in dataset:
         # TODO: Consider making sure aug_x_dict is not x_dict!
         aug_x_dict, aug_y_dict = f(x_dict, y_dict)
         if aug_x_dict is not None and aug_y_dict is not None:
             examples.append((aug_x_dict, aug_y_dict))
     for x_dict, y_dict in examples:
         for k, v in x_dict.items():
             X_dict[k].append(v)
         for k, v in y_dict.items():
             Y_dict[k].append(v)
     for k, v in Y_dict.items():
         Y_dict[k] = list_to_tensor(v)
     # X_dict, Y_dict = emmental_collate_fn(examples)
     aug_dataset = EmmentalDataset(name=f.__name__,
                                   X_dict=X_dict,
                                   Y_dict=Y_dict)
     logger.info(
         f"Total {len(aug_dataset)} augmented examples were created "
         f"from AF {f.__name__}")
     return aug_dataset
示例#5
0
文件: data.py 项目: SenWu/emmental
def emmental_collate_fn(
    batch: Union[List[Tuple[Dict[str, Any], Dict[str, Tensor]]],
                 List[Dict[str, Any]]],
    min_data_len: int = 0,
    max_data_len: int = 0,
) -> Union[Tuple[Dict[str, Any], Dict[str, Tensor]], Dict[str, Any]]:
    """Collate function.

    Args:
      batch: The batch to collate.
      min_data_len: The minimal data sequence length, defaults to 0.
      max_data_len: The maximal data sequence length (0 means no limit), defaults to 0.

    Returns:
      The collated batch.
    """
    X_batch: defaultdict = defaultdict(list)
    Y_batch: defaultdict = defaultdict(list)

    for item in batch:
        # Check if batch is (x_dict, y_dict) pair
        if isinstance(item, dict):
            x_dict = item
            y_dict: Dict[str, Any] = dict()
        else:
            x_dict, y_dict = item
        for field_name, value in x_dict.items():
            if isinstance(value, list):
                X_batch[field_name] += value
            else:
                X_batch[field_name].append(value)
        for label_name, value in y_dict.items():
            if isinstance(value, list):
                Y_batch[label_name] += value
            else:
                Y_batch[label_name].append(value)

    field_names = copy.deepcopy(list(X_batch.keys()))

    for field_name in field_names:
        values = X_batch[field_name]
        # Only merge list of tensors
        if isinstance(values[0], Tensor):
            item_tensor, item_mask_tensor = list_to_tensor(
                values,
                min_len=min_data_len,
                max_len=max_data_len,
            )
            X_batch[field_name] = item_tensor
            if item_mask_tensor is not None:
                X_batch[f"{field_name}_mask"] = item_mask_tensor

    for label_name, values in Y_batch.items():
        Y_batch[label_name] = list_to_tensor(
            values,
            min_len=min_data_len,
            max_len=max_data_len,
        )[0]

    if len(Y_batch) != 0:
        return dict(X_batch), dict(Y_batch)
    else:
        return dict(X_batch)
示例#6
0
def emmental_collate_fn(
    batch: Union[List[Tuple[Dict[str, Any], Dict[str, Tensor]]],
                 List[Dict[str, Any]]]
) -> Union[Tuple[Dict[str, Any], Dict[str, Tensor]], Dict[str, Any]]:
    """Collate function.

    Args:
      batch: The batch to collate.

    Returns:
      The collated batch.
    """
    X_batch: defaultdict = defaultdict(list)
    Y_batch: defaultdict = defaultdict(list)

    # Learnable batch should be a pair of dict, while non learnable batch is a dict
    is_learnable = True if not isinstance(batch[0], dict) else False

    if is_learnable:
        for x_dict, y_dict in batch:
            if isinstance(x_dict, dict) and isinstance(y_dict, dict):
                for field_name, value in x_dict.items():
                    if isinstance(value, list):
                        X_batch[field_name] += value
                    else:
                        X_batch[field_name].append(value)
                for label_name, value in y_dict.items():
                    if isinstance(value, list):
                        Y_batch[label_name] += value
                    else:
                        Y_batch[label_name].append(value)
    else:
        for x_dict in batch:  # type: ignore
            for field_name, value in x_dict.items():  # type: ignore
                if isinstance(value, list):
                    X_batch[field_name] += value
                else:
                    X_batch[field_name].append(value)

    field_names = copy.deepcopy(list(X_batch.keys()))

    for field_name in field_names:
        values = X_batch[field_name]
        # Only merge list of tensors
        if isinstance(values[0], Tensor):
            item_tensor, item_mask_tensor = list_to_tensor(
                values,
                min_len=Meta.config["data_config"]["min_data_len"],
                max_len=Meta.config["data_config"]["max_data_len"],
            )
            X_batch[field_name] = item_tensor
            if item_mask_tensor is not None:
                X_batch[f"{field_name}_mask"] = item_mask_tensor

    if is_learnable:
        for label_name, values in Y_batch.items():
            Y_batch[label_name] = list_to_tensor(
                values,
                min_len=Meta.config["data_config"]["min_data_len"],
                max_len=Meta.config["data_config"]["max_data_len"],
            )[0]

    if is_learnable:
        return dict(X_batch), dict(Y_batch)
    else:
        return dict(X_batch)
示例#7
0
def bootleg_collate_fn(
    batch: Union[List[Tuple[Dict[str, Any], Dict[str, Tensor]]],
                 List[Dict[str, Any]]]
) -> Union[Tuple[Dict[str, Any], Dict[str, Tensor]], Dict[str, Any]]:
    """Collate function (modified from emmental collate fn). The main
    difference is our collate function handles the kg_adj dictionary items from
    the dataset. We collate each value of each dict key.

    Args:
      batch: The batch to collate.

    Returns:
      The collated batch.
    """
    X_batch: defaultdict = defaultdict(list)
    # In Bootleg, we may have a nested dictionary in x_dict; we want to keep this structure but
    # collate the subtensors
    X_sub_batch: defaultdict = defaultdict(lambda: defaultdict(list))
    Y_batch: defaultdict = defaultdict(list)

    # Learnable batch should be a pair of dict, while non learnable batch is a dict
    is_learnable = True if not isinstance(batch[0], dict) else False

    if is_learnable:
        for x_dict, y_dict in batch:
            if isinstance(x_dict, dict) and isinstance(y_dict, dict):
                for field_name, value in x_dict.items():
                    if isinstance(value, list):
                        X_batch[field_name] += value
                    elif isinstance(value, dict):
                        # We reinstantiate the field_name here in case there is not kg adj data
                        # This keeps the field_name key intact
                        if field_name not in X_sub_batch:
                            X_sub_batch[field_name] = defaultdict(list)
                        for sub_field_name, sub_value in value.items():
                            if isinstance(sub_value, list):
                                X_sub_batch[field_name][
                                    sub_field_name] += sub_value
                            else:
                                X_sub_batch[field_name][sub_field_name].append(
                                    sub_value)
                    else:
                        X_batch[field_name].append(value)
                for label_name, value in y_dict.items():
                    if isinstance(value, list):
                        Y_batch[label_name] += value
                    else:
                        Y_batch[label_name].append(value)
    else:
        for x_dict in batch:  # type: ignore
            for field_name, value in x_dict.items():  # type: ignore
                if isinstance(value, list):
                    X_batch[field_name] += value
                elif isinstance(value, dict):
                    # We reinstantiate the field_name here in case there is not kg adj data
                    # This keeps the field_name key intact
                    if field_name not in X_sub_batch:
                        X_sub_batch[field_name] = defaultdict(list)
                    for sub_field_name, sub_value in value.items():
                        if isinstance(sub_value, list):
                            X_sub_batch[field_name][
                                sub_field_name] += sub_value
                        else:
                            X_sub_batch[field_name][sub_field_name].append(
                                sub_value)
                else:
                    X_batch[field_name].append(value)

    field_names = copy.deepcopy(list(X_batch.keys()))
    for field_name in field_names:
        values = X_batch[field_name]
        # Only merge list of tensors
        if isinstance(values[0], Tensor):
            item_tensor, item_mask_tensor = list_to_tensor(
                values,
                min_len=Meta.config["data_config"]["min_data_len"],
                max_len=Meta.config["data_config"]["max_data_len"],
            )
            X_batch[field_name] = item_tensor
            if item_mask_tensor is not None:
                X_batch[f"{field_name}_mask"] = item_mask_tensor

    field_names = copy.deepcopy(list(X_sub_batch.keys()))
    for field_name in field_names:
        sub_field_names = copy.deepcopy(list(X_sub_batch[field_name].keys()))
        for sub_field_name in sub_field_names:
            values = X_sub_batch[field_name][sub_field_name]
            # Only merge list of tensors
            if isinstance(values[0], Tensor):
                item_tensor, item_mask_tensor = list_to_tensor(
                    values,
                    min_len=Meta.config["data_config"]["min_data_len"],
                    max_len=Meta.config["data_config"]["max_data_len"],
                )
                X_sub_batch[field_name][sub_field_name] = item_tensor
                if item_mask_tensor is not None:
                    X_sub_batch[field_name][
                        f"{sub_field_name}_mask"] = item_mask_tensor

    # Add sub batch to batch
    for field_name in field_names:
        X_batch[field_name] = dict(X_sub_batch[field_name])
    if is_learnable:
        for label_name, values in Y_batch.items():
            Y_batch[label_name] = list_to_tensor(
                values,
                min_len=Meta.config["data_config"]["min_data_len"],
                max_len=Meta.config["data_config"]["max_data_len"],
            )[0]
    if is_learnable:
        return dict(X_batch), dict(Y_batch)
    else:
        return dict(X_batch)