Пример #1
0
def build_encoder(
    model_name_or_path,
    max_seq_length,
    pooling_mode,
    proj_emb_dim,
    drop_prob=0.1,
):
    base_layer = sent_trans.models.Transformer(model_name_or_path,
                                               max_seq_length=None)
    pooling_layer = sent_trans.models.Pooling(
        base_layer.get_word_embedding_dimension(),
        pooling_mode=pooling_mode,
    )
    dense_layer = sent_trans.models.Dense(
        in_features=pooling_layer.get_sentence_embedding_dimension(),
        out_features=proj_emb_dim,
        activation_function=nn.Tanh(),
    )
    # normalize_layer = sent_trans.models.LayerNorm(proj_emb_dim)
    normalize_layer = sent_trans.models.Normalize()
    dropout_layer = sent_trans.models.Dropout(dropout=drop_prob)
    proj_layer = sent_trans.models.Dense(
        in_features=512,
        out_features=128,
        activation_function=nn.Tanh(),
    )
    # encoder = sent_trans.SentenceTransformer(
    #     modules=[base_layer, pooling_layer, dense_layer, normalize_layer, dropout_layer],
    # )
    encoder = sent_trans.SentenceTransformer(
        modules=[base_layer, pooling_layer, dense_layer], )
    return encoder
Пример #2
0
 def load_model(self):
     if self.model is not None:
         return
     model_dir = self.download()
     self.model = sentence_transformers.SentenceTransformer(
         model_name_or_path=str(model_dir), device=get_device(use_gpu=True)
     )
Пример #3
0
    def calc_scores(self):
        super().calc_scores()

        # write input_tsv
        model_name = 'bert-large-nli-stsb-mean-tokens' # FIXME - hard coded
        model = sentence_transformers.SentenceTransformer(model_name)
        resp_list = []

        # read resps
        with open(self.config['input_path'], 'r+', encoding='utf-8') as f_in:
            reader = csv.DictReader(f_in)
            resp_keys = sorted([s for s in reader.fieldnames if
                                s.startswith('resp_') and utils.represents_int(s.split('resp_')[-1])])
            for row in reader:
                resps = [v for k, v in row.items() if k in resp_keys]
                resp_list += resps

        # calc embeds
        embeds = np.array(model.encode(resp_list)) # [ num_contexts * samples_per_context, embed_dim]
        assert len(embeds.shape) == 2
        assert embeds.shape[0] == self.config['num_sets'] * self.config['samples_per_set']
        embeds = np.reshape(embeds, [self.config['num_sets'], self.config['samples_per_set'], -1])

        # write a cache file compatible with the ordering in bert_score and bert_sts
        similarity_scores_list = [] # note: len() assertion are done in get_similarity_scores method
        for set_i in range(self.config['num_sets']):
            for sample_i in range(self.config['samples_per_set']):
                for sample_j in range(sample_i):
                    similarity_scores_list.append(self.similarity_metric(
                        embeds[set_i, sample_i, :], embeds[set_i, sample_j, :]))
        with open(self.config['cache_file'], 'w') as cache_f:
            for score in similarity_scores_list:
                cache_f.write('{:0.3f}\n'.format(score))
Пример #4
0
def get_model(local_rank=0) -> st.SentenceTransformer:
    word_embedding_model: st.BERT = st.models.BERT('bert-base-uncased')
    pooling_model = st.models.Pooling(
        word_embedding_model.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True,
        pooling_mode_cls_token=False,
        pooling_mode_max_tokens=False)
    return st.SentenceTransformer(
        modules=[word_embedding_model, pooling_model], local_rank=local_rank)
