Example #1
0
def calc_class_weights_from_dataloader(
        dataloader: 'torch.utils.data.DataLoader', num_classes: int,
        data_dir: str) -> List[float]:
    """
    Calculate the weights of each class to be used for weighted loss. This is similar to the function calc_class_weights
    in text_classification_dataset, but it gets the labels from a dataloader rather than from a file.
    Args:
        dataloader: the dataloader for the training set
        num_classes: number of classes in the dataset
    """
    labels = []
    for batch in dataloader:
        labels.extend(tensor2list(batch[-1]))
    logging.info(f'Calculating label frequency stats...')
    total_sents, sent_label_freq, max_id = get_label_stats(
        labels, os.path.join(data_dir, 'sentence_stats.tsv'), verbose=False)
    if max_id >= num_classes:
        raise ValueError(
            f'Found an invalid label! Labels should be from [0, num_classes-1].'
        )

    class_weights_dict = get_freq_weights(sent_label_freq)

    logging.info(f'Total Sentence Pairs: {total_sents}')
    logging.info(f'Class Frequencies: {sent_label_freq}')
    logging.info(f'Class Weights: {class_weights_dict}')
    class_weights = fill_class_weights(weights=class_weights_dict,
                                       max_id=num_classes - 1)
    return class_weights
Example #2
0
def calc_class_weights(file_path: str, num_classes: int):
    """
    iterates over a data file and calculate the weights of each class to be used for class_balancing
    Args:
        file_path: path to the data file
        num_classes: number of classes in the dataset
    """

    if not os.path.exists(file_path):
        raise FileNotFoundError(
            f"Could not find data file {file_path} to calculate the class weights!"
        )

    with open(file_path, 'r') as f:
        input_lines = f.readlines()

    labels = []
    for input_line in input_lines:
        parts = input_line.strip().split()
        try:
            label = int(parts[-1])
        except ValueError:
            raise ValueError(
                f'No numerical labels found for {file_path}. Labels should be integers and separated by [TAB] at the end of each line.'
            )
        labels.append(label)

    logging.info(f'Calculating stats of {file_path}...')
    total_sents, sent_label_freq, max_id = get_label_stats(
        labels, f'{file_path}_sentence_stats.tsv', verbose=False)
    if max_id >= num_classes:
        raise ValueError(
            f'Found an invalid label in {file_path}! Labels should be from [0, num_classes-1].'
        )

    class_weights_dict = get_freq_weights(sent_label_freq)

    logging.info(f'Total Sentence: {total_sents}')
    logging.info(f'Sentence class frequencies: {sent_label_freq}')

    logging.info(f'Class Weights: {class_weights_dict}')
    class_weights = fill_class_weights(weights=class_weights_dict,
                                       max_id=num_classes - 1)

    return class_weights
