Example #1
0
mask_ids_test = []

for i in tqdm(range(len(test_sentences))):
    encoding=tokenizer.encode_plus(test_sentences[i], max_length=128, pad_to_max_length=True, do_lower_case= False)
    input_ids, attention_id = encoding["input_ids"], encoding["attention_mask"]
    padded_ids_test.append(input_ids)
    mask_ids_test.append(attention_id)

train_id = np.array(padded_ids_train)
train_mask = np.array(mask_ids_train)
test_id = np.array(padded_ids_test)
test_mask = np.array(mask_ids_test)

input_1 = tf.keras.Input(shape = (128) , dtype=np.int32)
input_2 = tf.keras.Input(shape = (128) , dtype=np.int32)
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-cased')
output  = model([input_1 , input_2] , training = False )
x = tf.keras.layers.Dense(128 , activation = tf.nn.relu )(output[0])
x = tf.keras.layers.Dropout(0.15)(x)
x = tf.keras.layers.Dense(6 , activation = tf.nn.sigmoid )(x) 
model = tf.keras.Model(inputs = [input_1, input_2 ] , outputs = [x])
model.summary()

path= "model_distilBERT.h5"
checkpoint = ModelCheckpoint(filepath=path, monitor='val_precision', verbose=1, save_best_only=True, mode='max', save_weights_only=True)
model.compile(optimizer=Adam(lr=3e-5),loss=tf.keras.losses.binary_crossentropy, metrics=tf.keras.metrics.Precision())

# Training model...
history = model.fit([train_id,train_mask], y_train, batch_size=32, epochs=5, callbacks=checkpoint, validation_split=0.1)

# Loading model...
Example #2
0
                select_data_and_label_from_record)

        tokenizer = None
        config = None
        model = None

        # This is required when launching many instances at once...  the urllib request seems to get denied periodically
        successful_download = False
        retries = 0
        while (retries < 5 and not successful_download):
            try:
                tokenizer = DistilBertTokenizer.from_pretrained(
                    'distilbert-base-uncased')
                config = DistilBertConfig.from_pretrained(
                    'distilbert-base-uncased', num_labels=len(CLASSES))
                model = TFDistilBertForSequenceClassification.from_pretrained(
                    'distilbert-base-uncased', config=config)
                successful_download = True
                print(
                    'Sucessfully downloaded after {} retries.'.format(retries))
            except:
                retries = retries + 1
                random_sleep = random.randint(1, 30)
                print('Retry #{}.  Sleeping for {} seconds'.format(
                    retries, random_sleep))
                time.sleep(random_sleep)

        if not tokenizer or not model or not config:
            print('Not properly initialized...')

        optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
        if use_amp:
