コード例 #1
0
def main(configuration, init_distributed=False, predict=False):
    # A reload might be needed for imports
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    if init_distributed:
        distributed_init(config)

    seed = config.training.seed
    config.training.seed = set_seed(seed if seed == -1 else seed + get_rank())
    registry.register("seed", config.training.seed)

    config = build_config(configuration)

    setup_logger(color=config.training.colored_logs,
                 disable=config.training.should_not_log)
    logger = logging.getLogger("mmf_cli.run")
    # Log args for debugging purposes
    logger.info(configuration.args)
    logger.info(f"Torch version: {torch.__version__}")
    log_device_names()
    logger.info(f"Using seed {config.training.seed}")

    trainer = build_trainer(config)
    trainer.load()
    if predict:
        trainer.inference()
    else:
        trainer.train()
コード例 #2
0
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.trainer = argparse.Namespace()
        self.config = load_yaml(os.path.join("configs", "defaults.yaml"))
        self.config = OmegaConf.merge(
            self.config,
            {
                "model": "simple",
                "model_config": {},
                "training": {
                    "checkpoint_interval": 1,
                    "evaluation_interval": 10,
                    "early_stop": {
                        "criteria": "val/total_loss"
                    },
                    "batch_size": 16,
                    "log_interval": 10,
                    "logger_level": "info",
                },
                "env": {
                    "save_dir": self.tmpdir
                },
            },
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)
        setup_logger()
        self.report = Mock(spec=Report)
        self.report.dataset_name = "abcd"
        self.report.dataset_type = "test"

        self.trainer.model = SimpleModule()
        self.trainer.val_loader = torch.utils.data.DataLoader(
            NumbersDataset(), batch_size=self.config.training.batch_size)

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.device = "cpu"
        self.trainer.num_updates = 0
        self.trainer.current_iteration = 0
        self.trainer.current_epoch = 0
        self.trainer.max_updates = 0
        self.trainer.meter = Meter()
        self.cb = LogisticsCallback(self.config, self.trainer)
コード例 #3
0
ファイル: interactive.py プロジェクト: vishalbelsare/pythia
def interactive(opts: typing.Optional[typing.List[str]] = None):
    """Inference runs inference on an image and text provided by the user.
    You can optionally run inference programmatically by passing an optlist as opts.

    Args:
        opts (typing.Optional[typing.List[str]], optional): Optlist which can be used.
            to override opts programmatically. For e.g. if you pass
            opts = ["checkpoint_path=my/directory"], this will set the checkpoint.
    """
    if opts is None:
        parser = flags.get_parser()
        args = parser.parse_args()
    else:
        args = argparse.Namespace(config_override=None)
        args.opts = opts

    setup_logger()
    logger = logging.getLogger("mmf_cli.interactive")

    config = construct_config(args.opts)
    inference = Inference(checkpoint_path=config.checkpoint_path)
    logger.info("Enter 'exit' at any point to terminate.")
    logger.info("Enter an image URL:")
    image_url = input()
    logger.info("Got image URL.")
    logger.info("Enter text:")
    text = input()
    logger.info("Got text input.")
    while text != "exit":
        logger.info("Running inference on image and text input.")
        answer = inference.forward(image_url, {"text": text}, image_format="url")
        logger.info("Model response: " + answer)
        logger.info(
            f"Enter another image URL or leave it blank to continue using {image_url}:"
        )
        new_image_url = input()
        if new_image_url != "":
            image_url = new_image_url
        if new_image_url == "exit":
            break
        logger.info("Enter another text input:")
        text = input()
コード例 #4
0
ファイル: test_logger.py プロジェクト: SunYanCN/pythia
 def setUpClass(cls) -> None:
     cls._tmpdir = tempfile.mkdtemp()
     args = argparse.Namespace()
     args.opts = [
         f"env.save_dir={cls._tmpdir}", f"model=cnn_lstm", f"dataset=clevr"
     ]
     args.config_override = None
     configuration = Configuration(args)
     configuration.freeze()
     cls.config = configuration.get_config()
     registry.register("config", cls.config)
     cls.writer = setup_logger()