Example #3
0
def get_label_ids(
    label_file: str,
    is_training: bool = False,
    pad_label: str = 'O',
    label_ids_dict: Dict[str, int] = None,
    get_weights: bool = True,
    class_labels_file_artifact='label_ids.csv',
):
    """
    Generates str to int labels mapping for training data or checks correctness of the label_ids_dict
    file for non-training files or if label_ids_dict is specified

    Args:
        label_file: the path of the label file to process
        is_training: indicates whether the label_file is used for training
        pad_label: token used for padding
        label_ids_dict: str label name to int ids mapping. Required for non-training data.
            If specified, the check that all labels from label_file are present in label_ids_dict will be performed.
            For training data, if label_ids_dict is None, a new mapping will be generated from label_file.
        get_weights: set to True to calculate class weights, required for Weighted Loss.
        class_labels_file_artifact: name of the file to save in .nemo
    """
    if not os.path.exists(label_file):
        raise ValueError(f'File {label_file} was not found.')

    logging.info(f'Processing {label_file}')
    if not is_training and label_ids_dict is None:
        raise ValueError(
            f'For non training data, label_ids_dict created during preprocessing of the training data '
            f'should be provided')

    # collect all labels from the label_file
    data_dir = os.path.dirname(label_file)
    unique_labels = set(pad_label)
    all_labels = []
    with open(label_file, 'r') as f:
        for line in f:
            line = line.strip().split()
            all_labels.extend(line)
            unique_labels.update(line)

    # check that all labels from label_file are present in the specified label_ids_dict
    # or generate label_ids_dict from data (for training only)
    save_label_ids = True
    if label_ids_dict:
        logging.info(f'Using provided labels mapping {label_ids_dict}')
        save_label_ids = False
        for name in unique_labels:
            if name not in label_ids_dict:
                raise ValueError(
                    f'{name} class from {label_file} not found in the provided mapping: {label_ids_dict}'
                )
    else:
        label_ids_dict = {pad_label: 0}
        if pad_label in unique_labels:
            unique_labels.remove(pad_label)
        for label in sorted(unique_labels):
            label_ids_dict[label] = len(label_ids_dict)

    label_ids_filename = os.path.join(data_dir, class_labels_file_artifact)
    if is_training and save_label_ids:
        with open(label_ids_filename, 'w') as f:
            labels, _ = zip(
                *sorted(label_ids_dict.items(), key=lambda x: x[1]))
            f.write('\n'.join(labels))
        logging.info(
            f'Labels mapping {label_ids_dict} saved to : {label_ids_filename}')

    # calculate label statistics
    base_name = os.path.splitext(os.path.basename(label_file))[0]
    stats_file = os.path.join(data_dir, f'{base_name}_label_stats.tsv')
    if os.path.exists(stats_file) and not is_training and not get_weights:
        logging.info(f'{stats_file} found, skipping stats calculation.')
    else:
        all_labels = [label_ids_dict[label] for label in all_labels]
        logging.info(f'Three most popular labels in {label_file}:')
        total_labels, label_frequencies, max_id = get_label_stats(
            all_labels, stats_file)
        logging.info(
            f'Total labels: {total_labels}. Label frequencies - {label_frequencies}'
        )

    if get_weights:
        class_weights_pkl = os.path.join(data_dir, f'{base_name}_weights.p')
        if os.path.exists(class_weights_pkl):
            class_weights = pickle.load(open(class_weights_pkl, 'rb'))
            logging.info(f'Class weights restored from {class_weights_pkl}')
        else:
            class_weights_dict = get_freq_weights(label_frequencies)
            logging.info(f'Class Weights: {class_weights_dict}')
            class_weights = fill_class_weights(class_weights_dict, max_id)

            pickle.dump(class_weights, open(class_weights_pkl, "wb"))
            logging.info(f'Class weights saved to {class_weights_pkl}')
    else:
        class_weights = None

    return label_ids_dict, label_ids_filename, class_weights
