示例#1
0
def test_session_cv_callback_early_exit(tmpdir, device_id):
    device = cntk_device(device_id)
    t, feature, label = create_sample_model(device)
    mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)

    input_map = {feature: mbs.streams.features, label: mbs.streams.labels}

    counter = [0]

    def cv_callback(index, average_error, num_samples, num_mb):
        assert (counter[0] == index)
        assert average_error == 0
        assert num_samples == 0
        assert num_mb == 0
        counter[0] += 1
        return counter[0] < 1

    C.training_session(trainer=t,
                       mb_source=mbs,
                       mb_size=4,
                       model_inputs_to_streams=input_map,
                       max_samples=60,
                       cv_config=C.CrossValidationConfig(
                           frequency=20, callback=cv_callback)).train(device)
    assert counter == [1]
示例#2
0
def test_usermbsource_training(tmpdir):
    input_dim = 1000
    num_output_classes = 5

    mbs = MyDataSource(input_dim, num_output_classes)
    # Using this for testing the UserMinibatchSource checkpointing
    mbs_cv = MyDataSource(input_dim, num_output_classes)

    from cntk import sequence, parameter, plus, cross_entropy_with_softmax, \
            classification_error, learning_rate_schedule, sgd, Trainer, \
            training_session, times, UnitType

    feature = sequence.input_variable(shape=(input_dim, ))
    label = C.input_variable(shape=(num_output_classes, ))
    p = parameter(shape=(input_dim, num_output_classes), init=10)
    z = times(sequence.reduce_sum(feature), p, name='z')
    ce = cross_entropy_with_softmax(z, label)
    errs = classification_error(z, label)

    lr_per_sample = learning_rate_schedule([0.3, 0.2, 0.1, 0.0],
                                           UnitType.sample)
    learner = sgd(z.parameters, lr_per_sample)
    trainer = Trainer(z, (ce, errs), [learner])
    input_map = {feature: mbs.fsi, label: mbs.lsi}

    session = training_session(trainer=trainer,
                               mb_source=mbs,
                               model_inputs_to_streams=input_map,
                               mb_size=4,
                               max_samples=20,
                               cv_config=C.CrossValidationConfig(
                                   source=mbs_cv, max_samples=10, mb_size=2))
    session.train()

    assert trainer.total_number_of_samples_seen == 20
示例#3
0
def test_session_cv_callback_with_cross_validation_3_times(tmpdir, device_id):
    device = cntk_device(device_id)
    t, feature, label = create_sample_model(device)
    mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
    cv_mbs = mb_source(tmpdir, "cv")

    input_map = {feature: mbs.streams.features, label: mbs.streams.labels}

    def cv_callback(index, average_error, num_samples, num_mb):
        initial_position = cv_mbs.current_position
        total_error = 0
        while True:
            mb = cv_mbs.next_minibatch(2, input_map=input_map)
            if not mb:
                break
            mb_error = t.test_minibatch(mb, device=device)
            total_error += mb_error * mb[label].num_samples

        total_samples = 25  # Please see input data
        assert ((total_error * 100) / total_samples == 92)
        cv_mbs.current_position = initial_position
        return True

    C.training_session(trainer=t,
                       mb_source=mbs,
                       mb_size=4,
                       model_inputs_to_streams=input_map,
                       max_samples=60,
                       cv_config=C.CrossValidationConfig(
                           frequency=20, callback=cv_callback)).train(device)

    assert (t.total_number_of_samples_seen == 61)
