Ejemplo n.º 1
0
 def __init__(self,
              log_dir,
              dataset_path,
              dataset_config,
              start_epoch=0,
              save_best_only=False,
              save_last_only=False,
              save_eval_plot=True,
              period=1):
     self.log_dir = log_dir
     self.start_epoch = start_epoch
     self.save_best_only = save_best_only
     self.save_last_only = save_last_only
     self.save_eval_plot = save_eval_plot
     self.period = period
     self.cloud_run = cst.BUCKET_NAME in log_dir
     if self.cloud_run:
         self.client = storage.Client()
         self.bucket = self.client.get_bucket(
             cst.BUCKET_NAME)  # Only used for saving evaluation plots
     if self.save_eval_plot:
         self.validation_dataset = dp.create_dataset(
             data_dir=dataset_path,
             window_size=dataset_config["window_size"],
             shift=dataset_config["shift"],
             stride=dataset_config["stride"],
             batch_size=dataset_config["batch_size"],
             cycle_length=1,  # Has to be set for plotting
             num_parallel_calls=1,  # Has to be set for plotting
             shuffle=False,  # Has to be set for plotting
             repeat=False)  # Has to be set for plotting
Ejemplo n.º 2
0
 def __init__(self,
              log_dir,
              dataset_path,
              dataset_config,
              start_epoch=0,
              save_best_only=False,
              save_last_only=False,
              save_eval_plot=True,
              period=1):
     self.log_dir = log_dir
     self.start_epoch = start_epoch
     self.save_best_only = save_best_only
     self.save_last_only = save_last_only
     self.save_eval_plot = save_eval_plot
     self.period = period
     if self.save_eval_plot:
         self.validation_dataset = dp.create_dataset(
             data_dir=dataset_path,
             window_size=dataset_config["window_size"],
             shift=dataset_config["shift"],
             stride=dataset_config["stride"],
             batch_size=dataset_config["batch_size"],
             cycle_length=1,  # Has to be set for plotting
             num_parallel_calls=1,  # Has to be set for plotting
             shuffle=False,  # Has to be set for plotting
             repeat=False)  # Has to be set for plotting
Ejemplo n.º 3
0
def calculate_steps_per_epoch(data_dir, dataset_config):
    temp_dataset = dp.create_dataset(data_dir=data_dir,
                                     window_size=dataset_config["window_size"],
                                     shift=dataset_config["shift"],
                                     stride=dataset_config["stride"],
                                     batch_size=dataset_config["batch_size"],
                                     repeat=False)
    steps_per_epoch = 0
    for batch in temp_dataset:
        steps_per_epoch += 1
    return steps_per_epoch
import trainer.constants as cst
from trainer.data_pipeline import create_dataset
from server.constants import NUM_SAMPLES, SAMPLES_DIR
"""Create sample files in json format from test data and save it in the server module.
These can be used by the 'load random sample' button as examples on the website.
"""

samples_fullpath = os.path.join('server', SAMPLES_DIR)

if not os.path.exists(samples_fullpath):
    os.makedirs(samples_fullpath)

#dataset = create_dataset(cst.SECONDARY_TEST_SET,
dataset = create_dataset('./data/tfrecords/train/*tfrecord',
                         window_size=20,
                         shift=1,
                         stride=1,
                         batch_size=1)
rows = dataset.take(NUM_SAMPLES)
for i, row in enumerate(rows):
    sample = {
        key: str(value.numpy().tolist())
        for key, value in row[0].items()
    }
    with open(
            os.path.join(samples_fullpath,
                         'sample_input_{}.json'.format(i + 1)),
            'w') as outfile:
        json.dump(sample, outfile)
