def setup_model(args): """Setup model and optimizer.""" model = get_model(args) # if args.deepspeed: # print_rank_0("DeepSpeed is enabled.") # # model, _, _, _ = deepspeed.initialize( # model=model, # model_parameters=model.parameters(), # args=args, # mpu=mpu, # dist_init_required=False # ) if args.load is not None: if args.deepspeed: iteration, release, success = get_checkpoint_iteration(args) print(iteration) path = os.path.join(args.load, str(iteration), "mp_rank_00_model_states.pt") checkpoint = torch.load(path) model.load_state_dict(checkpoint["module"]) else: _ = load_checkpoint( model, None, None, args, load_optimizer_states=False) # if args.deepspeed: # model = model.module return model
def load_pretrained(model, checkpoint_path, args, task_tokens=None): load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path) checkpoint_name = get_checkpoint_name(load_dir, tag, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading pretrained model {}'.format( torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. sd = torch.load(checkpoint_name, map_location='cpu') if args.deepspeed: model = model.module if isinstance(model, TorchDDP): model = model.module if isinstance(model, FP16_Module): model = model.module if hasattr(model, "model"): model = model.model # Model. def extend_embedding_weights(state_weights, model_weights): original_length = state_weights.shape[0] assert original_length <= args.max_position_embeddings + 1 new_weights = model_weights.clone() new_weights[:original_length] = state_weights return new_weights if args.block_lm: if "transformer.block_position_embeddings.weight" in sd["module"]: position_weights = sd['module'][ "transformer.position_embeddings.weight"] if args.max_position_embeddings + 1 > position_weights.shape[0]: sd['module'][ "transformer.position_embeddings.weight"] = extend_embedding_weights( position_weights, model.state_dict() ["transformer.position_embeddings.weight"].data) print_rank_0( f"Extend position embedding to {args.max_position_embeddings + 1}" ) if "transformer.block_position_embeddings.weight" in sd["module"]: block_position_weights = sd['module'][ "transformer.block_position_embeddings.weight"] if args.max_position_embeddings + 1 > block_position_weights.shape[ 0]: sd['module'][ "transformer.block_position_embeddings.weight"] = extend_embedding_weights( block_position_weights, model.state_dict() ["transformer.block_position_embeddings.weight"].data) print_rank_0( f"Extend block position embedding to {args.max_position_embeddings + 1}" ) missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False) if missing_keys or unexpected_keys: print_rank_0( f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}") if args.continuous_prompt and args.prompt_init: model.prompt_spell.init_embedding(model.word_embeddings.weight.data, task_tokens)