Exemple #1
0
plt.legend()
plt.show()

print(df_train['context'].value_counts())
print(df_val['context'].value_counts())

# vectorize the text
# the dataset contains repetition of prompts, that's ok as long as the labels are consistent.
# In NN training, as long as we shuffle the data, we should be fine even if the same data
# is seen multiple times.
vectorizer = TextVectorization(max_tokens=20000,
                               output_sequence_length=sequence_len)
vectorizer.adapt(df_train['prompt'].values)
X_train = vectorizer(df_train['prompt']).numpy()
X_val = vectorizer(df_val['prompt']).numpy()
print(vectorizer.get_config())  # configuration
print(vectorizer.get_vocabulary()[:10])  # first 10 words
save_artifacts({'vectorizer_vocab': vectorizer.get_vocabulary()})
save_artifacts({'vectorizer_config': vectorizer.get_config()})

# encode and convert labels to categorical
le = LabelEncoder()
y_train = to_categorical(le.fit_transform(df_train['context']))
y_val = to_categorical(le.transform(df_val['context']))
print(y_train)
print(y_val)
print(le.classes_)
save_artifacts({'label_encoder': le})

# define our model
vocab_len = len(vectorizer.get_vocabulary())
Exemple #2
0
    model.compile(
        loss=cfg["model"]["loss"],
        optimizer=cfg["model"]["optimizer"],
        metrics=cfg["model"]["metrics"],
    )

    csv_logger = tensorflow.keras.callbacks.CSVLogger(
        out / const.TRAIN_LOG_OUTPUT_FILE, separator="\t"
    )

    if args.dev:
        model.fit(
            train_ds,
            validation_data=dev_ds,
            epochs=cfg["exp"]["epochs"],
            callbacks=[csv_logger],
        )
    else:
        model.fit(
            train_ds, epochs=cfg["exp"]["epochs"], callbacks=[csv_logger],
        )

    # Save configuration, model and embeddings in `out/` folder
    with open(out / const.CFG_OUTPUT_FILE, "w") as file:
        yaml.safe_dump(cfg, file, indent=2)
    model.save(out / "model")
    pickle.dump(
        {"cfg": vlayer.get_config(), "w": vlayer.get_weights()},
        open(out / const.VLAYER_OUTPUT_FILE, "wb"),
    )