Ejemplo n.º 1
0
def main(configuration, init_distributed=False, predict=False):
    # A reload might be needed for imports
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    if init_distributed:
        distributed_init(config)

    seed = config.training.seed
    config.training.seed = set_seed(seed if seed == -1 else seed + get_rank())
    registry.register("seed", config.training.seed)
    print(f"Using seed {config.training.seed}")

    config = build_config(configuration)

    # Logger should be registered after config is registered
    registry.register("writer", Logger(config, name="mmf.train"))
    trainer = build_trainer(config)
    trainer.load()
    if predict:
        trainer.inference()
    else:
        trainer.train()
Ejemplo n.º 2
0
def main(configuration, init_distributed=False, predict=False):
    # A reload might be needed for imports
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    if init_distributed:
        distributed_init(config)

    seed = config.training.seed
    config.training.seed = set_seed(seed if seed == -1 else seed + get_rank())
    registry.register("seed", config.training.seed)

    config = build_config(configuration)

    setup_logger(color=config.training.colored_logs,
                 disable=config.training.should_not_log)
    logger = logging.getLogger("mmf_cli.run")
    # Log args for debugging purposes
    logger.info(configuration.args)
    logger.info(f"Torch version: {torch.__version__}")
    log_device_names()
    logger.info(f"Using seed {config.training.seed}")

    trainer = build_trainer(config)
    trainer.load()
    if predict:
        trainer.inference()
    else:
        trainer.train()
Ejemplo n.º 3
0
    def forward(self, sample_list: Dict[str, Tensor],
                model_output: Dict[str, Tensor]):
        assert ("embedding_1" in model_output and "embedding_2" in model_output
                ), "Embedding names must be available before loss calculation"

        embedding_1 = model_output["embedding_1"]
        embedding_2 = model_output["embedding_2"]

        assert embedding_1.size(0) == embedding_2.size(
            0), "batch size must match"
        per_gpu_batch_size = embedding_1.size(0)

        embedding_1_all_gpus = gather_tensor_along_batch_with_backward(
            embedding_1)
        embedding_2_all_gpus = gather_tensor_along_batch_with_backward(
            embedding_2)

        temperature = model_output["temperature"]

        logits_1 = (
            torch.matmul(embedding_1, embedding_2_all_gpus.transpose(0, 1)) /
            temperature)
        logits_2 = (
            torch.matmul(embedding_2, embedding_1_all_gpus.transpose(0, 1)) /
            temperature)
        labels = per_gpu_batch_size * get_rank() + torch.arange(
            per_gpu_batch_size, device=temperature.device)

        loss_1 = F.cross_entropy(logits_1, labels)
        loss_2 = F.cross_entropy(logits_2, labels)

        return (loss_1 + loss_2) / 2
Ejemplo n.º 4
0
def setup_logger(
    output: str = None,
    color: bool = True,
    name: str = "mmf",
    disable: bool = False,
    clear_handlers=True,
    *args,
    **kwargs,
):
    """
    Initialize the MMF logger and set its verbosity level to "INFO".
    Outside libraries shouldn't call this in case they have set there
    own logging handlers and setup. If they do, and don't want to
    clear handlers, pass clear_handlers options.

    The initial version of this function was taken from D2 and adapted
    for MMF.

    Args:
        output (str): a file name or a directory to save log.
            If ends with ".txt" or ".log", assumed to be a file name.
            Default: Saved to file <save_dir/logs/log_[timestamp].txt>
        color (bool): If false, won't log colored logs. Default: true
        name (str): the root module name of this logger. Defaults to "mmf".
        clear_handlers (bool): If false, won't clear existing handlers.

    Returns:
        logging.Logger: a logger
    """
    if disable:
        return None
    logger = logging.getLogger(name)
    logger.propagate = False

    logging.captureWarnings(True)
    warnings_logger = logging.getLogger("py.warnings")

    plain_formatter = logging.Formatter(
        "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%S",
    )

    distributed_rank = get_rank()
    handlers = []

    logging_level = registry.get("config").training.logger_level.upper()
    if distributed_rank == 0:
        logger.setLevel(logging_level)
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging_level)
        if color:
            formatter = ColorfulFormatter(
                colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
                datefmt="%Y-%m-%dT%H:%M:%S",
            )
        else:
            formatter = plain_formatter
        ch.setFormatter(formatter)
        logger.addHandler(ch)
        warnings_logger.addHandler(ch)
        handlers.append(ch)

    # file logging: all workers
    if output is None:
        output = setup_output_folder()

    if output is not None:
        if output.endswith(".txt") or output.endswith(".log"):
            filename = output
        else:
            filename = os.path.join(output, "train.log")
        if distributed_rank > 0:
            filename = filename + f".rank{distributed_rank}"
        PathManager.mkdirs(os.path.dirname(filename))

        fh = logging.StreamHandler(_cached_log_stream(filename))
        fh.setLevel(logging_level)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)
        warnings_logger.addHandler(fh)
        handlers.append(fh)

        # Slurm/FB output, only log the main process
        if "train.log" not in filename and distributed_rank == 0:
            save_dir = get_mmf_env(key="save_dir")
            filename = os.path.join(save_dir, "train.log")
            sh = logging.StreamHandler(_cached_log_stream(filename))
            sh.setLevel(logging_level)
            sh.setFormatter(plain_formatter)
            logger.addHandler(sh)
            warnings_logger.addHandler(sh)
            handlers.append(sh)

        logger.info(f"Logging to: {filename}")

    # Remove existing handlers to add MMF specific handlers
    if clear_handlers:
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)
    # Now, add our handlers.
    logging.basicConfig(level=logging_level, handlers=handlers)

    registry.register("writer", logger)

    return logger