示例#4
0
def test_usermbsource_training(tmpdir, with_checkpoint_impl):
    input_dim = 1000
    num_output_classes = 5

    mbs = MyDataSource(input_dim, num_output_classes)
    # Using this for testing the UserMinibatchSource checkpointing
    if with_checkpoint_impl:
        MBS_CV_CLASS = MyDataSourceWithCheckpoint
    else:
        MBS_CV_CLASS = MyDataSource

    mbs_cv = MBS_CV_CLASS(input_dim, num_output_classes)

    from cntk import sequence, parameter, plus, cross_entropy_with_softmax, \
            classification_error, learning_rate_schedule, sgd, Trainer, \
            training_session, times, UnitType

    feature = sequence.input_variable(shape=(input_dim,))
    label = C.input_variable(shape=(num_output_classes,))
    p = parameter(shape=(input_dim, num_output_classes), init=10)
    z = times(sequence.reduce_sum(feature), p, name='z')
    ce = cross_entropy_with_softmax(z, label)
    errs = classification_error(z, label)

    #having a large learning rate to prevent the model from converging earlier where not all the intended samples are fed
    #note that training session can end earlier if there is no updates
    lr_per_sample = learning_rate_schedule(0.3, UnitType.sample)
    learner = sgd(z.parameters, lr_per_sample)
    trainer = Trainer(z, (ce, errs), [learner])
    input_map = {
        feature: mbs.fsi,
        label: mbs.lsi
    }

    session = training_session(
        trainer=trainer, mb_source=mbs,
        model_inputs_to_streams=input_map,
        mb_size=4, max_samples=20,
        cv_config = C.CrossValidationConfig(minibatch_source=mbs_cv, max_samples=10,
            minibatch_size=2)
    )
    session.train()

    assert trainer.total_number_of_samples_seen == 20
    if with_checkpoint_impl:
        assert mbs_cv._restore_from_checkpoint_calls == 1
示例#5
0
def test_session_cross_validation_at_end(tmpdir, device_id):
    device = cntk_device(device_id)
    writer = MockProgressWriter(expected_test_summary=[[92, 25]])
    t, feature, label = create_sample_model(device, writer)
    mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
    mbs1 = mb_source(tmpdir, "cv")

    input_map = {feature: mbs.streams.features, label: mbs.streams.labels}

    C.training_session(trainer=t,
                       mb_source=mbs,
                       mb_size=4,
                       model_inputs_to_streams=input_map,
                       max_samples=20,
                       cv_config=C.CrossValidationConfig(mbs1)).train(device)

    assert (t.total_number_of_samples_seen == 21)
    assert (writer.test_summary_counter == 1)
示例#6
0
def test_session_cross_validation_3_times_checkpoints_2_save_all(
        tmpdir, device_id):
    device = cntk_device(device_id)
    writer = MockProgressWriter(
        expected_test_summary=[[92, 25], [92, 25], [92, 25]])
    t, feature, label = create_sample_model(device, writer)
    mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
    mbs1 = mb_source(tmpdir, "cv")

    input_map = {feature: mbs.streams.features, label: mbs.streams.labels}

    test_dir = str(tmpdir)

    C.training_session(trainer=t,
                       mb_source=mbs,
                       mb_size=4,
                       model_inputs_to_streams=input_map,
                       max_samples=60,
                       checkpoint_config=C.CheckpointConfig(
                           frequency=35,
                           preserve_all=True,
                           filename=str(tmpdir / "checkpoint_save_all")),
                       cv_config=C.CrossValidationConfig(
                           mbs1, frequency=20)).train(device)

    candidates = [
        f for f in listdir(test_dir)
        if isfile(join(test_dir, f)) and f.startswith("checkpoint_save_all")
    ]

    assert ("checkpoint_save_all0" in candidates)
    assert ("checkpoint_save_all0.ckp" in candidates)

    assert ("checkpoint_save_all1" in candidates)
    assert ("checkpoint_save_all1.ckp" in candidates)

    assert ("checkpoint_save_all" in candidates)
    assert ("checkpoint_save_all.ckp" in candidates)

    assert (writer.test_summary_counter == 3)
