示例#1
0
 def forward(self, x):
     tokens = self.token_embedding(x)
     b, t, e = tokens.size()
     # Get positional embeddings of the batch
     positions = self.pos_embedding(torch.arange(
         t, device=util.device()))[None, :, :].expand(b, t, e)
     # Unify embeddings
     x = self.unify_embeddings(
         torch.cat((tokens, positions), dim=2).view(-1,
                                                    2 * e)).view(b, t, e)
     # Run the batch through transformer blocks
     x = self.t_blocks(x)
     x = self.to_probs(x.view(b * t, e)).view(b, t, self.n_tokens)
     # Predicted log probability for each token based on preceding tokens
     return F.log_softmax(x, dim=2)
示例#2
0
def train(hidden_dim_sweep=(5, 10, 25),
          n_epochs=20,
          out_dir='out',
          data_dir='data',
          device=util.device(),
          Optimizer=optim.Adam,
          seed=42):
    out_dir, data_dir = map(Path, (out_dir, data_dir))
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    tracess = []
    best_trainer = None
    best_loss = util.INF
    vocab = util.Vocab.load(data_dir / 'vocab.txt')
    for hidden_dim in hidden_dim_sweep:
        model = Model(hidden_dim=hidden_dim, vocab=vocab, out_dim=2)
        loss_fn = nn.CrossEntropyLoss()
        optimizer = Optimizer(model.parameters(), lr=1e-4)
        trainer = Trainer(model, loss_fn, vocab, device)
        traces, loss_cur = trainer.train_loop(data_dir=data_dir,
                                              n_epochs=n_epochs,
                                              optimizer=optimizer,
                                              scheduler=None)
        if loss_cur < best_loss:
            best_trainer = trainer
            best_loss = loss_cur
        tracess.append((hidden_dim, traces))

    out_dir.mkdir(exist_ok=True)
    for h, traces in tracess:
        plotting.plot_traces(traces,
                             out=out_dir / f'traces_{h}.png',
                             title=f'Loss,hidden_dim={h}')
        util.jsondump(traces, out_dir / f'traces.dim_{h}.seed_{seed}.json')

    L.info('Best model loss: %s', best_loss)

    model_file = out_dir / 'model.pt'
    L.info('Saving best model to %s', model_file)
    torch.save(best_trainer.model.state_dict(), model_file)
示例#3
0
def test(hidden_dim=25, out_dir='out', data_dir='data', device=util.device()):
    out_dir, data_dir = map(Path, (out_dir, data_dir))
    vocab = util.Vocab.load(data_dir / 'vocab.txt')
    model = Model(hidden_dim=hidden_dim, vocab=vocab)
    model.load_state_dict(torch.load(out_dir / 'model.pt'))
    loss_fn = nn.CrossEntropyLoss()
    trainer = Trainer(model, loss_fn, vocab, device)

    with data.SentPairStream(data_dir / 'dev.tsv') as dev_data:
        dev_loader = torchdata.DataLoader(dev_data,
                                          shuffle=False,
                                          batch_size=8)
        dev_metrics = trainer.eval_(dev_loader)
        L.info('Dev performance: %s', dev_metrics)

    with data.SentPairStream(data_dir / 'test.tsv') as test_data:
        test_loader = torchdata.DataLoader(test_data,
                                           shuffle=False,
                                           batch_size=8)
        test_metrics = trainer.eval_(test_loader)
        L.info('Test performance: %s', test_metrics)
示例#4
0
 def initHidden(self):
     return torch.zeros(1, 1, self.hidden_size, device=util.device())
示例#5
0
import time
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np

from lang import Lang
import util
from util import prepareData
import const
from encoder_rnn import EncoderRNN
from attn_decoder_rnn import AttnDecoderRNN

#input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
#print(random.choice(pairs))
device = util.device()
teacher_forcing_ratio = 0.5


def train(input_tensor,
          target_tensor,
          encoder,
          decoder,
          encoder_optimizer,
          decoder_optimizer,
          criterion,
          max_length=const.MAX_LENGTH):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