Пример #5
0
def main():
    args = parse_args()

    dataset_path = 'examples/datasets/iambot-wikipedia-sections-triplets-all'

    output_path = 'output/bert-base-wikipedia-sections-mean-tokens-' + \
                  datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    batch_size = 17
    num_epochs = 1

    is_distributed = torch.cuda.device_count() > 1 and args.local_rank >= 0

    if is_distributed:
        torch.distributed.init_process_group(backend='nccl')

    model = get_model(local_rank=args.local_rank)

    logging.info('Read Triplet train dataset')
    train_data = get_triplet_dataset(dataset_path, 'train.csv', model)
    train_dataloader = get_data_loader(dataset=train_data,
                                       shuffle=True,
                                       batch_size=batch_size,
                                       distributed=is_distributed)

    logging.info('Read Wikipedia Triplet dev dataset')
    dev_dataloader = get_data_loader(dataset=get_triplet_dataset(
        dataset_path, 'validation.csv', model, 1000),
                                     shuffle=False,
                                     batch_size=batch_size)
    evaluator = TripletEvaluator(dev_dataloader)

    warmup_steps = int(len(train_data) * num_epochs / batch_size * 0.1)

    loss = st.losses.TripletLoss(model=model)

    model.fit(train_objectives=[(train_dataloader, loss)],
              evaluator=evaluator,
              epochs=num_epochs,
              evaluation_steps=1000,
              warmup_steps=warmup_steps,
              output_path=output_path,
              local_rank=args.local_rank)

    if args.local_rank == 0 or not is_distributed:
        del model
        torch.cuda.empty_cache()

        model = st.SentenceTransformer(output_path)
        test_data = get_triplet_dataset(dataset_path, 'test.csv', model)
        test_dataloader = get_data_loader(test_data,
                                          shuffle=False,
                                          batch_size=batch_size)
        evaluator = TripletEvaluator(test_dataloader)

        model.evaluate(evaluator)
Пример #6
0
    def collect_sents(self, ranked_dates, collection, vectorizer,
                      include_titles):

        embedding_model = sentence_transformers.SentenceTransformer(
            'roberta-large-nli-stsb-mean-tokens')
        embedding_model.max_seq_length = 256

        date_to_pub, date_to_ment = self._first_pass(collection,
                                                     include_titles)

        for d, sents in self._second_pass(ranked_dates, date_to_pub,
                                          date_to_ment, embedding_model):
            yield d, sents
Пример #7
0
    def __init__(self, language: str, embeddings_path='', dataset_split='val'):
        """initializes metric for language

        Args:
            language (str): german or english
        """

        assert language in ["english", "german"]
        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.transformer = sentence_transformers.SentenceTransformer(
            "bert-base-nli-stsb-mean-tokens").to(self.device)
        self.cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

        if language == "english":
            print("Intitialized semantic similarity metric for English texts.")

        else:
            print("Intitialized semantic similarity metric for German texts.")
        self.saved_embeddings = self.init_embeddings(embeddings_path,
                                                     dataset_split)
Пример #8
0
 def __init__(self, model_name='all-MiniLM-L6-v2', device='cpu'):
     self.device = device
     self.model = st.SentenceTransformer(model_name, device=device)