Example #4
0
    def __init__(
        self,
        data_dir: str,
        modes: List[str] = ['train', 'test', 'dev'],
        none_slot_label: str = 'O',
        pad_label: int = -1,
    ):
        if not if_exist(data_dir, ['dict.intents.csv', 'dict.slots.csv']):
            raise FileNotFoundError(
                "Make sure that your data follows the standard format "
                "supported by JointIntentSlotDataset. Your data must "
                "contain dict.intents.csv and dict.slots.csv.")

        self.data_dir = data_dir
        self.intent_dict_file = self.data_dir + '/dict.intents.csv'
        self.slot_dict_file = self.data_dir + '/dict.slots.csv'

        self.intents_label_ids = IntentSlotDataDesc.label2idx(
            self.intent_dict_file)
        self.num_intents = len(self.intents_label_ids)
        self.slots_label_ids = IntentSlotDataDesc.label2idx(
            self.slot_dict_file)
        self.num_slots = len(self.slots_label_ids)

        infold = self.data_dir
        for mode in modes:
            if not if_exist(self.data_dir, [f'{mode}.tsv']):
                logging.info(f' Stats calculation for {mode} mode'
                             f' is skipped as {mode}.tsv was not found.')
                continue
            logging.info(f' Stats calculating for {mode} mode...')
            slot_file = f'{self.data_dir}/{mode}_slots.tsv'
            with open(slot_file, 'r') as f:
                slot_lines = f.readlines()

            input_file = f'{self.data_dir}/{mode}.tsv'
            with open(input_file, 'r') as f:
                input_lines = f.readlines()[1:]  # Skipping headers at index 0

            if len(slot_lines) != len(input_lines):
                raise ValueError(
                    "Make sure that the number of slot lines match the "
                    "number of intent lines. There should be a 1-1 "
                    "correspondence between every slot and intent lines.")

            dataset = list(zip(slot_lines, input_lines))

            raw_slots, raw_intents = [], []
            for slot_line, input_line in dataset:
                slot_list = [int(slot) for slot in slot_line.strip().split()]
                raw_slots.append(slot_list)
                parts = input_line.strip().split()
                raw_intents.append(int(parts[-1]))

            logging.info(f'Three most popular intents in {mode} mode:')
            total_intents, intent_label_freq, max_id = get_label_stats(
                raw_intents, infold + f'/{mode}_intent_stats.tsv')

            merged_slots = itertools.chain.from_iterable(raw_slots)
            logging.info(f'Three most popular slots in {mode} mode:')
            slots_total, slots_label_freq, max_id = get_label_stats(
                merged_slots, infold + f'/{mode}_slot_stats.tsv')

            logging.info(f'Total Number of Intents: {total_intents}')
            logging.info(f'Intent Label Frequencies: {intent_label_freq}')
            logging.info(f'Total Number of Slots: {slots_total}')
            logging.info(f'Slots Label Frequencies: {slots_label_freq}')

            if mode == 'train':
                intent_weights_dict = get_freq_weights(intent_label_freq)
                logging.info(f'Intent Weights: {intent_weights_dict}')
                slot_weights_dict = get_freq_weights(slots_label_freq)
                logging.info(f'Slot Weights: {slot_weights_dict}')

        self.intent_weights = fill_class_weights(intent_weights_dict,
                                                 self.num_intents - 1)
        self.slot_weights = fill_class_weights(slot_weights_dict,
                                               self.num_slots - 1)

        if pad_label != -1:
            self.pad_label = pad_label
        else:
            if none_slot_label not in self.slots_label_ids:
                raise ValueError(f'none_slot_label {none_slot_label} not '
                                 f'found in {self.slot_dict_file}.')
            self.pad_label = self.slots_label_ids[none_slot_label]
