Ejemplo n.º 1
0
def test_session_with_test_own_inputs(tmpdir, device_id):
    device = cntk_device(device_id)
    writer = MockProgressWriter(expected_test_summary=[[92, 25]])
    t, feature, label = create_sample_model(device, writer)

    mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
    mbs1 = mb_source(tmpdir, "test", ctf=ctf_data2, streams=['S4', 'S5'])

    input_map = {
        feature: mbs.streams.features,
        label: mbs.streams.labels
    }

    input_map1 = {
        feature: mbs1.streams.features,
        label: mbs1.streams.labels
    }

    C.training_session(
        trainer=t, mb_source=mbs, 
        mb_size=4, model_inputs_to_streams=input_map,
        max_samples=60,
        test_config = C.TestConfig(mbs1, minibatch_size=2, model_inputs_to_streams = input_map1),
    ).train(device)

    assert(t.total_number_of_samples_seen == 61)
    assert(writer.test_summary_counter == 1)
Ejemplo n.º 2
0
def test_training_session_with_infinite_samples(tmpdir, device_id):
    import pytest
    device = cntk_device(device_id)
    t, feature, label = create_sample_model(device)
    mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)

    input_map = {feature: mbs.streams.features, label: mbs.streams.labels}

    with pytest.raises(ValueError) as info1:
        C.training_session(trainer=t,
                           mb_source=mbs,
                           mb_size=4,
                           model_inputs_to_streams=input_map).train(device)
    assert 'Train minibatch source must have a limited number of samples or sweeps' in str(
        info1.value)

    with pytest.raises(ValueError) as info2:
        mbs1 = mb_source(tmpdir, "test", max_samples=INFINITELY_REPEAT)
        C.training_session(
            trainer=t,
            mb_source=mbs,
            mb_size=4,
            model_inputs_to_streams=input_map,
            max_samples=10,
            test_config=C.TestConfig(mbs1, minibatch_size=2),
        ).train(device)
    assert 'Test minibatch source must have a limited number of samples or sweeps' in str(
        info2.value)

    with pytest.raises(ValueError) as info3:
        mbs2 = mb_source(tmpdir, "cv", max_samples=INFINITELY_REPEAT)
        C.training_session(
            trainer=t,
            mb_source=mbs,
            mb_size=4,
            model_inputs_to_streams=input_map,
            max_samples=20,
            cv_config=C.CrossValidationConfig(mbs2)).train(device)
    assert 'Cross validation minibatch source must have a limited number of samples or sweeps' in str(
        info3.value)
Ejemplo n.º 3
0
def test_session_with_legacy_api(tmpdir, device_id):
    run_simple_training(
        tmpdir,
        device_id,
        test_config_factory=lambda mbs, input_map: C.TestConfig(
            source=mbs, mb_size=2, model_inputs_to_streams=input_map))
            return False  # means we are done
        print(
            "Improvement of metric from {:.3f} to {:.3f} insufficient. Halving learning rate to {}."
            .format(prev_metric, average_error, learner.learning_rate()))
    prev_metric = average_error
    return True  # means continue


cv_callback_config = C.CrossValidationConfig((X_cv, Y_cv),
                                             3 * epoch_size,
                                             minibatch_size=256,
                                             callback=adjust_lr_callback,
                                             criterion=criterion)

# Callback for testing the final model.
test_callback_config = C.TestConfig((X_test, Y_test), criterion=criterion)

# Configure distributed training.
# For this, we wrap the learner in a distributed_learner object.
# This specific example implements the BlockMomentum method. The Python script must be run
# using mpiexec in order to have effect. For example, under Windows, the command is:
#   mpiexec -n 4 -lines python -u MNIST_Complex_Training.py
learner = C.train.distributed.data_parallel_distributed_learner(learner)

# For distributed training, we must maximize the minibatch size, as to minimize
# communication cost and GPU underutilization. Hence, we use a "schedule"
# that increases the minibatch size after a few epochs. By specifying the learning rate
# as per sample, the contribution per sample maintains the same scale without
# having to fix up the learning rate.
# For this MNIST model, larger minibatch sizes make it faster, because the
# model is too small to utilize a full GPU. Hence data-parallel training cannot
Ejemplo n.º 5
0
def test_session_with_own_test_inputs(tmpdir, device_id):
    run_simple_training(
        tmpdir, device_id,
        test_config_factory = lambda mbs, input_map : C.TestConfig(minibatch_source=mbs, minibatch_size=2, model_inputs_to_streams = input_map))