Пример #9
0
def load_textclassification_data(dataname, vecname='glove.42B.300d', shuffle=True,
            random_seed=None, num_workers = 0, preembed_sentences=True,
            loading_method='sentence_transformers', device='cpu',
            embedding_model=None,
            batch_size = 16, valid_size=0.1, maxsize=None, print_stats = False):
    """ Load torchtext datasets.

    Note: torchtext's TextClassification datasets are a bit different from the others:
        - they don't have split method.
        - no obvious creation of (nor access to) fields

    """



    def batch_processor_tt(batch, TEXT=None, sentemb=None, return_lengths=True, device=None):
        """ For torchtext data/models """
        labels, texts = zip(*batch)
        lens = [len(t) for t in texts]
        labels = torch.Tensor(labels)
        pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
        texttensor = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=pad_idx)
        if sentemb:
            texttensor = sentemb(texttensor)
        if return_lengths:
            return texttensor, labels, lens
        else:
            return texttensor, labels

    def batch_processor_st(batch, model, device=None):
        """ For sentence_transformers data/models """
        device = process_device_arg(device)
        with torch.no_grad():
            batch = model.smart_batching_collate(batch)
            ## Always run embedding model on gpu if available
            features, labels = st.util.batch_to_device(batch, device)
            emb = model(features[0])['sentence_embedding']
        return emb, labels


    if shuffle == True and random_seed:
        np.random.seed(random_seed)

    debug = False

    dataroot = '/tmp/' if debug else DATA_DIR #os.path.join(ROOT_DIR, 'data')
    veccache = os.path.join(dataroot,'.vector_cache')

    if loading_method == 'torchtext':
        ## TextClassification object datasets already do word to token mapping inside.
        DATASET = getattr(torchtext.datasets, dataname)
        train, test = DATASET(root=dataroot, ngrams=1)

        ## load_vectors reindexes embeddings so that they match the vocab's itos indices.
        train._vocab.load_vectors(vecname,cache=veccache,max_vectors = 50000)
        test._vocab.load_vectors(vecname,cache=veccache, max_vectors = 50000)

        ## Define Fields for Text and Labels
        text_field = torchtext.data.Field(sequential=True, lower=True,
                           tokenize=get_tokenizer("basic_english"),
                           batch_first=True,
                           include_lengths=True,
                           use_vocab=True)

        text_field.vocab = train._vocab

        if preembed_sentences:
            ## This will be used for distance computation
            vsize = len(text_field.vocab)
            edim  = text_field.vocab.vectors.shape[1]
            pidx  = text_field.vocab.stoi[text_field.pad_token]
            sentembedder = BoWSentenceEmbedding(vsize, edim, text_field.vocab.vectors, pidx)
            batch_processor = partial(batch_processor_tt,TEXT=text_field,sentemb=sentembedder,return_lengths=False)
        else:
            batch_processor = partial(batch_processor_tt,TEXT=text_field,return_lengths=True)
    elif loading_method == 'sentence_transformers':
        import sentence_transformers as st
        dpath  = os.path.join(dataroot,TEXTDATA_PATHS[dataname])
        reader = st.readers.LabelSentenceReader(dpath)
        if embedding_model is None:
            model  = st.SentenceTransformer('distilbert-base-nli-stsb-mean-tokens').eval()
        elif type(embedding_model) is str:
            model  = st.SentenceTransformer(embedding_model).eval()
        elif isinstance(embedding_model, st.SentenceTransformer):
            model = embedding_model.eval()
        else:
            raise ValueError('embedding model has wrong type')
        print('Reading and embedding {} train data...'.format(dataname))
        train  = st.SentencesDataset(reader.get_examples('train.tsv'), model=model)
        train.targets = train.labels
        print('Reading and embedding {} test data...'.format(dataname))
        test   = st.SentencesDataset(reader.get_examples('test.tsv'), model=model)
        test.targets = test.labels
        if preembed_sentences:
            batch_processor = partial(batch_processor_st, model=model, device=device)
        else:
            batch_processor = None

    ## Seems like torchtext alredy maps class ids to 0...n-1. Adapt class names to account for this.
    classes = torchtext.datasets.text_classification.LABELS[dataname]
    classes = [classes[k+1] for k in range(len(classes))]
    train.classes = classes
    test.classes  = classes

    train_idx, valid_idx = random_index_split(len(train), 1-valid_size, (maxsize, None)) # No maxsize for validation
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    dataloader_args = dict(batch_size=batch_size,num_workers=num_workers,collate_fn=batch_processor)
    train_loader = dataloader.DataLoader(train, sampler=train_sampler,**dataloader_args)
    valid_loader = dataloader.DataLoader(train, sampler=valid_sampler,**dataloader_args)
    dataloader_args['shuffle'] = False
    test_loader  = dataloader.DataLoader(test, **dataloader_args)

    if print_stats:
        print('Classes: {} (effective: {})'.format(len(train.classes), len(torch.unique(train.targets))))
        print('Fold Sizes: {}/{}/{} (train/valid/test)'.format(len(train_idx), len(valid_idx), len(test)))

    return train_loader, valid_loader, test_loader, train, test
Пример #10
0
logger.info(sent_trans.__file__)

# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)

# Load pretrained model and tokenizer
if args.model_name_or_path == 'bert-base-uncased' or args.model_name_or_path == 'sentence-transformers/paraphrase-mpnet-base-v2':
    label_encoder = build_encoder(
        args.model_name_or_path,
        args.max_label_length,
        args.pooling_mode,
        args.proj_emb_dim,
    )
else:
    label_encoder = sent_trans.SentenceTransformer(args.model_name_or_path)

tokenizer = label_encoder._first_module().tokenizer

instance_encoder = label_encoder

model = DualEncoderModel(
    label_encoder,
    instance_encoder,
)
model = model.to(device)

# the whole label set
data_path = os.path.join(os.path.abspath(os.getcwd()), 'dataset', args.dataset)
all_labels = pd.read_json(os.path.join(data_path, 'lbl.json'), lines=True)
label_list = list(all_labels.title)
import numpy
import sentence_transformers.models
import sklearn.metrics.pairwise
import sentence_transformers