示例#6
0
文件: generate.py 项目: cxcd/CPS803
def train(n_heads=8,
          depth=4,
          seq_length=32,
          n_tokens=256,
          emb_size=128,
          n_batches=500,
          batch_size=64,
          test_every=50,
          lr=0.0001,
          warmup=100,
          seed=-1,
          data_sub=1000,
          output_path="genmodel.pt"):
    """
	Train the model and save it to output_path
	"""
    # Seed the network
    if (seed < 0):
        seed = random.randint(0, 1000000)
        print("Using seed: ", seed)
    else:
        torch.manual_seed(seed)

    # Load training data
    data_train, data_valid = get_data()
    losses = []
    # Create the model
    model = tf.GenTransformer(emb=emb_size,
                              n_heads=n_heads,
                              depth=depth,
                              seq_length=seq_length,
                              n_tokens=n_tokens)
    if util.use_cuda():
        model = model.cuda()
    # Optimizer
    opt = torch.optim.Adam(model.parameters(), lr)
    # Train over batches of random sequences
    for i in tqdm.trange(n_batches - 1):  # tqdm is a nice progress bar
        # Warming up learning rate by linearly increasing to the provided learning rate
        if lr > 0 and i < warmup:
            lr = max((lr / warmup) * i, 1e-10)
            opt.lr = lr
        # Prevent gradient accumulation
        opt.zero_grad()
        # Sample batch of random subsequences
        starts = torch.randint(size=(batch_size, ),
                               low=0,
                               high=data_train.size(0) - seq_length - 1)
        seqs_source = [
            data_train[start:start + seq_length] for start in starts
        ]
        # The target is the same as the source sequence except one character ahead
        seqs_target = [
            data_train[start + 1:start + seq_length + 1] for start in starts
        ]
        source = torch.cat([s[None, :] for s in seqs_source],
                           dim=0).to(torch.long)
        target = torch.cat([s[None, :] for s in seqs_target],
                           dim=0).to(torch.long)
        # Get cuda
        if util.use_cuda():
            source, target = source.cuda(), target.cuda()
        source, target = Variable(source), Variable(target)
        # Initialize the output
        output = model(source)
        # Get the loss
        loss = F.nll_loss(output.transpose(2, 1), target, reduction='mean')
        loss.backward()
        losses.append(loss.item())
        # Clip the gradients
        nn.utils.clip_grad_norm_(model.parameters(), 1)

        # Perform optimization step
        opt.step()
        # Validate every so often, compute compression then generate
        if i != 0 and (i % test_every == 0 or i == n_batches - 1):
            # TODO sort of arbitrary, make this rigorous
            upto = data_valid.size(0) if i == n_batches - 1 else 100
            data_sub = data_valid[:upto]
            #
            with torch.no_grad():
                bits = 0.0
                # When this buffer is full we run it through the model
                batch = []
                for current in range(data_sub.size(0)):
                    fr = max(0, current - seq_length)
                    to = current + 1
                    context = data_sub[fr:to].to(torch.long)
                    # If the data doesnt fit the sequence length pad it
                    if context.size(0) < seq_length + 1:
                        pad = torch.zeros(size=(seq_length + 1 -
                                                context.size(0), ),
                                          dtype=torch.long)
                        context = torch.cat([pad, context], dim=0)
                        assert context.size(0) == seq_length + 1
                    # Get cuda
                    if util.use_cuda():
                        context = context.cuda()
                    # Fill the batch
                    batch.append(context[None, :])
                    # Check if the batch is full
                    if len(batch
                           ) == batch_size or current == data_sub.size(0) - 1:
                        # Run through model
                        b = len(batch)
                        all = torch.cat(batch, dim=0)
                        source = all[:, :-1]  # Input
                        target = all[:, -1]  # Target values
                        #
                        output = model(source)
                        # Get probabilities and convert to bits
                        lnprobs = output[torch.arange(b, device=util.device()),
                                         -1, target]
                        log2probs = lnprobs * math.log2(math.e)
                        # For logging
                        bits += log2probs.sum()
                        # Empty batch buffer
                        batch = []
                # Print validation performance
                bits_per_byte = abs(bits / data_sub.size(0))
                print(f' epoch {i}: {bits_per_byte:.4} bits per byte')
                print("Loss:", loss.item())
                # Monitor progress by generating data based on the validation data
                seedfr = random.randint(0, data_valid.size(0) - seq_length)
                input = data_valid[seedfr:seedfr + seq_length].to(torch.long)
                output_valid = gen(model, input)
                print("OUT:", output_valid[:30])
    util.save_model(model, output_path)
    return losses

    # Save the model when we're done training it
    #
    print("Finished training. Model saved to", output_path)