Ejemplo n.º 1
0
def objective(hyper_params):
    # Update the trainer with the new hyper-parameters
    model_config.update(hyper_params)

    # Create the model trainer
    model_trainer = ModelTrainerFactory(model=SimpleConvNet()).create(model_config)

    # Train with the training strategy
    monitor = SimpleTrainer("MNIST Trainer", train_loader, valid_loader, None, model_trainer,
                            RunConfiguration(use_amp=False)) \
        .with_event_handler(PrintTrainingStatus(every=100), Event.ON_BATCH_END) \
        .train(training_config.nb_epochs)

    return {'loss': monitor["SimpleNet"][Phase.VALIDATION][Monitor.LOSS]["CrossEntropy"], 'status': STATUS_OK}
Ejemplo n.º 2
0
    def test_should_create_model_trainer_with_config(self):
        model_trainers = ModelTrainerFactory(model=self._model).create(
            model_trainer_configs=self._model_trainer_configs)
        simple_net = model_trainers[0]

        assert_that(simple_net, instance_of(ModelTrainer))
        assert_that(simple_net.name, is_(self.SIMPLE_NET_NAME))
        assert_that(len(simple_net.criterions), is_(2))
        [
            assert_that(simple_net.criterions[criterion], instance_of(_Loss))
            for criterion in simple_net.criterions.keys()
        ]
        assert_that(simple_net.optimizer, instance_of(torch.optim.SGD))
        assert_that(simple_net.scheduler,
                    instance_of(torch.optim.lr_scheduler.ReduceLROnPlateau))
        assert_that(len(simple_net._metric_computers.keys()), is_(2))
        [
            assert_that(simple_net._metric_computers[metric],
                        instance_of(Metric))
            for metric in simple_net._metric_computers.keys()
        ]
Ejemplo n.º 3
0
    reconstruction_datasets = list()

    iSEG_train = None
    iSEG_CSV = None
    MRBrainS_train = None
    MRBrainS_CSV = None
    ABIDE_train = None
    ABIDE_CSV = None

    iSEG_augmentation_strategy = None
    MRBrainS_augmentation_strategy = None
    ABIDE_augmentation_strategy = None

    # Initialize the model trainers
    model_trainer_factory = ModelTrainerFactory(
        model_factory=CustomModelFactory(),
        criterion_factory=CustomCriterionFactory())
    model_trainers = model_trainer_factory.create(model_trainer_configs)
    if not isinstance(model_trainers, list):
        model_trainers = [model_trainers]

    # Create datasets
    if dataset_configs.get("iSEG", None) is not None:
        iSEG_train, iSEG_valid, iSEG_test, iSEG_reconstruction = iSEGSliceDatasetFactory.create_train_valid_test(
            source_dir=dataset_configs["iSEG"].path,
            modalities=dataset_configs["iSEG"].modalities,
            dataset_id=ISEG_ID,
            test_size=dataset_configs["iSEG"].validation_split,
            max_subjects=dataset_configs["iSEG"].max_subjects,
            max_num_patches=dataset_configs["iSEG"].max_num_patches,
            augment=dataset_configs["iSEG"].augment,
Ejemplo n.º 4
0
                           Normalize((0.1307, ), (0.3081, ))]))
    test_dataset = torchvision.datasets.MNIST(
        './files/',
        train=False,
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))]))

    # Initialize loaders.
    train_loader, valid_loader = DataloaderFactory(
        train_dataset, test_dataset).create(run_config, training_config)

    # Initialize the loggers.
    if run_config.local_rank == 0:
        visdom_logger = VisdomLogger(
            VisdomConfiguration.from_yml(CONFIG_FILE_PATH))

    # Initialize the model trainers.
    model_trainer = ModelTrainerFactory(model=SimpleNet()).create(
        model_trainer_config, run_config)

    # Train with the training strategy.
    if run_config.local_rank == 0:
        trainer = MNISTTrainer(training_config, model_trainer, train_loader, valid_loader, run_config) \
            .with_event_handler(ConsoleLogger(), Event.ON_EPOCH_END) \
            .with_event_handler(visdom_logger, Event.ON_EPOCH_END, PlotAllModelStateVariables()) \
            .train(training_config.nb_epochs)
    else:
        trainer = MNISTTrainer(training_config, model_trainer, train_loader, valid_loader, run_config) \
            .train(training_config.nb_epochs)
Ejemplo n.º 5
0
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))])),
                              batch_size=training_config.batch_size_train,
                              shuffle=True)

    test_loader = DataLoader(torchvision.datasets.MNIST(
        './files/',
        train=False,
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))])),
                             batch_size=training_config.batch_size_valid,
                             shuffle=True)

    # Initialize the loggers
    visdom_logger = VisdomLogger(
        VisdomConfiguration.from_yml(CONFIG_FILE_PATH))

    # Initialize the model trainers
    model_trainer = ModelTrainerFactory(model=SimpleNet()).create(
        model_trainer_config, RunConfiguration(use_amp=False))

    # Train with the training strategy
    trainer = SimpleTrainer("MNIST Trainer", train_loader, test_loader, model_trainer) \
        .with_event_handler(PrintTrainingStatus(every=100), Event.ON_TRAIN_BATCH_END) \
        .with_event_handler(PrintModelTrainersStatus(every=100), Event.ON_BATCH_END) \
        .with_event_handler(PlotAllModelStateVariables(visdom_logger), Event.ON_EPOCH_END) \
        .with_event_handler(PlotGradientFlow(visdom_logger, every=100), Event.ON_TRAIN_BATCH_END) \
        .train(training_config.nb_epochs)