Example #3
0
def main():

    parser = argparse.ArgumentParser(
        description=
        'Script for running text topic classification with transformers package'
    )
    parser.add_argument(
        '-m',
        '--model',
        choices=[
            'bert-base-uncased', 'bert-large-uncased', 'roberta-base',
            'roberta-large', 'distilbert-base-uncased',
            'google/electra-base-discriminator'
        ],
        help='Class of Model Architecture to use for classification')
    parser.add_argument('-b',
                        '--BATCH_SIZE',
                        default=64,
                        type=int,
                        help='batch size to use per replica')
    parser.add_argument(
        '-l',
        '--SEQUENCE_LENGTH',
        default=128,
        type=int,
        help=
        'maximum sequence length. short sequences are padded. long are truncated'
    )
    parser.add_argument(
        '-e',
        '--EPOCHS',
        default=5,
        type=int,
        help=
        'the number of passes over the dataset to run. early stopping with 2 epoch patience is used'
    )

    args = parser.parse_args()

    if args.model[:4] == 'robe':
        # Use Roberta tokenizer
        TOKENIZER = RobertaTokenizer.from_pretrained(args.model)
    else:
        # Use Bert tokenizer
        TOKENIZER = BertTokenizer.from_pretrained(args.model)

    train_sentences, train_labels = gather_data(TRAINING_DATA)
    val_sentences, val_labels = gather_data(VAL_DATA)

    print(f'Length of Training Set: {len(train_sentences)}')
    print(f'Length of Test Set: {len(val_sentences)}')

    training_dataset = create_dataset(train_sentences, train_labels,
                                      args.SEQUENCE_LENGTH, TOKENIZER)
    val_dataset = create_dataset(val_sentences, val_labels,
                                 args.SEQUENCE_LENGTH, TOKENIZER)

    print(f'Maximum Sequence Length: {args.SEQUENCE_LENGTH}')

    mirrored_strategy = tf.distribute.MirroredStrategy()
    print(f'Number of devices: {mirrored_strategy.num_replicas_in_sync}')

    BATCH_SIZE_PER_REPLICA = args.BATCH_SIZE
    GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
    print(f'Global Batch Size: {GLOBAL_BATCH_SIZE}')

    batched_training_dataset = training_dataset.shuffle(1024).batch(
        GLOBAL_BATCH_SIZE, drop_remainder=True)
    batched_val_dataset = val_dataset.shuffle(1024).batch(GLOBAL_BATCH_SIZE,
                                                          drop_remainder=True)

    #dist_train_dataset = mirrored_strategy.experimental_distribute_dataset(batched_training_dataset)
    #dist_val_dataset = mirrored_strategy.experimental_distribute_dataset(batched_val_dataset)

    with mirrored_strategy.scope():
        if args.model[:4] == 'bert':
            model = TFBertForSequenceClassification.from_pretrained(
                args.model, num_labels=4)
        elif args.model[:4] == 'robe':
            model = TFRobertaForSequenceClassification.from_pretrained(
                args.model, num_labels=4)
        elif args.model[:5] == 'distil':
            model = TFDistilBertForSequenceClassification.from_pretrained(
                args.model, num_labels=4)
        else:
            model = TFElectraForSequenceClassification.from_pretrained(
                args.model, num_labels=4)

        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        METRICS = [tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')]
        model.compile(optimizer=optimizer, loss=loss, metrics=METRICS)

    # Use an early stopping callback and our timing callback
    early_stop = tf.keras.callbacks.EarlyStopping(verbose=1,
                                                  patience=2,
                                                  min_delta=0.005,
                                                  restore_best_weights=True)

    time_callback = TimeHistory()

    history = model.fit(batched_training_dataset,
                        epochs=args.EPOCHS,
                        validation_data=batched_val_dataset,
                        callbacks=[early_stop, time_callback])

    df = pd.DataFrame(history.history)
    df['times'] = time_callback.times

    df.to_pickle(
        f'{args.model}_BS{args.BATCH_SIZE}_SEQ{args.SEQUENCE_LENGTH}.pkl')
    model.save_pretrained(
        f'./{args.model}_BS{args.BATCH_SIZE}_SEQ{args.SEQUENCE_LENGTH}/')
    def __init__(self):
        models_dir = path.join(path.dirname(__file__), 'models')
        # Download model if it does not exist
        model_location = path.join(models_dir, 'tf_model.h5')
        if not path.isfile(model_location):
            wget.download("https://lemay-images.nyc3.cdn.digitaloceanspaces.com/models/tf_model.h5",out=model_location)
        
        self.model = TFDistilBertForSequenceClassification.from_pretrained(models_dir)
        self.tokenizer = joblib.load(path.join(models_dir,'tokenizer.joblib'))

        self.target_names = ['Animal Services',
                             'Auto',
                             'Community Association',
                             'Computer Services',
                             'Contractor',
                             'Dealer',
                             'Electrical Contractor',
                             'Employment Agency',
                             'Entertainment Services',
                             'Financial Services',
                             'Food',
                             'Health and Beauty',
                             'Hotel',
                             'Instruction',
                             'Janitorial Services',
                             'Jeweller',
                             'Landscape Gardener',
                             'Liquor Establishment',
                             'Money Services',
                             'Moving/Transfer Service',
                             'Painter',
                             'Personal Services',
                             'Plumber and/or Gas Contractor',
                             'Printing Services',
                             'Real Estate Dealer',
                             'Referral Services',
                             'Repair/ Service/Maintenance',
                             'Residential/Commercial',
                             'Roofer',
                             'Scavenging',
                             'Seamstress/Tailor',
                             'Security Services',
                             'Studio',
                             'Travel Agent',
                             'Wholesale  Dealer']

        self.inverse_transform={0: 'Acupuncturist', 1: 'Adult Entertainment Store', 2: 'Animal Clinic/Hospital', 3: 'Animal Services', 4: 'Artist', 5: 'Artist Live/Work Studio', 6: 'Assembly Hall', 7: 'Auctioneer', 8: 'Auto Dealer', 9: 'Auto Detailing', 10: 'Auto Painter & Body Shop', 11: 'Auto Parking Lot/Parkade', 12: 'Auto Repairs', 13: 'Auto Washer', 14: 'Auto Wholesaler', 15: 'Beauty Services', 16: 'Bed and Breakfast', 17: 'Boat Charter Services', 18: 'Booking Agency', 19: 'Boot & Shoe Repairs', 20: 'Business Services', 21: 'Carpet/Upholstery Cleaner', 22: 'Caterer', 23: 'Club', 24: 'Community Association', 25: 'Computer Services', 26: 'Contractor', 27: 'Contractor - Special Trades', 28: 'Cosmetologist', 29: 'Dance Hall', 30: 'Dating Services', 31: 'ESL Instruction', 32: 'Educational', 33: 'Electrical Contractor', 34: 'Electrical-Security Alarm Installation', 35: 'Employment Agency', 36: 'Entertainment Services', 37: 'Equipment Operator', 38: 'Exhibitions/Shows/Concerts', 39: 'Financial Institution', 40: 'Financial Services', 41: 'Fitness Centre', 42: 'Food Processing', 43: 'Gas Contractor', 44: 'Gasoline Station', 45: 'Hair Stylist/Hairdresser', 46: 'Health Services', 47: 'Health and Beauty', 48: 'Home Business', 49: 'Homecraft', 50: 'Hotel', 51: 'Instruction', 52: 'Janitorial Services', 53: 'Jeweller', 54: 'Laboratory', 55: 'Landscape Gardener', 56: 'Late Night Dance Event', 57: 'Laundry', 58: 'Liquor Equipment', 59: 'Liquor Establishment', 60: 'Liquor License Application', 61: 'Liquor Retail Store', 62: 'Locksmith', 63: 'Manufacturer', 64: 'Manufacturer - Food', 65: 'Marina Operator', 66: 'Marine Services', 67: 'Massage Therapist', 68: 'Money Services', 69: 'Moving/Transfer Service', 70: 'Non-profit Housing', 71: 'Office', 72: 'Painter', 73: 'Pawnbroker', 74: 'Personal Care Home', 75: 'Personal Services', 76: 'Pest Control/Exterminator', 77: 'Pet Store', 78: 'Photo Services', 79: 'Photographer', 80: 'Physical Therapist', 81: 'Plumber', 82: 'Plumber & Gas Contractor', 83: 'Plumber & Sprinkler Contractor', 84: 'Plumber Sprinkler & Gas Contractor', 85: 'Postal Rental Agency', 86: 'Power/ Pressure Washing', 87: 'Printing Services', 88: 'Product Assembly', 89: 'Production Company', 90: 'Property Management', 91: 'Psychic/Fortune Teller', 92: 'Real Estate Dealer', 93: 'Recycling Depot', 94: 'Referral Services', 95: 'Rentals', 96: 'Repair/ Service/Maintenance', 97: 'Residential/Commercial', 98: 'Restaurant', 99: 'Retail Dealer', 100: 'Retail Dealer - Food', 101: 'Retail Dealer - Grocery', 102: 'Roofer', 103: 'Rooming House', 104: 'Scavenging', 105: 'School (Business & Trade)', 106: 'School (Private)', 107: 'Seamstress/Tailor', 108: 'Secondary Suite - Permanent', 109: 'Secondhand Dealer', 110: 'Security Services', 111: 'Social Escort Services', 112: 'Soliciting For Charity', 113: 'Sprinkler Contractor', 114: 'Studio', 115: 'Talent Agency', 116: 'Tanning Salon', 117: 'Tattoo Parlour', 118: 'Telecommunications', 119: 'Theatre', 120: 'Therapeutic Touch Technique', 121: 'Travel Agent', 122: 'Venue', 123: 'Warehouse Operator', 124: 'Wholesale  Dealer', 125: 'Wholesale Dealer - Food', 126: 'Window Cleaner'}
        
        self.typeDict= {    "Plumber and/or Gas Contractor":'B2BC',
                            'Dealer':'B2BC',
                            'Auto':'B2BC',
                            'Food' : 'B2BC',
                            'Acupuncturist'			:'B2C',
                            'Adult Entertainment Store'	:'B2C',
                            'Animal Clinic/Hospital'	:'B2C',
                            'Animal Services'		:'B2C',
                            'Artist'			:'B2BC',
                            'Artist Live/Work Studio'	:'B2BC',
                            'Assembly Hall'			:'PUB',
                            'Auctioneer'			:'B2BC',
                            'Auto Dealer'			:'B2BC',
                            'Auto Detailing'		:'B2C',
                            'Auto Painter & Body Shop'	:'B2C',
                            'Auto Parking Lot/Parkade'	:'B2C',
                            'Auto Repairs'			:'B2C',
                            'Auto Washer'			:'B2C',
                            'Auto Wholesaler'		:'B2B',
                            'Beauty Services'		:'B2C',
                            'Bed and Breakfast'		:'B2C',
                            'Boat Charter Services'		:'B2BC',
                            'Booking Agency'		:'B2BC',
                            'Boot & Shoe Repairs'		:'B2C',
                            'Business Services'		:'B2B',
                            'Carpet/Upholstery Cleaner'	:'B2BC',
                            'Caterer'			:'B2BC',
                            'Club'				:'B2BC',
                            'Community Association'		:'PUB',
                            'Computer Services'		:'B2BC',
                            'Contractor'			:'B2BC',
                            'Contractor - Special Trades'	:'B2BC',
                            'Cosmetologist'			:'B2C',
                            'Dance Hall'			:'B2C',
                            'Dating Services'		:'B2C',
                            'ESL Instruction'		:'B2C',
                            'Educational'			:'PUB',
                            'Electrical Contractor'		:'B2BC',
                            'Electrical-Security Alarm Installation' :'B2BC',
                            'Employment Agency'		:'B2BC',
                            'Entertainment Services'	:'B2BC',
                            'Equipment Operator'		:'B2BC',
                            'Exhibitions/Shows/Concerts'	:'B2BC',
                            'Financial Institution'		:'B2BC',
                            'Financial Services'		:'B2BC',
                            'Fitness Centre'		:'B2C',
                            'Food Processing'		:'B2B',
                            'Gas Contractor'		:'B2BC',
                            'Gasoline Station'		:'B2C',
                            'Hair Stylist/Hairdresser'	:'B2C',
                            'Health Services'		:'B2C',
                            'Health and Beauty'		:'B2C',
                            'Home Business'			:'B2BC',
                            'Homecraft'			:'B2B',
                            'Hotel'				:'B2BC',
                            'Instruction'			:'B2BC',
                            'Janitorial Services'		:'B2B',
                            'Jeweller'			:'B2C',
                            'Laboratory'			:'B2B',
                            'Landscape Gardener'		:'B2BC',
                            'Late Night Dance Event'	:'B2C',
                            'Laundry'			:'B2C',
                            'Liquor Equipment'		:'B2B',
                            'Liquor Establishment'		:'B2C',
                            'Liquor License Application'	:'B2C',
                            'Liquor Retail Store'		:'B2C',
                            'Locksmith'			:'B2BC',
                            'Manufacturer'			:'B2B',
                            'Manufacturer - Food'		:'B2B',
                            'Marina Operator'		:'B2BC',
                            'Marine Services'		:'B2BC',
                            'Massage Therapist'		:'B2C',
                            'Money Services'		:'B2BC',
                            'Moving/Transfer Service'	:'B2BC',
                            'Non-profit Housing'		:'PUB',
                            'Office'			:'B2BC',
                            'Painter'			:'B2BC',
                            'Pawnbroker'			:'B2C',
                            'Personal Care Home'		:'B2C',
                            'Personal Services'		:'B2C',
                            'Pest Control/Exterminator'	:'B2BC',
                            'Pet Store'			:'B2C',
                            'Photo Services'		:'B2C',
                            'Photographer'			:'B2BC',
                            'Physical Therapist'		:'B2C',
                            'Plumber'			:'B2BC',
                            'Plumber & Gas Contractor'	:'B2BC',
                            'Plumber & Sprinkler Contractor':'B2BC',
                            'Plumber Sprinkler & Gas Contractor':'B2BC',
                            'Postal Rental Agency'		:'B2BC',
                            'Power/ Pressure Washing'	:'B2BC',
                            'Printing Services'		:'B2BC',
                            'Product Assembly'		:'B2B',
                            'Production Company'		:'B2BC',
                            'Property Management'		:'B2BC',
                            'Psychic/Fortune Teller'	:'B2C',
                            'Real Estate Dealer'		:'B2BC',
                            'Recycling Depot'		:'PUB',
                            'Referral Services'		:'B2BC',
                            'Rentals'			:'B2BC',
                            'Repair/ Service/Maintenance'	:'B2BC',
                            'Residential/Commercial'	:'B2BC',
                            'Restaurant'			:'B2C',
                            'Retail Dealer'			:'B2C',
                            'Retail Dealer - Food'		:'B2C',
                            'Retail Dealer - Grocery'	:'B2C',
                            'Roofer'			:'B2BC',
                            'Rooming House'			:'PUB',
                            'Scavenging'			:'B2BC',
                            'School (Business & Trade)'	:'PUB',
                            'School (Private)'		:'PUB',
                            'Seamstress/Tailor'		:'B2C',
                            'Secondary Suite - Permanent'	:'B2B',
                            'Secondhand Dealer'		:'B2C',
                            'Security Services'		:'B2BC',
                            'Social Escort Services'	:'B2C',
                            'Soliciting For Charity'	:'PUB',
                            'Sprinkler Contractor'		:'B2BC',
                            'Studio'			:'B2BC',
                            'Talent Agency'			:'B2BC',
                            'Tanning Salon'			:'B2C',
                            'Tattoo Parlour'		:'B2C',
                            'Telecommunications'		:'B2BC',
                            'Theatre'			:'B2C',
                            'Therapeutic Touch Technique'	:'B2C',
                            'Travel Agent'			:'B2BC',
                            'Venue'				:'B2BC',
                            'Warehouse Operator'		:'B2B',
                            'Wholesale  Dealer'		:'B2B',
                            'Wholesale Dealer - Food'	:'B2B',
                            'Window Cleaner'		:'B2BC'}
train_dataset = np.array(list(dict(train_encodings).values()))
val_dataset = np.array(list(dict(val_encodings).values()))

BATCH_SIZE = 16

# Create a callback that saves the model's weights every x epochs
checkpoint_path = "distilbert16_ckpt/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1,
    save_weights_only=True)