Example #5
0
    def __init__(
        self,
        data_dir: str,
        modes: List[str] = ["train", "test", "dev"],
        none_slot_label: str = "O",
        pad_label: int = -1,
    ):
        if not if_exist(data_dir, ["dict.intents.csv", "dict.slots.csv"]):
            raise FileNotFoundError(
                "Make sure that your data follows the standard format "
                "supported by MultiLabelIntentSlotDataset. Your data must "
                "contain dict.intents.csv and dict.slots.csv.")

        self.data_dir = data_dir
        self.intent_dict_file = self.data_dir + "/dict.intents.csv"
        self.slot_dict_file = self.data_dir + "/dict.slots.csv"

        self.intents_label_ids = get_labels_to_labels_id_mapping(
            self.intent_dict_file)
        self.num_intents = len(self.intents_label_ids)
        self.slots_label_ids = get_labels_to_labels_id_mapping(
            self.slot_dict_file)
        self.num_slots = len(self.slots_label_ids)

        infold = self.data_dir
        for mode in modes:
            if not if_exist(self.data_dir, [f"{mode}.tsv"]):
                logging.info(f" Stats calculation for {mode} mode"
                             f" is skipped as {mode}.tsv was not found.")
                continue
            logging.info(f" Stats calculating for {mode} mode...")
            slot_file = f"{self.data_dir}/{mode}_slots.tsv"
            with open(slot_file, "r") as f:
                slot_lines = f.readlines()

            input_file = f"{self.data_dir}/{mode}.tsv"
            with open(input_file, "r") as f:
                input_lines = f.readlines()[1:]  # Skipping headers at index 0

            if len(slot_lines) != len(input_lines):
                raise ValueError(
                    "Make sure that the number of slot lines match the "
                    "number of intent lines. There should be a 1-1 "
                    "correspondence between every slot and intent lines.")

            dataset = list(zip(slot_lines, input_lines))

            raw_slots, raw_intents = [], []
            for slot_line, input_line in dataset:
                slot_list = [int(slot) for slot in slot_line.strip().split()]
                raw_slots.append(slot_list)
                parts = input_line.strip().split("\t")[1:][0]
                parts = list(map(int, parts.split(",")))
                parts = [
                    1 if label in parts else 0
                    for label in range(self.num_intents)
                ]
                raw_intents.append(tuple(parts))

            logging.info(f"Three most popular intents in {mode} mode:")
            total_intents, intent_label_freq, max_id = get_multi_label_stats(
                raw_intents, infold + f"/{mode}_intent_stats.tsv")

            merged_slots = itertools.chain.from_iterable(raw_slots)
            logging.info(f"Three most popular slots in {mode} mode:")
            slots_total, slots_label_freq, max_id = get_label_stats(
                merged_slots, infold + f"/{mode}_slot_stats.tsv")

            logging.info(f"Total Number of Intent Labels: {total_intents}")
            logging.info(f"Intent Label Frequencies: {intent_label_freq}")
            logging.info(f"Total Number of Slots: {slots_total}")
            logging.info(f"Slots Label Frequencies: {slots_label_freq}")

            if mode == "train":
                intent_weights_dict = get_freq_weights_bce_with_logits_loss(
                    intent_label_freq)
                logging.info(f"Intent Weights: {intent_weights_dict}")
                slot_weights_dict = get_freq_weights(slots_label_freq)
                logging.info(f"Slot Weights: {slot_weights_dict}")

        self.intent_weights = fill_class_weights(intent_weights_dict,
                                                 self.num_intents - 1)
        self.slot_weights = fill_class_weights(slot_weights_dict,
                                               self.num_slots - 1)

        if pad_label != -1:
            self.pad_label = pad_label
        else:
            if none_slot_label not in self.slots_label_ids:
                raise ValueError(f"none_slot_label {none_slot_label} not "
                                 f"found in {self.slot_dict_file}.")
            self.pad_label = self.slots_label_ids[none_slot_label]
    def __init__(self,
                 data_dir: str,
                 modes: List[str] = ['train', 'test', 'dev'],
                 pad_label='O',
                 label_ids_dict=None):
        """A descriptor class that reads all the data and calculates some stats of the data and also calculates
        the class weights to be used for class balancing
        Args:
            data_dir: the path to the data folder
            modes: list of the modes to read, it can be from ["train", "test", "dev"] by default.
            It is going to look for the data files at {data_dir}/{mode}.txt
            label_ids_dict: labels to ids mapping from pretrained model
        """
        self.data_dir = data_dir
        self.label_ids = None
        unique_labels = set()

        for mode in modes:
            all_labels = []
            label_file = os.path.join(data_dir, 'labels_' + mode + '.txt')
            if not os.path.exists(label_file):
                logging.info(
                    f'Stats calculation for {mode} mode is skipped as {label_file} was not found.'
                )
                continue

            with open(label_file, 'r') as f:
                for line in f:
                    line = line.strip().split()
                    all_labels.extend(line)
                    unique_labels.update(line)

            if mode == 'train':
                label_ids = {pad_label: 0}
                if pad_label in unique_labels:
                    unique_labels.remove(pad_label)
                for label in sorted(unique_labels):
                    label_ids[label] = len(label_ids)

                self.pad_label = pad_label
                if label_ids_dict:
                    if len(set(label_ids_dict)
                           | set(label_ids)) != len(label_ids_dict):
                        raise ValueError(
                            f'Provided labels to ids map: {label_ids_dict} does not match the labels '
                            f'in the data: {label_ids}')
                self.label_ids = label_ids_dict if label_ids_dict else label_ids
                logging.info(f'Labels: {self.label_ids}')
                self.label_ids_filename = os.path.join(data_dir,
                                                       'label_ids.csv')
                out = open(self.label_ids_filename, 'w')
                labels, _ = zip(
                    *sorted(self.label_ids.items(), key=lambda x: x[1]))
                out.write('\n'.join(labels))
                logging.info(f'Labels mapping saved to : {out.name}')

            all_labels = [self.label_ids[label] for label in all_labels]
            logging.info(f'Three most popular labels in {mode} dataset:')
            total_labels, label_frequencies, max_id = get_label_stats(
                all_labels, os.path.join(data_dir, mode + '_label_stats.tsv'))

            logging.info(f'Total labels: {total_labels}')
            logging.info(f'Label frequencies - {label_frequencies}')

            if mode == 'train':
                class_weights_dict = get_freq_weights(label_frequencies)
                logging.info(f'Class Weights: {class_weights_dict}')
                self.class_weights = fill_class_weights(
                    class_weights_dict, max_id)
                self.num_classes = max_id + 1