def test_session_restart_from_end_checkpoint(tmpdir, device_id): device = cntk_device(device_id) writer = MockProgressWriter() t, feature, label = create_sample_model(device, writer) mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) input_map = {feature: mbs.streams.features, label: mbs.streams.labels} test_dir = str(tmpdir) C.training_session( trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, progress_frequency=20, checkpoint_config=C.CheckpointConfig( frequency=20, filename=str(tmpdir / "restart_from_checkpoint"))).train(device) candidates = [ f for f in listdir(test_dir) if isfile(join(test_dir, f)) and f.startswith("restart_from_checkpoint") ] assert (len(candidates) == 2) assert ("restart_from_checkpoint" in candidates) assert ("restart_from_checkpoint" in candidates) # remove information from the mock printer writer.minibatch_info = [] writer.training_summary_counter = 0 writer.testing_summary_counter = 0 # restoring from a particular checkpoint should not cause any training mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) C.training_session( trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, progress_frequency=20, checkpoint_config=C.CheckpointConfig( frequency=35, restore=True, filename=str(tmpdir / "restart_from_checkpoint"))).train(device) assert (len(writer.minibatch_info) == 0) assert (writer.training_summary_counter == 0) assert (writer.testing_summary_counter == 0)
def train_and_test(network, trainer, train_source, test_source, minibatch_size, epoch_size, restore, model_path=_MODEL_PATH, cv_config=None): """ Train and test """ # define mapping from intput streams to network inputs input_map = { network['feature']: train_source.streams.features, network['label']: train_source.streams.labels } cntk.training_session(trainer=trainer, mb_source=train_source, mb_size=minibatch_size, model_inputs_to_streams=input_map, checkpoint_config=cntk.CheckpointConfig( filename=os.path.join(model_path, _MODEL_NAME), restore=restore), progress_frequency=epoch_size, cv_config=cv_config).train()
def test_session_cross_validation_3_times_checkpoints_2_save_all( tmpdir, device_id): device = cntk_device(device_id) writer = MockProgressWriter( expected_test_summary=[[92, 25], [92, 25], [92, 25]]) t, feature, label = create_sample_model(device, writer) mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) mbs1 = mb_source(tmpdir, "cv") input_map = {feature: mbs.streams.features, label: mbs.streams.labels} test_dir = str(tmpdir) C.training_session(trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, checkpoint_config=C.CheckpointConfig( frequency=35, preserve_all=True, filename=str(tmpdir / "checkpoint_save_all")), cv_config=C.CrossValidationConfig( mbs1, frequency=20)).train(device) candidates = [ f for f in listdir(test_dir) if isfile(join(test_dir, f)) and f.startswith("checkpoint_save_all") ] assert ("checkpoint_save_all0" in candidates) assert ("checkpoint_save_all0.ckp" in candidates) assert ("checkpoint_save_all1" in candidates) assert ("checkpoint_save_all1.ckp" in candidates) assert ("checkpoint_save_all" in candidates) assert ("checkpoint_save_all.ckp" in candidates) assert (writer.test_summary_counter == 3)
def test_session_restart_from_checkpoint_preserve_all(tmpdir, device_id): device = cntk_device(device_id) writer = MockProgressWriter() t, feature, label = create_sample_model(device, writer) mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) input_map = {feature: mbs.streams.features, label: mbs.streams.labels} test_dir = str(tmpdir) C.training_session( trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, progress_frequency=20, checkpoint_config=C.CheckpointConfig( frequency=20, preserve_all=True, filename=str(tmpdir / "restart_from_checkpoint"))).train(device) candidates = [ f for f in listdir(test_dir) if isfile(join(test_dir, f)) and f.startswith("restart_from_checkpoint") ] assert ("restart_from_checkpoint0" in candidates) assert ("restart_from_checkpoint0.ckp" in candidates) assert ("restart_from_checkpoint1" in candidates) assert ("restart_from_checkpoint1.ckp" in candidates) assert ("restart_from_checkpoint2" in candidates) assert ("restart_from_checkpoint2.ckp" in candidates) assert ("restart_from_checkpoint" in candidates) assert ("restart_from_checkpoint" in candidates) # remove everything except for 1 for f in candidates: if f != "restart_from_checkpoint1" and f != "restart_from_checkpoint1.ckp": os.remove(str(tmpdir / f)) # remove information about 1 and 2 epoch from the mock printer first_run_minibatch_info = [ i for i in writer.minibatch_info if i[0] != 0 and i[0] != 1 ] writer.minibatch_info = [] writer.training_summary_counter = 2 # restoring from a particular checkpoint and again save everything from the 3 epoch mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) C.training_session( trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, progress_frequency=20, checkpoint_config=C.CheckpointConfig( frequency=20, restore=True, preserve_all=True, filename=str(tmpdir / "restart_from_checkpoint"))).train(device) candidates = [ f for f in listdir(test_dir) if isfile(join(test_dir, f)) and f.startswith("restart_from_checkpoint") ] assert ("restart_from_checkpoint1" in candidates) assert ("restart_from_checkpoint1.ckp" in candidates) assert ("restart_from_checkpoint2" in candidates) assert ("restart_from_checkpoint2.ckp" in candidates) assert ("restart_from_checkpoint" in candidates) assert ("restart_from_checkpoint.ckp" in candidates) assert (len(candidates) == 6) assert (first_run_minibatch_info == writer.minibatch_info) # remove everything except for 1 for f in candidates: if f != "restart_from_checkpoint1" and f != "restart_from_checkpoint1.ckp": os.remove(str(tmpdir / f)) # remove information about 1 and 2 epoch from the mock printer writer.minibatch_info = [] writer.training_summary_counter = 2 # renaming checkpoint 1 to generic one os.rename(str(tmpdir / "restart_from_checkpoint1"), str(tmpdir / "restart_from_checkpoint")) os.rename(str(tmpdir / "restart_from_checkpoint1.ckp"), str(tmpdir / "restart_from_checkpoint.ckp")) # restoring from a particular checkpoint and again save everything from the 3 epoch mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) C.training_session( trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, progress_frequency=20, checkpoint_config=C.CheckpointConfig( frequency=20, restore=True, preserve_all=True, filename=str(tmpdir / "restart_from_checkpoint"))).train(device) candidates = [ f for f in listdir(test_dir) if isfile(join(test_dir, f)) and f.startswith("restart_from_checkpoint") ] assert ("restart_from_checkpoint2" in candidates) assert ("restart_from_checkpoint2.ckp" in candidates) assert ("restart_from_checkpoint" in candidates) assert ("restart_from_checkpoint.ckp" in candidates) assert (len(candidates) == 4) assert (first_run_minibatch_info == writer.minibatch_info)
epoch_size=epoch_size) # Instantiate the trainer object to drive the model training. learner = C.learners.momentum_sgd(model.parameters, lr_schedule, mm_schedule) # Configure trainer callbacks. This is the main point that this sample illustrates. # Trainer callbacks are the mechanism via which logging, check-pointing, learning-rate # adjustment, early stopping, and final testing are configured. # Callback for progress logging loss and metric at the end of each epoch. progress_writer = C.logging.ProgressPrinter() # Callback for checkpointing. This will save a model every 'epoch_size' samples. # Change 'restore' to True to have training start from a prior checkpoint file if available. checkpoint_callback_config = C.CheckpointConfig(model_path, epoch_size, restore=False) # Callback for cross-validation. # The cross-validation callback mechanism allows you to implement your own # learning-rate control and early stopping. # The following implements a simple callback that halves the learning rate if the # metric has not improved by at least 5% relative. The cross-validation callback # gets configured to call this every 3*epoch_size samples, i.e. only every 3rd epoch. prev_metric = 1 # metric from previous call to the callback. At very beginning, error rate is 100%. def adjust_lr_callback(index, average_error, cv_num_samples, cv_num_minibatches): global prev_metric if (
def convnet_mnist(data_path, model_path, max_epochs=40, model_suffix=None, hidden_layers_dim=96, feedforward_const=0.0039, log_dir=None, tensorboard_logdir=None, debug_output=False): image_height = 28 image_width = 28 num_channels = 1 input_dim = image_height * image_width * num_channels num_output_classes = 10 # Input variables denoting the features and label data input_var = cntk.ops.input((num_channels, image_height, image_width), np.float32) label_var = cntk.ops.input(num_output_classes, np.float32) # Instantiate the feedforward classification model scaled_input = cntk.ops.element_times(cntk.ops.constant(feedforward_const), input_var) with cntk.layers.default_options(activation=cntk.ops.relu, pad=False): conv1 = cntk.layers.Convolution2D((5, 5), 32, pad=True)(scaled_input) pool1 = cntk.layers.MaxPooling((3, 3), (2, 2))(conv1) conv2 = cntk.layers.Convolution2D((3, 3), 48)(pool1) pool2 = cntk.layers.MaxPooling((3, 3), (2, 2))(conv2) conv3 = cntk.layers.Convolution2D((3, 3), 64)(pool2) f4 = cntk.layers.Dense(hidden_layers_dim)(conv3) drop4 = cntk.layers.Dropout(0.5)(f4) z = cntk.layers.Dense(num_output_classes, activation=None)(drop4) ce = cntk.losses.cross_entropy_with_softmax(z, label_var) pe = cntk.metrics.classification_error(z, label_var) reader_train = create_reader( os.path.join(data_path, 'Train-28x28_cntk_text.txt'), True, input_dim, num_output_classes) # training config epoch_size = 60000 # for now we manually specify epoch size minibatch_size = 64 # Set learning parameters lr_per_sample = [0.001] * 10 + [0.0005] * 10 + [0.0001] lr_schedule = cntk.learning_rate_schedule(lr_per_sample, cntk.learners.UnitType.sample, epoch_size) mm_time_constant = [0] * 5 + [1024] mm_schedule = cntk.learners.momentum_as_time_constant_schedule( mm_time_constant, epoch_size) # Instantiate the trainer object to drive the model training learner = cntk.learners.momentum_sgd(z.parameters, lr_schedule, mm_schedule) progress_writers = [ cntk.logging.ProgressPrinter( # freq=training_progress_output_freq, tag='Training', log_to_file=log_dir, num_epochs=max_epochs) ] if tensorboard_logdir is not None: progress_writers.append( cntk.logging.TensorBoardProgressWriter(freq=10, log_dir=tensorboard_logdir, model=z)) trainer = cntk.Trainer(z, (ce, pe), learner, progress_writers) model_name = "MNIST_Model" if model_suffix is not None: model_name += ("_" + model_suffix) # define mapping from reader streams to network inputs input_map = { input_var: reader_train.streams.features, label_var: reader_train.streams.labels } cntk.training_session(trainer=trainer, mb_source=reader_train, mb_size=minibatch_size, model_inputs_to_streams=input_map, max_samples=epoch_size * max_epochs, checkpoint_config=cntk.CheckpointConfig( frequency=epoch_size, filename=os.path.join(model_path, model_name), restore=True), progress_frequency=epoch_size).train()
def train_and_test(self, reader_train, reader_test, reader_cv, restore_checkpoint=True): ''' Train the model and validate the results Args: reader_train (:class:`~cntk.io.MinibatchSource`): the dataset reader for training. reader_test (:class:`~cntk.io.MinibatchSource`): the dataset reader for evaluation. restore_checkpoint (bool, optional): Continue training form latest checkpoint if True (default) Returns: None ''' from CapsNet import CapsNet self.input = ct.input_variable(self.input_dim_model, name='MINST_Input') self.labels = ct.input_variable(self.output_dim_model, name='MINST_Labels') self.perturbations = ct.input_variable(self.perturbations_dim, name='Perturbations') self.caps_net = CapsNet(self.input / 255., self.labels, routings=3, use_reconstruction=True) # models self.training_model, self.digitcaps_model, self.prediction_model, self.reconstruction_model = self.caps_net.models( ) self.manipulation_model = self.caps_net.manipulation( self.perturbations) # loss & error loss, error = self.caps_net.criterion() # Number of parameters in the network # 5. Capsules on MNIST "... CapsNet has 8.2M parameters and 6.8M parameters without the reconstruction subnetwork." num_parameters, num_tensors = get_number_of_parameters( self.training_model) print( "DigitCaps contains {} learneable parameters in {} parameter tensors." .format(num_parameters, num_tensors)) # Initialize the parameters for the trainer minibatch_size = 128 num_samples_per_sweep = 60000 num_sweeps_to_train_with = 30 # Report & Checkpoint frequency print_frequency = (4, ct.DataUnit.minibatch) checkpoint_frequency = (100, ct.DataUnit.minibatch) cross_validation_frequency = (40, ct.DataUnit.minibatch) tensorboard_logdir = './tensorboard' # Map the data streams to the input and labels. self.input_map = { self.labels: reader_train.streams.labels, self.input: reader_train.streams.features } self.test_input_map = { self.labels: reader_test.streams.labels, self.input: reader_test.streams.features } self.cv_input_map = { self.labels: reader_cv.streams.labels, self.input: reader_cv.streams.features } # Instantiate progress writers. progress_writers = [ ct.logging.ProgressPrinter( tag='Training', num_epochs=int(num_samples_per_sweep * num_sweeps_to_train_with / minibatch_size / print_frequency[0])) ] training_progress_output_freq = 1 if tensorboard_logdir is not None: self.tb_printer = ct.logging.TensorBoardProgressWriter( freq=training_progress_output_freq, log_dir=tensorboard_logdir, model=self.training_model) progress_writers.append(self.tb_printer) # Instantiate the learning rate schedule learning_rate_schedule = [0.01] * 30 + [0.007] learning_rate_schedule = ct.learning_parameter_schedule( learning_rate_schedule, minibatch_size=minibatch_size, epoch_size=num_samples_per_sweep) # Instantiate the trainer object to drive the model training learner = ct.adam(self.training_model.parameters, learning_rate_schedule, momentum=[0.9], variance_momentum=[0.999], gaussian_noise_injection_std_dev=[0.0]) trainer = ct.Trainer(self.training_model, (loss, error), [learner], progress_writers) ct.training_session( trainer=trainer, mb_source=reader_train, mb_size=minibatch_size, model_inputs_to_streams=self.input_map, max_samples=num_samples_per_sweep * num_sweeps_to_train_with, progress_frequency=print_frequency, checkpoint_config=ct.CheckpointConfig( filename='./checkpoints/checkpoint', frequency=checkpoint_frequency, restore=restore_checkpoint), cv_config=ct.CrossValidationConfig( minibatch_size=128, minibatch_source=reader_cv, frequency=cross_validation_frequency, callback=self.cross_validation_callbackfunc, max_samples=1024, model_inputs_to_streams=self.cv_input_map), test_config=ct.TestConfig( minibatch_source=reader_test, minibatch_size=minibatch_size, model_inputs_to_streams=self.test_input_map)).train() # save models self.digitcaps_model.save('./models/digitcaps_model.cntk') self.training_model.save('./models/training_model.cntk') self.prediction_model.save('./models/prediction_model.cntk') if self.reconstruction_model: self.reconstruction_model.save( './models/reconstruction_model.cntk') self.manipulation_model.save('./models/manipulation_model.cntk') print('Done.')