Exemplo n.º 1
0
# Sistema que nos permite marcar checkpoints en
# el entrenamiento para que cada cierto tiempo
# se vaya guardando la información y podamos
# continuar entrenamientos o añadir nuevos.
checkpoint_path = "./ckpt"
ckpt = tf.train.Checkpoint(Dcnn=Dcnn)

ckpt_manager = tf.train.CheckpointManager(
    ckpt, checkpoint_path, max_to_keep=5)  # Guarda sólo los cinco últimos

# Si hay checkpoint, restauro
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('Último checkpoint restaurado!!')

Dcnn.fit(train_inputs, train_labels, batch_size=BATCH_SIZE,
         epochs=NB_EPOCHS)  # pasará por cada tweet cinco veces.
ckpt_manager.save()

# FASE DE EVALUACIÓN:
print('Evaluación')

results = Dcnn.evaluate(test_inputs, test_labels, batch_size=BATCH_SIZE)
print(results)

# Hacemos una predicción:
Dcnn(np.array([tokenizer.encode("I hate you")]), training=False).numpy()
# Devuelve la probabilidad de que sea positivo
tokenizer.encode("bad")