示例#1
0
    def __init__(self, args):
        self.args = args

        self.shard_id = args.local_rank if args.local_rank != -1 else 0
        self.distributed_factor = args.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        model_file = get_model_file(self.args, self.args.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_encoder_params_from_state(saved_state.encoder_params, args)

        tensorizer, reader, optimizer = init_reader_components(
            args.encoder_model_type, args)

        reader, optimizer = setup_for_distributed_mode(reader, optimizer,
                                                       args.device, args.n_gpu,
                                                       args.local_rank,
                                                       args.fp16,
                                                       args.fp16_opt_level)
        self.reader = reader
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        if saved_state:
            self._load_saved_state(saved_state)
示例#2
0
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        model_file = get_model_file(self.cfg, self.cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

        gradient_checkpointing = getattr(self.cfg, "gradient_checkpointing",
                                         False)
        tensorizer, reader, optimizer = init_reader_components(
            cfg.encoder.encoder_model_type,
            cfg,
            gradient_checkpointing=gradient_checkpointing,
        )

        reader, optimizer = setup_for_distributed_mode(
            reader,
            optimizer,
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )
        self.reader = reader
        self.optimizer = optimizer
        self.tensorizer = tensorizer
        self.debugging = getattr(self.cfg, "debugging", False)
        self.wiki_data = None
        self.dev_iterator = None
        self.start_epoch = 0
        self.start_batch = 0
        self.scheduler_state = None
        self.best_validation_result = None
        self.best_cp_name = None
        if saved_state:
            self._load_saved_state(saved_state)
示例#3
0
    def __init__(self, args, model_file):
        self.args = args

        saved_state = load_states_from_checkpoint(model_file)
        set_encoder_params_from_state(saved_state.encoder_params, args)

        tensorizer, reader, optimizer = init_reader_components(
            args.encoder_model_type, args)
        tensorizer.pad_to_max = False
        del optimizer

        reader = reader.cuda()
        reader = reader.eval()
        self.reader = reader
        self.tensorizer = tensorizer

        model_to_load = get_model_obj(self.reader)
        model_to_load.load_state_dict(saved_state.model_dict)
示例#4
0
    def __init__(self, cfg: DictConfig):

        self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0
        self.distributed_factor = cfg.distributed_world_size or 1

        logger.info("***** Initializing components for training *****")

        # if model file is specified, encoder parameters from saved state should be used for initialization
        model_file = get_model_file(cfg, cfg.checkpoint_file_name)
        saved_state = None
        if model_file:
            saved_state = load_states_from_checkpoint_legacy(model_file)
            set_cfg_params_from_state(saved_state.encoder_params, cfg)

            if isinstance(saved_state, CheckpointStateOFA):
                self.mode = "normal"
                # Initialize everything
                gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False)
                tensorizer, model, biencoder_optimizer, reader_optimizer, forward_fn = init_ofa_model(
                    cfg.encoder.encoder_model_type, cfg, gradient_checkpointing=gradient_checkpointing,
                )

            else:
                # Only allowed during evaluation-only mode
                assert isinstance(saved_state, CheckpointState)
                assert cfg.train_datasets is None or len(cfg.train_datasets) == 0
                # Convert from old state to OFA state
                saved_state, self.mode = convert_from_old_state_to_ofa(saved_state)

                if self.mode == "biencoder":
                    # Sanity check
                    assert cfg.evaluate_retriever and (not cfg.evaluate_reader)
                    # Initialize everything
                    tensorizer, biencoder, _ = init_biencoder_components(
                        cfg.encoder.encoder_model_type, cfg, inference_only=True,
                    )
                    reader = None
                else:
                    # Sanity check
                    assert cfg.evaluate_reader and (not cfg.evaluate_retriever)
                    # Initialize everything
                    tensorizer, reader, _ = init_reader_components(
                        cfg.encoder.encoder_model_type, cfg, inference_only=True,
                    )
                    biencoder = None

                # Create a "fake" one-for-all model
                model = SimpleOneForAllModel(
                    biencoder=biencoder, reader=reader, tensorizer=tensorizer,
                )

                # Modify config
                cfg.ignore_checkpoint_optimizer = True
                cfg.ignore_checkpoint_offset = True
                cfg.gradient_checkpointing = False
                cfg.fp16 = False
                # Place holder for backward compatibility
                gradient_checkpointing = False
                biencoder_optimizer = None
                reader_optimizer = None
                forward_fn = ofa_simple_fw_pass  # always the simplest

        else:
            self.mode = "normal"
            # Initialize everything
            gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False)
            tensorizer, model, biencoder_optimizer, reader_optimizer, forward_fn = init_ofa_model(
                cfg.encoder.encoder_model_type, cfg, gradient_checkpointing=gradient_checkpointing,
            )

        model, (biencoder_optimizer, reader_optimizer) = setup_for_distributed_mode(
            model,
            [biencoder_optimizer, reader_optimizer],
            cfg.device,
            cfg.n_gpu,
            cfg.local_rank,
            cfg.fp16,
            cfg.fp16_opt_level,
            gradient_checkpointing=gradient_checkpointing,
        )

        self.forward_fn = forward_fn
        self.model = model
        self.cfg = cfg
        self.ds_cfg = OneForAllDatasetsCfg(cfg)
        self.biencoder_optimizer = biencoder_optimizer
        self.biencoder_scheduler_state = None
        self.reader_optimizer = reader_optimizer
        self.reader_scheduler_state = None
        self.clustering = cfg.biencoder.clustering
        if self.clustering:
            cfg.global_loss_buf_sz = 72000000  # this requires a lot of memory

        self.tensorizer = tensorizer
        self.start_epoch = 0
        self.start_batch = 0
        self.best_validation_result = None
        self.best_cp_name = None

        # Biencoder loss function (note that reader loss is automatically computed)
        self.biencoder_loss_function: BiEncoderNllLoss = init_loss(cfg.encoder.encoder_model_type, cfg)

        if saved_state:
            self._load_saved_state(saved_state)

        self.dev_iterator = None