def test(conf: "TrainConfig"): logger = logging.getLogger(__name__) logger.info("Start test") model = SentenceTransformer(str(conf.model_dir)) model.tokenizer = AutoTokenizer.from_pretrained(conf.transformer_model) logger.info(f"model: {type(model)}") logger.info(f"tokenizer: {type(model.tokenizer)}") encode_result = model.tokenizer(["日本語のトークナイゼーションの問題"], return_tensors='pt', padding=True) logger.info(model.tokenizer.convert_ids_to_tokens(encode_result.input_ids.flatten().tolist())) triplet_reader = TripletReader(str(conf.train_triplets_tsv.parent)) evaluator = TripletEvaluator.from_input_examples( triplet_reader.get_examples(conf.test_triplets_tsv.name), name="test" ) evaluator(model, output_path=str(conf.model_dir))
def train(conf: "TrainConfig"): logger = logging.getLogger(__name__) logger.info("Initialize model") transformer = models.Transformer(conf.transformer_model) pooling = models.Pooling( transformer.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False, ) model = SentenceTransformer(modules=[transformer, pooling]) model.tokenizer = AutoTokenizer.from_pretrained(conf.transformer_model) logger.info(f"model: {type(model)}") logger.info(f"tokenizer: {type(model.tokenizer)}") encode_result = model.tokenizer(["日本語のトークナイゼーションの問題"], return_tensors='pt', padding=True) logger.info(model.tokenizer.convert_ids_to_tokens(encode_result.input_ids.flatten().tolist())) logger.info("Read training data") triplet_reader = TripletReader(str(conf.train_triplets_tsv.parent)) train_data = SentencesDataset( triplet_reader.get_examples(conf.train_triplets_tsv.name), model=model ) train_dataloader = DataLoader(train_data, shuffle=True, batch_size=conf.batch_size) train_loss = TripletLoss( model=model, distance_metric=TripletDistanceMetric.EUCLIDEAN, triplet_margin=1 ) evaluator = TripletEvaluator.from_input_examples( triplet_reader.get_examples(conf.dev_triplets_tsv.name), name="dev" ) logger.info("Start training") warmup_steps = int(len(train_data) // conf.batch_size * 0.1) model.fit( train_objectives=[(train_dataloader, train_loss)], evaluator=evaluator, epochs=conf.epochs, evaluation_steps=conf.eval_steps, warmup_steps=warmup_steps, output_path=str(conf.model_dir), )
shuffle=True, batch_size=train_batch_size) ### Triplet losses #################### ### There are 3 triplet loss variants: ### - BatchHardTripletLoss ### - BatchHardSoftMarginTripletLoss ### - BatchSemiHardTripletLoss ####################################### #train_loss = losses.BatchHardTripletLoss(sentence_embedder=model) #train_loss = losses.BatchHardSoftMarginTripletLoss(sentence_embedder=model) train_loss = losses.BatchSemiHardTripletLoss(sentence_embedder=model) logging.info("Read TREC val dataset") dev_evaluator = TripletEvaluator.from_input_examples(dev_set, name='dev') logging.info("Performance before fine-tuning:") dev_evaluator(model) warmup_steps = int(len(train_dataset) * num_epochs / train_batch_size * 0.1) # 10% of train data # Train the model model.fit( train_objectives=[(train_dataloader, train_loss)], evaluator=dev_evaluator, epochs=num_epochs, evaluation_steps=1000, warmup_steps=warmup_steps, output_path=output_path,
pooling_mode_max_tokens=False) model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) logging.info("Read Triplet train dataset") train_dataset = SentencesDataset(examples=triplet_reader.get_examples( 'train.csv', max_examples=100000), model=model) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) train_loss = losses.TripletLoss(model=model) logging.info("Read Wikipedia Triplet dev dataset") evaluator = TripletEvaluator.from_input_examples(triplet_reader.get_examples( 'validation.csv', 1000), name='dev') warmup_steps = int(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data # Train the model model.fit(train_objectives=[(train_dataloader, train_loss)], evaluator=evaluator, epochs=num_epochs, evaluation_steps=1000, warmup_steps=warmup_steps, output_path=output_path) ############################################################################## #