def __init__(self, device='cpu', model=None): vocabsize = 37 max_length = 50 encoder_config = BertConfig(vocab_size=vocabsize, max_position_embeddings=max_length + 64, num_attention_heads=4, num_hidden_layers=4, hidden_size=128, type_vocab_size=1) encoder = BertModel(config=encoder_config) vocabsize = 33 max_length = 50 decoder_config = BertConfig(vocab_size=vocabsize, max_position_embeddings=max_length + 64, num_attention_heads=4, num_hidden_layers=4, hidden_size=128, type_vocab_size=1, add_cross_attentions=True, is_decoder=True) decoder_config.add_cross_attention = True decoder = BertLMHeadModel(config=decoder_config) # Define encoder decoder model self.model = EncoderDecoderModel(encoder=encoder, decoder=decoder) self.model.to(device) self.device = device if model is not None: self.model.load_state_dict(torch.load(model))
def train_model(config_path: str): writer = SummaryWriter() config = read_training_pipeline_params(config_path) logger.info("pretrained_emb {b}", b=config.net_params.pretrained_emb) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info("Device is {device}", device=device) SRC, TRG, dataset = get_dataset(config.dataset_path, False) train_data, valid_data, test_data = split_data( dataset, **config.split_ration.__dict__) SRC.build_vocab(train_data, min_freq=3) TRG.build_vocab(train_data, min_freq=3) torch.save(SRC.vocab, config.src_vocab_name) torch.save(TRG.vocab, config.trg_vocab_name) logger.info("Vocab saved") print(f"Unique tokens in source (ru) vocabulary: {len(SRC.vocab)}") print(f"Unique tokens in target (en) vocabulary: {len(TRG.vocab)}") train_iterator, valid_iterator, test_iterator = BucketIterator.splits( (train_data, valid_data, test_data), batch_size=config.BATCH_SIZE, device=device, sort_key=_len_sort_key, ) INPUT_DIM = len(SRC.vocab) OUTPUT_DIM = len(TRG.vocab) config_encoder = BertConfig(vocab_size=INPUT_DIM) config_decoder = BertConfig(vocab_size=OUTPUT_DIM) config = EncoderDecoderConfig.from_encoder_decoder_configs( config_encoder, config_decoder) model = EncoderDecoderModel(config=config) config_encoder = model.config.encoder config_decoder = model.config.decoder config_decoder.is_decoder = True config_decoder.add_cross_attention = True config = EncoderDecoderConfig.from_encoder_decoder_configs( config_encoder, config_decoder) model = EncoderDecoderModel(config=config) args = TrainingArguments( output_dir="output", evaluation_strategy="steps", eval_steps=500, per_device_train_batch_size=128, per_device_eval_batch_size=128, num_train_epochs=10, save_steps=3000, seed=0, load_best_model_at_end=True, ) # args.place_model_on_device = device trainer = Trainer( model=model, args=args, train_dataset=train_iterator, eval_dataset=valid_iterator, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], ) trainer.train() model.save_pretrained("bert2bert")
def get_model(vocab_size=30000): config_encoder = BertConfig() config_decoder = BertConfig() config_encoder.vocab_size = vocab_size config_decoder.vocab_size = vocab_size config_decoder.is_decoder = True config_decoder.add_cross_attention = True config = EncoderDecoderConfig.from_encoder_decoder_configs( config_encoder, config_decoder) model = EncoderDecoderModel(config=config) return model
def __init__(self): super().__init__() encoder_config = BertConfig(num_hidden_layers=6, vocab_size=30522, hidden_size=512, num_attention_heads=8) self.encoder = BertModel(encoder_config) decoder_config = BertConfig(num_hidden_layers=6, vocab_size=30522, hidden_size=512, num_attention_heads=8) decoder_config.is_decoder = True decoder_config.add_cross_attention = True self.decoder = BertModel(decoder_config) self.linear = nn.Linear( 512, 30522, bias=False) # 21128 for chinese 30522 for English