Exemple #1
0
    def __init__(self, device='cpu', model=None):
        vocabsize = 37
        max_length = 50
        encoder_config = BertConfig(vocab_size=vocabsize,
                                    max_position_embeddings=max_length + 64,
                                    num_attention_heads=4,
                                    num_hidden_layers=4,
                                    hidden_size=128,
                                    type_vocab_size=1)
        encoder = BertModel(config=encoder_config)

        vocabsize = 33
        max_length = 50
        decoder_config = BertConfig(vocab_size=vocabsize,
                                    max_position_embeddings=max_length + 64,
                                    num_attention_heads=4,
                                    num_hidden_layers=4,
                                    hidden_size=128,
                                    type_vocab_size=1,
                                    add_cross_attentions=True,
                                    is_decoder=True)
        decoder_config.add_cross_attention = True
        decoder = BertLMHeadModel(config=decoder_config)

        # Define encoder decoder model
        self.model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
        self.model.to(device)
        self.device = device
        if model is not None:
            self.model.load_state_dict(torch.load(model))
Exemple #2
0
def train_model(config_path: str):
    writer = SummaryWriter()
    config = read_training_pipeline_params(config_path)
    logger.info("pretrained_emb {b}", b=config.net_params.pretrained_emb)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info("Device is {device}", device=device)
    SRC, TRG, dataset = get_dataset(config.dataset_path, False)
    train_data, valid_data, test_data = split_data(
        dataset, **config.split_ration.__dict__)
    SRC.build_vocab(train_data, min_freq=3)
    TRG.build_vocab(train_data, min_freq=3)
    torch.save(SRC.vocab, config.src_vocab_name)
    torch.save(TRG.vocab, config.trg_vocab_name)
    logger.info("Vocab saved")
    print(f"Unique tokens in source (ru) vocabulary: {len(SRC.vocab)}")
    print(f"Unique tokens in target (en) vocabulary: {len(TRG.vocab)}")
    train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
        (train_data, valid_data, test_data),
        batch_size=config.BATCH_SIZE,
        device=device,
        sort_key=_len_sort_key,
    )
    INPUT_DIM = len(SRC.vocab)
    OUTPUT_DIM = len(TRG.vocab)

    config_encoder = BertConfig(vocab_size=INPUT_DIM)
    config_decoder = BertConfig(vocab_size=OUTPUT_DIM)
    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        config_encoder, config_decoder)
    model = EncoderDecoderModel(config=config)
    config_encoder = model.config.encoder
    config_decoder = model.config.decoder
    config_decoder.is_decoder = True
    config_decoder.add_cross_attention = True
    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        config_encoder, config_decoder)
    model = EncoderDecoderModel(config=config)
    args = TrainingArguments(
        output_dir="output",
        evaluation_strategy="steps",
        eval_steps=500,
        per_device_train_batch_size=128,
        per_device_eval_batch_size=128,
        num_train_epochs=10,
        save_steps=3000,
        seed=0,
        load_best_model_at_end=True,
    )
    # args.place_model_on_device = device
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_iterator,
        eval_dataset=valid_iterator,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )
    trainer.train()

    model.save_pretrained("bert2bert")
Exemple #3
0
def get_model(vocab_size=30000):
    config_encoder = BertConfig()
    config_decoder = BertConfig()

    config_encoder.vocab_size = vocab_size
    config_decoder.vocab_size = vocab_size

    config_decoder.is_decoder = True
    config_decoder.add_cross_attention = True

    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        config_encoder, config_decoder)
    model = EncoderDecoderModel(config=config)

    return model
Exemple #4
0
    def __init__(self):
        super().__init__()
        encoder_config = BertConfig(num_hidden_layers=6,
                                    vocab_size=30522,
                                    hidden_size=512,
                                    num_attention_heads=8)
        self.encoder = BertModel(encoder_config)

        decoder_config = BertConfig(num_hidden_layers=6,
                                    vocab_size=30522,
                                    hidden_size=512,
                                    num_attention_heads=8)

        decoder_config.is_decoder = True
        decoder_config.add_cross_attention = True

        self.decoder = BertModel(decoder_config)

        self.linear = nn.Linear(
            512, 30522, bias=False)  # 21128 for chinese 30522 for English