save_model = True

model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3, return_dict=True)

if save_model:
    optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
    model.compile(optimizer=optimizer, loss=model.compute_loss, metrics=['accuracy'])

    model.fit(
        train_dataset[0],
        np.array(y_list),
        epochs=5,
        batch_size=BATCH_SIZE,
        callbacks=[cp_callback]
        )
else:
    latest = tf.train.latest_checkpoint(checkpoint_dir)
    model.load_weights(latest)
Example #6
0
        # This is required when launching many instances at once...  the urllib request seems to get denied periodically
        successful_download = False
        retries = 0
        while retries < 5 and not successful_download:
            try:
                tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
                config = DistilBertConfig.from_pretrained(
                    "distilbert-base-uncased",
                    num_labels=len(CLASSES),
                    id2label={0: 1, 1: 0, 2: -1},
                    label2id={1: 0, 0: 1, -1: 2},
                )

                transformer_model = TFDistilBertForSequenceClassification.from_pretrained(
                    "distilbert-base-uncased", config=config
                )

                input_ids = tf.keras.layers.Input(shape=(max_seq_length,), name="input_ids", dtype="int32")
                input_mask = tf.keras.layers.Input(shape=(max_seq_length,), name="input_mask", dtype="int32")

                embedding_layer = transformer_model.distilbert(input_ids, attention_mask=input_mask)[0]
                X = tf.keras.layers.Bidirectional(
                    tf.keras.layers.LSTM(50, return_sequences=True, dropout=0.1, recurrent_dropout=0.1)
                )(embedding_layer)
                X = tf.keras.layers.GlobalMaxPool1D()(X)
                X = tf.keras.layers.Dense(50, activation="relu")(X)
                X = tf.keras.layers.Dropout(0.2)(X)
                X = tf.keras.layers.Dense(len(CLASSES), activation="softmax")(X)

                model = tf.keras.Model(inputs=[input_ids, input_mask], outputs=X)
