def build_encoder( model_name_or_path, max_seq_length, pooling_mode, proj_emb_dim, drop_prob=0.1, ): base_layer = sent_trans.models.Transformer(model_name_or_path, max_seq_length=None) pooling_layer = sent_trans.models.Pooling( base_layer.get_word_embedding_dimension(), pooling_mode=pooling_mode, ) dense_layer = sent_trans.models.Dense( in_features=pooling_layer.get_sentence_embedding_dimension(), out_features=proj_emb_dim, activation_function=nn.Tanh(), ) # normalize_layer = sent_trans.models.LayerNorm(proj_emb_dim) normalize_layer = sent_trans.models.Normalize() dropout_layer = sent_trans.models.Dropout(dropout=drop_prob) proj_layer = sent_trans.models.Dense( in_features=512, out_features=128, activation_function=nn.Tanh(), ) # encoder = sent_trans.SentenceTransformer( # modules=[base_layer, pooling_layer, dense_layer, normalize_layer, dropout_layer], # ) encoder = sent_trans.SentenceTransformer( modules=[base_layer, pooling_layer, dense_layer], ) return encoder
def load_model(self): if self.model is not None: return model_dir = self.download() self.model = sentence_transformers.SentenceTransformer( model_name_or_path=str(model_dir), device=get_device(use_gpu=True) )
def calc_scores(self): super().calc_scores() # write input_tsv model_name = 'bert-large-nli-stsb-mean-tokens' # FIXME - hard coded model = sentence_transformers.SentenceTransformer(model_name) resp_list = [] # read resps with open(self.config['input_path'], 'r+', encoding='utf-8') as f_in: reader = csv.DictReader(f_in) resp_keys = sorted([s for s in reader.fieldnames if s.startswith('resp_') and utils.represents_int(s.split('resp_')[-1])]) for row in reader: resps = [v for k, v in row.items() if k in resp_keys] resp_list += resps # calc embeds embeds = np.array(model.encode(resp_list)) # [ num_contexts * samples_per_context, embed_dim] assert len(embeds.shape) == 2 assert embeds.shape[0] == self.config['num_sets'] * self.config['samples_per_set'] embeds = np.reshape(embeds, [self.config['num_sets'], self.config['samples_per_set'], -1]) # write a cache file compatible with the ordering in bert_score and bert_sts similarity_scores_list = [] # note: len() assertion are done in get_similarity_scores method for set_i in range(self.config['num_sets']): for sample_i in range(self.config['samples_per_set']): for sample_j in range(sample_i): similarity_scores_list.append(self.similarity_metric( embeds[set_i, sample_i, :], embeds[set_i, sample_j, :])) with open(self.config['cache_file'], 'w') as cache_f: for score in similarity_scores_list: cache_f.write('{:0.3f}\n'.format(score))
def get_model(local_rank=0) -> st.SentenceTransformer: word_embedding_model: st.BERT = st.models.BERT('bert-base-uncased') pooling_model = st.models.Pooling( word_embedding_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) return st.SentenceTransformer( modules=[word_embedding_model, pooling_model], local_rank=local_rank)
def main(): args = parse_args() dataset_path = 'examples/datasets/iambot-wikipedia-sections-triplets-all' output_path = 'output/bert-base-wikipedia-sections-mean-tokens-' + \ datetime.now().strftime('%Y-%m-%d_%H-%M-%S') batch_size = 17 num_epochs = 1 is_distributed = torch.cuda.device_count() > 1 and args.local_rank >= 0 if is_distributed: torch.distributed.init_process_group(backend='nccl') model = get_model(local_rank=args.local_rank) logging.info('Read Triplet train dataset') train_data = get_triplet_dataset(dataset_path, 'train.csv', model) train_dataloader = get_data_loader(dataset=train_data, shuffle=True, batch_size=batch_size, distributed=is_distributed) logging.info('Read Wikipedia Triplet dev dataset') dev_dataloader = get_data_loader(dataset=get_triplet_dataset( dataset_path, 'validation.csv', model, 1000), shuffle=False, batch_size=batch_size) evaluator = TripletEvaluator(dev_dataloader) warmup_steps = int(len(train_data) * num_epochs / batch_size * 0.1) loss = st.losses.TripletLoss(model=model) model.fit(train_objectives=[(train_dataloader, loss)], evaluator=evaluator, epochs=num_epochs, evaluation_steps=1000, warmup_steps=warmup_steps, output_path=output_path, local_rank=args.local_rank) if args.local_rank == 0 or not is_distributed: del model torch.cuda.empty_cache() model = st.SentenceTransformer(output_path) test_data = get_triplet_dataset(dataset_path, 'test.csv', model) test_dataloader = get_data_loader(test_data, shuffle=False, batch_size=batch_size) evaluator = TripletEvaluator(test_dataloader) model.evaluate(evaluator)
def collect_sents(self, ranked_dates, collection, vectorizer, include_titles): embedding_model = sentence_transformers.SentenceTransformer( 'roberta-large-nli-stsb-mean-tokens') embedding_model.max_seq_length = 256 date_to_pub, date_to_ment = self._first_pass(collection, include_titles) for d, sents in self._second_pass(ranked_dates, date_to_pub, date_to_ment, embedding_model): yield d, sents
def __init__(self, language: str, embeddings_path='', dataset_split='val'): """initializes metric for language Args: language (str): german or english """ assert language in ["english", "german"] self.device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') self.transformer = sentence_transformers.SentenceTransformer( "bert-base-nli-stsb-mean-tokens").to(self.device) self.cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6) if language == "english": print("Intitialized semantic similarity metric for English texts.") else: print("Intitialized semantic similarity metric for German texts.") self.saved_embeddings = self.init_embeddings(embeddings_path, dataset_split)
def __init__(self, model_name='all-MiniLM-L6-v2', device='cpu'): self.device = device self.model = st.SentenceTransformer(model_name, device=device)
def load_textclassification_data(dataname, vecname='glove.42B.300d', shuffle=True, random_seed=None, num_workers = 0, preembed_sentences=True, loading_method='sentence_transformers', device='cpu', embedding_model=None, batch_size = 16, valid_size=0.1, maxsize=None, print_stats = False): """ Load torchtext datasets. Note: torchtext's TextClassification datasets are a bit different from the others: - they don't have split method. - no obvious creation of (nor access to) fields """ def batch_processor_tt(batch, TEXT=None, sentemb=None, return_lengths=True, device=None): """ For torchtext data/models """ labels, texts = zip(*batch) lens = [len(t) for t in texts] labels = torch.Tensor(labels) pad_idx = TEXT.vocab.stoi[TEXT.pad_token] texttensor = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=pad_idx) if sentemb: texttensor = sentemb(texttensor) if return_lengths: return texttensor, labels, lens else: return texttensor, labels def batch_processor_st(batch, model, device=None): """ For sentence_transformers data/models """ device = process_device_arg(device) with torch.no_grad(): batch = model.smart_batching_collate(batch) ## Always run embedding model on gpu if available features, labels = st.util.batch_to_device(batch, device) emb = model(features[0])['sentence_embedding'] return emb, labels if shuffle == True and random_seed: np.random.seed(random_seed) debug = False dataroot = '/tmp/' if debug else DATA_DIR #os.path.join(ROOT_DIR, 'data') veccache = os.path.join(dataroot,'.vector_cache') if loading_method == 'torchtext': ## TextClassification object datasets already do word to token mapping inside. DATASET = getattr(torchtext.datasets, dataname) train, test = DATASET(root=dataroot, ngrams=1) ## load_vectors reindexes embeddings so that they match the vocab's itos indices. train._vocab.load_vectors(vecname,cache=veccache,max_vectors = 50000) test._vocab.load_vectors(vecname,cache=veccache, max_vectors = 50000) ## Define Fields for Text and Labels text_field = torchtext.data.Field(sequential=True, lower=True, tokenize=get_tokenizer("basic_english"), batch_first=True, include_lengths=True, use_vocab=True) text_field.vocab = train._vocab if preembed_sentences: ## This will be used for distance computation vsize = len(text_field.vocab) edim = text_field.vocab.vectors.shape[1] pidx = text_field.vocab.stoi[text_field.pad_token] sentembedder = BoWSentenceEmbedding(vsize, edim, text_field.vocab.vectors, pidx) batch_processor = partial(batch_processor_tt,TEXT=text_field,sentemb=sentembedder,return_lengths=False) else: batch_processor = partial(batch_processor_tt,TEXT=text_field,return_lengths=True) elif loading_method == 'sentence_transformers': import sentence_transformers as st dpath = os.path.join(dataroot,TEXTDATA_PATHS[dataname]) reader = st.readers.LabelSentenceReader(dpath) if embedding_model is None: model = st.SentenceTransformer('distilbert-base-nli-stsb-mean-tokens').eval() elif type(embedding_model) is str: model = st.SentenceTransformer(embedding_model).eval() elif isinstance(embedding_model, st.SentenceTransformer): model = embedding_model.eval() else: raise ValueError('embedding model has wrong type') print('Reading and embedding {} train data...'.format(dataname)) train = st.SentencesDataset(reader.get_examples('train.tsv'), model=model) train.targets = train.labels print('Reading and embedding {} test data...'.format(dataname)) test = st.SentencesDataset(reader.get_examples('test.tsv'), model=model) test.targets = test.labels if preembed_sentences: batch_processor = partial(batch_processor_st, model=model, device=device) else: batch_processor = None ## Seems like torchtext alredy maps class ids to 0...n-1. Adapt class names to account for this. classes = torchtext.datasets.text_classification.LABELS[dataname] classes = [classes[k+1] for k in range(len(classes))] train.classes = classes test.classes = classes train_idx, valid_idx = random_index_split(len(train), 1-valid_size, (maxsize, None)) # No maxsize for validation train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetRandomSampler(valid_idx) dataloader_args = dict(batch_size=batch_size,num_workers=num_workers,collate_fn=batch_processor) train_loader = dataloader.DataLoader(train, sampler=train_sampler,**dataloader_args) valid_loader = dataloader.DataLoader(train, sampler=valid_sampler,**dataloader_args) dataloader_args['shuffle'] = False test_loader = dataloader.DataLoader(test, **dataloader_args) if print_stats: print('Classes: {} (effective: {})'.format(len(train.classes), len(torch.unique(train.targets)))) print('Fold Sizes: {}/{}/{} (train/valid/test)'.format(len(train_idx), len(valid_idx), len(test))) return train_loader, valid_loader, test_loader, train, test
logger.info(sent_trans.__file__) # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Load pretrained model and tokenizer if args.model_name_or_path == 'bert-base-uncased' or args.model_name_or_path == 'sentence-transformers/paraphrase-mpnet-base-v2': label_encoder = build_encoder( args.model_name_or_path, args.max_label_length, args.pooling_mode, args.proj_emb_dim, ) else: label_encoder = sent_trans.SentenceTransformer(args.model_name_or_path) tokenizer = label_encoder._first_module().tokenizer instance_encoder = label_encoder model = DualEncoderModel( label_encoder, instance_encoder, ) model = model.to(device) # the whole label set data_path = os.path.join(os.path.abspath(os.getcwd()), 'dataset', args.dataset) all_labels = pd.read_json(os.path.join(data_path, 'lbl.json'), lines=True) label_list = list(all_labels.title)
import numpy import sentence_transformers.models import sklearn.metrics.pairwise import sentence_transformers cmb = sentence_transformers.models.CamemBERT('camembert-base') pooling_model = sentence_transformers.models.Pooling( cmb.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) model = sentence_transformers.SentenceTransformer(modules=[cmb, pooling_model]) def sentences_embeddings(sentences): return model.encode(sentences) def doc_sentence_sim(base_sentence_vecs, doc_sentences): return numpy.average( sklearn.metrics.pairwise.cosine_similarity( sentences_embeddings(doc_sentences), base_sentence_vecs))
#from sentence_transformers import SentenceTransformer import sentence_transformers import scipy #from sklearn.metrics.pairwise import cosine_similarity model = sentence_transformers.SentenceTransformer('bert-base-nli-mean-tokens') def sent_similarity(col_nouns, bok_noun): bok_nouns = bok_noun[0] sentence_embeddings_L1 = model.encode(bok_nouns[0]) sentence_embeddings_L2 = model.encode(bok_nouns[1]) sentence_embeddings_L3 = model.encode(bok_nouns[2]) #sentence_embeddings_L4a = model.encode(bok_nouns[3]) #sentence_embeddings_L4b = model.encode(bok_nouns[4]) #sent_embed_bok_levels = [sentence_embeddings_L1, sentence_embeddings_L2, sentence_embeddings_L3, sentence_embeddings_L4a, sentence_embeddings_L4b] sent_embed_bok_levels = [ sentence_embeddings_L1, sentence_embeddings_L2, sentence_embeddings_L3 ] noun_levels = [] unClassified = [] for i in range(len(col_nouns)): tier_nouns = [[], [], [], [], []] for query in col_nouns[i]: queries = [query] query_embeddings = model.encode(queries) for query, query_embedding in zip(queries, query_embeddings): prev_dist = 0 isClassified = False for indx, sentence_embeddings in enumerate(
def __init__(self, model_name_or_path, device=None): self.senttransf_model = sentence_transformers.SentenceTransformer( str(model_name_or_path), device=device )
def init_models(model_name: str, device='cuda:0') -> st.SentenceTransformer: return st.SentenceTransformer(model_name_or_path=model_name, device=device)
def main(): # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. args = parse_args() distributed_args = accelerate.DistributedDataParallelKwargs( find_unused_parameters=True) accelerator = Accelerator(kwargs_handlers=[distributed_args]) device = accelerator.device # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", filename=f'xmc_{args.dataset}_{args.mode}_{args.log}.log', datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state) # Setup logging, we only want one process per machine to log things on the screen. # accelerator.is_local_main_process is only True for one process per machine. logger.setLevel( logging.INFO if accelerator.is_local_main_process else logging.ERROR) ch = logging.StreamHandler(sys.stdout) logger.addHandler(ch) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() logger.info(sent_trans.__file__) # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Load pretrained model and tokenizer if args.model_name_or_path == 'bert-base-uncased' or args.model_name_or_path == 'sentence-transformers/paraphrase-mpnet-base-v2': query_encoder = build_encoder( args.model_name_or_path, args.max_label_length, args.pooling_mode, args.proj_emb_dim, ) else: query_encoder = sent_trans.SentenceTransformer(args.model_name_or_path) tokenizer = query_encoder._first_module().tokenizer block_encoder = query_encoder model = DualEncoderModel(query_encoder, block_encoder, args.mode) model = model.to(device) # the whole label set data_path = os.path.join(os.path.abspath(os.getcwd()), 'dataset', args.dataset) all_labels = pd.read_json(os.path.join(data_path, 'lbl.json'), lines=True) label_list = list(all_labels.title) label_ids = list(all_labels.uid) label_data = SimpleDataset(label_list, transform=tokenizer.encode) # label dataloader for searching sampler = SequentialSampler(label_data) label_padding_func = lambda x: padding_util(x, tokenizer.pad_token_id, 64) label_dataloader = DataLoader(label_data, sampler=sampler, batch_size=16, collate_fn=label_padding_func) # label dataloader for regularization reg_sampler = RandomSampler(label_data) reg_dataloader = DataLoader(label_data, sampler=reg_sampler, batch_size=4, collate_fn=label_padding_func) if args.mode == 'ict': train_data = ICTXMCDataset(tokenizer=tokenizer, dataset=args.dataset) elif args.mode == 'self-train': train_data = PosDataset(tokenizer=tokenizer, dataset=args.dataset, labels=label_list, mode=args.mode) elif args.mode == 'finetune-pair': train_path = os.path.join(data_path, 'trn.json') pos_pair = [] with open(train_path) as fp: for i, line in enumerate(fp): inst = json.loads(line.strip()) inst_id = inst['uid'] for ind in inst['target_ind']: pos_pair.append((inst_id, ind, i)) dataset_size = len(pos_pair) indices = list(range(dataset_size)) split = int(np.floor(args.ratio * dataset_size)) np.random.shuffle(indices) train_indices = indices[:split] torch.distributed.broadcast_object_list(train_indices, src=0, group=None) sample_pairs = [pos_pair[i] for i in train_indices] train_data = PosDataset(tokenizer=tokenizer, dataset=args.dataset, labels=label_list, mode=args.mode, sample_pairs=sample_pairs) elif args.mode == 'finetune-label': label_index = [] label_path = os.path.join(data_path, 'label_index.json') with open(label_path) as fp: for line in fp: label_index.append(json.loads(line.strip())) np.random.shuffle(label_index) sample_size = int(np.floor(args.ratio * len(label_index))) sample_label = label_index[:sample_size] torch.distributed.broadcast_object_list(sample_label, src=0, group=None) sample_pairs = [] for i, label in enumerate(sample_label): ind = label['ind'] for inst_id in label['instance']: sample_pairs.append((inst_id, ind, i)) train_data = PosDataset(tokenizer=tokenizer, dataset=args.dataset, labels=label_list, mode=args.mode, sample_pairs=sample_pairs) train_sampler = RandomSampler(train_data) padding_func = lambda x: ICT_batchify(x, tokenizer.pad_token_id, 64, 288) train_dataloader = torch.utils.data.DataLoader( train_data, sampler=train_sampler, batch_size=args.per_device_train_batch_size, num_workers=4, pin_memory=False, collate_fn=padding_func) try: accelerator.print("load cache") all_instances = torch.load( os.path.join(data_path, 'all_passages_with_titles.json.cache.pt')) test_data = SimpleDataset(all_instances.values()) except: all_instances = {} test_path = os.path.join(data_path, 'tst.json') if args.mode == 'ict': train_path = os.path.join(data_path, 'trn.json') train_instances = {} valid_passage_ids = train_data.valid_passage_ids with open(train_path) as fp: for line in fp: inst = json.loads(line.strip()) train_instances[ inst['uid']] = inst['title'] + '\t' + inst['content'] for inst_id in valid_passage_ids: all_instances[inst_id] = train_instances[inst_id] test_ids = [] with open(test_path) as fp: for line in fp: inst = json.loads(line.strip()) all_instances[ inst['uid']] = inst['title'] + '\t' + inst['content'] test_ids.append(inst['uid']) simple_transform = lambda x: tokenizer.encode( x, max_length=288, truncation=True) test_data = SimpleDataset(list(all_instances.values()), transform=simple_transform) inst_num = len(test_data) sampler = SequentialSampler(test_data) sent_padding_func = lambda x: padding_util(x, tokenizer.pad_token_id, 288) instance_dataloader = DataLoader(test_data, sampler=sampler, batch_size=128, collate_fn=sent_padding_func) # prepare pairs reader = csv.reader(open(os.path.join(data_path, 'all_pairs.txt'), encoding="utf-8"), delimiter=" ") qrels = {} for id, row in enumerate(reader): query_id, corpus_id, score = row[0], row[1], int(row[2]) if query_id not in qrels: qrels[query_id] = {corpus_id: score} else: qrels[query_id][corpus_id] = score logging.info("| |ICT_dataset|={} pairs.".format(len(train_data))) # Optimizer # Split weights in two groups, one with weight decay and the other not. no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8) # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, label_dataloader, reg_dataloader, instance_dataloader = accelerator.prepare( model, optimizer, train_dataloader, label_dataloader, reg_dataloader, instance_dataloader) # Scheduler and math around the number of training steps. num_update_steps_per_epoch = math.ceil( len(train_dataloader) / args.gradient_accumulation_steps) # args.max_train_steps = 100000 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_warmup_steps = int(0.1 * args.max_train_steps) lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps, ) # Train! total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_data)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info( f" Instantaneous batch size per device = {args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) logger.info( f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Learning Rate = {args.learning_rate}") logger.info(f" Total optimization steps = {args.max_train_steps}") # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 from torch.cuda.amp import autocast scaler = torch.cuda.amp.GradScaler() cluster_result = eval_and_cluster(args, logger, completed_steps, accelerator.unwrap_model(model), label_dataloader, label_ids, instance_dataloader, inst_num, test_ids, qrels, accelerator) reg_iter = iter(reg_dataloader) trial_name = f"dim-{args.proj_emb_dim}-bs-{args.per_device_train_batch_size}-{args.dataset}-{args.log}-{args.mode}" for epoch in range(args.num_train_epochs): model.train() for step, batch in enumerate(train_dataloader): batch = tuple(t for t in batch) label_tokens, inst_tokens, indices = batch if args.mode == 'ict': try: reg_data = next(reg_iter) except StopIteration: reg_iter = iter(reg_dataloader) reg_data = next(reg_iter) if cluster_result is not None: pseudo_labels = cluster_result[indices] else: pseudo_labels = indices with autocast(): if args.mode == 'ict': label_emb, inst_emb, inst_emb_aug, reg_emb = model( label_tokens, inst_tokens, reg_data) loss, stats_dict = loss_function_reg( label_emb, inst_emb, inst_emb_aug, reg_emb, pseudo_labels, accelerator) else: label_emb, inst_emb = model(label_tokens, inst_tokens, reg_data=None) loss, stats_dict = loss_function(label_emb, inst_emb, pseudo_labels, accelerator) loss = loss / args.gradient_accumulation_steps scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1) if step % args.gradient_accumulation_steps == 0 or step == len( train_dataloader) - 1: scaler.step(optimizer) scaler.update() lr_scheduler.step() optimizer.zero_grad() progress_bar.update(1) completed_steps += 1 if completed_steps % args.logging_steps == 0: if args.mode == 'ict': logger.info( "| Epoch [{:4d}/{:4d}] Step [{:8d}/{:8d}] Total Loss {:.6e} Contrast Loss {:.6e} Reg Loss {:.6e}" .format( epoch, args.num_train_epochs, completed_steps, args.max_train_steps, stats_dict["loss"].item(), stats_dict["contrast_loss"].item(), stats_dict["reg_loss"].item(), )) else: logger.info( "| Epoch [{:4d}/{:4d}] Step [{:8d}/{:8d}] Total Loss {:.6e}" .format( epoch, args.num_train_epochs, completed_steps, args.max_train_steps, stats_dict["loss"].item(), )) if completed_steps % args.eval_steps == 0: cluster_result = eval_and_cluster( args, logger, completed_steps, accelerator.unwrap_model(model), label_dataloader, label_ids, instance_dataloader, inst_num, test_ids, qrels, accelerator) unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.label_encoder.save( f"{args.output_dir}/{trial_name}/label_encoder") unwrapped_model.instance_encoder.save( f"{args.output_dir}/{trial_name}/instance_encoder") if completed_steps >= args.max_train_steps: break