num_layers=PUNCT_NUM_FC_LAYERS, name='Punctuation') capit_classifier = TokenClassifier(hidden_size=bert_model.hidden_size, num_classes=len(capit_label_ids), dropout=CLASSIFICATION_DROPOUT, name='Capitalization') # If you don't want to use weighted loss for Punctuation task, use class_weights=None punct_label_freqs = train_data_layer.dataset.punct_label_frequencies class_weights = calc_class_weights(punct_label_freqs) # define loss punct_loss = CrossEntropyLossNM(logits_ndim=3, weight=class_weights) capit_loss = CrossEntropyLossNM(logits_ndim=3) task_loss = LossAggregatorNM(num_inputs=2) input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, punct_labels, capit_labels = train_data_layer( ) hidden_states = bert_model(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) punct_logits = punct_classifier(hidden_states=hidden_states) capit_logits = capit_classifier(hidden_states=hidden_states) punct_loss = punct_loss(logits=punct_logits, labels=punct_labels, loss_mask=loss_mask)
encoder = EncoderRNN(vocab_size, args.emb_dim, args.hid_dim, args.dropout, args.n_layers) decoder = TRADEGenerator( data_desc.vocab, encoder.embedding, args.hid_dim, args.dropout, data_desc.slots, len(data_desc.gating_dict), teacher_forcing=args.teacher_forcing, ) gate_loss_fn = CrossEntropyLossNM(logits_dim=3) ptr_loss_fn = MaskedLogLoss() total_loss_fn = LossAggregatorNM(num_inputs=2) def create_pipeline(num_samples, batch_size, num_gpus, input_dropout, data_prefix, is_training): logging.info(f"Loading {data_prefix} data...") shuffle = args.shuffle_data if is_training else False data_layer = MultiWOZDataLayer( abs_data_dir, data_desc.domains, all_domains=data_desc.all_domains, vocab=data_desc.vocab, slots=data_desc.slots, gating_dict=data_desc.gating_dict, num_samples=num_samples,
def create_pipeline( pad_label=args.none_label, max_seq_length=args.max_seq_length, batch_size=args.batch_size, num_gpus=args.num_gpus, mode='train', punct_label_ids=None, capit_label_ids=None, ignore_extra_tokens=args.ignore_extra_tokens, ignore_start_end=args.ignore_start_end, overwrite_processed_files=args.overwrite_processed_files, dropout=args.fc_dropout, punct_num_layers=args.punct_num_fc_layers, capit_num_layers=args.capit_num_fc_layers, classifier=PunctCapitTokenClassifier, ): logging.info(f"Loading {mode} data...") shuffle = args.shuffle_data if mode == 'train' else False text_file = f'{args.data_dir}/text_{mode}.txt' label_file = f'{args.data_dir}/labels_{mode}.txt' if not (os.path.exists(text_file) or (os.path.exists(label_file))): raise FileNotFoundError(f'{text_file} or {label_file} not found. \ The data should be splitted into 2 files: text.txt and labels.txt. \ Each line of the text.txt file contains text sequences, where words\ are separated with spaces. The labels.txt file contains \ corresponding labels for each word in text.txt, the labels are \ separated with spaces. Each line of the files should follow the \ format: \ [WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \ [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).') data_layer = PunctuationCapitalizationDataLayer( tokenizer=tokenizer, text_file=text_file, label_file=label_file, pad_label=pad_label, punct_label_ids=punct_label_ids, capit_label_ids=capit_label_ids, max_seq_length=max_seq_length, batch_size=batch_size, shuffle=shuffle, ignore_extra_tokens=ignore_extra_tokens, ignore_start_end=ignore_start_end, overwrite_processed_files=overwrite_processed_files, num_workers=args.num_workers, pin_memory=args.enable_pin_memory, ) (input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, punct_labels, capit_labels) = data_layer() if mode == 'train': punct_label_ids = data_layer.dataset.punct_label_ids capit_label_ids = data_layer.dataset.capit_label_ids class_weights = None if args.use_weighted_loss_punct: logging.info(f"Using weighted loss for punctuation task") punct_label_freqs = data_layer.dataset.punct_label_frequencies class_weights = calc_class_weights(punct_label_freqs) classifier = classifier( hidden_size=hidden_size, punct_num_classes=len(punct_label_ids), capit_num_classes=len(capit_label_ids), dropout=dropout, punct_num_layers=punct_num_layers, capit_num_layers=capit_num_layers, ) punct_loss = CrossEntropyLossNM(logits_ndim=3, weight=class_weights) capit_loss = CrossEntropyLossNM(logits_ndim=3) task_loss = LossAggregatorNM( num_inputs=2, weights=[args.punct_loss_weight, 1.0 - args.punct_loss_weight]) hidden_states = model(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) punct_logits, capit_logits = classifier(hidden_states=hidden_states) if mode == 'train': punct_loss = punct_loss(logits=punct_logits, labels=punct_labels, loss_mask=loss_mask) capit_loss = capit_loss(logits=capit_logits, labels=capit_labels, loss_mask=loss_mask) task_loss = task_loss(loss_1=punct_loss, loss_2=capit_loss) steps_per_epoch = len(data_layer) // (batch_size * num_gpus) losses = [task_loss, punct_loss, capit_loss] logits = [punct_logits, capit_logits] return losses, logits, steps_per_epoch, punct_label_ids, capit_label_ids, classifier else: tensors_to_evaluate = [ punct_logits, capit_logits, punct_labels, capit_labels, subtokens_mask ] return tensors_to_evaluate, data_layer
num_intents=data_desc.num_intents, num_slots=data_desc.num_slots, dropout=args.fc_dropout) if args.class_balancing == 'weighted_loss': # To tackle imbalanced classes, you may use weighted loss intent_loss_fn = CrossEntropyLossNM(logits_ndim=2, weight=data_desc.intent_weights) slot_loss_fn = CrossEntropyLossNM(logits_ndim=3, weight=data_desc.slot_weights) else: intent_loss_fn = CrossEntropyLossNM(logits_ndim=2) slot_loss_fn = CrossEntropyLossNM(logits_ndim=3) total_loss_fn = LossAggregatorNM( num_inputs=2, weights=[args.intent_loss_weight, 1.0 - args.intent_loss_weight]) def create_pipeline(num_samples=-1, batch_size=32, data_prefix='train', is_training=True, num_gpus=1): logging.info(f"Loading {data_prefix} data...") data_file = f'{data_desc.data_dir}/{data_prefix}.tsv' slot_file = f'{data_desc.data_dir}/{data_prefix}_slots.tsv' shuffle = args.shuffle_data if is_training else False data_layer = BertJointIntentSlotDataLayer( input_file=data_file,