def objective(hyper_params): # Update the trainer with the new hyper-parameters model_config.update(hyper_params) # Create the model trainer model_trainer = ModelTrainerFactory(model=SimpleConvNet()).create(model_config) # Train with the training strategy monitor = SimpleTrainer("MNIST Trainer", train_loader, valid_loader, None, model_trainer, RunConfiguration(use_amp=False)) \ .with_event_handler(PrintTrainingStatus(every=100), Event.ON_BATCH_END) \ .train(training_config.nb_epochs) return {'loss': monitor["SimpleNet"][Phase.VALIDATION][Monitor.LOSS]["CrossEntropy"], 'status': STATUS_OK}
def test_should_create_model_trainer_with_config(self): model_trainers = ModelTrainerFactory(model=self._model).create( model_trainer_configs=self._model_trainer_configs) simple_net = model_trainers[0] assert_that(simple_net, instance_of(ModelTrainer)) assert_that(simple_net.name, is_(self.SIMPLE_NET_NAME)) assert_that(len(simple_net.criterions), is_(2)) [ assert_that(simple_net.criterions[criterion], instance_of(_Loss)) for criterion in simple_net.criterions.keys() ] assert_that(simple_net.optimizer, instance_of(torch.optim.SGD)) assert_that(simple_net.scheduler, instance_of(torch.optim.lr_scheduler.ReduceLROnPlateau)) assert_that(len(simple_net._metric_computers.keys()), is_(2)) [ assert_that(simple_net._metric_computers[metric], instance_of(Metric)) for metric in simple_net._metric_computers.keys() ]
reconstruction_datasets = list() iSEG_train = None iSEG_CSV = None MRBrainS_train = None MRBrainS_CSV = None ABIDE_train = None ABIDE_CSV = None iSEG_augmentation_strategy = None MRBrainS_augmentation_strategy = None ABIDE_augmentation_strategy = None # Initialize the model trainers model_trainer_factory = ModelTrainerFactory( model_factory=CustomModelFactory(), criterion_factory=CustomCriterionFactory()) model_trainers = model_trainer_factory.create(model_trainer_configs) if not isinstance(model_trainers, list): model_trainers = [model_trainers] # Create datasets if dataset_configs.get("iSEG", None) is not None: iSEG_train, iSEG_valid, iSEG_test, iSEG_reconstruction = iSEGSliceDatasetFactory.create_train_valid_test( source_dir=dataset_configs["iSEG"].path, modalities=dataset_configs["iSEG"].modalities, dataset_id=ISEG_ID, test_size=dataset_configs["iSEG"].validation_split, max_subjects=dataset_configs["iSEG"].max_subjects, max_num_patches=dataset_configs["iSEG"].max_num_patches, augment=dataset_configs["iSEG"].augment,
Normalize((0.1307, ), (0.3081, ))])) test_dataset = torchvision.datasets.MNIST( './files/', train=False, download=True, transform=Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))])) # Initialize loaders. train_loader, valid_loader = DataloaderFactory( train_dataset, test_dataset).create(run_config, training_config) # Initialize the loggers. if run_config.local_rank == 0: visdom_logger = VisdomLogger( VisdomConfiguration.from_yml(CONFIG_FILE_PATH)) # Initialize the model trainers. model_trainer = ModelTrainerFactory(model=SimpleNet()).create( model_trainer_config, run_config) # Train with the training strategy. if run_config.local_rank == 0: trainer = MNISTTrainer(training_config, model_trainer, train_loader, valid_loader, run_config) \ .with_event_handler(ConsoleLogger(), Event.ON_EPOCH_END) \ .with_event_handler(visdom_logger, Event.ON_EPOCH_END, PlotAllModelStateVariables()) \ .train(training_config.nb_epochs) else: trainer = MNISTTrainer(training_config, model_trainer, train_loader, valid_loader, run_config) \ .train(training_config.nb_epochs)
download=True, transform=Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))])), batch_size=training_config.batch_size_train, shuffle=True) test_loader = DataLoader(torchvision.datasets.MNIST( './files/', train=False, download=True, transform=Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))])), batch_size=training_config.batch_size_valid, shuffle=True) # Initialize the loggers visdom_logger = VisdomLogger( VisdomConfiguration.from_yml(CONFIG_FILE_PATH)) # Initialize the model trainers model_trainer = ModelTrainerFactory(model=SimpleNet()).create( model_trainer_config, RunConfiguration(use_amp=False)) # Train with the training strategy trainer = SimpleTrainer("MNIST Trainer", train_loader, test_loader, model_trainer) \ .with_event_handler(PrintTrainingStatus(every=100), Event.ON_TRAIN_BATCH_END) \ .with_event_handler(PrintModelTrainersStatus(every=100), Event.ON_BATCH_END) \ .with_event_handler(PlotAllModelStateVariables(visdom_logger), Event.ON_EPOCH_END) \ .with_event_handler(PlotGradientFlow(visdom_logger, every=100), Event.ON_TRAIN_BATCH_END) \ .train(training_config.nb_epochs)