コード例 #1
0
ファイル: oteann.py プロジェクト: marxav/oteann2
def train(config):

    training_t0 = datetime.datetime.now()

    block_size = config['block_size']

    print("config['train_filename']:", config['train_filename'])
    text = open(config['train_filename'], 'r').read()
    train_dataset = CharDataset(config['chars'], text, block_size,
                                debug=True)  # one line is 63 characters

    # create model
    mconf = GPTConfig(train_dataset.vocab_size,
                      train_dataset.block_size,
                      n_layer=config['n_layer'],
                      n_head=config['n_head'],
                      n_embd=config['n_embd'])

    model = GPT(mconf)

    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print('ANN parameters: %d' % pytorch_total_params)

    # train
    tconf = TrainerConfig(max_epochs=2,
                          batch_size=config['batch_size'],
                          learning_rate=6e-4,
                          lr_decay=True,
                          warmup_tokens=512 * 20,
                          final_tokens=2 * len(train_dataset) * block_size,
                          num_workers=4,
                          tqdm=False)  # not config['do_finetune'])
    trainer = Trainer(model, train_dataset, None, tconf)
    trainer.train()
    training_t1 = datetime.datetime.now()
    training_duration = training_t1 - training_t0
    print('training_duration', training_duration)

    torch.save(model.state_dict(), config['model_filename'])

    return model
コード例 #2
0
mconf = GPTConfig(train_dataset.vocab_size,
                  train_dataset.block_size,
                  n_layer=8,
                  n_head=8,
                  n_embd=512)
model = GPT(mconf)

from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=4,
                      batch_size=512,
                      learning_rate=6e-4,
                      lr_decay=True,
                      warmup_tokens=512 * 20,
                      final_tokens=2 * len(train_dataset) * block_size,
                      num_workers=4)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()

torch.save(model.state_dict(), "model.pth")

# alright, let's sample some character-level Shakespeare
from mingpt.utils import sample

context = "{Name}:\tDemo Text"
x = torch.tensor([train_dataset.stoi[s] for s in context],
                 dtype=torch.long)[None, ...].to(trainer.device)
y = sample(model, x, 2000, temperature=1.0, sample=True, top_k=10)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)