Exemplo n.º 1
0
    # Create loaders.
    dataloaders = list(
        map(
            lambda dataset: DataLoader(dataset,
                                       training_config.batch_size,
                                       sampler=None,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       collate_fn=augmented_sample_collate,
                                       drop_last=True,
                                       pin_memory=True),
            [train_dataset, valid_dataset, test_dataset]))

    # Initialize the loggers.
    visdom_config = VisdomConfiguration.from_yml(args.config_file, "visdom")
    exp = args.config_file.split("/")[-3:]
    if visdom_config.save_destination is not None:
        save_folder = visdom_config.save_destination + os.path.join(
            exp[0], exp[1],
            os.path.basename(os.path.normpath(visdom_config.env)))
    else:
        save_folder = "saves/{}".format(
            os.path.basename(os.path.normpath(visdom_config.env)))

    [
        os.makedirs("{}/{}".format(save_folder, model), exist_ok=True)
        for model in ["Discriminator", "Generator", "Segmenter"]
    ]
    visdom_logger = VisdomLogger(visdom_config)
Exemplo n.º 2
0
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))])),
                              batch_size=training_config.batch_size_train,
                              shuffle=True)

    test_loader = DataLoader(torchvision.datasets.MNIST(
        './files/',
        train=False,
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))])),
                             batch_size=training_config.batch_size_valid,
                             shuffle=True)

    # Initialize the loggers
    visdom_logger = VisdomLogger(
        VisdomConfiguration.from_yml(CONFIG_FILE_PATH))

    # Initialize the model trainers
    model_trainer = ModelTrainerFactory(model=SimpleNet()).create(
        model_trainer_config, RunConfiguration(use_amp=False))

    # Train with the training strategy
    trainer = SimpleTrainer("MNIST Trainer", train_loader, test_loader, model_trainer) \
        .with_event_handler(PrintTrainingStatus(every=100), Event.ON_TRAIN_BATCH_END) \
        .with_event_handler(PrintModelTrainersStatus(every=100), Event.ON_BATCH_END) \
        .with_event_handler(PlotAllModelStateVariables(visdom_logger), Event.ON_EPOCH_END) \
        .with_event_handler(PlotGradientFlow(visdom_logger, every=100), Event.ON_TRAIN_BATCH_END) \
        .train(training_config.nb_epochs)