def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    cfg = hyperparameters.get_config()

    train_dataset, data_shapes = datasets.get_sequence_dataset(
        data_dir=os.path.join(cfg.data_dir, cfg.train_dir),
        batch_size=cfg.batch_size,
        num_timesteps=cfg.observed_steps + cfg.predicted_steps)

    test_dataset, _ = datasets.get_sequence_dataset(
        data_dir=os.path.join(cfg.data_dir, cfg.test_dir),
        batch_size=cfg.batch_size,
        num_timesteps=cfg.observed_steps + cfg.predicted_steps)

    model = build_model(cfg, data_shapes)
    optimizer = tf.keras.optimizers.Adam(lr=cfg.learning_rate,
                                         clipnorm=cfg.clipnorm)
    model.compile(optimizer)

    model.fit(x=train_dataset,
              steps_per_epoch=cfg.steps_per_epoch,
              epochs=cfg.num_epochs,
              validation_data=test_dataset,
              validation_steps=1)
 def get_dataset(self, batch_size=None, random_offset=True, seed=0):
     return datasets.get_sequence_dataset(data_dir=self.data_dir,
                                          file_glob=self.file_glob,
                                          batch_size=batch_size
                                          or self.batch_size,
                                          num_timesteps=self.num_timesteps,
                                          random_offset=random_offset,
                                          seed=seed)
 def testAutoencoderTrainingLossGoesDown(self):
   """Tests a minimal Keras training loop for the non-dynamic model parts."""
   dataset, data_shapes = datasets.get_sequence_dataset(
       data_dir=self.cfg.train_dir,
       file_glob='acrobot*',
       batch_size=self.cfg.batch_size,
       num_timesteps=self.cfg.observed_steps + self.cfg.predicted_steps,
       random_offset=True)
   autoencoder = Autoencoder(self.cfg, data_shapes)
   optimizer = tf.keras.optimizers.Adam(lr=1e-4)
   autoencoder.compile(optimizer)
   history = autoencoder.fit(dataset, steps_per_epoch=1, epochs=3)
   self.assertLess(history.history['loss'][-1], history.history['loss'][0])