def main(configuration, init_distributed=False, predict=False): # A reload might be needed for imports setup_imports() configuration.import_user_dir() config = configuration.get_config() if torch.cuda.is_available(): torch.cuda.set_device(config.device_id) torch.cuda.init() if init_distributed: distributed_init(config) seed = config.training.seed config.training.seed = set_seed(seed if seed == -1 else seed + get_rank()) registry.register("seed", config.training.seed) config = build_config(configuration) setup_logger(color=config.training.colored_logs, disable=config.training.should_not_log) logger = logging.getLogger("mmf_cli.run") # Log args for debugging purposes logger.info(configuration.args) logger.info(f"Torch version: {torch.__version__}") log_device_names() logger.info(f"Using seed {config.training.seed}") trainer = build_trainer(config) trainer.load() if predict: trainer.inference() else: trainer.train()
def setUp(self): self.tmpdir = tempfile.mkdtemp() self.trainer = argparse.Namespace() self.config = load_yaml(os.path.join("configs", "defaults.yaml")) self.config = OmegaConf.merge( self.config, { "model": "simple", "model_config": {}, "training": { "checkpoint_interval": 1, "evaluation_interval": 10, "early_stop": { "criteria": "val/total_loss" }, "batch_size": 16, "log_interval": 10, "logger_level": "info", }, "env": { "save_dir": self.tmpdir }, }, ) # Keep original copy for testing purposes self.trainer.config = deepcopy(self.config) registry.register("config", self.trainer.config) setup_logger() self.report = Mock(spec=Report) self.report.dataset_name = "abcd" self.report.dataset_type = "test" self.trainer.model = SimpleModule() self.trainer.val_loader = torch.utils.data.DataLoader( NumbersDataset(), batch_size=self.config.training.batch_size) self.trainer.optimizer = torch.optim.Adam( self.trainer.model.parameters(), lr=1e-01) self.trainer.device = "cpu" self.trainer.num_updates = 0 self.trainer.current_iteration = 0 self.trainer.current_epoch = 0 self.trainer.max_updates = 0 self.trainer.meter = Meter() self.cb = LogisticsCallback(self.config, self.trainer)
def interactive(opts: typing.Optional[typing.List[str]] = None): """Inference runs inference on an image and text provided by the user. You can optionally run inference programmatically by passing an optlist as opts. Args: opts (typing.Optional[typing.List[str]], optional): Optlist which can be used. to override opts programmatically. For e.g. if you pass opts = ["checkpoint_path=my/directory"], this will set the checkpoint. """ if opts is None: parser = flags.get_parser() args = parser.parse_args() else: args = argparse.Namespace(config_override=None) args.opts = opts setup_logger() logger = logging.getLogger("mmf_cli.interactive") config = construct_config(args.opts) inference = Inference(checkpoint_path=config.checkpoint_path) logger.info("Enter 'exit' at any point to terminate.") logger.info("Enter an image URL:") image_url = input() logger.info("Got image URL.") logger.info("Enter text:") text = input() logger.info("Got text input.") while text != "exit": logger.info("Running inference on image and text input.") answer = inference.forward(image_url, {"text": text}, image_format="url") logger.info("Model response: " + answer) logger.info( f"Enter another image URL or leave it blank to continue using {image_url}:" ) new_image_url = input() if new_image_url != "": image_url = new_image_url if new_image_url == "exit": break logger.info("Enter another text input:") text = input()
def setUpClass(cls) -> None: cls._tmpdir = tempfile.mkdtemp() args = argparse.Namespace() args.opts = [ f"env.save_dir={cls._tmpdir}", f"model=cnn_lstm", f"dataset=clevr" ] args.config_override = None configuration = Configuration(args) configuration.freeze() cls.config = configuration.get_config() registry.register("config", cls.config) cls.writer = setup_logger()