print("Created {} sample files in server/static/samples".format(NUM_SAMPLES))
Ejemplo n.º 5
0
def train_and_evaluate(args, tboard_dir, hparams=None):
    """Trains and evaluates the Keras model.
    Uses the Keras model defined in model.py and trains on data loaded and
    preprocessed in data_pipeline.py. Saves the trained model in TensorFlow SavedModel
    format to the path defined in part by the --job-dir argument.
    Args:
    args: dictionary of arguments - see get_args() for details
    """
    # Config datasets for consistent usage
    ds_config = dict(window_size=args.window_size,
                     shift=args.shift,
                     stride=args.stride,
                     batch_size=args.batch_size)
    ds_train_path = args.data_dir_train
    ds_val_path = args.data_dir_validate

    # create model
    if args.model == 'split_model':
        print("Using split model!")
        model = split_model.create_keras_model(
            window_size=ds_config["window_size"],
            loss=args.loss,
            hparams_config=hparams)
    if args.model == 'full_cnn_model':
        print("Using full cnn model!")
        model = full_cnn_model.create_keras_model(
            window_size=ds_config["window_size"],
            loss=args.loss,
            hparams_config=hparams)

    # Calculate steps_per_epoch_train, steps_per_epoch_test
    # This is needed, since for counting repeat has to be false
    steps_per_epoch_train = calculate_steps_per_epoch(data_dir=ds_train_path,
                                                      dataset_config=ds_config)

    steps_per_epoch_validate = calculate_steps_per_epoch(
        data_dir=ds_val_path, dataset_config=ds_config)

    # load datasets
    dataset_train = dp.create_dataset(data_dir=ds_train_path,
                                      window_size=ds_config["window_size"],
                                      shift=ds_config["shift"],
                                      stride=ds_config["stride"],
                                      batch_size=ds_config["batch_size"])

    dataset_validate = dp.create_dataset(data_dir=ds_val_path,
                                         window_size=ds_config["window_size"],
                                         shift=ds_config["shift"],
                                         stride=ds_config["stride"],
                                         batch_size=ds_config["batch_size"])

    # if hparams is passed, we're running a HPO-job
    if hparams:
        checkpoint_callback = CustomCheckpoints(save_last_only=True,
                                                log_dir=tboard_dir,
                                                dataset_path=ds_val_path,
                                                dataset_config=ds_config,
                                                save_eval_plot=False)
    else:
        checkpoint_callback = CustomCheckpoints(save_best_only=True,
                                                start_epoch=args.save_from,
                                                log_dir=tboard_dir,
                                                dataset_path=ds_val_path,
                                                dataset_config=ds_config,
                                                save_eval_plot=False)
    callbacks = [
        tf.keras.callbacks.TensorBoard(
            log_dir=tboard_dir,
            histogram_freq=0,
            write_graph=False,
        ),
        checkpoint_callback,
    ]

    model.summary()

    # train model
    history = model.fit(dataset_train,
                        epochs=args.num_epochs,
                        steps_per_epoch=steps_per_epoch_train,
                        validation_data=dataset_validate,
                        validation_steps=steps_per_epoch_validate,
                        verbose=2,
                        callbacks=callbacks)

    mae_current = min(history.history["val_mae_current_cycle"])
    mae_remaining = min(history.history["val_mae_remaining_cycles"])
    return mae_current, mae_remaining
import tensorflow as tf
import trainer.constants as cst
from trainer.data_pipeline import create_dataset
from server.constants import NUM_SAMPLES, SAMPLES_DIR
"""Create sample files in json format from test data and save it in the server module.
These can be used by the 'load random sample' button as examples on the website.
"""

samples_fullpath = os.path.join('server', SAMPLES_DIR)

if not os.path.exists(samples_fullpath):
    os.makedirs(samples_fullpath)

dataset = create_dataset(cst.SECONDARY_TEST_SET,
                         window_size=20,
                         shift=1,
                         stride=1,
                         batch_size=1)
rows = dataset.take(NUM_SAMPLES)
for i, row in enumerate(rows):
    sample = {
        key: str(value.numpy().tolist())
        for key, value in row[0].items()
    }
    with open(
            os.path.join(samples_fullpath,
                         'sample_input_{}.json'.format(i + 1)),
            'w') as outfile:
        json.dump(sample, outfile)
print("Created {} sample files in server/static/samples".format(NUM_SAMPLES))