Example #1
0
    # Initialize the loggers.
    visdom_config = VisdomConfiguration.from_yml(args.config_file, "visdom")
    exp = args.config_file.split("/")[-3:]
    if visdom_config.save_destination is not None:
        save_folder = visdom_config.save_destination + os.path.join(
            exp[0], exp[1],
            os.path.basename(os.path.normpath(visdom_config.env)))
    else:
        save_folder = "saves/{}".format(
            os.path.basename(os.path.normpath(visdom_config.env)))

    [
        os.makedirs("{}/{}".format(save_folder, model), exist_ok=True)
        for model in ["Discriminator", "Generator", "Segmenter"]
    ]
    visdom_logger = VisdomLogger(visdom_config)

    visdom_logger(
        VisdomData("Experiment", "Experiment Config", PlotType.TEXT_PLOT,
                   PlotFrequency.EVERY_EPOCH, None, config_html))
    visdom_logger(
        VisdomData(
            "Experiment",
            "Patch count",
            PlotType.BAR_PLOT,
            PlotFrequency.EVERY_EPOCH,
            x=[
                len(iSEG_train) if iSEG_train is not None else 0,
                len(MRBrainS_train) if MRBrainS_train is not None else 0,
                len(ABIDE_train) if ABIDE_train is not None else 0
            ],
Example #2
0
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))])),
                              batch_size=training_config.batch_size_train,
                              shuffle=True)

    test_loader = DataLoader(torchvision.datasets.MNIST(
        './files/',
        train=False,
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))])),
                             batch_size=training_config.batch_size_valid,
                             shuffle=True)

    # Initialize the loggers
    visdom_logger = VisdomLogger(
        VisdomConfiguration.from_yml(CONFIG_FILE_PATH))

    # Initialize the model trainers
    model_trainer = ModelTrainerFactory(model=SimpleNet()).create(
        model_trainer_config, RunConfiguration(use_amp=False))

    # Train with the training strategy
    trainer = SimpleTrainer("MNIST Trainer", train_loader, test_loader, model_trainer) \
        .with_event_handler(PrintTrainingStatus(every=100), Event.ON_TRAIN_BATCH_END) \
        .with_event_handler(PrintModelTrainersStatus(every=100), Event.ON_BATCH_END) \
        .with_event_handler(PlotAllModelStateVariables(visdom_logger), Event.ON_EPOCH_END) \
        .with_event_handler(PlotGradientFlow(visdom_logger, every=100), Event.ON_TRAIN_BATCH_END) \
        .train(training_config.nb_epochs)