コード例 #1
0
def train_and_test(network, trainer, train_source, test_source, max_epochs,
                   minibatch_size, epoch_size, restore, profiling):
    # define mapping from intput streams to network inputs
    input_map = {
        network['feature']: train_source.streams.features,
        network['label']: train_source.streams.labels
    }

    if profiling:
        start_profiler(sync_gpu=True)

    training_session(
        trainer=trainer,
        mb_source=train_source,
        model_inputs_to_streams=input_map,
        mb_size=minibatch_size,
        progress_frequency=epoch_size,
        checkpoint_config=CheckpointConfig(frequency=epoch_size,
                                           filename=os.path.join(
                                               model_path,
                                               "BN-Inception_CIFAR10"),
                                           restore=restore),
        test_config=TestConfig(test_source,
                               minibatch_size=minibatch_size)).train()

    if profiling:
        stop_profiler()
コード例 #2
0
def train_and_test(network, trainer, train_source, test_source, max_epochs, minibatch_size, epoch_size, restore, profiler_dir, progress_printer, testing_parameters):

    # define mapping from intput streams to network inputs
    input_map = {
        network['feature']: train_source.streams.features,
        network['label']: train_source.streams.labels
    }

    # perform model training
    if profiler_dir:
        start_profiler(profiler_dir, True)

    for epoch in range(max_epochs):       # loop over epochs
        sample_count = 0
        while sample_count < epoch_size:  # loop over minibatches in the epoch
            data = train_source.next_minibatch(min(minibatch_size, epoch_size-sample_count), input_map=input_map) # fetch minibatch.
            trainer.train_minibatch(data)                                   # update model with it
            sample_count += trainer.previous_minibatch_sample_count         # count samples processed so far
            progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
        progress_printer.epoch_summary(with_metric=True)
        network['output'].save(os.path.join(model_path, "BN-Inception_CIFAR-10_{}.model".format(epoch)))
        enable_profiler() # begin to collect profiler data after first epoch

    if profiler_dir:
        stop_profiler()

    # Finished
    # Evaluation parameters
    test_epoch_size, test_minibatch_size = testing_parameters

    # process minibatches and evaluate the model
    metric_numer    = 0
    metric_denom    = 0
    sample_count    = 0
    minibatch_index = 0

    while sample_count < test_epoch_size:
        current_minibatch = min(test_minibatch_size, test_epoch_size - sample_count)
        # Fetch next test min batch.
        data = test_source.next_minibatch(current_minibatch, input_map=input_map)
        # minibatch data to be trained with
        metric_numer += trainer.test_minibatch(data) * current_minibatch
        metric_denom += current_minibatch
        # Keep track of the number of samples processed so far.
        sample_count += data[network['label']].num_samples
        minibatch_index += 1

    print("")
    print("Final Results: Minibatch[1-{}]: errs = {:0.2f}% * {}".format(minibatch_index+1, (metric_numer*100.0)/metric_denom, metric_denom))
    print("")

    return metric_numer/metric_denom
コード例 #3
0
def train_and_test(network, trainer, train_source, test_source, minibatch_size, epoch_size, restore, profiling=False):

    # define mapping from intput streams to network inputs
    input_map = {
        network['feature']: train_source.streams.features,
        network['label']: train_source.streams.labels
    }

    if profiling:
        start_profiler(sync_gpu=True)

    training_session(
        trainer=trainer, mb_source=train_source,
        model_inputs_to_streams = input_map,
        mb_size = minibatch_size,
        progress_frequency = epoch_size,
        checkpoint_config=CheckpointConfig(frequency=epoch_size, filename=os.path.join(model_path, model_name), restore=restore),
        test_config=TestConfig(source=test_source, mb_size=minibatch_size)
    ).train()
        
    if profiling:
        stop_profiler()