Example #7
0
from transformers import DistilBertTokenizer
from transformers import TFDistilBertForSequenceClassification
import tensorflow as tf
import pandas as pd

loaded_tokenizer = DistilBertTokenizer.from_pretrained('./weight')
loaded_model = TFDistilBertForSequenceClassification.from_pretrained(
    './weight')

# This sentence is in 13861 row in CSV file, The manual annotation is [H:N	H:H]

test_text = "\"Oh, how nicely it is made,\" exclaimed the ladies."

predict_input = loaded_tokenizer.encode(test_text,
                                        truncation=True,
                                        padding=True,
                                        return_tensors="tf")

print("")

output = loaded_model(predict_input)[0]
print(output)
prediction_value = tf.argmax(output, axis=1).numpy()[0]
print(prediction_value)

df = pd.read_csv('EncodedOutput.csv')
emotion = df[df['EncodedValues'] == prediction_value]['Emotion'].values

print(emotion)
Example #8
0
 def _load_huggingface_model(self, path_to_dir):
     self._model = TFDistilBertForSequenceClassification.from_pretrained(
         path_to_dir)
# Load MRPC data
data = tensorflow_datasets.load('glue/mrpc')

# Pick GPU device (only pick 1 GPU)
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

# Load tokenizer, model from pretrained model/vocabulary
bert_tokenizer = BertTokenizer.from_pretrained('mrpc/1')
bert_model = TFBertForSequenceClassification.from_pretrained('mrpc/1')

