def testRunTraining(self): training.run_training( dataset='mnist', output_type='bernoulli', n_y=10, n_y_active=1, training_data_type='sequential', n_concurrent_classes=1, lr_init=1e-3, lr_factor=1., lr_schedule=[1], blend_classes=False, train_supervised=False, n_steps=1000, report_interval=1000, knn_values=[3], random_seed=1, encoder_kwargs={ 'encoder_type': 'multi', 'n_enc': [1200, 600, 300, 150], 'enc_strides': [1], }, decoder_kwargs={ 'decoder_type': 'single', 'n_dec': [500, 500], 'dec_up_strides': None, }, n_z=32, dynamic_expansion=True, ll_thresh=-200.0, classify_with_samples=False, gen_replay_type='fixed', use_supervised_replay=False, )
def main(unused_argv): if FLAGS.dataset == 'mnist': n_y = 25 n_y_active = 25 n_z = 50 else: # omniglot n_y = 100 n_y_active = 100 n_z = 100 training.run_training( dataset=FLAGS.dataset, n_y=n_y, n_y_active=n_y_active, n_z=n_z, output_type='bernoulli', training_data_type='iid', n_concurrent_classes=1, lr_init=5e-4, lr_factor=1., lr_schedule=[1], blend_classes=False, train_supervised=False, n_steps=100000, report_interval=10000, knn_values=[3, 5, 10], random_seed=1, encoder_kwargs={ 'encoder_type': 'multi', 'n_enc': [500, 500], 'enc_strides': [1], }, decoder_kwargs={ 'decoder_type': 'single', 'n_dec': [500], 'dec_up_strides': None, }, dynamic_expansion=False, ll_thresh=-0.0, classify_with_samples=True, gen_replay_type=None, use_supervised_replay=False, )