def main(argv=None):
    model = RNN(train_images_dir='data/train/',
                val_images_dir='data/val/',
                test_images_dir='data/test/',
                num_epochs=40,
                train_batch_size=1000,
                val_batch_size=1000,
                test_batch_size=1000,
                height_of_image=28,
                width_of_image=28,
                num_channels=1,
                num_classes=10,
                learning_rate=0.001,
                base_dir='results',
                max_to_keep=2,
                model_name="RNN",
                model='RNN')

    model.create_network()
    model.initialize_network()

    if True:
        model.train_model(1, 1, 1, 2)
    else:
        model.test_model()
示例#2
0
def main(argv=None):
    model = RNN(train_images_dir=FLAGS.train_images_dir,
                val_images_dir=FLAGS.val_images_dir,
                test_images_dir=FLAGS.test_images_dir,
                num_epochs=FLAGS.num_epochs,
                train_batch_size=FLAGS.train_batch_size,
                val_batch_size=FLAGS.val_batch_size,
                test_batch_size=FLAGS.test_batch_size,
                height_of_image=FLAGS.height_of_image,
                width_of_image=FLAGS.width_of_image,
                num_channels=FLAGS.num_channels,
                num_classes=FLAGS.num_classes,
                learning_rate=FLAGS.learning_rate,
                base_dir=FLAGS.base_dir,
                max_to_keep=FLAGS.max_to_keep,
                model_name=FLAGS.model_name,
                keep_prob=FLAGS.keep_prob)

    model.create_network()
    model.initialize_network()

    if FLAGS.train:
        model.train_model(FLAGS.display_step, FLAGS.validation_step,
                          FLAGS.checkpoint_step, FLAGS.summary_step)
    else:
        model.test_model()
示例#3
0
def main(argv=None):
    if FLAGS.mode == 'preprocess':
        prepare_dataset(FLAGS.data_path, FLAGS.subset, FLAGS.split_valid, FLAGS.slice_duration, FLAGS.save_as)
        exit()

    if FLAGS.mode == 'generate_noisy':
        generate_noisy_signal(FLAGS.song_path)
        exit()    

    model = RNN(
        train_dir=FLAGS.data_path + '/train',
        val_dir=FLAGS.data_path + '/valid',
        test_dir=FLAGS.data_path + '/test',
        train_batch_size=FLAGS.train_batch_size,
        valid_batch_size=FLAGS.valid_batch_size,
        test_batch_size=FLAGS.test_batch_size,
        n_inputs=FLAGS.n_inputs,
        seq_length=FLAGS.seq_length,
        num_epochs=FLAGS.num_epochs,
        learning_rate=FLAGS.learning_rate,
        base_dir=FLAGS.base_dir,
        max_to_keep=FLAGS.max_to_keep,
        model_name=FLAGS.model_name
    )

    model.create_network()
    model.initialize_network()

    if FLAGS.mode == 'train':
        model.train_model(FLAGS.display_step, FLAGS.validation_step, FLAGS.checkpoint_step, FLAGS.summary_step)
    elif FLAGS.mode == 'test':
        model.test_model()
    elif FLAGS.mode == 'test_song':
        model.estimate_test_song(FLAGS.noisy_song_path, FLAGS.output_estimated_path)