def emmental_collate_fn(batch): X_batch = defaultdict(list) Y_batch = defaultdict(list) for x_dict, y_dict in batch: for field_name, value in x_dict.items(): X_batch[field_name].append(value) for label_name, value in y_dict.items(): Y_batch[label_name].append(value) for field_name, values in X_batch.items(): # Only merge list of tensors if isinstance(values[0], torch.Tensor): X_batch[field_name] = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], ) for label_name, values in Y_batch.items(): Y_batch[label_name] = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], ) return dict(X_batch), dict(Y_batch)
def test_list_to_tensor(caplog): """Unit test of list to tensor""" caplog.set_level(logging.INFO) # list of 1-D tensor with the different length batch = [torch.Tensor([1, 2]), torch.Tensor([3]), torch.Tensor([4, 5, 6])] padded_batch, mask_batch = list_to_tensor(batch) assert torch.equal(padded_batch, torch.Tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])) assert torch.equal( mask_batch, mask_batch.new_tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]])) # list of 1-D tensor with the same length batch = [ torch.Tensor([1, 2, 3]), torch.Tensor([4, 5, 6]), torch.Tensor([7, 8, 9]) ] padded_batch, mask_batch = list_to_tensor(batch) assert torch.equal(padded_batch, torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) assert torch.equal( mask_batch, mask_batch.new_tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])) # list of 2-D tensor with the same size batch = [ torch.Tensor([[1, 2, 3], [1, 2, 3]]), torch.Tensor([[4, 5, 6], [4, 5, 6]]), torch.Tensor([[7, 8, 9], [7, 8, 9]]), ] padded_batch, mask_batch = list_to_tensor(batch) assert torch.equal( padded_batch, torch.Tensor([[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6]], [[7, 8, 9], [7, 8, 9]]]), ) assert mask_batch is None # list of tensor with the different size batch = [ torch.Tensor([[1, 2], [2, 3]]), torch.Tensor([4, 5, 6]), torch.Tensor([7, 8, 9, 0]), ] padded_batch, mask_batch = list_to_tensor(batch) assert torch.equal( padded_batch, torch.Tensor([[1, 2, 2, 3], [4, 5, 6, 0], [7, 8, 9, 0]])) assert torch.equal( mask_batch, mask_batch.new_tensor([[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 1]]))
def emmental_collate_fn( batch: List[Tuple[Dict[str, Any], Dict[str, Tensor]]] ) -> Tuple[Dict[str, Any], Dict[str, Tensor]]: r"""Collate function. Args: batch(Tuple[Dict[str, Any], Dict[str, Tensor]]): The batch to collate. Returns: Tuple[Dict[str, Any], Dict[str, Tensor]]: The collated batch. """ X_batch: defaultdict = defaultdict(list) Y_batch: defaultdict = defaultdict(list) for x_dict, y_dict in batch: for field_name, value in x_dict.items(): if isinstance(value, list): X_batch[field_name] += value else: X_batch[field_name].append(value) for label_name, value in y_dict.items(): if isinstance(value, list): Y_batch[label_name] += value else: Y_batch[label_name].append(value) field_names = copy.deepcopy(list(X_batch.keys())) for field_name in field_names: values = X_batch[field_name] # Only merge list of tensors if isinstance(values[0], Tensor): item_tensor, item_mask_tensor = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], ) X_batch[field_name] = item_tensor if item_mask_tensor is not None: X_batch[f"{field_name}_mask"] = item_mask_tensor for label_name, values in Y_batch.items(): Y_batch[label_name] = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], )[0] return dict(X_batch), dict(Y_batch)
def wrapped_f(dataset): X_dict = defaultdict(list) Y_dict = defaultdict(list) examples = [] for x_dict, y_dict in dataset: # TODO: Consider making sure aug_x_dict is not x_dict! aug_x_dict, aug_y_dict = f(x_dict, y_dict) if aug_x_dict is not None and aug_y_dict is not None: examples.append((aug_x_dict, aug_y_dict)) for x_dict, y_dict in examples: for k, v in x_dict.items(): X_dict[k].append(v) for k, v in y_dict.items(): Y_dict[k].append(v) for k, v in Y_dict.items(): Y_dict[k] = list_to_tensor(v) # X_dict, Y_dict = emmental_collate_fn(examples) aug_dataset = EmmentalDataset(name=f.__name__, X_dict=X_dict, Y_dict=Y_dict) logger.info( f"Total {len(aug_dataset)} augmented examples were created " f"from AF {f.__name__}") return aug_dataset
def emmental_collate_fn( batch: Union[List[Tuple[Dict[str, Any], Dict[str, Tensor]]], List[Dict[str, Any]]], min_data_len: int = 0, max_data_len: int = 0, ) -> Union[Tuple[Dict[str, Any], Dict[str, Tensor]], Dict[str, Any]]: """Collate function. Args: batch: The batch to collate. min_data_len: The minimal data sequence length, defaults to 0. max_data_len: The maximal data sequence length (0 means no limit), defaults to 0. Returns: The collated batch. """ X_batch: defaultdict = defaultdict(list) Y_batch: defaultdict = defaultdict(list) for item in batch: # Check if batch is (x_dict, y_dict) pair if isinstance(item, dict): x_dict = item y_dict: Dict[str, Any] = dict() else: x_dict, y_dict = item for field_name, value in x_dict.items(): if isinstance(value, list): X_batch[field_name] += value else: X_batch[field_name].append(value) for label_name, value in y_dict.items(): if isinstance(value, list): Y_batch[label_name] += value else: Y_batch[label_name].append(value) field_names = copy.deepcopy(list(X_batch.keys())) for field_name in field_names: values = X_batch[field_name] # Only merge list of tensors if isinstance(values[0], Tensor): item_tensor, item_mask_tensor = list_to_tensor( values, min_len=min_data_len, max_len=max_data_len, ) X_batch[field_name] = item_tensor if item_mask_tensor is not None: X_batch[f"{field_name}_mask"] = item_mask_tensor for label_name, values in Y_batch.items(): Y_batch[label_name] = list_to_tensor( values, min_len=min_data_len, max_len=max_data_len, )[0] if len(Y_batch) != 0: return dict(X_batch), dict(Y_batch) else: return dict(X_batch)
def emmental_collate_fn( batch: Union[List[Tuple[Dict[str, Any], Dict[str, Tensor]]], List[Dict[str, Any]]] ) -> Union[Tuple[Dict[str, Any], Dict[str, Tensor]], Dict[str, Any]]: """Collate function. Args: batch: The batch to collate. Returns: The collated batch. """ X_batch: defaultdict = defaultdict(list) Y_batch: defaultdict = defaultdict(list) # Learnable batch should be a pair of dict, while non learnable batch is a dict is_learnable = True if not isinstance(batch[0], dict) else False if is_learnable: for x_dict, y_dict in batch: if isinstance(x_dict, dict) and isinstance(y_dict, dict): for field_name, value in x_dict.items(): if isinstance(value, list): X_batch[field_name] += value else: X_batch[field_name].append(value) for label_name, value in y_dict.items(): if isinstance(value, list): Y_batch[label_name] += value else: Y_batch[label_name].append(value) else: for x_dict in batch: # type: ignore for field_name, value in x_dict.items(): # type: ignore if isinstance(value, list): X_batch[field_name] += value else: X_batch[field_name].append(value) field_names = copy.deepcopy(list(X_batch.keys())) for field_name in field_names: values = X_batch[field_name] # Only merge list of tensors if isinstance(values[0], Tensor): item_tensor, item_mask_tensor = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], ) X_batch[field_name] = item_tensor if item_mask_tensor is not None: X_batch[f"{field_name}_mask"] = item_mask_tensor if is_learnable: for label_name, values in Y_batch.items(): Y_batch[label_name] = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], )[0] if is_learnable: return dict(X_batch), dict(Y_batch) else: return dict(X_batch)
def bootleg_collate_fn( batch: Union[List[Tuple[Dict[str, Any], Dict[str, Tensor]]], List[Dict[str, Any]]] ) -> Union[Tuple[Dict[str, Any], Dict[str, Tensor]], Dict[str, Any]]: """Collate function (modified from emmental collate fn). The main difference is our collate function handles the kg_adj dictionary items from the dataset. We collate each value of each dict key. Args: batch: The batch to collate. Returns: The collated batch. """ X_batch: defaultdict = defaultdict(list) # In Bootleg, we may have a nested dictionary in x_dict; we want to keep this structure but # collate the subtensors X_sub_batch: defaultdict = defaultdict(lambda: defaultdict(list)) Y_batch: defaultdict = defaultdict(list) # Learnable batch should be a pair of dict, while non learnable batch is a dict is_learnable = True if not isinstance(batch[0], dict) else False if is_learnable: for x_dict, y_dict in batch: if isinstance(x_dict, dict) and isinstance(y_dict, dict): for field_name, value in x_dict.items(): if isinstance(value, list): X_batch[field_name] += value elif isinstance(value, dict): # We reinstantiate the field_name here in case there is not kg adj data # This keeps the field_name key intact if field_name not in X_sub_batch: X_sub_batch[field_name] = defaultdict(list) for sub_field_name, sub_value in value.items(): if isinstance(sub_value, list): X_sub_batch[field_name][ sub_field_name] += sub_value else: X_sub_batch[field_name][sub_field_name].append( sub_value) else: X_batch[field_name].append(value) for label_name, value in y_dict.items(): if isinstance(value, list): Y_batch[label_name] += value else: Y_batch[label_name].append(value) else: for x_dict in batch: # type: ignore for field_name, value in x_dict.items(): # type: ignore if isinstance(value, list): X_batch[field_name] += value elif isinstance(value, dict): # We reinstantiate the field_name here in case there is not kg adj data # This keeps the field_name key intact if field_name not in X_sub_batch: X_sub_batch[field_name] = defaultdict(list) for sub_field_name, sub_value in value.items(): if isinstance(sub_value, list): X_sub_batch[field_name][ sub_field_name] += sub_value else: X_sub_batch[field_name][sub_field_name].append( sub_value) else: X_batch[field_name].append(value) field_names = copy.deepcopy(list(X_batch.keys())) for field_name in field_names: values = X_batch[field_name] # Only merge list of tensors if isinstance(values[0], Tensor): item_tensor, item_mask_tensor = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], ) X_batch[field_name] = item_tensor if item_mask_tensor is not None: X_batch[f"{field_name}_mask"] = item_mask_tensor field_names = copy.deepcopy(list(X_sub_batch.keys())) for field_name in field_names: sub_field_names = copy.deepcopy(list(X_sub_batch[field_name].keys())) for sub_field_name in sub_field_names: values = X_sub_batch[field_name][sub_field_name] # Only merge list of tensors if isinstance(values[0], Tensor): item_tensor, item_mask_tensor = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], ) X_sub_batch[field_name][sub_field_name] = item_tensor if item_mask_tensor is not None: X_sub_batch[field_name][ f"{sub_field_name}_mask"] = item_mask_tensor # Add sub batch to batch for field_name in field_names: X_batch[field_name] = dict(X_sub_batch[field_name]) if is_learnable: for label_name, values in Y_batch.items(): Y_batch[label_name] = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], )[0] if is_learnable: return dict(X_batch), dict(Y_batch) else: return dict(X_batch)