def test_session_cv_callback_early_exit(tmpdir, device_id): device = cntk_device(device_id) t, feature, label = create_sample_model(device) mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) input_map = {feature: mbs.streams.features, label: mbs.streams.labels} counter = [0] def cv_callback(index, average_error, num_samples, num_mb): assert (counter[0] == index) assert average_error == 0 assert num_samples == 0 assert num_mb == 0 counter[0] += 1 return counter[0] < 1 C.training_session(trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, cv_config=C.CrossValidationConfig( frequency=20, callback=cv_callback)).train(device) assert counter == [1]
def test_usermbsource_training(tmpdir): input_dim = 1000 num_output_classes = 5 mbs = MyDataSource(input_dim, num_output_classes) # Using this for testing the UserMinibatchSource checkpointing mbs_cv = MyDataSource(input_dim, num_output_classes) from cntk import sequence, parameter, plus, cross_entropy_with_softmax, \ classification_error, learning_rate_schedule, sgd, Trainer, \ training_session, times, UnitType feature = sequence.input_variable(shape=(input_dim, )) label = C.input_variable(shape=(num_output_classes, )) p = parameter(shape=(input_dim, num_output_classes), init=10) z = times(sequence.reduce_sum(feature), p, name='z') ce = cross_entropy_with_softmax(z, label) errs = classification_error(z, label) lr_per_sample = learning_rate_schedule([0.3, 0.2, 0.1, 0.0], UnitType.sample) learner = sgd(z.parameters, lr_per_sample) trainer = Trainer(z, (ce, errs), [learner]) input_map = {feature: mbs.fsi, label: mbs.lsi} session = training_session(trainer=trainer, mb_source=mbs, model_inputs_to_streams=input_map, mb_size=4, max_samples=20, cv_config=C.CrossValidationConfig( source=mbs_cv, max_samples=10, mb_size=2)) session.train() assert trainer.total_number_of_samples_seen == 20
def test_session_cv_callback_with_cross_validation_3_times(tmpdir, device_id): device = cntk_device(device_id) t, feature, label = create_sample_model(device) mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) cv_mbs = mb_source(tmpdir, "cv") input_map = {feature: mbs.streams.features, label: mbs.streams.labels} def cv_callback(index, average_error, num_samples, num_mb): initial_position = cv_mbs.current_position total_error = 0 while True: mb = cv_mbs.next_minibatch(2, input_map=input_map) if not mb: break mb_error = t.test_minibatch(mb, device=device) total_error += mb_error * mb[label].num_samples total_samples = 25 # Please see input data assert ((total_error * 100) / total_samples == 92) cv_mbs.current_position = initial_position return True C.training_session(trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, cv_config=C.CrossValidationConfig( frequency=20, callback=cv_callback)).train(device) assert (t.total_number_of_samples_seen == 61)
def test_usermbsource_training(tmpdir, with_checkpoint_impl): input_dim = 1000 num_output_classes = 5 mbs = MyDataSource(input_dim, num_output_classes) # Using this for testing the UserMinibatchSource checkpointing if with_checkpoint_impl: MBS_CV_CLASS = MyDataSourceWithCheckpoint else: MBS_CV_CLASS = MyDataSource mbs_cv = MBS_CV_CLASS(input_dim, num_output_classes) from cntk import sequence, parameter, plus, cross_entropy_with_softmax, \ classification_error, learning_rate_schedule, sgd, Trainer, \ training_session, times, UnitType feature = sequence.input_variable(shape=(input_dim,)) label = C.input_variable(shape=(num_output_classes,)) p = parameter(shape=(input_dim, num_output_classes), init=10) z = times(sequence.reduce_sum(feature), p, name='z') ce = cross_entropy_with_softmax(z, label) errs = classification_error(z, label) #having a large learning rate to prevent the model from converging earlier where not all the intended samples are fed #note that training session can end earlier if there is no updates lr_per_sample = learning_rate_schedule(0.3, UnitType.sample) learner = sgd(z.parameters, lr_per_sample) trainer = Trainer(z, (ce, errs), [learner]) input_map = { feature: mbs.fsi, label: mbs.lsi } session = training_session( trainer=trainer, mb_source=mbs, model_inputs_to_streams=input_map, mb_size=4, max_samples=20, cv_config = C.CrossValidationConfig(minibatch_source=mbs_cv, max_samples=10, minibatch_size=2) ) session.train() assert trainer.total_number_of_samples_seen == 20 if with_checkpoint_impl: assert mbs_cv._restore_from_checkpoint_calls == 1
def test_session_cross_validation_at_end(tmpdir, device_id): device = cntk_device(device_id) writer = MockProgressWriter(expected_test_summary=[[92, 25]]) t, feature, label = create_sample_model(device, writer) mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) mbs1 = mb_source(tmpdir, "cv") input_map = {feature: mbs.streams.features, label: mbs.streams.labels} C.training_session(trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=20, cv_config=C.CrossValidationConfig(mbs1)).train(device) assert (t.total_number_of_samples_seen == 21) assert (writer.test_summary_counter == 1)
def test_session_cross_validation_3_times_checkpoints_2_save_all( tmpdir, device_id): device = cntk_device(device_id) writer = MockProgressWriter( expected_test_summary=[[92, 25], [92, 25], [92, 25]]) t, feature, label = create_sample_model(device, writer) mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) mbs1 = mb_source(tmpdir, "cv") input_map = {feature: mbs.streams.features, label: mbs.streams.labels} test_dir = str(tmpdir) C.training_session(trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, checkpoint_config=C.CheckpointConfig( frequency=35, preserve_all=True, filename=str(tmpdir / "checkpoint_save_all")), cv_config=C.CrossValidationConfig( mbs1, frequency=20)).train(device) candidates = [ f for f in listdir(test_dir) if isfile(join(test_dir, f)) and f.startswith("checkpoint_save_all") ] assert ("checkpoint_save_all0" in candidates) assert ("checkpoint_save_all0.ckp" in candidates) assert ("checkpoint_save_all1" in candidates) assert ("checkpoint_save_all1.ckp" in candidates) assert ("checkpoint_save_all" in candidates) assert ("checkpoint_save_all.ckp" in candidates) assert (writer.test_summary_counter == 3)
def test_training_session_with_infinite_samples(tmpdir, device_id): import pytest device = cntk_device(device_id) t, feature, label = create_sample_model(device) mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) input_map = {feature: mbs.streams.features, label: mbs.streams.labels} with pytest.raises(ValueError) as info1: C.training_session(trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map).train(device) assert 'Train minibatch source must have a limited number of samples or sweeps' in str( info1.value) with pytest.raises(ValueError) as info2: mbs1 = mb_source(tmpdir, "test", max_samples=INFINITELY_REPEAT) C.training_session( trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=10, test_config=C.TestConfig(mbs1, minibatch_size=2), ).train(device) assert 'Test minibatch source must have a limited number of samples or sweeps' in str( info2.value) with pytest.raises(ValueError) as info3: mbs2 = mb_source(tmpdir, "cv", max_samples=INFINITELY_REPEAT) C.training_session( trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=20, cv_config=C.CrossValidationConfig(mbs2)).train(device) assert 'Cross validation minibatch source must have a limited number of samples or sweeps' in str( info3.value)
def test_session_cross_validation_3_times_on_minibatch_unit(tmpdir, device_id): device = cntk_device(device_id) writer = MockProgressWriter( expected_test_summary=[[92, 25], [92, 25], [92, 25]]) t, feature, label = create_sample_model(device, writer) mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT) mbs1 = mb_source(tmpdir, "cv") input_map = {feature: mbs.streams.features, label: mbs.streams.labels} C.training_session( trainer=t, mb_source=mbs, mb_size=4, model_inputs_to_streams=input_map, max_samples=60, cv_config=C.CrossValidationConfig( mbs1, frequency=(5, C.train.DataUnit.minibatch), minibatch_size=2), ).train(device) assert (t.total_number_of_samples_seen == 61) assert (writer.test_summary_counter == 3)
2)) if learner.learning_rate() < lr_per_sample / ( 2**7 - 0.1): # we are done after the 6-th LR cut print("Learning rate {} too small. Training complete.".format( learner.learning_rate())) return False # means we are done print( "Improvement of metric from {:.3f} to {:.3f} insufficient. Halving learning rate to {}." .format(prev_metric, average_error, learner.learning_rate())) prev_metric = average_error return True # means continue cv_callback_config = C.CrossValidationConfig((X_cv, Y_cv), 3 * epoch_size, minibatch_size=256, callback=adjust_lr_callback, criterion=criterion) # Callback for testing the final model. test_callback_config = C.TestConfig((X_test, Y_test), criterion=criterion) # Configure distributed training. # For this, we wrap the learner in a distributed_learner object. # This specific example implements the BlockMomentum method. The Python script must be run # using mpiexec in order to have effect. For example, under Windows, the command is: # mpiexec -n 4 -lines python -u MNIST_Complex_Training.py learner = C.train.distributed.data_parallel_distributed_learner(learner) # For distributed training, we must maximize the minibatch size, as to minimize # communication cost and GPU underutilization. Hence, we use a "schedule"
def train_and_test(self, reader_train, reader_test, reader_cv, restore_checkpoint=True): ''' Train the model and validate the results Args: reader_train (:class:`~cntk.io.MinibatchSource`): the dataset reader for training. reader_test (:class:`~cntk.io.MinibatchSource`): the dataset reader for evaluation. restore_checkpoint (bool, optional): Continue training form latest checkpoint if True (default) Returns: None ''' from CapsNet import CapsNet self.input = ct.input_variable(self.input_dim_model, name='MINST_Input') self.labels = ct.input_variable(self.output_dim_model, name='MINST_Labels') self.perturbations = ct.input_variable(self.perturbations_dim, name='Perturbations') self.caps_net = CapsNet(self.input / 255., self.labels, routings=3, use_reconstruction=True) # models self.training_model, self.digitcaps_model, self.prediction_model, self.reconstruction_model = self.caps_net.models( ) self.manipulation_model = self.caps_net.manipulation( self.perturbations) # loss & error loss, error = self.caps_net.criterion() # Number of parameters in the network # 5. Capsules on MNIST "... CapsNet has 8.2M parameters and 6.8M parameters without the reconstruction subnetwork." num_parameters, num_tensors = get_number_of_parameters( self.training_model) print( "DigitCaps contains {} learneable parameters in {} parameter tensors." .format(num_parameters, num_tensors)) # Initialize the parameters for the trainer minibatch_size = 128 num_samples_per_sweep = 60000 num_sweeps_to_train_with = 30 # Report & Checkpoint frequency print_frequency = (4, ct.DataUnit.minibatch) checkpoint_frequency = (100, ct.DataUnit.minibatch) cross_validation_frequency = (40, ct.DataUnit.minibatch) tensorboard_logdir = './tensorboard' # Map the data streams to the input and labels. self.input_map = { self.labels: reader_train.streams.labels, self.input: reader_train.streams.features } self.test_input_map = { self.labels: reader_test.streams.labels, self.input: reader_test.streams.features } self.cv_input_map = { self.labels: reader_cv.streams.labels, self.input: reader_cv.streams.features } # Instantiate progress writers. progress_writers = [ ct.logging.ProgressPrinter( tag='Training', num_epochs=int(num_samples_per_sweep * num_sweeps_to_train_with / minibatch_size / print_frequency[0])) ] training_progress_output_freq = 1 if tensorboard_logdir is not None: self.tb_printer = ct.logging.TensorBoardProgressWriter( freq=training_progress_output_freq, log_dir=tensorboard_logdir, model=self.training_model) progress_writers.append(self.tb_printer) # Instantiate the learning rate schedule learning_rate_schedule = [0.01] * 30 + [0.007] learning_rate_schedule = ct.learning_parameter_schedule( learning_rate_schedule, minibatch_size=minibatch_size, epoch_size=num_samples_per_sweep) # Instantiate the trainer object to drive the model training learner = ct.adam(self.training_model.parameters, learning_rate_schedule, momentum=[0.9], variance_momentum=[0.999], gaussian_noise_injection_std_dev=[0.0]) trainer = ct.Trainer(self.training_model, (loss, error), [learner], progress_writers) ct.training_session( trainer=trainer, mb_source=reader_train, mb_size=minibatch_size, model_inputs_to_streams=self.input_map, max_samples=num_samples_per_sweep * num_sweeps_to_train_with, progress_frequency=print_frequency, checkpoint_config=ct.CheckpointConfig( filename='./checkpoints/checkpoint', frequency=checkpoint_frequency, restore=restore_checkpoint), cv_config=ct.CrossValidationConfig( minibatch_size=128, minibatch_source=reader_cv, frequency=cross_validation_frequency, callback=self.cross_validation_callbackfunc, max_samples=1024, model_inputs_to_streams=self.cv_input_map), test_config=ct.TestConfig( minibatch_source=reader_test, minibatch_size=minibatch_size, model_inputs_to_streams=self.test_input_map)).train() # save models self.digitcaps_model.save('./models/digitcaps_model.cntk') self.training_model.save('./models/training_model.cntk') self.prediction_model.save('./models/prediction_model.cntk') if self.reconstruction_model: self.reconstruction_model.save( './models/reconstruction_model.cntk') self.manipulation_model.save('./models/manipulation_model.cntk') print('Done.')