def run(discriminative_regularization=True): streams = create_celeba_streams(training_batch_size=100, monitoring_batch_size=500, include_targets=False) main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3] # Compute parameter updates for the batch normalization population # statistics. They are updated following an exponential moving average. rval = create_training_computation_graphs(discriminative_regularization) cg, bn_cg, variance_parameters = rval pop_updates = list( set(get_batch_normalization_updates(bn_cg, allow_duplicates=True))) decay_rate = 0.05 extra_updates = [(p, m * decay_rate + p * (1 - decay_rate)) for p, m in pop_updates] model = Model(bn_cg.outputs[0]) selector = Selector( find_bricks( model.top_bricks, lambda brick: brick.name in ('encoder_convnet', 'encoder_mlp', 'decoder_convnet', 'decoder_mlp'))) parameters = list(selector.get_parameters().values()) + variance_parameters # Prepare algorithm step_rule = Adam() algorithm = GradientDescent(cost=bn_cg.outputs[0], parameters=parameters, step_rule=step_rule) algorithm.add_updates(extra_updates) # Prepare monitoring monitored_quantities_list = [] for graph in [bn_cg, cg]: cost, kl_term, reconstruction_term = graph.outputs cost.name = 'nll_upper_bound' avg_kl_term = kl_term.mean(axis=0) avg_kl_term.name = 'avg_kl_term' avg_reconstruction_term = -reconstruction_term.mean(axis=0) avg_reconstruction_term.name = 'avg_reconstruction_term' monitored_quantities_list.append( [cost, avg_kl_term, avg_reconstruction_term]) train_monitoring = DataStreamMonitoring( monitored_quantities_list[0], train_monitor_stream, prefix="train", updates=extra_updates, after_epoch=False, before_first_epoch=False, every_n_epochs=5) valid_monitoring = DataStreamMonitoring( monitored_quantities_list[1], valid_monitor_stream, prefix="valid", after_epoch=False, before_first_epoch=False, every_n_epochs=5) # Prepare checkpoint save_path = 'celeba_vae_{}regularization.zip'.format( '' if discriminative_regularization else 'no_') checkpoint = Checkpoint(save_path, every_n_epochs=5, use_cpickle=True) extensions = [Timing(), FinishAfter(after_n_epochs=75), train_monitoring, valid_monitoring, checkpoint, Printing(), ProgressBar()] main_loop = MainLoop(data_stream=main_loop_stream, algorithm=algorithm, extensions=extensions) main_loop.run()
def test_find_first_level(self): found = set(find_bricks([self.mlp], lambda x: isinstance(x, Sequence))) assert len(found) == 5 assert self.mlp in found found.remove(self.mlp) sequences = set(self.mlp.activations[0:2] + [self.mlp.activations[3], self.mlp.activations[3].children[0]]) assert sequences == found
def test_find_first_level(self): found = set(find_bricks([self.mlp], lambda x: isinstance(x, Sequence))) assert len(found) == 5 assert self.mlp in found found.remove(self.mlp) sequences = set( self.mlp.activations[0:2] + [self.mlp.activations[3], self.mlp.activations[3].children[0]]) assert sequences == found
def training_noise(*bricks): r"""Context manager to run noise layers in "training mode". """ # Avoid circular imports. from blocks.bricks import BatchNormalization bn = find_bricks(bricks, lambda b: isinstance(b, NoiseLayer)) # Can't use either nested() (deprecated) nor ExitStack (not available # on Python 2.7). Well, that sucks. try: for brick in bn: brick.__enter__() yield finally: for brick in bn[::-1]: brick.__exit__()
def test_find_zeroth_level_repeated(self): found = find_bricks([self.mlp, self.mlp], lambda x: isinstance(x, MLP)) assert len(found) == 1 assert found[0] == self.mlp
def batch_normalization(*bricks): r"""Context manager to run batch normalization in "training mode". Parameters ---------- \*bricks One or more bricks which will be inspected for descendant instances of :class:`~blocks.bricks.BatchNormalization`. Notes ----- Graph replacement using :func:`apply_batch_normalization`, while elegant, can lead to Theano graphs that are quite large and result in very slow compiles. This provides an alternative mechanism for building the batch normalized training graph. It can be somewhat less convenient as it requires building the graph twice if one wishes to monitor the output of the inference graph during training. Examples -------- First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. >>> import theano >>> from blocks.bricks import BatchNormalizedMLP, Tanh >>> from blocks.initialization import Constant, IsotropicGaussian >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], ... weights_init=IsotropicGaussian(0.1), ... biases_init=Constant(0)) >>> mlp.initialize() Now, we'll construct an output variable as we would normally. This is getting normalized by the *population* statistics, which by default are initialized to 0 (mean) and 1 (standard deviation), respectively. >>> x = theano.tensor.matrix() >>> y = mlp.apply(x) And now, to construct an output with batch normalization enabled, i.e. normalizing pre-activations using per-minibatch statistics, we simply make a similar call inside of a `with` statement: >>> with batch_normalization(mlp): ... y_bn = mlp.apply(x) Let's verify that these two graphs behave differently on the same data: >>> import numpy >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) >>> inf_y = y.eval({x: data}) >>> trn_y = y_bn.eval({x: data}) >>> numpy.allclose(inf_y, trn_y) False """ bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization)) # Can't use either nested() (deprecated) nor ExitStack (not available # on Python 2.7). Well, that sucks. try: for brick in bn: brick.__enter__() yield finally: for brick in bn[::-1]: brick.__exit__()
def run(discriminative_regularization=True): streams = create_celeba_streams(training_batch_size=100, monitoring_batch_size=500, include_targets=False) main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3] # Compute parameter updates for the batch normalization population # statistics. They are updated following an exponential moving average. rval = create_training_computation_graphs(discriminative_regularization) cg, bn_cg, variance_parameters = rval pop_updates = list( set(get_batch_normalization_updates(bn_cg, allow_duplicates=True))) decay_rate = 0.05 extra_updates = [(p, m * decay_rate + p * (1 - decay_rate)) for p, m in pop_updates] model = Model(bn_cg.outputs[0]) selector = Selector( find_bricks( model.top_bricks, lambda brick: brick.name in ('encoder_convnet', 'encoder_mlp', 'decoder_convnet', 'decoder_mlp' ))) parameters = list(selector.get_parameters().values()) + variance_parameters # Prepare algorithm step_rule = Adam() algorithm = GradientDescent(cost=bn_cg.outputs[0], parameters=parameters, step_rule=step_rule) algorithm.add_updates(extra_updates) # Prepare monitoring monitored_quantities_list = [] for graph in [bn_cg, cg]: cost, kl_term, reconstruction_term = graph.outputs cost.name = 'nll_upper_bound' avg_kl_term = kl_term.mean(axis=0) avg_kl_term.name = 'avg_kl_term' avg_reconstruction_term = -reconstruction_term.mean(axis=0) avg_reconstruction_term.name = 'avg_reconstruction_term' monitored_quantities_list.append( [cost, avg_kl_term, avg_reconstruction_term]) train_monitoring = DataStreamMonitoring(monitored_quantities_list[0], train_monitor_stream, prefix="train", updates=extra_updates, after_epoch=False, before_first_epoch=False, every_n_epochs=5) valid_monitoring = DataStreamMonitoring(monitored_quantities_list[1], valid_monitor_stream, prefix="valid", after_epoch=False, before_first_epoch=False, every_n_epochs=5) # Prepare checkpoint save_path = 'celeba_vae_{}regularization.zip'.format( '' if discriminative_regularization else 'no_') checkpoint = Checkpoint(save_path, every_n_epochs=5, use_cpickle=True) extensions = [ Timing(), FinishAfter(after_n_epochs=75), train_monitoring, valid_monitoring, checkpoint, Printing(), ProgressBar() ] main_loop = MainLoop(data_stream=main_loop_stream, algorithm=algorithm, extensions=extensions) main_loop.run()
def test_find_first_and_second_and_third_level(self): found = set(find_bricks([self.mlp], lambda x: isinstance(x, Logistic))) assert self.mlp.activations[2] in found assert self.mlp.activations[1].children[0] in found assert self.mlp.activations[3].children[0].children[0]
def test_find_second_and_third_level(self): found = set(find_bricks([self.mlp], lambda x: isinstance(x, Identity))) assert len(found) == 2 assert self.mlp.activations[0].children[0] in found assert self.mlp.activations[1].children[1] in found
def run(batch_size, save_path, z_dim, oldmodel, discriminative_regularization, classifier, vintage, monitor_every, monitor_before, checkpoint_every, dataset, color_convert, image_size, net_depth, subdir, reconstruction_factor, kl_factor, discriminative_factor, disc_weights, num_epochs): if dataset: streams = create_custom_streams(filename=dataset, training_batch_size=batch_size, monitoring_batch_size=batch_size, include_targets=False, color_convert=color_convert) else: streams = create_celeba_streams(training_batch_size=batch_size, monitoring_batch_size=batch_size, include_targets=False) main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3] # Compute parameter updates for the batch normalization population # statistics. They are updated following an exponential moving average. rval = create_training_computation_graphs(z_dim, image_size, net_depth, discriminative_regularization, classifier, vintage, reconstruction_factor, kl_factor, discriminative_factor, disc_weights) cg, bn_cg, variance_parameters = rval pop_updates = list( set(get_batch_normalization_updates(bn_cg, allow_duplicates=True))) decay_rate = 0.05 extra_updates = [(p, m * decay_rate + p * (1 - decay_rate)) for p, m in pop_updates] model = Model(bn_cg.outputs[0]) selector = Selector( find_bricks( model.top_bricks, lambda brick: brick.name in ('encoder_convnet', 'encoder_mlp', 'decoder_convnet', 'decoder_mlp' ))) parameters = list(selector.get_parameters().values()) + variance_parameters # Prepare algorithm step_rule = Adam() algorithm = GradientDescent(cost=bn_cg.outputs[0], parameters=parameters, step_rule=step_rule) algorithm.add_updates(extra_updates) # Prepare monitoring sys.setrecursionlimit(1000000) monitored_quantities_list = [] for graph in [bn_cg, cg]: # cost, kl_term, reconstruction_term, discriminative_term = graph.outputs cost, kl_term, reconstruction_term, discriminative_term = graph.outputs[: 4] discriminative_layer_terms = graph.outputs[4:] cost.name = 'nll_upper_bound' avg_kl_term = kl_term.mean(axis=0) avg_kl_term.name = 'avg_kl_term' avg_reconstruction_term = -reconstruction_term.mean(axis=0) avg_reconstruction_term.name = 'avg_reconstruction_term' avg_discriminative_term = discriminative_term.mean(axis=0) avg_discriminative_term.name = 'avg_discriminative_term' num_layer_terms = len(discriminative_layer_terms) avg_discriminative_layer_terms = [None] * num_layer_terms for i, term in enumerate(discriminative_layer_terms): avg_discriminative_layer_terms[i] = discriminative_layer_terms[ i].mean(axis=0) avg_discriminative_layer_terms[ i].name = "avg_discriminative_term_layer_{:02d}".format(i) monitored_quantities_list.append([ cost, avg_kl_term, avg_reconstruction_term, avg_discriminative_term ] + avg_discriminative_layer_terms) train_monitoring = DataStreamMonitoring(monitored_quantities_list[0], train_monitor_stream, prefix="train", updates=extra_updates, after_epoch=False, before_first_epoch=monitor_before, every_n_epochs=monitor_every) valid_monitoring = DataStreamMonitoring(monitored_quantities_list[1], valid_monitor_stream, prefix="valid", after_epoch=False, before_first_epoch=monitor_before, every_n_epochs=monitor_every) # Prepare checkpoint checkpoint = Checkpoint(save_path, every_n_epochs=checkpoint_every, before_training=True, use_cpickle=True) sample_checkpoint = SampleCheckpoint(interface=DiscGenModel, z_dim=z_dim / 2, image_size=(image_size, image_size), channels=3, dataset=dataset, split="valid", save_subdir=subdir, before_training=True, after_epoch=True) # TODO: why does z_dim=foo become foo/2? extensions = [ Timing(), FinishAfter(after_n_epochs=num_epochs), checkpoint, sample_checkpoint, train_monitoring, valid_monitoring, Printing(), ProgressBar() ] main_loop = MainLoop(model=model, data_stream=main_loop_stream, algorithm=algorithm, extensions=extensions) if oldmodel is not None: print("Initializing parameters with old model {}".format(oldmodel)) try: saved_model = load(oldmodel) except AttributeError: # newer version of blocks with open(oldmodel, 'rb') as src: saved_model = load(src) main_loop.model.set_parameter_values( saved_model.model.get_parameter_values()) del saved_model main_loop.run()
def test_find_none(self): found = find_bricks([self.mlp], lambda _: False) assert len(found) == 0
def test_find_all_unique(self): found = find_bricks([self.mlp, self.mlp] + list(self.mlp.children), lambda _: True) assert len(found) == 16 # 12 activations plus 4 linear transformations
def train_extractive_qa(new_training_job, config, save_path, params, fast_start, fuel_server, seed): if seed: logger.debug("Changing Fuel random seed to {}".format(seed)) fuel.config.default_seed = seed logger.debug("Changing Blocks random seed to {}".format(seed)) blocks.config.config.default_seed = seed root_path = os.path.join(save_path, 'training_state') extension = '.tar' tar_path = root_path + extension best_tar_path = root_path + '_best' + extension c = config data, qam = initialize_data_and_model(c) if theano.config.compute_test_value != 'off': test_value_data = next( data.get_stream('train', shuffle=True, batch_size=4, max_length=5).get_epoch_iterator(as_dict=True)) for var in qam.input_vars.values(): var.tag.test_value = test_value_data[var.name] costs = qam.apply_with_default_vars() cost = rename(costs.mean(), 'mean_cost') cg = Model(cost) if params: logger.debug("Load parameters from {}".format(params)) with open(params) as src: cg.set_parameter_values(load_parameters(src)) length = rename(qam.contexts.shape[1], 'length') batch_size = rename(qam.contexts.shape[0], 'batch_size') predicted_begins, = VariableFilter(name='predicted_begins')(cg) predicted_ends, = VariableFilter(name='predicted_ends')(cg) exact_match, = VariableFilter(name='exact_match')(cg) exact_match_ratio = rename(exact_match.mean(), 'exact_match_ratio') context_unk_ratio, = VariableFilter(name='context_unk_ratio')(cg) monitored_vars = [ length, batch_size, cost, exact_match_ratio, context_unk_ratio ] if c['dict_path']: def_unk_ratio, = VariableFilter(name='def_unk_ratio')(cg) num_definitions = rename(qam.input_vars['defs'].shape[0], 'num_definitions') max_definition_length = rename(qam.input_vars['defs'].shape[1], 'max_definition_length') monitored_vars.extend( [def_unk_ratio, num_definitions, max_definition_length]) if c['def_word_gating'] == 'self_attention': def_gates = VariableFilter(name='def_gates')(cg) def_gates_min = tensor.minimum(*[x.min() for x in def_gates]) def_gates_max = tensor.maximum(*[x.max() for x in def_gates]) monitored_vars.extend([ rename(def_gates_min, 'def_gates_min'), rename(def_gates_max, 'def_gates_max') ]) text_match_ratio = TextMatchRatio(data_path=os.path.join( fuel.config.data_path[0], 'squad/dev-v1.1.json'), requires=[ predicted_begins, predicted_ends, tensor.ltensor3('contexts_text'), tensor.lmatrix('q_ids') ], name='text_match_ratio') parameters = cg.get_parameter_dict() trained_parameters = parameters.values() if c['embedding_path']: logger.debug("Exclude word embeddings from the trained parameters") trained_parameters = [ p for p in trained_parameters if not p == qam.embeddings_var() ] if c['train_only_def_part']: def_reading_parameters = qam.def_reading_parameters() trained_parameters = [ p for p in trained_parameters if p in def_reading_parameters ] bricks = find_bricks([qam], lambda brick: isinstance(brick, Initializable)) init_schemes = {} for brick in bricks: brick_name = "/".join([b.name for b in brick.get_unique_path()]) init_schemes[brick_name] = {} for arg, value in brick.__dict__.items(): if arg.endswith('_init'): init_schemes[brick_name][arg] = repr(value) logger.info("Initialization schemes:") logger.info(json.dumps(init_schemes, indent=2)) logger.info("Cost parameters" + "\n" + pprint.pformat([ " ".join( (key, str(parameters[key].get_value().shape), 'trained' if parameters[key] in trained_parameters else 'frozen', str((6 / sum(parameters[key].get_value().shape))**0.5))) for key in sorted(parameters.keys()) ], width=120)) # apply dropout to the training cost and to all the variables # that we monitor during training train_cost = cost train_monitored_vars = list(monitored_vars) if c['dropout']: regularized_cg = ComputationGraph([cost] + train_monitored_vars) # TODO: don't access private attributes bidir_outputs, = VariableFilter(bricks=[qam._bidir], roles=[OUTPUT])(cg) readout_layers = VariableFilter(bricks=[Rectifier], roles=[OUTPUT])(cg) dropout_vars = [bidir_outputs] + readout_layers logger.debug("applying dropout to {}".format(", ".join( [v.name for v in dropout_vars]))) if c['dropout_type'] == 'same_mask': regularized_cg = apply_dropout2(regularized_cg, dropout_vars, c['dropout']) elif c['dropout_type'] == 'regular': regularized_cg = apply_dropout(regularized_cg, dropout_vars, c['dropout']) else: raise ValueError() # a new dropout with exactly same mask at different steps emb_vars = VariableFilter(roles=[EMBEDDINGS])(regularized_cg) emb_dropout_mask = get_dropout_mask(emb_vars[0], c['emb_dropout']) if c['emb_dropout_type'] == 'same_mask': regularized_cg = apply_dropout2(regularized_cg, emb_vars, c['emb_dropout'], dropout_mask=emb_dropout_mask) elif c['emb_dropout_type'] == 'regular': regularized_cg = apply_dropout(regularized_cg, emb_vars, c['emb_dropout']) else: raise ValueError() train_cost = regularized_cg.outputs[0] train_monitored_vars = regularized_cg.outputs[1:] rules = [] if c['grad_clip_threshold']: rules.append(StepClipping(c['grad_clip_threshold'])) rules.append(Adam(learning_rate=c['learning_rate'], beta1=c['momentum'])) algorithm = GradientDescent(cost=train_cost, parameters=trained_parameters, step_rule=CompositeRule(rules)) if c['grad_clip_threshold']: train_monitored_vars.append(algorithm.total_gradient_norm) if c['monitor_parameters']: train_monitored_vars.extend(parameter_stats(parameters, algorithm)) training_stream = data.get_stream('train', batch_size=c['batch_size'], shuffle=True, max_length=c['max_length']) original_training_stream = training_stream if fuel_server: # the port will be configured by the StartFuelServer extension training_stream = ServerDataStream( sources=training_stream.sources, produces_examples=training_stream.produces_examples) extensions = [ LoadNoUnpickling(tar_path, load_iteration_state=True, load_log=True).set_conditions( before_training=not new_training_job), StartFuelServer(original_training_stream, os.path.join(save_path, 'stream.pkl'), before_training=fuel_server), Timing(every_n_batches=c['mon_freq_train']), TrainingDataMonitoring(train_monitored_vars, prefix="train", every_n_batches=c['mon_freq_train']), ] validation = DataStreamMonitoring( [text_match_ratio] + monitored_vars, data.get_stream('dev', batch_size=c['batch_size_valid'], raw_text=True, q_ids=True), prefix="dev").set_conditions(before_training=not fast_start, after_epoch=True) dump_predictions = DumpPredictions(save_path, text_match_ratio, before_training=not fast_start, after_epoch=True) track_the_best_exact = TrackTheBest( validation.record_name(exact_match_ratio), choose_best=max).set_conditions(before_training=True, after_epoch=True) track_the_best_text = TrackTheBest( validation.record_name(text_match_ratio), choose_best=max).set_conditions(before_training=True, after_epoch=True) extensions.extend([ validation, dump_predictions, track_the_best_exact, track_the_best_text ]) # We often use pretrained word embeddings and we don't want # to load and save them every time. To avoid that, we use # save_main_loop=False, we only save the trained parameters, # and we save the log and the iterations state separately # in the tar file. extensions.extend([ Checkpoint(tar_path, parameters=trained_parameters, save_main_loop=False, save_separately=['log', 'iteration_state'], before_training=not fast_start, every_n_epochs=c['save_freq_epochs'], every_n_batches=c['save_freq_batches'], after_training=not fast_start).add_condition( ['after_batch', 'after_epoch'], OnLogRecord(track_the_best_text.notification_name), (best_tar_path, )), DumpTensorflowSummaries(save_path, after_epoch=True, every_n_batches=c['mon_freq_train'], after_training=True), RetrievalPrintStats(retrieval=data._retrieval, every_n_batches=c['mon_freq_train'], before_training=not fast_start), Printing(after_epoch=True, every_n_batches=c['mon_freq_train']), FinishAfter(after_n_batches=c['n_batches'], after_n_epochs=c['n_epochs']), Annealing(c['annealing_learning_rate'], after_n_epochs=c['annealing_start_epoch']), LoadNoUnpickling(best_tar_path, after_n_epochs=c['annealing_start_epoch']) ]) main_loop = MainLoop(algorithm, training_stream, model=Model(cost), extensions=extensions) main_loop.run()
def run(batch_size, save_path, z_dim, oldmodel, discriminative_regularization, classifier, vintage, monitor_every, monitor_before, checkpoint_every, dataset, color_convert, image_size, net_depth, subdir, reconstruction_factor, kl_factor, discriminative_factor, disc_weights, num_epochs): if dataset: streams = create_custom_streams(filename=dataset, training_batch_size=batch_size, monitoring_batch_size=batch_size, include_targets=False, color_convert=color_convert) else: streams = create_celeba_streams(training_batch_size=batch_size, monitoring_batch_size=batch_size, include_targets=False) main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3] # Compute parameter updates for the batch normalization population # statistics. They are updated following an exponential moving average. rval = create_training_computation_graphs( z_dim, image_size, net_depth, discriminative_regularization, classifier, vintage, reconstruction_factor, kl_factor, discriminative_factor, disc_weights) cg, bn_cg, variance_parameters = rval pop_updates = list( set(get_batch_normalization_updates(bn_cg, allow_duplicates=True))) decay_rate = 0.05 extra_updates = [(p, m * decay_rate + p * (1 - decay_rate)) for p, m in pop_updates] model = Model(bn_cg.outputs[0]) selector = Selector( find_bricks( model.top_bricks, lambda brick: brick.name in ('encoder_convnet', 'encoder_mlp', 'decoder_convnet', 'decoder_mlp'))) parameters = list(selector.get_parameters().values()) + variance_parameters # Prepare algorithm step_rule = Adam() algorithm = GradientDescent(cost=bn_cg.outputs[0], parameters=parameters, step_rule=step_rule) algorithm.add_updates(extra_updates) # Prepare monitoring sys.setrecursionlimit(1000000) monitored_quantities_list = [] for graph in [bn_cg, cg]: # cost, kl_term, reconstruction_term, discriminative_term = graph.outputs cost, kl_term, reconstruction_term, discriminative_term = graph.outputs[:4] discriminative_layer_terms = graph.outputs[4:] cost.name = 'nll_upper_bound' avg_kl_term = kl_term.mean(axis=0) avg_kl_term.name = 'avg_kl_term' avg_reconstruction_term = -reconstruction_term.mean(axis=0) avg_reconstruction_term.name = 'avg_reconstruction_term' avg_discriminative_term = discriminative_term.mean(axis=0) avg_discriminative_term.name = 'avg_discriminative_term' num_layer_terms = len(discriminative_layer_terms) avg_discriminative_layer_terms = [None] * num_layer_terms for i, term in enumerate(discriminative_layer_terms): avg_discriminative_layer_terms[i] = discriminative_layer_terms[i].mean(axis=0) avg_discriminative_layer_terms[i].name = "avg_discriminative_term_layer_{:02d}".format(i) monitored_quantities_list.append( [cost, avg_kl_term, avg_reconstruction_term, avg_discriminative_term] + avg_discriminative_layer_terms) train_monitoring = DataStreamMonitoring( monitored_quantities_list[0], train_monitor_stream, prefix="train", updates=extra_updates, after_epoch=False, before_first_epoch=monitor_before, every_n_epochs=monitor_every) valid_monitoring = DataStreamMonitoring( monitored_quantities_list[1], valid_monitor_stream, prefix="valid", after_epoch=False, before_first_epoch=monitor_before, every_n_epochs=monitor_every) # Prepare checkpoint checkpoint = Checkpoint(save_path, every_n_epochs=checkpoint_every, before_training=True, use_cpickle=True) sample_checkpoint = SampleCheckpoint(interface=DiscGenModel, z_dim=z_dim/2, image_size=(image_size, image_size), channels=3, dataset=dataset, split="valid", save_subdir=subdir, before_training=True, after_epoch=True) # TODO: why does z_dim=foo become foo/2? extensions = [Timing(), FinishAfter(after_n_epochs=num_epochs), checkpoint, sample_checkpoint, train_monitoring, valid_monitoring, Printing(), ProgressBar()] main_loop = MainLoop(model=model, data_stream=main_loop_stream, algorithm=algorithm, extensions=extensions) if oldmodel is not None: print("Initializing parameters with old model {}".format(oldmodel)) try: saved_model = load(oldmodel) except AttributeError: # newer version of blocks with open(oldmodel, 'rb') as src: saved_model = load(src) main_loop.model.set_parameter_values( saved_model.model.get_parameter_values()) del saved_model main_loop.run()