Ejemplo n.º 6
0
    def train_and_test(self,
                       reader_train,
                       reader_test,
                       reader_cv,
                       restore_checkpoint=True):
        '''
        Train the model and validate the results

        Args:
            reader_train (:class:`~cntk.io.MinibatchSource`): the dataset reader for training.
            reader_test (:class:`~cntk.io.MinibatchSource`): the dataset reader for evaluation.
            restore_checkpoint (bool, optional): Continue training form latest checkpoint if True (default)

        Returns:
            None
        '''
        from CapsNet import CapsNet

        self.input = ct.input_variable(self.input_dim_model,
                                       name='MINST_Input')
        self.labels = ct.input_variable(self.output_dim_model,
                                        name='MINST_Labels')
        self.perturbations = ct.input_variable(self.perturbations_dim,
                                               name='Perturbations')

        self.caps_net = CapsNet(self.input / 255.,
                                self.labels,
                                routings=3,
                                use_reconstruction=True)

        # models
        self.training_model, self.digitcaps_model, self.prediction_model, self.reconstruction_model = self.caps_net.models(
        )
        self.manipulation_model = self.caps_net.manipulation(
            self.perturbations)

        # loss & error
        loss, error = self.caps_net.criterion()

        # Number of parameters in the network
        # 5. Capsules on MNIST "... CapsNet has 8.2M parameters and 6.8M parameters without the reconstruction subnetwork."
        num_parameters, num_tensors = get_number_of_parameters(
            self.training_model)
        print(
            "DigitCaps contains {} learneable parameters in {} parameter tensors."
            .format(num_parameters, num_tensors))

        # Initialize the parameters for the trainer
        minibatch_size = 128
        num_samples_per_sweep = 60000
        num_sweeps_to_train_with = 30

        # Report & Checkpoint frequency
        print_frequency = (4, ct.DataUnit.minibatch)
        checkpoint_frequency = (100, ct.DataUnit.minibatch)
        cross_validation_frequency = (40, ct.DataUnit.minibatch)

        tensorboard_logdir = './tensorboard'

        # Map the data streams to the input and labels.
        self.input_map = {
            self.labels: reader_train.streams.labels,
            self.input: reader_train.streams.features
        }

        self.test_input_map = {
            self.labels: reader_test.streams.labels,
            self.input: reader_test.streams.features
        }

        self.cv_input_map = {
            self.labels: reader_cv.streams.labels,
            self.input: reader_cv.streams.features
        }

        # Instantiate progress writers.
        progress_writers = [
            ct.logging.ProgressPrinter(
                tag='Training',
                num_epochs=int(num_samples_per_sweep *
                               num_sweeps_to_train_with / minibatch_size /
                               print_frequency[0]))
        ]

        training_progress_output_freq = 1

        if tensorboard_logdir is not None:
            self.tb_printer = ct.logging.TensorBoardProgressWriter(
                freq=training_progress_output_freq,
                log_dir=tensorboard_logdir,
                model=self.training_model)
            progress_writers.append(self.tb_printer)

        # Instantiate the learning rate schedule
        learning_rate_schedule = [0.01] * 30 + [0.007]
        learning_rate_schedule = ct.learning_parameter_schedule(
            learning_rate_schedule,
            minibatch_size=minibatch_size,
            epoch_size=num_samples_per_sweep)

        # Instantiate the trainer object to drive the model training
        learner = ct.adam(self.training_model.parameters,
                          learning_rate_schedule,
                          momentum=[0.9],
                          variance_momentum=[0.999],
                          gaussian_noise_injection_std_dev=[0.0])
        trainer = ct.Trainer(self.training_model, (loss, error), [learner],
                             progress_writers)

        ct.training_session(
            trainer=trainer,
            mb_source=reader_train,
            mb_size=minibatch_size,
            model_inputs_to_streams=self.input_map,
            max_samples=num_samples_per_sweep * num_sweeps_to_train_with,
            progress_frequency=print_frequency,
            checkpoint_config=ct.CheckpointConfig(
                filename='./checkpoints/checkpoint',
                frequency=checkpoint_frequency,
                restore=restore_checkpoint),
            cv_config=ct.CrossValidationConfig(
                minibatch_size=128,
                minibatch_source=reader_cv,
                frequency=cross_validation_frequency,
                callback=self.cross_validation_callbackfunc,
                max_samples=1024,
                model_inputs_to_streams=self.cv_input_map),
            test_config=ct.TestConfig(
                minibatch_source=reader_test,
                minibatch_size=minibatch_size,
                model_inputs_to_streams=self.test_input_map)).train()

        # save models
        self.digitcaps_model.save('./models/digitcaps_model.cntk')
        self.training_model.save('./models/training_model.cntk')
        self.prediction_model.save('./models/prediction_model.cntk')
        if self.reconstruction_model:
            self.reconstruction_model.save(
                './models/reconstruction_model.cntk')
            self.manipulation_model.save('./models/manipulation_model.cntk')

        print('Done.')