def create_main_loop(): model, bn_model, bn_updates = create_models() ali, = bn_model.top_bricks discriminator_loss, generator_loss = bn_model.outputs step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1) algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters, step_rule, generator_loss, ali.generator_parameters, step_rule) algorithm.add_updates(bn_updates) streams = create_gaussian_mixture_data_streams( batch_size=BATCH_SIZE, monitoring_batch_size=MONITORING_BATCH_SIZE, means=MEANS, variances=VARIANCES, priors=PRIORS) main_loop_stream, train_monitor_stream, valid_monitor_stream = streams bn_monitored_variables = ([ v for v in bn_model.auxiliary_variables if 'norm' not in v.name ] + bn_model.outputs) monitored_variables = ( [v for v in model.auxiliary_variables if 'norm' not in v.name] + model.outputs) extensions = [ Timing(), FinishAfter(after_n_epochs=NUM_EPOCHS), DataStreamMonitoring(bn_monitored_variables, train_monitor_stream, prefix="train", updates=bn_updates), DataStreamMonitoring(monitored_variables, valid_monitor_stream, prefix="valid"), Checkpoint(os.path.join(self._work_dir, "main_loop.tar"), after_epoch=True, after_training=True, use_cpickle=True), ProgressBar(), Printing(), #ModelLogger(folder=self._work_dir, after_epoch=True), GraphLogger(num_modes=1, num_samples=2500, dimension=2, r=0, std=1, folder=self._work_dir, after_epoch=True, after_training=True), MetricLogger(means=MEANS, variances=VARIANCES, folder=self._work_dir, after_epoch=True) ] main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, algorithm=algorithm, extensions=extensions) return main_loop
def create_main_loop(save_path): model, bn_model, bn_updates = create_models() gan, = bn_model.top_bricks discriminator_loss, generator_loss = bn_model.outputs step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1) algorithm = ali_algorithm(discriminator_loss, gan.discriminator_parameters, step_rule, generator_loss, gan.generator_parameters, step_rule) algorithm.add_updates(bn_updates) streams = create_gaussian_mixture_data_streams( batch_size=BATCH_SIZE, monitoring_batch_size=MONITORING_BATCH_SIZE, means=MEANS, variances=VARIANCES, priors=PRIORS) main_loop_stream, train_monitor_stream, valid_monitor_stream = streams bn_monitored_variables = ( [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] + bn_model.outputs) monitored_variables = ( [v for v in model.auxiliary_variables if 'norm' not in v.name] + model.outputs) extensions = [ Timing(), FinishAfter(after_n_epochs=NUM_EPOCHS), DataStreamMonitoring( bn_monitored_variables, train_monitor_stream, prefix="train", updates=bn_updates), DataStreamMonitoring( monitored_variables, valid_monitor_stream, prefix="valid"), Checkpoint(save_path, after_epoch=True, after_training=True, use_cpickle=True), ProgressBar(), Printing(), ] main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, algorithm=algorithm, extensions=extensions) return main_loop
def get_data(main_loop, n_points=1000): means = main_loop.data_stream.dataset.means variances = main_loop.data_stream.dataset.variances priors = main_loop.data_stream.dataset.priors _, _, stream = create_gaussian_mixture_data_streams(n_points, n_points, sources=('features', 'label'), means=means, variances=variances, priors=priors) originals, labels = next(stream.get_epoch_iterator()) return {'originals': originals, 'labels': labels}