def main(configuration, init_distributed=False, predict=False): # A reload might be needed for imports setup_imports() configuration.import_user_dir() config = configuration.get_config() if torch.cuda.is_available(): torch.cuda.set_device(config.device_id) torch.cuda.init() if init_distributed: distributed_init(config) seed = config.training.seed config.training.seed = set_seed(seed if seed == -1 else seed + get_rank()) registry.register("seed", config.training.seed) print(f"Using seed {config.training.seed}") config = build_config(configuration) # Logger should be registered after config is registered registry.register("writer", Logger(config, name="mmf.train")) trainer = build_trainer(config) trainer.load() if predict: trainer.inference() else: trainer.train()
def main(configuration, init_distributed=False, predict=False): # A reload might be needed for imports setup_imports() configuration.import_user_dir() config = configuration.get_config() if torch.cuda.is_available(): torch.cuda.set_device(config.device_id) torch.cuda.init() if init_distributed: distributed_init(config) seed = config.training.seed config.training.seed = set_seed(seed if seed == -1 else seed + get_rank()) registry.register("seed", config.training.seed) config = build_config(configuration) setup_logger(color=config.training.colored_logs, disable=config.training.should_not_log) logger = logging.getLogger("mmf_cli.run") # Log args for debugging purposes logger.info(configuration.args) logger.info(f"Torch version: {torch.__version__}") log_device_names() logger.info(f"Using seed {config.training.seed}") trainer = build_trainer(config) trainer.load() if predict: trainer.inference() else: trainer.train()
def forward(self, sample_list: Dict[str, Tensor], model_output: Dict[str, Tensor]): assert ("embedding_1" in model_output and "embedding_2" in model_output ), "Embedding names must be available before loss calculation" embedding_1 = model_output["embedding_1"] embedding_2 = model_output["embedding_2"] assert embedding_1.size(0) == embedding_2.size( 0), "batch size must match" per_gpu_batch_size = embedding_1.size(0) embedding_1_all_gpus = gather_tensor_along_batch_with_backward( embedding_1) embedding_2_all_gpus = gather_tensor_along_batch_with_backward( embedding_2) temperature = model_output["temperature"] logits_1 = ( torch.matmul(embedding_1, embedding_2_all_gpus.transpose(0, 1)) / temperature) logits_2 = ( torch.matmul(embedding_2, embedding_1_all_gpus.transpose(0, 1)) / temperature) labels = per_gpu_batch_size * get_rank() + torch.arange( per_gpu_batch_size, device=temperature.device) loss_1 = F.cross_entropy(logits_1, labels) loss_2 = F.cross_entropy(logits_2, labels) return (loss_1 + loss_2) / 2
def setup_logger( output: str = None, color: bool = True, name: str = "mmf", disable: bool = False, clear_handlers=True, *args, **kwargs, ): """ Initialize the MMF logger and set its verbosity level to "INFO". Outside libraries shouldn't call this in case they have set there own logging handlers and setup. If they do, and don't want to clear handlers, pass clear_handlers options. The initial version of this function was taken from D2 and adapted for MMF. Args: output (str): a file name or a directory to save log. If ends with ".txt" or ".log", assumed to be a file name. Default: Saved to file <save_dir/logs/log_[timestamp].txt> color (bool): If false, won't log colored logs. Default: true name (str): the root module name of this logger. Defaults to "mmf". clear_handlers (bool): If false, won't clear existing handlers. Returns: logging.Logger: a logger """ if disable: return None logger = logging.getLogger(name) logger.propagate = False logging.captureWarnings(True) warnings_logger = logging.getLogger("py.warnings") plain_formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(name)s : %(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) distributed_rank = get_rank() handlers = [] logging_level = registry.get("config").training.logger_level.upper() if distributed_rank == 0: logger.setLevel(logging_level) ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging_level) if color: formatter = ColorfulFormatter( colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) else: formatter = plain_formatter ch.setFormatter(formatter) logger.addHandler(ch) warnings_logger.addHandler(ch) handlers.append(ch) # file logging: all workers if output is None: output = setup_output_folder() if output is not None: if output.endswith(".txt") or output.endswith(".log"): filename = output else: filename = os.path.join(output, "train.log") if distributed_rank > 0: filename = filename + f".rank{distributed_rank}" PathManager.mkdirs(os.path.dirname(filename)) fh = logging.StreamHandler(_cached_log_stream(filename)) fh.setLevel(logging_level) fh.setFormatter(plain_formatter) logger.addHandler(fh) warnings_logger.addHandler(fh) handlers.append(fh) # Slurm/FB output, only log the main process if "train.log" not in filename and distributed_rank == 0: save_dir = get_mmf_env(key="save_dir") filename = os.path.join(save_dir, "train.log") sh = logging.StreamHandler(_cached_log_stream(filename)) sh.setLevel(logging_level) sh.setFormatter(plain_formatter) logger.addHandler(sh) warnings_logger.addHandler(sh) handlers.append(sh) logger.info(f"Logging to: {filename}") # Remove existing handlers to add MMF specific handlers if clear_handlers: for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # Now, add our handlers. logging.basicConfig(level=logging_level, handlers=handlers) registry.register("writer", logger) return logger