def __init__(self, log_dir, dataset_path, dataset_config, start_epoch=0, save_best_only=False, save_last_only=False, save_eval_plot=True, period=1): self.log_dir = log_dir self.start_epoch = start_epoch self.save_best_only = save_best_only self.save_last_only = save_last_only self.save_eval_plot = save_eval_plot self.period = period self.cloud_run = cst.BUCKET_NAME in log_dir if self.cloud_run: self.client = storage.Client() self.bucket = self.client.get_bucket( cst.BUCKET_NAME) # Only used for saving evaluation plots if self.save_eval_plot: self.validation_dataset = dp.create_dataset( data_dir=dataset_path, window_size=dataset_config["window_size"], shift=dataset_config["shift"], stride=dataset_config["stride"], batch_size=dataset_config["batch_size"], cycle_length=1, # Has to be set for plotting num_parallel_calls=1, # Has to be set for plotting shuffle=False, # Has to be set for plotting repeat=False) # Has to be set for plotting
def __init__(self, log_dir, dataset_path, dataset_config, start_epoch=0, save_best_only=False, save_last_only=False, save_eval_plot=True, period=1): self.log_dir = log_dir self.start_epoch = start_epoch self.save_best_only = save_best_only self.save_last_only = save_last_only self.save_eval_plot = save_eval_plot self.period = period if self.save_eval_plot: self.validation_dataset = dp.create_dataset( data_dir=dataset_path, window_size=dataset_config["window_size"], shift=dataset_config["shift"], stride=dataset_config["stride"], batch_size=dataset_config["batch_size"], cycle_length=1, # Has to be set for plotting num_parallel_calls=1, # Has to be set for plotting shuffle=False, # Has to be set for plotting repeat=False) # Has to be set for plotting
def calculate_steps_per_epoch(data_dir, dataset_config): temp_dataset = dp.create_dataset(data_dir=data_dir, window_size=dataset_config["window_size"], shift=dataset_config["shift"], stride=dataset_config["stride"], batch_size=dataset_config["batch_size"], repeat=False) steps_per_epoch = 0 for batch in temp_dataset: steps_per_epoch += 1 return steps_per_epoch
import trainer.constants as cst from trainer.data_pipeline import create_dataset from server.constants import NUM_SAMPLES, SAMPLES_DIR """Create sample files in json format from test data and save it in the server module. These can be used by the 'load random sample' button as examples on the website. """ samples_fullpath = os.path.join('server', SAMPLES_DIR) if not os.path.exists(samples_fullpath): os.makedirs(samples_fullpath) #dataset = create_dataset(cst.SECONDARY_TEST_SET, dataset = create_dataset('./data/tfrecords/train/*tfrecord', window_size=20, shift=1, stride=1, batch_size=1) rows = dataset.take(NUM_SAMPLES) for i, row in enumerate(rows): sample = { key: str(value.numpy().tolist()) for key, value in row[0].items() } with open( os.path.join(samples_fullpath, 'sample_input_{}.json'.format(i + 1)), 'w') as outfile: json.dump(sample, outfile) print("Created {} sample files in server/static/samples".format(NUM_SAMPLES))
def train_and_evaluate(args, tboard_dir, hparams=None): """Trains and evaluates the Keras model. Uses the Keras model defined in model.py and trains on data loaded and preprocessed in data_pipeline.py. Saves the trained model in TensorFlow SavedModel format to the path defined in part by the --job-dir argument. Args: args: dictionary of arguments - see get_args() for details """ # Config datasets for consistent usage ds_config = dict(window_size=args.window_size, shift=args.shift, stride=args.stride, batch_size=args.batch_size) ds_train_path = args.data_dir_train ds_val_path = args.data_dir_validate # create model if args.model == 'split_model': print("Using split model!") model = split_model.create_keras_model( window_size=ds_config["window_size"], loss=args.loss, hparams_config=hparams) if args.model == 'full_cnn_model': print("Using full cnn model!") model = full_cnn_model.create_keras_model( window_size=ds_config["window_size"], loss=args.loss, hparams_config=hparams) # Calculate steps_per_epoch_train, steps_per_epoch_test # This is needed, since for counting repeat has to be false steps_per_epoch_train = calculate_steps_per_epoch(data_dir=ds_train_path, dataset_config=ds_config) steps_per_epoch_validate = calculate_steps_per_epoch( data_dir=ds_val_path, dataset_config=ds_config) # load datasets dataset_train = dp.create_dataset(data_dir=ds_train_path, window_size=ds_config["window_size"], shift=ds_config["shift"], stride=ds_config["stride"], batch_size=ds_config["batch_size"]) dataset_validate = dp.create_dataset(data_dir=ds_val_path, window_size=ds_config["window_size"], shift=ds_config["shift"], stride=ds_config["stride"], batch_size=ds_config["batch_size"]) # if hparams is passed, we're running a HPO-job if hparams: checkpoint_callback = CustomCheckpoints(save_last_only=True, log_dir=tboard_dir, dataset_path=ds_val_path, dataset_config=ds_config, save_eval_plot=False) else: checkpoint_callback = CustomCheckpoints(save_best_only=True, start_epoch=args.save_from, log_dir=tboard_dir, dataset_path=ds_val_path, dataset_config=ds_config, save_eval_plot=False) callbacks = [ tf.keras.callbacks.TensorBoard( log_dir=tboard_dir, histogram_freq=0, write_graph=False, ), checkpoint_callback, ] model.summary() # train model history = model.fit(dataset_train, epochs=args.num_epochs, steps_per_epoch=steps_per_epoch_train, validation_data=dataset_validate, validation_steps=steps_per_epoch_validate, verbose=2, callbacks=callbacks) mae_current = min(history.history["val_mae_current_cycle"]) mae_remaining = min(history.history["val_mae_remaining_cycles"]) return mae_current, mae_remaining
import tensorflow as tf import trainer.constants as cst from trainer.data_pipeline import create_dataset from server.constants import NUM_SAMPLES, SAMPLES_DIR """Create sample files in json format from test data and save it in the server module. These can be used by the 'load random sample' button as examples on the website. """ samples_fullpath = os.path.join('server', SAMPLES_DIR) if not os.path.exists(samples_fullpath): os.makedirs(samples_fullpath) dataset = create_dataset(cst.SECONDARY_TEST_SET, window_size=20, shift=1, stride=1, batch_size=1) rows = dataset.take(NUM_SAMPLES) for i, row in enumerate(rows): sample = { key: str(value.numpy().tolist()) for key, value in row[0].items() } with open( os.path.join(samples_fullpath, 'sample_input_{}.json'.format(i + 1)), 'w') as outfile: json.dump(sample, outfile) print("Created {} sample files in server/static/samples".format(NUM_SAMPLES))