Exemplo n.º 1
0
def train_and_test(network, trainer, train_source, test_source, max_epochs,
                   minibatch_size, epoch_size, restore, profiling):
    # define mapping from intput streams to network inputs
    input_map = {
        network['feature']: train_source.streams.features,
        network['label']: train_source.streams.labels
    }

    if profiling:
        start_profiler(sync_gpu=True)

    training_session(
        trainer=trainer,
        mb_source=train_source,
        model_inputs_to_streams=input_map,
        mb_size=minibatch_size,
        progress_frequency=epoch_size,
        checkpoint_config=CheckpointConfig(frequency=epoch_size,
                                           filename=os.path.join(
                                               model_path,
                                               "BN-Inception_CIFAR10"),
                                           restore=restore),
        test_config=TestConfig(test_source,
                               minibatch_size=minibatch_size)).train()

    if profiling:
        stop_profiler()
Exemplo n.º 2
0
def main(params):
    # Create output and log directories if they don't exist
    if not os.path.isdir(params['output_folder']):
        os.makedirs(params['output_folder'])

    if not os.path.isdir(params['log_folder']):
        os.makedirs(params['log_folder'])

    # Create the network
    network = create_network()

    # Create readers
    train_reader = cbf_reader(os.path.join(params['input_folder'], 'train{}.cbf'.format(params['prefix'])), is_training=True,
                              max_samples=cntk.io.INFINITELY_REPEAT)
    cv_reader = cbf_reader(os.path.join(params['input_folder'], 'test{}.cbf'.format(params['prefix'])), is_training=False,
                           max_samples=cntk.io.FULL_DATA_SWEEP)
    test_reader = cbf_reader(os.path.join(params['input_folder'], 'test{}.cbf'.format(params['prefix'])), is_training=False,
                             max_samples=cntk.io.FULL_DATA_SWEEP)

    input_map = {
        network['input']: train_reader.streams.front,
        network['target']: train_reader.streams.label
    }

    # Create learner
    mm_schedule = momentum_schedule(0.90)
    lr_schedule = learning_parameter_schedule([(40, 0.1), (40, 0.01)], minibatch_size=params['minibatch_size'])
    learner = cntk.adam(network['model'].parameters, lr_schedule, mm_schedule, l2_regularization_weight=0.0005,
                        epoch_size=params['epoch_size'], minibatch_size=params['minibatch_size'])

    # Use TensorBoard for visual logging
    log_file = os.path.join(params['log_folder'], 'log.txt')
    pp_writer = cntk.logging.ProgressPrinter(freq=10, tag='Training', num_epochs=params['max_epochs'], log_to_file=log_file)
    tb_writer = cntk.logging.TensorBoardProgressWriter(freq=10, log_dir=params['log_folder'], model=network['model'])

    # Create trainer and training session
    trainer = Trainer(network['model'], (network['loss'], network['metric']), [learner], [pp_writer, tb_writer])
    test_config = TestConfig(minibatch_source=test_reader, minibatch_size=params['minibatch_size'], model_inputs_to_streams=input_map)
    cv_config = CrossValidationConfig(minibatch_source=cv_reader, frequency=(1, DataUnit.sweep),
                                      minibatch_size=params['minibatch_size'], model_inputs_to_streams=input_map)
    checkpoint_config = CheckpointConfig(os.path.join(params['output_folder'], model_name), frequency=(10, DataUnit.sweep), restore=params['restore'])

    session = training_session(trainer=trainer,
                               mb_source=train_reader,
                               mb_size=params['minibatch_size'],
                               model_inputs_to_streams=input_map,
                               max_samples=params['epoch_size'] * params['max_epochs'],
                               progress_frequency=(1, DataUnit.sweep),
                               checkpoint_config=checkpoint_config,
                               cv_config=cv_config,
                               test_config=test_config)

    cntk.logging.log_number_of_parameters(network['model'])
    session.train()

    # Save the trained model
    path = os.path.join(params['output_folder'], 'final_model.dnn')
    network['model'].save(path)
    print('Saved final model to', path)