コード例 #1
0
def main(configuration, init_distributed=False, predict=False):
    # A reload might be needed for imports
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    if init_distributed:
        distributed_init(config)

    config.training.seed = set_seed(config.training.seed)
    registry.register("seed", config.training.seed)
    print(f"Using seed {config.training.seed}")

    config = build_config(configuration)

    # Logger should be registered after config is registered
    registry.register("writer", Logger(config, name="mmf.train"))
    trainer = build_trainer(config)
    trainer.load()
    if predict:
        trainer.inference()
    else:
        trainer.train()
コード例 #2
0
def get_model(device, opts):
    from mmf.utils.build import build_config, build_trainer
    from mmf.common.registry import registry
    from mmf.utils.configuration import Configuration
    from mmf.utils.env import set_seed, setup_imports
    args = argparse.Namespace(config_override=None)
    args.opts = opts
    configuration = Configuration(args)
    configuration.args = args
    config = configuration.get_config()
    config.start_rank = 0
    config.device_id = 0
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    config.training.seed = set_seed(config.training.seed)
    registry.register("seed", config.training.seed)

    config = build_config(configuration)

    # Logger should be registered after config is registered
    registry.register("writer", Logger(config, name="mmf.train"))
    trainer = build_trainer(config)
    # trainer.load()
    ready_trainer(trainer)
    trainer.model.to(device)
    return trainer.model
コード例 #3
0
def main(configuration, init_distributed=False, predict=False):
    # A reload might be needed for imports
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    if init_distributed:
        distributed_init(config)

    seed = config.training.seed
    config.training.seed = set_seed(seed if seed == -1 else seed + get_rank())
    registry.register("seed", config.training.seed)

    config = build_config(configuration)

    setup_logger(color=config.training.colored_logs,
                 disable=config.training.should_not_log)
    logger = logging.getLogger("mmf_cli.run")
    # Log args for debugging purposes
    logger.info(configuration.args)
    logger.info(f"Torch version: {torch.__version__}")
    log_device_names()
    logger.info(f"Using seed {config.training.seed}")

    trainer = build_trainer(config)
    trainer.load()
    if predict:
        trainer.inference()
    else:
        trainer.train()
コード例 #4
0
def main(configuration, init_distributed=False):
    # A reload might be needed for imports
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    if init_distributed:
        distributed_init(config)

    config.training.seed = set_seed(config.training.seed)
    registry.register("seed", config.training.seed)
    print("Using seed {}".format(config.training.seed))

    registry.register("writer", Logger(config, name="mmf.train"))

    trainer = build_trainer(configuration)
    trainer.load()
    trainer.train()