def __init__(self, args): self.args = args self.shard_id = args.local_rank if args.local_rank != -1 else 0 self.distributed_factor = args.distributed_world_size or 1 logger.info("***** Initializing components for training *****") model_file = get_model_file(self.args, self.args.checkpoint_file_name) saved_state = None if model_file: saved_state = load_states_from_checkpoint(model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, reader, optimizer = init_reader_components( args.encoder_model_type, args) reader, optimizer = setup_for_distributed_mode(reader, optimizer, args.device, args.n_gpu, args.local_rank, args.fp16, args.fp16_opt_level) self.reader = reader self.optimizer = optimizer self.tensorizer = tensorizer self.start_epoch = 0 self.start_batch = 0 self.scheduler_state = None self.best_validation_result = None self.best_cp_name = None if saved_state: self._load_saved_state(saved_state)
def __init__(self, cfg: DictConfig): self.cfg = cfg self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0 self.distributed_factor = cfg.distributed_world_size or 1 logger.info("***** Initializing components for training *****") model_file = get_model_file(self.cfg, self.cfg.checkpoint_file_name) saved_state = None if model_file: saved_state = load_states_from_checkpoint(model_file) set_cfg_params_from_state(saved_state.encoder_params, cfg) gradient_checkpointing = getattr(self.cfg, "gradient_checkpointing", False) tensorizer, reader, optimizer = init_reader_components( cfg.encoder.encoder_model_type, cfg, gradient_checkpointing=gradient_checkpointing, ) reader, optimizer = setup_for_distributed_mode( reader, optimizer, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16, cfg.fp16_opt_level, gradient_checkpointing=gradient_checkpointing, ) self.reader = reader self.optimizer = optimizer self.tensorizer = tensorizer self.debugging = getattr(self.cfg, "debugging", False) self.wiki_data = None self.dev_iterator = None self.start_epoch = 0 self.start_batch = 0 self.scheduler_state = None self.best_validation_result = None self.best_cp_name = None if saved_state: self._load_saved_state(saved_state)
def __init__(self, args, model_file): self.args = args saved_state = load_states_from_checkpoint(model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, reader, optimizer = init_reader_components( args.encoder_model_type, args) tensorizer.pad_to_max = False del optimizer reader = reader.cuda() reader = reader.eval() self.reader = reader self.tensorizer = tensorizer model_to_load = get_model_obj(self.reader) model_to_load.load_state_dict(saved_state.model_dict)
def __init__(self, cfg: DictConfig): self.shard_id = cfg.local_rank if cfg.local_rank != -1 else 0 self.distributed_factor = cfg.distributed_world_size or 1 logger.info("***** Initializing components for training *****") # if model file is specified, encoder parameters from saved state should be used for initialization model_file = get_model_file(cfg, cfg.checkpoint_file_name) saved_state = None if model_file: saved_state = load_states_from_checkpoint_legacy(model_file) set_cfg_params_from_state(saved_state.encoder_params, cfg) if isinstance(saved_state, CheckpointStateOFA): self.mode = "normal" # Initialize everything gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False) tensorizer, model, biencoder_optimizer, reader_optimizer, forward_fn = init_ofa_model( cfg.encoder.encoder_model_type, cfg, gradient_checkpointing=gradient_checkpointing, ) else: # Only allowed during evaluation-only mode assert isinstance(saved_state, CheckpointState) assert cfg.train_datasets is None or len(cfg.train_datasets) == 0 # Convert from old state to OFA state saved_state, self.mode = convert_from_old_state_to_ofa(saved_state) if self.mode == "biencoder": # Sanity check assert cfg.evaluate_retriever and (not cfg.evaluate_reader) # Initialize everything tensorizer, biencoder, _ = init_biencoder_components( cfg.encoder.encoder_model_type, cfg, inference_only=True, ) reader = None else: # Sanity check assert cfg.evaluate_reader and (not cfg.evaluate_retriever) # Initialize everything tensorizer, reader, _ = init_reader_components( cfg.encoder.encoder_model_type, cfg, inference_only=True, ) biencoder = None # Create a "fake" one-for-all model model = SimpleOneForAllModel( biencoder=biencoder, reader=reader, tensorizer=tensorizer, ) # Modify config cfg.ignore_checkpoint_optimizer = True cfg.ignore_checkpoint_offset = True cfg.gradient_checkpointing = False cfg.fp16 = False # Place holder for backward compatibility gradient_checkpointing = False biencoder_optimizer = None reader_optimizer = None forward_fn = ofa_simple_fw_pass # always the simplest else: self.mode = "normal" # Initialize everything gradient_checkpointing = getattr(cfg, "gradient_checkpointing", False) tensorizer, model, biencoder_optimizer, reader_optimizer, forward_fn = init_ofa_model( cfg.encoder.encoder_model_type, cfg, gradient_checkpointing=gradient_checkpointing, ) model, (biencoder_optimizer, reader_optimizer) = setup_for_distributed_mode( model, [biencoder_optimizer, reader_optimizer], cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16, cfg.fp16_opt_level, gradient_checkpointing=gradient_checkpointing, ) self.forward_fn = forward_fn self.model = model self.cfg = cfg self.ds_cfg = OneForAllDatasetsCfg(cfg) self.biencoder_optimizer = biencoder_optimizer self.biencoder_scheduler_state = None self.reader_optimizer = reader_optimizer self.reader_scheduler_state = None self.clustering = cfg.biencoder.clustering if self.clustering: cfg.global_loss_buf_sz = 72000000 # this requires a lot of memory self.tensorizer = tensorizer self.start_epoch = 0 self.start_batch = 0 self.best_validation_result = None self.best_cp_name = None # Biencoder loss function (note that reader loss is automatically computed) self.biencoder_loss_function: BiEncoderNllLoss = init_loss(cfg.encoder.encoder_model_type, cfg) if saved_state: self._load_saved_state(saved_state) self.dev_iterator = None