예제 #1
0
파일: oteann.py 프로젝트: marxav/oteann2
def get_model(config):
    # following two lines are copied from train()
    block_size = config['block_size']
    text = open(config['train_filename'], 'r').read()
    train_dataset = CharDataset(config['chars'], text, config['block_size'])

    mconf = GPTConfig(train_dataset.vocab_size,
                      train_dataset.block_size,
                      n_layer=config['n_layer'],
                      n_head=config['n_head'],
                      n_embd=config['n_embd'])
    model = GPT(mconf)
    model.load_state_dict(torch.load(config['model_filename']))
    model.eval()
    return model
예제 #2
0
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y

block_size = 128 # spatial extent of the model for its context
text = open('chat/all.txt', 'r').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters

# Load Model
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf).cuda()

# Load weight
model.load_state_dict(torch.load("model.pth"))

from mingpt.utils import sample
def run(context):
    x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].cuda()
    y = sample(model, x, 500, temperature=1.0, sample=True, top_k=10)[0]
    completion = ''.join([train_dataset.itos[int(i)] for i in y])
    # print(completion)
    return completion

# Bot
import telebot
from telebot import types
import configparser
config = configparser.ConfigParser()
config.read('config.ini')
예제 #3
0
tokens_per_epoch = len(train_data) * train_dataset.block_size
train_epochs = 20 # todo run a bigger model and longer, this is tiny


# initalize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs = train_epochs, batch_size = 16*8, learning_rate=3e-3,
                    betas = (0.9, 0.95), weight_decay=0,
                    lr_decay=True, warmup_tokens=tokens_per_epoch, final_tokens=train_epochs*tokens_per_epoch,
                    ckpt_path = 'cifar10_model.pt',
                    num_workers=8)   
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

# load the state of the best model we've seen based on early stopping
checkpoint = torch.load('cifar10_model.pt')
model.load_state_dict(checkpoint)

# to sample we also have to technically "train" a separate model for the first token in the sequence
# we are going to do so below simply by calculating and normalizing the histogram of the first token
counts = torch.ones(ncluster) # start counts as 1 not zero, this is called "smoothing"
rp = torch.randperm(len(train_dataset))
nest = 5000 # how many images to use for the estimation
for i in range(nest):
    a, _ = train_dataset[int(rp[i])]
    t = a[0].item() # index of first token in the sequence
    counts[t] +=1
prob = counts/counts.sum()

%%time

from mingpt.utils import sample
예제 #4
0
## set up model (TODO: better way to handle the model config)
mconf = GPTConfig(train_dataset.vocab_size,
                  train_dataset.block_size,
                  embd_pdrop=0.0,
                  resid_pdrop=0.0,
                  attn_pdrop=0.0,
                  n_layer=24,
                  n_head=8,
                  n_embd=512)
model = GPT(mconf)

# load the model
print("Loading model")
model_ckpt = torch.load(args.model_cache)
model.load_state_dict(model_ckpt['model_state_dict'])

if torch.cuda.is_available():
    model = model.cuda()

if args.condition == 'uncond':
    # generate some samples unconditionally
    print("Generating unconditional samples")
    generate_samples(model, train_dataset, 32)
elif args.condition == 'half' or args.condition == 'chimera':
    # generate samples conditioned on upper half
    img_dir = '/scratch/eo41/minGPT/frames_for_half_3'
    print("Generating samples from upper half of images at {}".format(img_dir))
    x_data = torchvision.datasets.ImageFolder(
        img_dir,
        torchvision.transforms.Resize(