示例#7
0
def test_training_session_with_infinite_samples(tmpdir, device_id):
    import pytest
    device = cntk_device(device_id)
    t, feature, label = create_sample_model(device)
    mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)

    input_map = {feature: mbs.streams.features, label: mbs.streams.labels}

    with pytest.raises(ValueError) as info1:
        C.training_session(trainer=t,
                           mb_source=mbs,
                           mb_size=4,
                           model_inputs_to_streams=input_map).train(device)
    assert 'Train minibatch source must have a limited number of samples or sweeps' in str(
        info1.value)

    with pytest.raises(ValueError) as info2:
        mbs1 = mb_source(tmpdir, "test", max_samples=INFINITELY_REPEAT)
        C.training_session(
            trainer=t,
            mb_source=mbs,
            mb_size=4,
            model_inputs_to_streams=input_map,
            max_samples=10,
            test_config=C.TestConfig(mbs1, minibatch_size=2),
        ).train(device)
    assert 'Test minibatch source must have a limited number of samples or sweeps' in str(
        info2.value)

    with pytest.raises(ValueError) as info3:
        mbs2 = mb_source(tmpdir, "cv", max_samples=INFINITELY_REPEAT)
        C.training_session(
            trainer=t,
            mb_source=mbs,
            mb_size=4,
            model_inputs_to_streams=input_map,
            max_samples=20,
            cv_config=C.CrossValidationConfig(mbs2)).train(device)
    assert 'Cross validation minibatch source must have a limited number of samples or sweeps' in str(
        info3.value)
示例#8
0
def test_session_cross_validation_3_times_on_minibatch_unit(tmpdir, device_id):
    device = cntk_device(device_id)
    writer = MockProgressWriter(
        expected_test_summary=[[92, 25], [92, 25], [92, 25]])
    t, feature, label = create_sample_model(device, writer)
    mbs = mb_source(tmpdir, "training", max_samples=INFINITELY_REPEAT)
    mbs1 = mb_source(tmpdir, "cv")

    input_map = {feature: mbs.streams.features, label: mbs.streams.labels}

    C.training_session(
        trainer=t,
        mb_source=mbs,
        mb_size=4,
        model_inputs_to_streams=input_map,
        max_samples=60,
        cv_config=C.CrossValidationConfig(
            mbs1, frequency=(5, C.train.DataUnit.minibatch), minibatch_size=2),
    ).train(device)

    assert (t.total_number_of_samples_seen == 61)
    assert (writer.test_summary_counter == 3)
                                                     2))
        if learner.learning_rate() < lr_per_sample / (
                2**7 - 0.1):  # we are done after the 6-th LR cut
            print("Learning rate {} too small. Training complete.".format(
                learner.learning_rate()))
            return False  # means we are done
        print(
            "Improvement of metric from {:.3f} to {:.3f} insufficient. Halving learning rate to {}."
            .format(prev_metric, average_error, learner.learning_rate()))
    prev_metric = average_error
    return True  # means continue


cv_callback_config = C.CrossValidationConfig((X_cv, Y_cv),
                                             3 * epoch_size,
                                             minibatch_size=256,
                                             callback=adjust_lr_callback,
                                             criterion=criterion)

# Callback for testing the final model.
test_callback_config = C.TestConfig((X_test, Y_test), criterion=criterion)

# Configure distributed training.
# For this, we wrap the learner in a distributed_learner object.
# This specific example implements the BlockMomentum method. The Python script must be run
# using mpiexec in order to have effect. For example, under Windows, the command is:
#   mpiexec -n 4 -lines python -u MNIST_Complex_Training.py
learner = C.train.distributed.data_parallel_distributed_learner(learner)

