def calc_class_weights_from_dataloader( dataloader: 'torch.utils.data.DataLoader', num_classes: int, data_dir: str) -> List[float]: """ Calculate the weights of each class to be used for weighted loss. This is similar to the function calc_class_weights in text_classification_dataset, but it gets the labels from a dataloader rather than from a file. Args: dataloader: the dataloader for the training set num_classes: number of classes in the dataset """ labels = [] for batch in dataloader: labels.extend(tensor2list(batch[-1])) logging.info(f'Calculating label frequency stats...') total_sents, sent_label_freq, max_id = get_label_stats( labels, os.path.join(data_dir, 'sentence_stats.tsv'), verbose=False) if max_id >= num_classes: raise ValueError( f'Found an invalid label! Labels should be from [0, num_classes-1].' ) class_weights_dict = get_freq_weights(sent_label_freq) logging.info(f'Total Sentence Pairs: {total_sents}') logging.info(f'Class Frequencies: {sent_label_freq}') logging.info(f'Class Weights: {class_weights_dict}') class_weights = fill_class_weights(weights=class_weights_dict, max_id=num_classes - 1) return class_weights
def calc_class_weights(file_path: str, num_classes: int): """ iterates over a data file and calculate the weights of each class to be used for class_balancing Args: file_path: path to the data file num_classes: number of classes in the dataset """ if not os.path.exists(file_path): raise FileNotFoundError( f"Could not find data file {file_path} to calculate the class weights!" ) with open(file_path, 'r') as f: input_lines = f.readlines() labels = [] for input_line in input_lines: parts = input_line.strip().split() try: label = int(parts[-1]) except ValueError: raise ValueError( f'No numerical labels found for {file_path}. Labels should be integers and separated by [TAB] at the end of each line.' ) labels.append(label) logging.info(f'Calculating stats of {file_path}...') total_sents, sent_label_freq, max_id = get_label_stats( labels, f'{file_path}_sentence_stats.tsv', verbose=False) if max_id >= num_classes: raise ValueError( f'Found an invalid label in {file_path}! Labels should be from [0, num_classes-1].' ) class_weights_dict = get_freq_weights(sent_label_freq) logging.info(f'Total Sentence: {total_sents}') logging.info(f'Sentence class frequencies: {sent_label_freq}') logging.info(f'Class Weights: {class_weights_dict}') class_weights = fill_class_weights(weights=class_weights_dict, max_id=num_classes - 1) return class_weights
def get_label_ids( label_file: str, is_training: bool = False, pad_label: str = 'O', label_ids_dict: Dict[str, int] = None, get_weights: bool = True, class_labels_file_artifact='label_ids.csv', ): """ Generates str to int labels mapping for training data or checks correctness of the label_ids_dict file for non-training files or if label_ids_dict is specified Args: label_file: the path of the label file to process is_training: indicates whether the label_file is used for training pad_label: token used for padding label_ids_dict: str label name to int ids mapping. Required for non-training data. If specified, the check that all labels from label_file are present in label_ids_dict will be performed. For training data, if label_ids_dict is None, a new mapping will be generated from label_file. get_weights: set to True to calculate class weights, required for Weighted Loss. class_labels_file_artifact: name of the file to save in .nemo """ if not os.path.exists(label_file): raise ValueError(f'File {label_file} was not found.') logging.info(f'Processing {label_file}') if not is_training and label_ids_dict is None: raise ValueError( f'For non training data, label_ids_dict created during preprocessing of the training data ' f'should be provided') # collect all labels from the label_file data_dir = os.path.dirname(label_file) unique_labels = set(pad_label) all_labels = [] with open(label_file, 'r') as f: for line in f: line = line.strip().split() all_labels.extend(line) unique_labels.update(line) # check that all labels from label_file are present in the specified label_ids_dict # or generate label_ids_dict from data (for training only) save_label_ids = True if label_ids_dict: logging.info(f'Using provided labels mapping {label_ids_dict}') save_label_ids = False for name in unique_labels: if name not in label_ids_dict: raise ValueError( f'{name} class from {label_file} not found in the provided mapping: {label_ids_dict}' ) else: label_ids_dict = {pad_label: 0} if pad_label in unique_labels: unique_labels.remove(pad_label) for label in sorted(unique_labels): label_ids_dict[label] = len(label_ids_dict) label_ids_filename = os.path.join(data_dir, class_labels_file_artifact) if is_training and save_label_ids: with open(label_ids_filename, 'w') as f: labels, _ = zip( *sorted(label_ids_dict.items(), key=lambda x: x[1])) f.write('\n'.join(labels)) logging.info( f'Labels mapping {label_ids_dict} saved to : {label_ids_filename}') # calculate label statistics base_name = os.path.splitext(os.path.basename(label_file))[0] stats_file = os.path.join(data_dir, f'{base_name}_label_stats.tsv') if os.path.exists(stats_file) and not is_training and not get_weights: logging.info(f'{stats_file} found, skipping stats calculation.') else: all_labels = [label_ids_dict[label] for label in all_labels] logging.info(f'Three most popular labels in {label_file}:') total_labels, label_frequencies, max_id = get_label_stats( all_labels, stats_file) logging.info( f'Total labels: {total_labels}. Label frequencies - {label_frequencies}' ) if get_weights: class_weights_pkl = os.path.join(data_dir, f'{base_name}_weights.p') if os.path.exists(class_weights_pkl): class_weights = pickle.load(open(class_weights_pkl, 'rb')) logging.info(f'Class weights restored from {class_weights_pkl}') else: class_weights_dict = get_freq_weights(label_frequencies) logging.info(f'Class Weights: {class_weights_dict}') class_weights = fill_class_weights(class_weights_dict, max_id) pickle.dump(class_weights, open(class_weights_pkl, "wb")) logging.info(f'Class weights saved to {class_weights_pkl}') else: class_weights = None return label_ids_dict, label_ids_filename, class_weights
def __init__( self, data_dir: str, modes: List[str] = ['train', 'test', 'dev'], none_slot_label: str = 'O', pad_label: int = -1, ): if not if_exist(data_dir, ['dict.intents.csv', 'dict.slots.csv']): raise FileNotFoundError( "Make sure that your data follows the standard format " "supported by JointIntentSlotDataset. Your data must " "contain dict.intents.csv and dict.slots.csv.") self.data_dir = data_dir self.intent_dict_file = self.data_dir + '/dict.intents.csv' self.slot_dict_file = self.data_dir + '/dict.slots.csv' self.intents_label_ids = IntentSlotDataDesc.label2idx( self.intent_dict_file) self.num_intents = len(self.intents_label_ids) self.slots_label_ids = IntentSlotDataDesc.label2idx( self.slot_dict_file) self.num_slots = len(self.slots_label_ids) infold = self.data_dir for mode in modes: if not if_exist(self.data_dir, [f'{mode}.tsv']): logging.info(f' Stats calculation for {mode} mode' f' is skipped as {mode}.tsv was not found.') continue logging.info(f' Stats calculating for {mode} mode...') slot_file = f'{self.data_dir}/{mode}_slots.tsv' with open(slot_file, 'r') as f: slot_lines = f.readlines() input_file = f'{self.data_dir}/{mode}.tsv' with open(input_file, 'r') as f: input_lines = f.readlines()[1:] # Skipping headers at index 0 if len(slot_lines) != len(input_lines): raise ValueError( "Make sure that the number of slot lines match the " "number of intent lines. There should be a 1-1 " "correspondence between every slot and intent lines.") dataset = list(zip(slot_lines, input_lines)) raw_slots, raw_intents = [], [] for slot_line, input_line in dataset: slot_list = [int(slot) for slot in slot_line.strip().split()] raw_slots.append(slot_list) parts = input_line.strip().split() raw_intents.append(int(parts[-1])) logging.info(f'Three most popular intents in {mode} mode:') total_intents, intent_label_freq, max_id = get_label_stats( raw_intents, infold + f'/{mode}_intent_stats.tsv') merged_slots = itertools.chain.from_iterable(raw_slots) logging.info(f'Three most popular slots in {mode} mode:') slots_total, slots_label_freq, max_id = get_label_stats( merged_slots, infold + f'/{mode}_slot_stats.tsv') logging.info(f'Total Number of Intents: {total_intents}') logging.info(f'Intent Label Frequencies: {intent_label_freq}') logging.info(f'Total Number of Slots: {slots_total}') logging.info(f'Slots Label Frequencies: {slots_label_freq}') if mode == 'train': intent_weights_dict = get_freq_weights(intent_label_freq) logging.info(f'Intent Weights: {intent_weights_dict}') slot_weights_dict = get_freq_weights(slots_label_freq) logging.info(f'Slot Weights: {slot_weights_dict}') self.intent_weights = fill_class_weights(intent_weights_dict, self.num_intents - 1) self.slot_weights = fill_class_weights(slot_weights_dict, self.num_slots - 1) if pad_label != -1: self.pad_label = pad_label else: if none_slot_label not in self.slots_label_ids: raise ValueError(f'none_slot_label {none_slot_label} not ' f'found in {self.slot_dict_file}.') self.pad_label = self.slots_label_ids[none_slot_label]
def __init__( self, data_dir: str, modes: List[str] = ["train", "test", "dev"], none_slot_label: str = "O", pad_label: int = -1, ): if not if_exist(data_dir, ["dict.intents.csv", "dict.slots.csv"]): raise FileNotFoundError( "Make sure that your data follows the standard format " "supported by MultiLabelIntentSlotDataset. Your data must " "contain dict.intents.csv and dict.slots.csv.") self.data_dir = data_dir self.intent_dict_file = self.data_dir + "/dict.intents.csv" self.slot_dict_file = self.data_dir + "/dict.slots.csv" self.intents_label_ids = get_labels_to_labels_id_mapping( self.intent_dict_file) self.num_intents = len(self.intents_label_ids) self.slots_label_ids = get_labels_to_labels_id_mapping( self.slot_dict_file) self.num_slots = len(self.slots_label_ids) infold = self.data_dir for mode in modes: if not if_exist(self.data_dir, [f"{mode}.tsv"]): logging.info(f" Stats calculation for {mode} mode" f" is skipped as {mode}.tsv was not found.") continue logging.info(f" Stats calculating for {mode} mode...") slot_file = f"{self.data_dir}/{mode}_slots.tsv" with open(slot_file, "r") as f: slot_lines = f.readlines() input_file = f"{self.data_dir}/{mode}.tsv" with open(input_file, "r") as f: input_lines = f.readlines()[1:] # Skipping headers at index 0 if len(slot_lines) != len(input_lines): raise ValueError( "Make sure that the number of slot lines match the " "number of intent lines. There should be a 1-1 " "correspondence between every slot and intent lines.") dataset = list(zip(slot_lines, input_lines)) raw_slots, raw_intents = [], [] for slot_line, input_line in dataset: slot_list = [int(slot) for slot in slot_line.strip().split()] raw_slots.append(slot_list) parts = input_line.strip().split("\t")[1:][0] parts = list(map(int, parts.split(","))) parts = [ 1 if label in parts else 0 for label in range(self.num_intents) ] raw_intents.append(tuple(parts)) logging.info(f"Three most popular intents in {mode} mode:") total_intents, intent_label_freq, max_id = get_multi_label_stats( raw_intents, infold + f"/{mode}_intent_stats.tsv") merged_slots = itertools.chain.from_iterable(raw_slots) logging.info(f"Three most popular slots in {mode} mode:") slots_total, slots_label_freq, max_id = get_label_stats( merged_slots, infold + f"/{mode}_slot_stats.tsv") logging.info(f"Total Number of Intent Labels: {total_intents}") logging.info(f"Intent Label Frequencies: {intent_label_freq}") logging.info(f"Total Number of Slots: {slots_total}") logging.info(f"Slots Label Frequencies: {slots_label_freq}") if mode == "train": intent_weights_dict = get_freq_weights_bce_with_logits_loss( intent_label_freq) logging.info(f"Intent Weights: {intent_weights_dict}") slot_weights_dict = get_freq_weights(slots_label_freq) logging.info(f"Slot Weights: {slot_weights_dict}") self.intent_weights = fill_class_weights(intent_weights_dict, self.num_intents - 1) self.slot_weights = fill_class_weights(slot_weights_dict, self.num_slots - 1) if pad_label != -1: self.pad_label = pad_label else: if none_slot_label not in self.slots_label_ids: raise ValueError(f"none_slot_label {none_slot_label} not " f"found in {self.slot_dict_file}.") self.pad_label = self.slots_label_ids[none_slot_label]
def __init__(self, data_dir: str, modes: List[str] = ['train', 'test', 'dev'], pad_label='O', label_ids_dict=None): """A descriptor class that reads all the data and calculates some stats of the data and also calculates the class weights to be used for class balancing Args: data_dir: the path to the data folder modes: list of the modes to read, it can be from ["train", "test", "dev"] by default. It is going to look for the data files at {data_dir}/{mode}.txt label_ids_dict: labels to ids mapping from pretrained model """ self.data_dir = data_dir self.label_ids = None unique_labels = set() for mode in modes: all_labels = [] label_file = os.path.join(data_dir, 'labels_' + mode + '.txt') if not os.path.exists(label_file): logging.info( f'Stats calculation for {mode} mode is skipped as {label_file} was not found.' ) continue with open(label_file, 'r') as f: for line in f: line = line.strip().split() all_labels.extend(line) unique_labels.update(line) if mode == 'train': label_ids = {pad_label: 0} if pad_label in unique_labels: unique_labels.remove(pad_label) for label in sorted(unique_labels): label_ids[label] = len(label_ids) self.pad_label = pad_label if label_ids_dict: if len(set(label_ids_dict) | set(label_ids)) != len(label_ids_dict): raise ValueError( f'Provided labels to ids map: {label_ids_dict} does not match the labels ' f'in the data: {label_ids}') self.label_ids = label_ids_dict if label_ids_dict else label_ids logging.info(f'Labels: {self.label_ids}') self.label_ids_filename = os.path.join(data_dir, 'label_ids.csv') out = open(self.label_ids_filename, 'w') labels, _ = zip( *sorted(self.label_ids.items(), key=lambda x: x[1])) out.write('\n'.join(labels)) logging.info(f'Labels mapping saved to : {out.name}') all_labels = [self.label_ids[label] for label in all_labels] logging.info(f'Three most popular labels in {mode} dataset:') total_labels, label_frequencies, max_id = get_label_stats( all_labels, os.path.join(data_dir, mode + '_label_stats.tsv')) logging.info(f'Total labels: {total_labels}') logging.info(f'Label frequencies - {label_frequencies}') if mode == 'train': class_weights_dict = get_freq_weights(label_frequencies) logging.info(f'Class Weights: {class_weights_dict}') self.class_weights = fill_class_weights( class_weights_dict, max_id) self.num_classes = max_id + 1