valid_dataset = glue_convert_examples_to_features(data['validation'], bert_tokenizer, max_length=128, task='mrpc')
valid_dataset = valid_dataset.batch(64)

# Evaluate time for bert_model (bigger model)
start_time = time.time()
results = bert_model.predict(valid_dataset)
execution_time = time.time() - start_time

# Load tokenizer, model from pretrained model/vocabulary
distilbert_tokenizer = DistilBertTokenizer.from_pretrained('mrpc/2')
distilbert_model = TFDistilBertForSequenceClassification.from_pretrained('mrpc/2')

valid_dataset = glue_convert_examples_to_features(data['validation'], distilbert_tokenizer, max_length=128, task='mrpc')
valid_dataset = valid_dataset.batch(64)

# Evaluate time for distilbert_model (bigger model)
start_time = time.time()
results = distilbert_model.predict(valid_dataset)
execution_time = time.time() - start_time
Example #10
0
for index, row in df.iterrows():
    if  row['encoded_cat'] in cat2label:
        pass
    else:    
        cat2label[row['encoded_cat']] = row['label']


# In[6]:


tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
save_directory = save_model

#Loading the model and the tokenizer
loaded_tokenizer = DistilBertTokenizer.from_pretrained(save_directory)
loaded_model = TFDistilBertForSequenceClassification.from_pretrained(save_directory)


# In[9]:


imptr_prod_desc = sys.argv[1]
output_file =  sys.argv[2]
top5_match = sys.argv[3]

#imptr_prod_desc = "C:/Users/Guozhen.Liu/Documents/Conceptant/PREDICT_External/LINES_2018_5000lines.txt"


outfile = open(output_file, "w")
fo = csv.writer(outfile, lineterminator='\n')
outfile2 = open(top5_match, "w")
Example #11
0
 def load_model(self, path_to_model: Text):
     return TFDistilBertForSequenceClassification.from_pretrained(
         path_to_model)