cmb = sentence_transformers.models.CamemBERT('camembert-base')
pooling_model = sentence_transformers.models.Pooling(
    cmb.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True,
    pooling_mode_cls_token=False,
    pooling_mode_max_tokens=False)

model = sentence_transformers.SentenceTransformer(modules=[cmb, pooling_model])


def sentences_embeddings(sentences):
    return model.encode(sentences)


def doc_sentence_sim(base_sentence_vecs, doc_sentences):
    return numpy.average(
        sklearn.metrics.pairwise.cosine_similarity(
            sentences_embeddings(doc_sentences), base_sentence_vecs))
Пример #12
0
#from sentence_transformers import SentenceTransformer
import sentence_transformers
import scipy
#from sklearn.metrics.pairwise import cosine_similarity

model = sentence_transformers.SentenceTransformer('bert-base-nli-mean-tokens')


def sent_similarity(col_nouns, bok_noun):
    bok_nouns = bok_noun[0]
    sentence_embeddings_L1 = model.encode(bok_nouns[0])
    sentence_embeddings_L2 = model.encode(bok_nouns[1])
    sentence_embeddings_L3 = model.encode(bok_nouns[2])
    #sentence_embeddings_L4a = model.encode(bok_nouns[3])
    #sentence_embeddings_L4b = model.encode(bok_nouns[4])

    #sent_embed_bok_levels = [sentence_embeddings_L1, sentence_embeddings_L2, sentence_embeddings_L3, sentence_embeddings_L4a, sentence_embeddings_L4b]
    sent_embed_bok_levels = [
        sentence_embeddings_L1, sentence_embeddings_L2, sentence_embeddings_L3
    ]
    noun_levels = []
    unClassified = []
    for i in range(len(col_nouns)):
        tier_nouns = [[], [], [], [], []]
        for query in col_nouns[i]:
            queries = [query]
            query_embeddings = model.encode(queries)
            for query, query_embedding in zip(queries, query_embeddings):
                prev_dist = 0
                isClassified = False
                for indx, sentence_embeddings in enumerate(
Пример #13
0
    def __init__(self, model_name_or_path, device=None):

        self.senttransf_model = sentence_transformers.SentenceTransformer(
            str(model_name_or_path), device=device
        )
Пример #14
0
def init_models(model_name: str, device='cuda:0') -> st.SentenceTransformer:
    return st.SentenceTransformer(model_name_or_path=model_name, device=device)
Пример #15
0
def main():
    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    args = parse_args()
    distributed_args = accelerate.DistributedDataParallelKwargs(
        find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[distributed_args])
    device = accelerator.device
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        filename=f'xmc_{args.dataset}_{args.mode}_{args.log}.log',
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state)

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(
        logging.INFO if accelerator.is_local_main_process else logging.ERROR)
    ch = logging.StreamHandler(sys.stdout)
    logger.addHandler(ch)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()

    logger.info(sent_trans.__file__)

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Load pretrained model and tokenizer
    if args.model_name_or_path == 'bert-base-uncased' or args.model_name_or_path == 'sentence-transformers/paraphrase-mpnet-base-v2':
        query_encoder = build_encoder(
            args.model_name_or_path,
            args.max_label_length,
            args.pooling_mode,
            args.proj_emb_dim,
        )
    else:
        query_encoder = sent_trans.SentenceTransformer(args.model_name_or_path)

    tokenizer = query_encoder._first_module().tokenizer

    block_encoder = query_encoder

    model = DualEncoderModel(query_encoder, block_encoder, args.mode)
    model = model.to(device)

    # the whole label set
    data_path = os.path.join(os.path.abspath(os.getcwd()), 'dataset',
                             args.dataset)
    all_labels = pd.read_json(os.path.join(data_path, 'lbl.json'), lines=True)
    label_list = list(all_labels.title)
    label_ids = list(all_labels.uid)
    label_data = SimpleDataset(label_list, transform=tokenizer.encode)

    # label dataloader for searching
    sampler = SequentialSampler(label_data)
    label_padding_func = lambda x: padding_util(x, tokenizer.pad_token_id, 64)
    label_dataloader = DataLoader(label_data,
                                  sampler=sampler,
                                  batch_size=16,
                                  collate_fn=label_padding_func)

    # label dataloader for regularization
    reg_sampler = RandomSampler(label_data)
    reg_dataloader = DataLoader(label_data,
                                sampler=reg_sampler,
                                batch_size=4,
                                collate_fn=label_padding_func)

    if args.mode == 'ict':
        train_data = ICTXMCDataset(tokenizer=tokenizer, dataset=args.dataset)
    elif args.mode == 'self-train':
        train_data = PosDataset(tokenizer=tokenizer,
                                dataset=args.dataset,
                                labels=label_list,
                                mode=args.mode)
    elif args.mode == 'finetune-pair':
        train_path = os.path.join(data_path, 'trn.json')
        pos_pair = []
        with open(train_path) as fp:
            for i, line in enumerate(fp):
                inst = json.loads(line.strip())
                inst_id = inst['uid']
                for ind in inst['target_ind']:
                    pos_pair.append((inst_id, ind, i))
        dataset_size = len(pos_pair)
        indices = list(range(dataset_size))
        split = int(np.floor(args.ratio * dataset_size))
        np.random.shuffle(indices)
        train_indices = indices[:split]
        torch.distributed.broadcast_object_list(train_indices,
                                                src=0,
                                                group=None)
        sample_pairs = [pos_pair[i] for i in train_indices]
        train_data = PosDataset(tokenizer=tokenizer,
                                dataset=args.dataset,
                                labels=label_list,
                                mode=args.mode,
                                sample_pairs=sample_pairs)
    elif args.mode == 'finetune-label':
        label_index = []
        label_path = os.path.join(data_path, 'label_index.json')
        with open(label_path) as fp:
            for line in fp:
                label_index.append(json.loads(line.strip()))
        np.random.shuffle(label_index)
        sample_size = int(np.floor(args.ratio * len(label_index)))
        sample_label = label_index[:sample_size]
        torch.distributed.broadcast_object_list(sample_label,
                                                src=0,
                                                group=None)
        sample_pairs = []
        for i, label in enumerate(sample_label):
            ind = label['ind']
            for inst_id in label['instance']:
                sample_pairs.append((inst_id, ind, i))
        train_data = PosDataset(tokenizer=tokenizer,
                                dataset=args.dataset,
                                labels=label_list,
                                mode=args.mode,
                                sample_pairs=sample_pairs)

    train_sampler = RandomSampler(train_data)
    padding_func = lambda x: ICT_batchify(x, tokenizer.pad_token_id, 64, 288)
    train_dataloader = torch.utils.data.DataLoader(
        train_data,
        sampler=train_sampler,
        batch_size=args.per_device_train_batch_size,
        num_workers=4,
        pin_memory=False,
        collate_fn=padding_func)

    try:
        accelerator.print("load cache")
        all_instances = torch.load(
            os.path.join(data_path, 'all_passages_with_titles.json.cache.pt'))
        test_data = SimpleDataset(all_instances.values())
    except:
        all_instances = {}
        test_path = os.path.join(data_path, 'tst.json')
        if args.mode == 'ict':
            train_path = os.path.join(data_path, 'trn.json')
            train_instances = {}
            valid_passage_ids = train_data.valid_passage_ids
            with open(train_path) as fp:
                for line in fp:
                    inst = json.loads(line.strip())
                    train_instances[
                        inst['uid']] = inst['title'] + '\t' + inst['content']
            for inst_id in valid_passage_ids:
                all_instances[inst_id] = train_instances[inst_id]
        test_ids = []
        with open(test_path) as fp:
            for line in fp:
                inst = json.loads(line.strip())
                all_instances[
                    inst['uid']] = inst['title'] + '\t' + inst['content']
                test_ids.append(inst['uid'])
        simple_transform = lambda x: tokenizer.encode(
            x, max_length=288, truncation=True)
        test_data = SimpleDataset(list(all_instances.values()),
                                  transform=simple_transform)
        inst_num = len(test_data)

    sampler = SequentialSampler(test_data)
    sent_padding_func = lambda x: padding_util(x, tokenizer.pad_token_id, 288)
    instance_dataloader = DataLoader(test_data,
                                     sampler=sampler,
                                     batch_size=128,
                                     collate_fn=sent_padding_func)

    # prepare pairs
    reader = csv.reader(open(os.path.join(data_path, 'all_pairs.txt'),
                             encoding="utf-8"),
                        delimiter=" ")
    qrels = {}
    for id, row in enumerate(reader):
        query_id, corpus_id, score = row[0], row[1], int(row[2])
        if query_id not in qrels:
            qrels[query_id] = {corpus_id: score}
        else:
            qrels[query_id][corpus_id] = score

    logging.info("| |ICT_dataset|={} pairs.".format(len(train_data)))

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=1e-8)

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, label_dataloader, reg_dataloader, instance_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, label_dataloader, reg_dataloader,
        instance_dataloader)

    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps)
    # args.max_train_steps = 100000
    args.num_train_epochs = math.ceil(args.max_train_steps /
                                      num_update_steps_per_epoch)
    args.num_warmup_steps = int(0.1 * args.max_train_steps)
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    # Train!
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_data)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(
        f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Learning Rate = {args.learning_rate}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps),
                        disable=not accelerator.is_local_main_process)
    completed_steps = 0
    from torch.cuda.amp import autocast
    scaler = torch.cuda.amp.GradScaler()
    cluster_result = eval_and_cluster(args, logger, completed_steps,
                                      accelerator.unwrap_model(model),
                                      label_dataloader, label_ids,
                                      instance_dataloader, inst_num, test_ids,
                                      qrels, accelerator)
    reg_iter = iter(reg_dataloader)
    trial_name = f"dim-{args.proj_emb_dim}-bs-{args.per_device_train_batch_size}-{args.dataset}-{args.log}-{args.mode}"
    for epoch in range(args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t for t in batch)
            label_tokens, inst_tokens, indices = batch
            if args.mode == 'ict':
                try:
                    reg_data = next(reg_iter)
                except StopIteration:
                    reg_iter = iter(reg_dataloader)
                    reg_data = next(reg_iter)

            if cluster_result is not None:
                pseudo_labels = cluster_result[indices]
            else:
                pseudo_labels = indices
            with autocast():
                if args.mode == 'ict':
                    label_emb, inst_emb, inst_emb_aug, reg_emb = model(
                        label_tokens, inst_tokens, reg_data)
                    loss, stats_dict = loss_function_reg(
                        label_emb, inst_emb, inst_emb_aug, reg_emb,
                        pseudo_labels, accelerator)
                else:
                    label_emb, inst_emb = model(label_tokens,
                                                inst_tokens,
                                                reg_data=None)
                    loss, stats_dict = loss_function(label_emb, inst_emb,
                                                     pseudo_labels,
                                                     accelerator)
                loss = loss / args.gradient_accumulation_steps

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            if step % args.gradient_accumulation_steps == 0 or step == len(
                    train_dataloader) - 1:
                scaler.step(optimizer)
                scaler.update()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

            if completed_steps % args.logging_steps == 0:
                if args.mode == 'ict':
                    logger.info(
                        "| Epoch [{:4d}/{:4d}] Step [{:8d}/{:8d}] Total Loss {:.6e}  Contrast Loss {:.6e}  Reg Loss {:.6e}"
                        .format(
                            epoch,
                            args.num_train_epochs,
                            completed_steps,
                            args.max_train_steps,
                            stats_dict["loss"].item(),
                            stats_dict["contrast_loss"].item(),
                            stats_dict["reg_loss"].item(),
                        ))
                else:
                    logger.info(
                        "| Epoch [{:4d}/{:4d}] Step [{:8d}/{:8d}] Total Loss {:.6e}"
                        .format(
                            epoch,
                            args.num_train_epochs,
                            completed_steps,
                            args.max_train_steps,
                            stats_dict["loss"].item(),
                        ))
            if completed_steps % args.eval_steps == 0:
                cluster_result = eval_and_cluster(
                    args, logger, completed_steps,
                    accelerator.unwrap_model(model), label_dataloader,
                    label_ids, instance_dataloader, inst_num, test_ids, qrels,
                    accelerator)
                unwrapped_model = accelerator.unwrap_model(model)

                unwrapped_model.label_encoder.save(
                    f"{args.output_dir}/{trial_name}/label_encoder")
                unwrapped_model.instance_encoder.save(
                    f"{args.output_dir}/{trial_name}/instance_encoder")

            if completed_steps >= args.max_train_steps:
                break