Exemplo n.º 1
0
    def __call__(self, x, y, norm):
        x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)) / norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss.item() * norm


if __name__ == '__main__':
    # Greedy decoding
    V = 11
    criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
    model = make_model(V, V, n=2)
    model = model.cuda()
    model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400,
                        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    for epoch in range(10):
        model.train()
        run_epoch(data_gen(V, 30, 20), model, SimpleLossCompute(model.generator, criterion, model_opt))
        model.eval()
        print(run_epoch(data_gen(V, 30, 5), model, SimpleLossCompute(model.generator, criterion, None)))

    model.eval()
    src = Variable(torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])).cuda()
    src_mask = Variable(torch.ones(1, 1, 10)).cuda()
    print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1))
Exemplo n.º 2
0
def train(
    train_path,
    val_path, 
    save_path,
    n_layers = 6,
    model_dim = 512,
    feedforward_dim = 2048,
    n_heads = 8,
    dropout_rate = 0.1,
    n_epochs = 10,
    max_len = 60,
    min_freq = 10,
    max_val_outputs = 20):

    train, val, TGT, SRC, EOS_WORD, BOS_WORD, BLANK_WORD = get_dataset(train_path, val_path, min_freq)

    #torch.save(SRC.vocab, 'models/electronics/src_vocab.pt')
    #torch.save(TGT.vocab, 'models/electronics/trg_vocab.pt')
    SRC.vocab = torch.load('models/electronics/src_vocab.pt')
    TGT.vocab = torch.load('models/electronics/trg_vocab.pt')
    pad_idx = TGT.vocab.stoi[BLANK_WORD]

    # model = make_model(len(SRC.vocab), len(TGT.vocab),
    #                      n=n_layers, d_model=model_dim,
    #                      d_ff=feedforward_dim, h=n_heads,
    #                      dropout=dropout_rate)
    model = torch.load('models/electronics/electronics_autoencoder_epoch3.pt')
    model.cuda()
    criterion = LabelSmoothing(size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1)
    criterion.cuda()
    BATCH_SIZE = 2048  # Was 12000, but I only have 12 GB RAM on my single GPU.
    train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=0, repeat=False, #Faster with device warning
                            sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batch_size_fn, train=True)
    valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=0, repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batch_size_fn, train=False)
    model_par = nn.DataParallel(model, device_ids=devices)

    model_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000,
                        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    for epoch in range(n_epochs):
        model_par.train()
        run_epoch((rebatch(pad_idx, b) for b in train_iter), model_par,
                  MultiGPULossCompute(model.generator, criterion, devices=devices, opt=model_opt))
        save_name = save_path + '_epoch' + str(epoch + 4) + '.pt'
        torch.save(model, save_name)
        model_par.eval()
        loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model_par,
                         MultiGPULossCompute(model.generator, criterion, devices=devices, opt=None))
        print(loss)

    for i, batch in enumerate(valid_iter):
        if i > max_val_outputs:
            break
        src = batch.src.transpose(0, 1)[:1].cuda()
        src_mask = (src != SRC.vocab.stoi[BLANK_WORD]).unsqueeze(-2).cuda()
        out = greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=TGT.vocab.stoi[BOS_WORD])
        print('Translation:', end='\t')
        for i in range(1, out.size(1)):
            sym = TGT.vocab.itos[out[0, i]]
            if sym == EOS_WORD:
                break
            print(sym, end=' ')
        print()
        print('Target:', end='\t')
        for j in range(batch.trg.size(0)):
            sym = TGT.vocab.itos[batch.trg.data[j, 0]]
            if sym == EOS_WORD:
                break
            print(sym, end=' ')
        print()
Exemplo n.º 3
0
                            device=0,
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=False)
    model_par = nn.DataParallel(model, device_ids=devices)

    model_opt = NoamOpt(
        model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
                         eps=1e-9))
    for epoch in range(10):
        model_par.train()
        run_epoch((rebatch(pad_idx, b) for b in train_iter), model_par,
                  MultiGPULossCompute(model.generator,
                                      criterion,
                                      devices=devices,
                                      opt=model_opt))

        model_par.eval()
        loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model_par,
                         MultiGPULossCompute(model.generator,
                                             criterion,
                                             devices=devices,
                                             opt=None))
        print(loss)
else:
    model = torch.load('iwslt.pt')

for i, batch in enumerate(valid_iter):
    src = batch.src.transpose(0, 1)[:1]
Exemplo n.º 4
0

if __name__ == "__main__":
    criterion = LabelSmoothing(size=args.V,
                               padding_idx=0,
                               smoothing=args.smoothing)
    model = make_model(args.V, args.V, args.N)
    model_opt = NoamOpt(model_size=model.src_embed[0].d_model,
                        factor=args.factor,
                        warm_up=args.warm_up,
                        optimizer=Adam(model.parameters(),
                                       lr=0,
                                       betas=args.betas,
                                       eps=args.eps))

    for epoch_index in range(args.num_epochs):
        model.train()
        run_epoch(
            data_generate(args.V, args.batch_size, args.num_batches_train),
            model, SimpleLossCompute(model.generator, criterion, model_opt))

        model.eval()
        run_epoch(
            data_generate(args.V, args.batch_size, args.num_batched_eval),
            model, SimpleLossCompute(model.generator, criterion, None))

    model.eval()
    src = Variable(torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]))
    src_mask = Variable(torch.ones(1, 1, 10))
    print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1))