# For distributed training, we must maximize the minibatch size, as to minimize
# communication cost and GPU underutilization. Hence, we use a "schedule"
示例#10
0
    def train_and_test(self,
                       reader_train,
                       reader_test,
                       reader_cv,
                       restore_checkpoint=True):
        '''
        Train the model and validate the results

        Args:
            reader_train (:class:`~cntk.io.MinibatchSource`): the dataset reader for training.
            reader_test (:class:`~cntk.io.MinibatchSource`): the dataset reader for evaluation.
            restore_checkpoint (bool, optional): Continue training form latest checkpoint if True (default)

        Returns:
            None
        '''
        from CapsNet import CapsNet

        self.input = ct.input_variable(self.input_dim_model,
                                       name='MINST_Input')
        self.labels = ct.input_variable(self.output_dim_model,
                                        name='MINST_Labels')
        self.perturbations = ct.input_variable(self.perturbations_dim,
                                               name='Perturbations')

        self.caps_net = CapsNet(self.input / 255.,
                                self.labels,
                                routings=3,
                                use_reconstruction=True)

        # models
        self.training_model, self.digitcaps_model, self.prediction_model, self.reconstruction_model = self.caps_net.models(
        )
        self.manipulation_model = self.caps_net.manipulation(
            self.perturbations)

        # loss & error
        loss, error = self.caps_net.criterion()

        # Number of parameters in the network
        # 5. Capsules on MNIST "... CapsNet has 8.2M parameters and 6.8M parameters without the reconstruction subnetwork."
        num_parameters, num_tensors = get_number_of_parameters(
            self.training_model)
        print(
            "DigitCaps contains {} learneable parameters in {} parameter tensors."
            .format(num_parameters, num_tensors))

        # Initialize the parameters for the trainer
        minibatch_size = 128
        num_samples_per_sweep = 60000
        num_sweeps_to_train_with = 30

        # Report & Checkpoint frequency
        print_frequency = (4, ct.DataUnit.minibatch)
        checkpoint_frequency = (100, ct.DataUnit.minibatch)
        cross_validation_frequency = (40, ct.DataUnit.minibatch)

        tensorboard_logdir = './tensorboard'

        # Map the data streams to the input and labels.
        self.input_map = {
            self.labels: reader_train.streams.labels,
            self.input: reader_train.streams.features
        }

        self.test_input_map = {
            self.labels: reader_test.streams.labels,
            self.input: reader_test.streams.features
        }

        self.cv_input_map = {
            self.labels: reader_cv.streams.labels,
            self.input: reader_cv.streams.features
        }

        # Instantiate progress writers.
        progress_writers = [
            ct.logging.ProgressPrinter(
                tag='Training',
                num_epochs=int(num_samples_per_sweep *
                               num_sweeps_to_train_with / minibatch_size /
                               print_frequency[0]))
        ]

        training_progress_output_freq = 1

        if tensorboard_logdir is not None:
            self.tb_printer = ct.logging.TensorBoardProgressWriter(
                freq=training_progress_output_freq,
                log_dir=tensorboard_logdir,
                model=self.training_model)
            progress_writers.append(self.tb_printer)

        # Instantiate the learning rate schedule
        learning_rate_schedule = [0.01] * 30 + [0.007]
        learning_rate_schedule = ct.learning_parameter_schedule(
            learning_rate_schedule,
            minibatch_size=minibatch_size,
            epoch_size=num_samples_per_sweep)

        # Instantiate the trainer object to drive the model training
        learner = ct.adam(self.training_model.parameters,
                          learning_rate_schedule,
                          momentum=[0.9],
                          variance_momentum=[0.999],
                          gaussian_noise_injection_std_dev=[0.0])
        trainer = ct.Trainer(self.training_model, (loss, error), [learner],
                             progress_writers)

        ct.training_session(
            trainer=trainer,
            mb_source=reader_train,
            mb_size=minibatch_size,
            model_inputs_to_streams=self.input_map,
            max_samples=num_samples_per_sweep * num_sweeps_to_train_with,
            progress_frequency=print_frequency,
            checkpoint_config=ct.CheckpointConfig(
                filename='./checkpoints/checkpoint',
                frequency=checkpoint_frequency,
                restore=restore_checkpoint),
            cv_config=ct.CrossValidationConfig(
                minibatch_size=128,
                minibatch_source=reader_cv,
                frequency=cross_validation_frequency,
                callback=self.cross_validation_callbackfunc,
                max_samples=1024,
                model_inputs_to_streams=self.cv_input_map),
            test_config=ct.TestConfig(
                minibatch_source=reader_test,
                minibatch_size=minibatch_size,
                model_inputs_to_streams=self.test_input_map)).train()

        # save models
        self.digitcaps_model.save('./models/digitcaps_model.cntk')
        self.training_model.save('./models/training_model.cntk')
        self.prediction_model.save('./models/prediction_model.cntk')
        if self.reconstruction_model:
            self.reconstruction_model.save(
                './models/reconstruction_model.cntk')
            self.manipulation_model.save('./models/manipulation_model.cntk')

        print('Done.')