示例#1
0
    def testRunTraining(self):

        training.run_training(
            dataset='mnist',
            output_type='bernoulli',
            n_y=10,
            n_y_active=1,
            training_data_type='sequential',
            n_concurrent_classes=1,
            lr_init=1e-3,
            lr_factor=1.,
            lr_schedule=[1],
            blend_classes=False,
            train_supervised=False,
            n_steps=1000,
            report_interval=1000,
            knn_values=[3],
            random_seed=1,
            encoder_kwargs={
                'encoder_type': 'multi',
                'n_enc': [1200, 600, 300, 150],
                'enc_strides': [1],
            },
            decoder_kwargs={
                'decoder_type': 'single',
                'n_dec': [500, 500],
                'dec_up_strides': None,
            },
            n_z=32,
            dynamic_expansion=True,
            ll_thresh=-200.0,
            classify_with_samples=False,
            gen_replay_type='fixed',
            use_supervised_replay=False,
        )
示例#2
0
def main(unused_argv):
    if FLAGS.dataset == 'mnist':
        n_y = 25
        n_y_active = 25
        n_z = 50
    else:  # omniglot
        n_y = 100
        n_y_active = 100
        n_z = 100

    training.run_training(
        dataset=FLAGS.dataset,
        n_y=n_y,
        n_y_active=n_y_active,
        n_z=n_z,
        output_type='bernoulli',
        training_data_type='iid',
        n_concurrent_classes=1,
        lr_init=5e-4,
        lr_factor=1.,
        lr_schedule=[1],
        blend_classes=False,
        train_supervised=False,
        n_steps=100000,
        report_interval=10000,
        knn_values=[3, 5, 10],
        random_seed=1,
        encoder_kwargs={
            'encoder_type': 'multi',
            'n_enc': [500, 500],
            'enc_strides': [1],
        },
        decoder_kwargs={
            'decoder_type': 'single',
            'n_dec': [500],
            'dec_up_strides': None,
        },
        dynamic_expansion=False,
        ll_thresh=-0.0,
        classify_with_samples=True,
        gen_replay_type=None,
        use_supervised_replay=False,
    )