def main(configuration, init_distributed=False, predict=False): # A reload might be needed for imports setup_imports() configuration.import_user_dir() config = configuration.get_config() if torch.cuda.is_available(): torch.cuda.set_device(config.device_id) torch.cuda.init() if init_distributed: distributed_init(config) seed = config.training.seed config.training.seed = set_seed(seed if seed == -1 else seed + get_rank()) registry.register("seed", config.training.seed) config = build_config(configuration) setup_logger(color=config.training.colored_logs, disable=config.training.should_not_log) logger = logging.getLogger("multimodelity_cli.run") # Log args for debugging purposes logger.info(configuration.args) logger.info(f"Torch version: {torch.__version__}") log_device_names() logger.info(f"Using seed {config.training.seed}") trainer = build_trainer(config) trainer.load() if predict: trainer.inference() else: trainer.train()
def get_data_t(self, t, data, batch_size_t, prev_output): if self.teacher_forcing: # Modify batch_size for timestep t batch_size_t = sum([l > t for l in data["decode_lengths"]]) elif prev_output is not None and self.config.inference.type == "greedy": # Adding t-1 output words to data["text"] for greedy decoding output_softmax = torch.log_softmax(prev_output, dim=1) _, indices = torch.max(output_softmax, dim=1, keepdim=True) data["texts"] = torch.cat( (data["texts"], indices.view(batch_size_t, 1)), dim=1 ) # Slice data based on batch_size at timestep t data["texts"] = data["texts"][:batch_size_t] if "state" in data: h1 = data["state"]["td_hidden"][0][:batch_size_t] c1 = data["state"]["td_hidden"][1][:batch_size_t] h2 = data["state"]["lm_hidden"][0][:batch_size_t] c2 = data["state"]["lm_hidden"][1][:batch_size_t] else: h1, c1 = self.init_hidden_state(data["texts"]) h2, c2 = self.init_hidden_state(data["texts"]) data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)} registry.register(f"{h1.device}_lstm_state", data["state"]) return data, batch_size_t
def update_meter(self, report: Dict[str, Any], meter: Type[Meter] = None, eval_mode: bool = False) -> None: if meter is None: meter = self.meter if hasattr(report, "metrics"): metrics_dict = report.metrics reduced_metrics_dict = reduce_dict(metrics_dict) if not eval_mode: loss_dict = report.losses reduced_loss_dict = reduce_dict(loss_dict) with torch.no_grad(): # Add metrics to meter only when mode is `eval` meter_update_dict = {} if not eval_mode: total_loss_key = report.dataset_type + "/total_loss" meter_update_dict, total_loss = self.update_dict( meter_update_dict, reduced_loss_dict) registry.register(total_loss_key, total_loss) meter_update_dict.update({total_loss_key: total_loss}) if hasattr(report, "metrics"): meter_update_dict, _ = self.update_dict( meter_update_dict, reduced_metrics_dict) meter.update(meter_update_dict, report.batch_size)
def __call__(self, sample_list, model_output, *args, **kwargs): values = {} dataset_type = sample_list.dataset_type dataset_name = sample_list.dataset_name with torch.no_grad(): for metric_name, metric_object in self.metrics.items(): key = f"{dataset_type}/{dataset_name}/{metric_name}" values[key] = metric_object._calculate_with_checks( sample_list, model_output, *args, **kwargs ) if not isinstance(values[key], torch.Tensor): values[key] = torch.tensor(values[key], dtype=torch.float) else: values[key] = values[key].float() if values[key].dim() == 0: values[key] = values[key].view(1) registry.register( "{}.{}.{}".format("metrics", sample_list.dataset_name, dataset_type), values ) return values
def __init__(self, args=None, default_only=False): self.config = {} if not args: import argparse args = argparse.Namespace(opts=[]) default_only = True self.args = args self._register_resolvers() self._default_config = self._build_default_config() if default_only: other_configs = {} else: other_configs = self._build_other_configs() self.config = OmegaConf.merge(self._default_config, other_configs) self.config = self._merge_with_dotlist(self.config, args.opts) self._update_specific(self.config) self.upgrade(self.config) # Resolve the config here itself after full creation so that spawned workers # don't face any issues self.config = OmegaConf.create( OmegaConf.to_container(self.config, resolve=True)) registry.register("config", self.config)
def forward(self, sample_list: Dict[str, Tensor], model_output: Dict[str, Tensor]): """Takes in the original ``SampleList`` returned from DataLoader and `model_output` returned from the model and returned a Dict containing loss for each of the losses in `losses`. Args: sample_list (SampleList): SampleList given be the dataloader. model_output (Dict): Dict returned from model as output. Returns: Dict: Dictionary containing loss value for each of the loss. """ output = {} if "targets" not in sample_list: if not self._evaluation_predict: warnings.warn("Sample list has not field 'targets', are you " "sure that your ImDB has labels? you may have " "wanted to run with evaluation.predict=true") return output for loss in self.losses: output.update(loss(sample_list, model_output)) if not torch.jit.is_scripting(): registry_loss_key = "{}.{}.{}".format("losses", sample_list["dataset_name"], sample_list["dataset_type"]) # Register the losses to registry registry.register(registry_loss_key, output) return output
def update_registry_for_model(self, config): registry.register( self.dataset_name + "_text_vocab_size", self.dataset.text_processor.get_vocab_size(), ) registry.register( self.dataset_name + "_num_final_outputs", self.dataset.answer_processor.get_vocab_size(), )
def get_multimodelity_root(): from multimodelity.common.registry import registry multimodelity_root = registry.get("multimodelity_root", no_warning=True) if multimodelity_root is None: multimodelity_root = os.path.dirname(os.path.abspath(__file__)) multimodelity_root = os.path.abspath( os.path.join(multimodelity_root, "..")) registry.register("multimodelity_root", multimodelity_root) return multimodelity_root
def get_global_config(key=None): config = registry.get("config") if config is None: configuration = Configuration() config = configuration.get_config() registry.register("config", config) if key: config = OmegaConf.select(config, key) return config
def get_data_t(self, data, batch_size_t): data["texts"] = data["texts"][:batch_size_t] if "state" in data: h1 = data["state"]["td_hidden"][0][:batch_size_t] c1 = data["state"]["td_hidden"][1][:batch_size_t] h2 = data["state"]["lm_hidden"][0][:batch_size_t] c2 = data["state"]["lm_hidden"][1][:batch_size_t] else: h1, c1 = self.init_hidden_state(data["texts"]) h2, c2 = self.init_hidden_state(data["texts"]) data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)} registry.register(f"{h1.device}_lstm_state", data["state"]) return data, batch_size_t
def setUpClass(cls) -> None: cls._tmpdir = tempfile.mkdtemp() args = argparse.Namespace() args.opts = [ f"env.save_dir={cls._tmpdir}", f"model=cnn_lstm", f"dataset=clevr" ] args.config_override = None configuration = Configuration(args) configuration.freeze() cls.config = configuration.get_config() registry.register("config", cls.config) setup_output_folder.cache_clear() setup_logger.cache_clear() cls.writer = setup_logger()
def configure_device(self) -> None: self.local_rank = self.config.device_id self.device = self.local_rank self.distributed = False # Will be updated later based on distributed setup registry.register("global_device", self.device) if self.config.distributed.init_method is not None: self.distributed = True self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.local_rank) elif torch.cuda.is_available(): self.device = torch.device("cuda") torch.cuda.set_device(0) else: self.device = torch.device("cpu") registry.register("global_device", self.config.distributed.rank)
def build_config( configuration: Type[Configuration], *args, **kwargs ) -> multimodelity_typings.DictConfig: """Builder function for config. Freezes the configuration and registers configuration object and config DictConfig object to registry. Args: configuration (Configuration): Configuration object that will be used to create the config. Returns: (DictConfig): A config which is of type Omegaconf.DictConfig """ configuration.freeze() config = configuration.get_config() registry.register("config", config) registry.register("configuration", configuration) return config
def setUp(self): setup_imports() torch.manual_seed(1234) config_path = os.path.join( get_multimodelity_root(), "..", "projects", "butd", "configs", "coco", "beam_search.yaml", ) config_path = os.path.abspath(config_path) args = dummy_args(model="butd", dataset="coco") args.opts.append(f"config={config_path}") configuration = Configuration(args) configuration.config.datasets = "coco" configuration.freeze() self.config = configuration.config registry.register("config", self.config)
def setUp(self): self.tmpdir = tempfile.mkdtemp() self.trainer = argparse.Namespace() self.config = OmegaConf.create({ "model": "simple", "model_config": {}, "training": { "checkpoint_interval": 1, "evaluation_interval": 10, "early_stop": { "criteria": "val/total_loss" }, "batch_size": 16, "log_interval": 10, "logger_level": "info", }, "env": { "save_dir": self.tmpdir }, }) # Keep original copy for testing purposes self.trainer.config = deepcopy(self.config) registry.register("config", self.trainer.config) setup_logger.cache_clear() setup_logger() self.report = Mock(spec=Report) self.report.dataset_name = "abcd" self.report.dataset_type = "test" self.trainer.model = SimpleModule() self.trainer.val_dataset = NumbersDataset() self.trainer.optimizer = torch.optim.Adam( self.trainer.model.parameters(), lr=1e-01) self.trainer.device = "cpu" self.trainer.num_updates = 0 self.trainer.current_iteration = 0 self.trainer.current_epoch = 0 self.trainer.max_updates = 0 self.trainer.meter = Meter() self.cb = LogisticsCallback(self.config, self.trainer)
def setUp(self): setup_imports() torch.manual_seed(1234) config_path = os.path.join( get_multimodelity_root(), "..", "projects", "butd", "configs", "coco", "nucleus_sampling.yaml", ) config_path = os.path.abspath(config_path) args = dummy_args(model="butd", dataset="coco") args.opts.append(f"config={config_path}") configuration = Configuration(args) configuration.config.datasets = "coco" configuration.config.model_config.butd.inference.params.sum_threshold = 0.5 configuration.freeze() self.config = configuration.config registry.register("config", self.config)
def test_caption_bleu4(self): path = os.path.join( os.path.abspath(__file__), "../../../multimodelity/configs/datasets/coco/defaults.yaml", ) config = load_yaml(os.path.abspath(path)) captioning_config = config.dataset_config.coco caption_processor_config = captioning_config.processors.caption_processor vocab_path = os.path.join(os.path.abspath(__file__), "..", "..", "data", "vocab.txt") caption_processor_config.params.vocab.type = "random" caption_processor_config.params.vocab.vocab_file = os.path.abspath( vocab_path) caption_processor = CaptionProcessor(caption_processor_config.params) registry.register("coco_caption_processor", caption_processor) caption_bleu4 = metrics.CaptionBleu4Metric() expected = Sample() predicted = dict() # Test complete match expected.answers = torch.empty((5, 5, 10)) expected.answers.fill_(4) predicted["scores"] = torch.zeros((5, 10, 19)) predicted["scores"][:, :, 4] = 1.0 self.assertEqual( caption_bleu4.calculate(expected, predicted).item(), 1.0) # Test partial match expected.answers = torch.empty((5, 5, 10)) expected.answers.fill_(4) predicted["scores"] = torch.zeros((5, 10, 19)) predicted["scores"][:, 0:5, 4] = 1.0 predicted["scores"][:, 5:, 18] = 1.0 self.assertAlmostEqual( caption_bleu4.calculate(expected, predicted).item(), 0.3928, 4)
def setUp(self): self.trainer = argparse.Namespace() self.config = OmegaConf.create({ "model": "simple", "model_config": {}, "training": { "lr_scheduler": True, "lr_ratio": 0.1, "lr_steps": [1, 2], "use_warmup": False, }, }) # Keep original copy for testing purposes self.trainer.config = deepcopy(self.config) registry.register("config", self.trainer.config) self.trainer.model = SimpleModule() self.trainer.val_dataset = NumbersDataset() self.trainer.optimizer = torch.optim.Adam( self.trainer.model.parameters(), lr=1e-01) self.trainer.lr_scheduler_callback = LRSchedulerCallback( self.config, self.trainer)
def parallelize_model(self) -> None: registry.register("data_parallel", False) registry.register("distributed", False) if ("cuda" in str(self.device) and torch.cuda.device_count() > 1 and not self.distributed): registry.register("data_parallel", True) self.model = torch.nn.DataParallel(self.model) if "cuda" in str(self.device) and self.distributed: registry.register("distributed", True) self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], output_device=self.local_rank, check_reduction=True, find_unused_parameters=self.config.training. find_unused_parameters, )
def setup_imports(): from multimodelity.common.registry import registry # First, check if imports are already setup has_already_setup = registry.get("imports_setup", no_warning=True) if has_already_setup: return # Automatically load all of the modules, so that # they register with registry root_folder = registry.get("multimodelity_root", no_warning=True) if root_folder is None: root_folder = os.path.dirname(os.path.abspath(__file__)) root_folder = os.path.join(root_folder, "..") environment_multimodelity_path = os.environ.get( "multimodelity_PATH", os.environ.get("PYTHIA_PATH")) if environment_multimodelity_path is not None: root_folder = environment_multimodelity_path registry.register("pythia_path", root_folder) registry.register("multimodelity_path", root_folder) trainer_folder = os.path.join(root_folder, "trainers") trainer_pattern = os.path.join(trainer_folder, "**", "*.py") datasets_folder = os.path.join(root_folder, "datasets") datasets_pattern = os.path.join(datasets_folder, "**", "*.py") model_folder = os.path.join(root_folder, "models") model_pattern = os.path.join(model_folder, "**", "*.py") importlib.import_module("multimodelity.common.meter") files = (glob.glob(datasets_pattern, recursive=True) + glob.glob(model_pattern, recursive=True) + glob.glob(trainer_pattern, recursive=True)) for f in files: f = os.path.realpath(f) if f.endswith(".py") and not f.endswith("__init__.py"): splits = f.split(os.sep) import_prefix_index = 0 for idx, split in enumerate(splits): if split == "multimodelity": import_prefix_index = idx + 1 file_name = splits[-1] module_name = file_name[:file_name.find(".py")] module = ".".join(["multimodelity"] + splits[import_prefix_index:-1] + [module_name]) importlib.import_module(module) registry.register("imports_setup", True)
def update_registry_for_pretrained(cls, config, checkpoint, full_output): from omegaconf import OmegaConf # Hack datasets using OmegaConf datasets = full_output["full_config"].datasets dataset = datasets.split(",")[0] config_mock = OmegaConf.create({"datasets": datasets}) registry.register("config", config_mock) registry.register( f"{dataset}_num_final_outputs", # Need to add as it is subtracted checkpoint["classifier.module.weight"].size(0) + config.classifier.ocr_max_num, ) # Fix this later, when processor pipeline is available answer_processor = OmegaConf.create({"BOS_IDX": 1}) registry.register(f"{dataset}_answer_processor", answer_processor)
def _load_counts_and_lr_scheduler(self, ckpt): ckpt_config = self.trainer.config.checkpoint if "best_update" in ckpt: if ckpt_config.resume_best: self.trainer.num_updates = ckpt.get("best_update", self.trainer.num_updates) self.trainer.current_iteration = ckpt.get( "best_iteration", self.trainer.current_iteration) else: self.trainer.num_updates = ckpt.get("num_updates", self.trainer.num_updates) self.trainer.current_iteration = ckpt.get( "current_iteration", self.trainer.current_iteration) self.trainer.current_epoch = ckpt.get("current_epoch", self.trainer.current_epoch) elif "best_iteration" in ckpt: # Preserve old behavior for old checkpoints where we always # load best iteration if ckpt_config.resume_best and "current_iteration" in ckpt: self.trainer.current_iteration = ckpt["current_iteration"] else: self.trainer.current_iteration = ckpt.get( "best_iteration", self.trainer.current_iteration) self.trainer.num_updates = self.trainer.current_iteration lr_scheduler = self.trainer.lr_scheduler_callback._scheduler if lr_scheduler is not None: if "lr_scheduler" in ckpt: lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) else: warnings.warn( "'lr_scheduler' key is not present in the " "checkpoint asked to be loaded. Setting lr_scheduler's " "last_epoch to current_iteration.") lr_scheduler.last_epoch = self.trainer.current_iteration registry.register("current_iteration", self.trainer.current_iteration) registry.register("num_updates", self.trainer.num_updates) self.trainer.current_epoch = ckpt.get("best_epoch", self.trainer.current_epoch) registry.register("current_epoch", self.trainer.current_epoch)
def setUp(self): torch.manual_seed(1234) registry.register("clevr_text_vocab_size", 80) registry.register("clevr_num_final_outputs", 32) config_path = os.path.join( get_multimodelity_root(), "..", "projects", "others", "cnn_lstm", "clevr", "defaults.yaml", ) config_path = os.path.abspath(config_path) args = dummy_args(model="cnn_lstm", dataset="clevr") args.opts.append(f"config={config_path}") configuration = Configuration(args) configuration.config.datasets = "clevr" configuration.freeze() self.config = configuration.config registry.register("config", self.config)
def update_registry_for_model(self, config): registry.register( self.dataset_name + "_text_vocab_size", self.dataset.masked_token_processor.get_vocab_size(), )
def run_training_epoch(self) -> None: should_break = False while self.num_updates < self.max_updates and not should_break: self.current_epoch += 1 registry.register("current_epoch", self.current_epoch) # Seed the sampler in case if it is distributed self.dataset_loader.seed_sampler("train", self.current_epoch) # For iterable datasets we cannot determine length of dataset properly. # For those cases we set num_remaining_batches to be the (number of # updates remaining x update_frequency) num_remaining_batches = (((self.max_updates - self.num_updates) * self.training_config.update_frequency) if isinstance( self.train_loader.current_dataset, torch.utils.data.IterableDataset) else len(self.train_loader)) combined_report = None num_batches_for_this_update = 1 for idx, batch in enumerate(self.train_loader): if (idx + 1) % self.training_config.update_frequency == 0: combined_report = None num_batches_for_this_update = min( self.training_config.update_frequency, num_remaining_batches) self._start_update() # batch execution starts here self.on_batch_start() self.profile("Batch load time") report = self.run_training_batch(batch, num_batches_for_this_update) # accumulate necessary params for metric calculation if combined_report is None: combined_report = report else: combined_report.accumulate_tensor_fields( report, self.metrics.required_params) combined_report.batch_size += report.batch_size # batch execution ends here self.on_batch_end(report=combined_report, meter=self.meter) # check if an update has finished, if no continue if (idx + 1) % self.training_config.update_frequency: continue self._finish_update() should_log = False if self.num_updates % self.logistics_callback.log_interval == 0: should_log = True # Calculate metrics every log interval for debugging if self.training_config.evaluate_metrics: combined_report.metrics = self.metrics( combined_report, combined_report) self.update_meter(combined_report, self.meter) self.on_update_end(report=combined_report, meter=self.meter, should_log=should_log) num_remaining_batches -= num_batches_for_this_update # Check if training should be stopped should_break = False if self.num_updates % self.training_config.evaluation_interval == 0: # Validation begin callbacks self.on_validation_start() logger.info( "Evaluation time. Running on full validation set...") # Validation and Early stopping # Create a new meter for this case report, meter = self.evaluation_loop(self.val_loader) # Validation end callbacks stop = self.early_stop_callback.on_validation_end( report=report, meter=meter) self.on_validation_end(report=report, meter=meter) gc.collect() if "cuda" in str(self.device): torch.cuda.empty_cache() if stop is True: logger.info("Early stopping activated") should_break = True if self.num_updates >= self.max_updates: should_break = True if should_break: break
def setup_logger( output: str = None, color: bool = True, name: str = "multimodelity", disable: bool = False, clear_handlers=True, *args, **kwargs, ): """ Initialize the multimodelity logger and set its verbosity level to "INFO". Outside libraries shouldn't call this in case they have set there own logging handlers and setup. If they do, and don't want to clear handlers, pass clear_handlers options. The initial version of this function was taken from D2 and adapted for multimodelity. Args: output (str): a file name or a directory to save log. If ends with ".txt" or ".log", assumed to be a file name. Default: Saved to file <save_dir/logs/log_[timestamp].txt> color (bool): If false, won't log colored logs. Default: true name (str): the root module name of this logger. Defaults to "multimodelity". clear_handlers (bool): If false, won't clear existing handlers. Returns: logging.Logger: a logger """ if disable: return None logger = logging.getLogger(name) logger.propagate = False logging.captureWarnings(True) warnings_logger = logging.getLogger("py.warnings") plain_formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(name)s : %(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) distributed_rank = get_rank() handlers = [] if distributed_rank == 0: logger.setLevel(logging.INFO) ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.INFO) if color: formatter = ColorfulFormatter( colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) else: formatter = plain_formatter ch.setFormatter(formatter) logger.addHandler(ch) warnings_logger.addHandler(ch) handlers.append(ch) # file logging: all workers if output is None: output = setup_output_folder() if output is not None: if output.endswith(".txt") or output.endswith(".log"): filename = output else: filename = os.path.join(output, "train.log") if distributed_rank > 0: filename = filename + f".rank{distributed_rank}" PathManager.mkdirs(os.path.dirname(filename)) fh = logging.StreamHandler(_cached_log_stream(filename)) fh.setLevel(logging.INFO) fh.setFormatter(plain_formatter) logger.addHandler(fh) warnings_logger.addHandler(fh) handlers.append(fh) # Slurm/FB output, only log the main process if "train.log" not in filename and distributed_rank == 0: save_dir = get_multimodelity_env(key="save_dir") filename = os.path.join(save_dir, "train.log") sh = logging.StreamHandler(_cached_log_stream(filename)) sh.setLevel(logging.INFO) sh.setFormatter(plain_formatter) logger.addHandler(sh) warnings_logger.addHandler(sh) handlers.append(sh) logger.info(f"Logging to: {filename}") # Remove existing handlers to add multimodelity specific handlers if clear_handlers: for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # Now, add our handlers. logging.basicConfig(level=logging.INFO, handlers=handlers) registry.register("writer", logger) return logger