def main(argv=None): model = RNN(train_images_dir='data/train/', val_images_dir='data/val/', test_images_dir='data/test/', num_epochs=40, train_batch_size=1000, val_batch_size=1000, test_batch_size=1000, height_of_image=28, width_of_image=28, num_channels=1, num_classes=10, learning_rate=0.001, base_dir='results', max_to_keep=2, model_name="RNN", model='RNN') model.create_network() model.initialize_network() if True: model.train_model(1, 1, 1, 2) else: model.test_model()
def main(argv=None): model = RNN(train_images_dir=FLAGS.train_images_dir, val_images_dir=FLAGS.val_images_dir, test_images_dir=FLAGS.test_images_dir, num_epochs=FLAGS.num_epochs, train_batch_size=FLAGS.train_batch_size, val_batch_size=FLAGS.val_batch_size, test_batch_size=FLAGS.test_batch_size, height_of_image=FLAGS.height_of_image, width_of_image=FLAGS.width_of_image, num_channels=FLAGS.num_channels, num_classes=FLAGS.num_classes, learning_rate=FLAGS.learning_rate, base_dir=FLAGS.base_dir, max_to_keep=FLAGS.max_to_keep, model_name=FLAGS.model_name, keep_prob=FLAGS.keep_prob) model.create_network() model.initialize_network() if FLAGS.train: model.train_model(FLAGS.display_step, FLAGS.validation_step, FLAGS.checkpoint_step, FLAGS.summary_step) else: model.test_model()
def main(argv=None): if FLAGS.mode == 'preprocess': prepare_dataset(FLAGS.data_path, FLAGS.subset, FLAGS.split_valid, FLAGS.slice_duration, FLAGS.save_as) exit() if FLAGS.mode == 'generate_noisy': generate_noisy_signal(FLAGS.song_path) exit() model = RNN( train_dir=FLAGS.data_path + '/train', val_dir=FLAGS.data_path + '/valid', test_dir=FLAGS.data_path + '/test', train_batch_size=FLAGS.train_batch_size, valid_batch_size=FLAGS.valid_batch_size, test_batch_size=FLAGS.test_batch_size, n_inputs=FLAGS.n_inputs, seq_length=FLAGS.seq_length, num_epochs=FLAGS.num_epochs, learning_rate=FLAGS.learning_rate, base_dir=FLAGS.base_dir, max_to_keep=FLAGS.max_to_keep, model_name=FLAGS.model_name ) model.create_network() model.initialize_network() if FLAGS.mode == 'train': model.train_model(FLAGS.display_step, FLAGS.validation_step, FLAGS.checkpoint_step, FLAGS.summary_step) elif FLAGS.mode == 'test': model.test_model() elif FLAGS.mode == 'test_song': model.estimate_test_song(FLAGS.noisy_song_path, FLAGS.output_estimated_path)