コード例 #1
0
ファイル: main.py プロジェクト: sami-ets/DeepNormalize
np.random.seed(43)
random.seed(43)

if __name__ == '__main__':
    # Basic settings
    logging.basicConfig(level=logging.INFO)
    torch.set_num_threads(multiprocessing.cpu_count())
    torch.set_num_interop_threads(multiprocessing.cpu_count())
    args = ArgsParserFactory.create_parser(
        ArgsParserType.MODEL_TRAINING).parse_args()

    # Create configurations.
    run_config = RunConfiguration(use_amp=args.use_amp,
                                  local_rank=args.local_rank,
                                  amp_opt_level=args.amp_opt_level)
    model_trainer_configs, training_config = YamlConfigurationParser.parse(
        args.config_file)
    dataset_configs = YamlConfigurationParser.parse_section(
        args.config_file, "dataset")
    dataset_configs = {
        k: DatasetConfiguration(v)
        for k, v, in dataset_configs.items()
    }
    data_augmentation_config = YamlConfigurationParser.parse_section(
        args.config_file, "data_augmentation")
    config_html = [
        training_config.to_html(),
        list(map(lambda config: config.to_html(), dataset_configs.values())),
        list(map(lambda config: config.to_html(), model_trainer_configs))
    ]

    # Prepare the data.
コード例 #2
0
ファイル: main.py プロジェクト: banctilrobitaille/kerosene
        parser.add_argument("--local_rank",
                            dest="local_rank",
                            default=0,
                            type=int,
                            help="The local_rank of the GPU.")
        return parser


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    CONFIG_FILE_PATH = "config.yml"
    args = ArgsParserFactory.create_parser().parse_args()
    run_config = RunConfiguration(args.use_amp, args.amp_opt_level,
                                  args.local_rank)

    model_trainer_config, training_config = YamlConfigurationParser.parse(
        CONFIG_FILE_PATH)

    # Initialize the dataset. This is the only part the user must define manually.
    train_dataset = torchvision.datasets.MNIST(
        './files/',
        train=True,
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))]))
    test_dataset = torchvision.datasets.MNIST(
        './files/',
        train=False,
        download=True,
        transform=Compose([ToTensor(),
                           Normalize((0.1307, ), (0.3081, ))]))