Exemplo n.º 1
0
def main(args):
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    # Set hyper parameter
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = make_model(src_vocab_size, tgt_vocab_size).to(device)
    optimizer = get_std_opt(model)
    criterion = LabelSmoothing(size=tgt_vocab_size,
                               padding_idx=pad_idx,
                               smoothing=0.1)
    train_criterion = SimpleLossCompute(model.generator, criterion, optimizer)
    valid_criterion = SimpleLossCompute(model.generator, criterion, None)
    print('Using device:', device)

    if not args.test:
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        best_loss = 987654321
        for epoch in range(args.epochs):
            train_total_loss, valid_total_loss = 0.0, 0.0
            start = time.time()
            total_tokens = 0
            tokens = 0

            model.train()
            # Train
            for src_batch, tgt_batch in train_loader:
                src_batch = torch.tensor(src_batch).to(device)
                tgt_batch = torch.tensor(tgt_batch).to(device)
                batch = Batch(src_batch, tgt_batch, pad_idx)

                prediction = model(batch.src, batch.trg, batch.src_mask,
                                   batch.trg_mask)
                loss = train_criterion(prediction, batch.trg_y, batch.ntokens)

                train_total_loss += loss
                total_tokens += batch.ntokens
                tokens += batch.ntokens

            # Valid
            model.eval()
            for src_batch, tgt_batch in valid_loader:
                src_batch = torch.tensor(src_batch).to(device)
                tgt_batch = torch.tensor(tgt_batch).to(device)
                batch = Batch(src_batch, tgt_batch, pad_idx)

                prediction = model(batch.src, batch.trg, batch.src_mask,
                                   batch.trg_mask)
                loss = valid_criterion(prediction, batch.trg_y, batch.ntokens)
                valid_total_loss += loss
                total_tokens += batch.ntokens
                tokens += batch.ntokens

            if valid_total_loss.item() < best_loss:
                best_loss = valid_total_loss
                best_model_state = model.state_dict()
                best_optimizer_state = optimizer.optimizer.state_dict()

            elpsed = time.time() - start
            print(
                time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "|| [" +
                str(epoch) + "/" + str(args.epochs) + "], train_loss = " +
                str(train_total_loss.item()) + ", valid_loss = " +
                str(valid_total_loss.item()) + ", Tokens per Sec = " +
                str(tokens.item() / elpsed))
            tokens = 0
            start = time.time()

            if epoch % 100 == 0:
                # Save model
                torch.save(
                    {
                        'epoch': args.epochs,
                        'model_state_dict': best_model_state,
                        'optimizer_state': best_optimizer_state,
                        'loss': best_loss
                    }, args.model_dir + "/intermediate.pt")
                print("Model saved")

        # Save model
        torch.save(
            {
                'epoch': args.epochs,
                'model_state_dict': best_model_state,
                'optimizer_state': best_optimizer_state,
                'loss': best_loss
            }, args.model_dir + "/best.pt")
        print("Model saved")
    else:
        # Load the model
        checkpoint = torch.load(args.model_dir + "/" + args.model_name,
                                map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.optimizer.load_state_dict(checkpoint['optimizer_state'])
        model.eval()
        print("Model loaded")

        # Test
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        pred = []

        for src_batch, tgt_batch in test_loader:
            src_batch = torch.tensor(src_batch).to(device)
            tgt_batch = torch.tensor(tgt_batch).to(device)
            batch = Batch(src_batch, tgt_batch, pad_idx)

            # Get pred_batch
            memory = model.encode(batch.src, batch.src_mask)
            pred_batch = torch.ones(src_batch.size(0), 1)\
                            .fill_(sos_idx).type_as(batch.src.data).to(device)
            for i in range(max_length - 1):
                out = model.decode(
                    memory, batch.src_mask, Variable(pred_batch),
                    Variable(
                        Batch.make_std_mask(pred_batch,
                                            pad_idx).type_as(batch.src.data)))
                prob = model.generator(out[:, -1])
                prob.index_fill_(1,
                                 torch.tensor([sos_idx, pad_idx]).to(device),
                                 -float('inf'))
                _, next_word = torch.max(prob, dim=1)

                pred_batch = torch.cat(
                    [pred_batch, next_word.unsqueeze(-1)], dim=1)
            pred_batch = torch.cat([pred_batch, torch.ones(src_batch.size(0), 1)\
                                                    .fill_(eos_idx).type_as(batch.src.data).to(device)], dim=1)

            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred += seq2sen(pred_batch.tolist(), tgt_vocab)

        with open('results/pred.txt', 'w', encoding='utf-8') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
Exemplo n.º 2
0
def main(args):

    # 0. initial setting

    # set environmet
    cudnn.benchmark = True

    if not os.path.isdir('./ckpt'):
        os.mkdir('./ckpt')
    if not os.path.isdir('./results'):
        os.mkdir('./results')
    if not os.path.isdir(os.path.join('./ckpt', args.name)):
        os.mkdir(os.path.join('./ckpt', args.name))
    if not os.path.isdir(os.path.join('./results', args.name)):
        os.mkdir(os.path.join('./results', args.name))
    if not os.path.isdir(os.path.join('./results', args.name, "log")):
        os.mkdir(os.path.join('./results', args.name, "log"))

    # set logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    handler = logging.FileHandler("results/{}/log/{}.log".format(
        args.name, time.strftime('%c', time.localtime(time.time()))))
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.addHandler(logging.StreamHandler())
    args.logger = logger

    # set cuda
    if torch.cuda.is_available():
        args.logger.info("running on cuda")
        args.device = torch.device("cuda")
        args.use_cuda = True
    else:
        args.logger.info("running on cpu")
        args.device = torch.device("cpu")
        args.use_cuda = False

    args.logger.info("[{}] starts".format(args.name))

    # 1. load data

    args.logger.info("loading data...")
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    # 2. setup

    args.logger.info("setting up...")

    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    # transformer config
    d_e = 512  # embedding size
    d_q = 64  # query size (= key, value size)
    d_h = 2048  # hidden layer size in feed forward network
    num_heads = 8
    num_layers = 6  # number of encoder/decoder layers in encoder/decoder

    args.sos_idx = sos_idx
    args.eos_idx = eos_idx
    args.pad_idx = pad_idx
    args.max_length = max_length
    args.src_vocab_size = src_vocab_size
    args.tgt_vocab_size = tgt_vocab_size
    args.d_e = d_e
    args.d_q = d_q
    args.d_h = d_h
    args.num_heads = num_heads
    args.num_layers = num_layers

    model = Transformer(args)
    model.to(args.device)
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

    if args.load:
        model.load_state_dict(load(args, args.ckpt))

    # 3. train / test

    if not args.test:
        # train
        args.logger.info("starting training")
        acc_val_meter = AverageMeter(name="Acc-Val (%)",
                                     save_all=True,
                                     save_dir=os.path.join(
                                         'results', args.name))
        train_loss_meter = AverageMeter(name="Loss",
                                        save_all=True,
                                        save_dir=os.path.join(
                                            'results', args.name))
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        for epoch in range(1, 1 + args.epochs):
            spent_time = time.time()
            model.train()
            train_loss_tmp_meter = AverageMeter()
            for src_batch, tgt_batch in tqdm(train_loader):
                # src_batch: (batch x source_length), tgt_batch: (batch x target_length)
                optimizer.zero_grad()
                src_batch, tgt_batch = torch.LongTensor(src_batch).to(
                    args.device), torch.LongTensor(tgt_batch).to(args.device)
                batch = src_batch.shape[0]
                # split target batch into input and output
                tgt_batch_i = tgt_batch[:, :-1]
                tgt_batch_o = tgt_batch[:, 1:]

                pred = model(src_batch.to(args.device),
                             tgt_batch_i.to(args.device))
                loss = loss_fn(pred.contiguous().view(-1, tgt_vocab_size),
                               tgt_batch_o.contiguous().view(-1))
                loss.backward()
                optimizer.step()

                train_loss_tmp_meter.update(loss / batch, weight=batch)

            train_loss_meter.update(train_loss_tmp_meter.avg)
            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] train loss: {:.3f} took {:.1f} seconds".format(
                    epoch, train_loss_tmp_meter.avg, spent_time))

            # validation
            model.eval()
            acc_val_tmp_meter = AverageMeter()
            spent_time = time.time()

            for src_batch, tgt_batch in tqdm(valid_loader):
                src_batch, tgt_batch = torch.LongTensor(
                    src_batch), torch.LongTensor(tgt_batch)
                tgt_batch_i = tgt_batch[:, :-1]
                tgt_batch_o = tgt_batch[:, 1:]

                with torch.no_grad():
                    pred = model(src_batch.to(args.device),
                                 tgt_batch_i.to(args.device))

                corrects, total = val_check(
                    pred.max(dim=-1)[1].cpu(), tgt_batch_o)
                acc_val_tmp_meter.update(100 * corrects / total, total)

            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] validation accuracy: {:.1f} %, took {} seconds".format(
                    epoch, acc_val_tmp_meter.avg, spent_time))
            acc_val_meter.update(acc_val_tmp_meter.avg)

            if epoch % args.save_period == 0:
                save(args, "epoch_{}".format(epoch), model.state_dict())
                acc_val_meter.save()
                train_loss_meter.save()
    else:
        # test
        args.logger.info("starting test")
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)
        pred_list = []
        model.eval()

        for src_batch, tgt_batch in test_loader:
            #src_batch: (batch x source_length)
            src_batch = torch.Tensor(src_batch).long().to(args.device)
            batch = src_batch.shape[0]
            pred_batch = torch.zeros(batch, 1).long().to(args.device)
            pred_mask = torch.zeros(batch, 1).bool().to(
                args.device)  # mask whether each sentece ended up

            with torch.no_grad():
                for _ in range(args.max_length):
                    pred = model(
                        src_batch,
                        pred_batch)  # (batch x length x tgt_vocab_size)
                    pred[:, :, pad_idx] = -1  # ignore <pad>
                    pred = pred.max(dim=-1)[1][:, -1].unsqueeze(
                        -1)  # next word prediction: (batch x 1)
                    pred = pred.masked_fill(
                        pred_mask,
                        2).long()  # fill out <pad> for ended sentences
                    pred_mask = torch.gt(pred.eq(1) + pred.eq(2), 0)
                    pred_batch = torch.cat([pred_batch, pred], dim=1)
                    if torch.prod(pred_mask) == 1:
                        break

            pred_batch = torch.cat([
                pred_batch,
                torch.ones(batch, 1).long().to(args.device) + pred_mask.long()
            ],
                                   dim=1)  # close all sentences
            pred_list += seq2sen(pred_batch.cpu().numpy().tolist(), tgt_vocab)

        with open('results/pred.txt', 'w', encoding='utf-8') as f:
            for line in pred_list:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
Exemplo n.º 3
0
def main(args):
    src, tgt = load_data(args.path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    N = 6
    dim = 512

    # MODEL Construction
    encoder = Encoder(N, dim, pad_idx, src_vocab_size, device).to(device)
    decoder = Decoder(N, dim, pad_idx, tgt_vocab_size, device).to(device)

    if args.model_load:
        ckpt = torch.load("drive/My Drive/checkpoint/best.ckpt")
        encoder.load_state_dict(ckpt["encoder"])
        decoder.load_state_dict(ckpt["decoder"])

    params = list(encoder.parameters()) + list(decoder.parameters())

    if not args.test:
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        warmup = 4000
        steps = 1
        lr = 1. * (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5))
        optimizer = torch.optim.Adam(params,
                                     lr=lr,
                                     betas=(0.9, 0.98),
                                     eps=1e-09)

        train_losses = []
        val_losses = []
        latest = 1e08  # to store latest checkpoint

        start_epoch = 0

        if (args.model_load):
            start_epoch = ckpt["epoch"]
            optimizer.load_state_dict(ckpt["optim"])
            steps = start_epoch * 30

        for epoch in range(start_epoch, args.epochs):

            for src_batch, tgt_batch in train_loader:
                encoder.train()
                decoder.train()
                optimizer.zero_grad()
                tgt_batch = torch.LongTensor(tgt_batch)

                src_batch = Variable(torch.LongTensor(src_batch)).to(device)
                gt = Variable(tgt_batch[:, 1:]).to(device)
                tgt_batch = Variable(tgt_batch[:, :-1]).to(device)

                enc_output, seq_mask = encoder(src_batch)
                dec_output = decoder(tgt_batch, enc_output, seq_mask)

                gt = gt.view(-1)
                dec_output = dec_output.view(gt.size()[0], -1)

                loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx)
                loss.backward()
                train_losses.append(loss.item())
                optimizer.step()

                steps += 1
                lr = (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5))
                update_lr(optimizer, lr)

                if (steps % 10 == 0):
                    print("loss : %f" % loss.item())

            for src_batch, tgt_batch in valid_loader:
                encoder.eval()
                decoder.eval()

                src_batch = Variable(torch.LongTensor(src_batch)).to(device)
                tgt_batch = torch.LongTensor(tgt_batch)
                gt = Variable(tgt_batch[:, 1:]).to(device)
                tgt_batch = Variable(tgt_batch[:, :-1]).to(device)

                enc_output, seq_mask = encoder(src_batch)
                dec_output = decoder(tgt_batch, enc_output, seq_mask)

                gt = gt.view(-1)
                dec_output = dec_output.view(gt.size()[0], -1)

                loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx)

                val_losses.append(loss.item())
            print("[EPOCH %d] Loss %f" % (epoch, loss.item()))

            if (val_losses[-1] <= latest):
                checkpoint = {'encoder':encoder.state_dict(), 'decoder':decoder.state_dict(), \
                    'optim':optimizer.state_dict(), 'epoch':epoch}
                torch.save(checkpoint, "drive/My Drive/checkpoint/best.ckpt")
                latest = val_losses[-1]

            if (epoch % 20 == 0):
                plt.figure()
                plt.plot(val_losses)
                plt.xlabel("epoch")
                plt.ylabel("model loss")
                plt.show()

    else:
        # test
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        # LOAD CHECKPOINT

        pred = []
        for src_batch, tgt_batch in test_loader:
            encoder.eval()
            decoder.eval()

            b_s = min(args.batch_size, len(src_batch))
            tgt_batch = torch.zeros(b_s, 1).to(device).long()
            src_batch = Variable(torch.LongTensor(src_batch)).to(device)

            enc_output, seq_mask = encoder(src_batch)
            pred_batch = decoder(tgt_batch, enc_output, seq_mask)
            _, pred_batch = torch.max(pred_batch, 2)

            while (not is_finished(pred_batch, max_length, eos_idx)):
                # do something
                next_input = torch.cat((tgt_batch, pred_batch.long()), 1)
                pred_batch = decoder(next_input, enc_output, seq_mask)
                _, pred_batch = torch.max(pred_batch, 2)
            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred_batch = pred_batch.tolist()
            for line in pred_batch:
                line[-1] = 1
            pred += seq2sen(pred_batch, tgt_vocab)
            # print(pred)

        with open('results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
Exemplo n.º 4
0
def main(args):
    eos_idx = 1
    pad_idx = -1
    pad_val = 0.0

    feature_dim = 40
    listener_hidden_dim = 256
    num_of_pyramidal_layers = 3
    speller_hidden_dim = 512
    attention_hidden_dim = 128
    num_of_classes = 30

    learning_rate = 0.2
    geometric_decay = 0.98

    device = torch.device("cuda" if(torch.cuda.is_available()) else "cpu")
    listener = Listener(feature_dim, listener_hidden_dim, num_of_pyramidal_layers).to(device)
    speller = Speller(speller_hidden_dim, listener_hidden_dim, attention_hidden_dim, num_of_classes, device).to(device)
    #print(device, listener, speller)

    if not args.test:
        # train
        src, trg = load_train_data(args.path)
        train_loader = DataLoader(src, trg, args.batch_size, pad_idx, pad_val)

        criterion = nn.CrossEntropyLoss(ignore_index = pad_idx)
        #optimizer = torch.optim.ASGD([{'params':listener.parameters()}, {'params':speller.parameters()}], lr=learning_rate)
        optimizer = torch.optim.Adam([{'params':listener.parameters()}, {'params':speller.parameters()}], lr = 0.0001)

        print('Start training ...')
        for epoch in range(args.epochs):
            start_epoch = time.time()
            i = 0

            for src_batch, tgt_batch in train_loader:
                batch_start = time.time()

                src_batch = torch.tensor(src_batch).to(device)
                trg_batch = torch.tensor(tgt_batch).to(device)
                
                trg_input = trg_batch[:,:-1] 
                trg_output = trg_batch[:,1:].contiguous().view(-1)
                
                h = listener(src_batch)
                preds = speller(trg_input, h)
                
                # lr decay for every 1/20 epoch
                #if (i+1) % ((train_loader.size//args.batch_size)//20) is 0 :
                #    learning_rate = geometric_decay * learning_rate
                #    print('learing rate decayed : %.4f'%(learning_rate))
                #    for group in optimizer.param_groups:
                #        group['lr'] = learning_rate

                optimizer.zero_grad()

                loss = criterion(preds.view(-1, preds.size(-1)), trg_output)
                loss.backward()

                optimizer.step()

                i = i+1
                
                # flush the GPU cache
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                batch_time = time.time() - batch_start
                print('[%d/%d][%d/%d] train loss : %.4f | time : %.2fs'%(epoch+1, 100, i, train_loader.size//args.batch_size + 1, loss.item(), batch_time))
                
            epoch_time = time.time() - start_epoch
            print('Time taken for %d epoch : %.2fs'%(epoch+1, epoch_time))

            save_checkpoint(listener, speller, 'checkpoints/epoch_%d_'%(epoch+1))

        print('End of the training')
        save_checkpoint(listener, speller, 'checkpoints/final_')
    else:
        if os.path.exists(args.checkpoint + 'listener') and os.path.exists(args.checkpoint + 'speller'):
            listener_checkpoint = torch.load(args.checkpoint + 'listener')
            listener.load_state_dict(listener_checkpoint['state_dict'])
            print("trained model " + args.checkpoint + "listener is loaded")

            speller_checkpoint = torch.load(args.checkpoint + 'speller')
            speller.load_state_dict(speller_checkpoint['state_dict'])
            print("trained model " + args.checkpoint + "speller is loaded")

        # test
        src, trg = load_test_data(args.path)
        test_loader = DataLoader(src, trg, args.batch_size, pad_idx, pad_val)
        mapping = get_mapping(args.path)

        j = 0
        pred = []
        ref = []
        for src_batch, trg_batch in test_loader:
            # predict pred_batch from src_batch with your model.
            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            start_batch = time.time()

            src_batch = torch.tensor(src_batch).to(device)
            trg_batch = torch.tensor(trg_batch).to(device)
            
            max_length = trg_batch.size(1)
            
            pred_batch = torch.zeros(args.batch_size, 1, dtype = int).to(device) # [batch, 1] = [[0],[0],...,[0]]
            
            # eos_mask[i] = 1 means i-th sentence has eos
            eos_mask = torch.zeros(args.batch_size, dtype = int)
            
            h = listener(src_batch)
            
            for k in range(max_length):
                start = time.time()
                output = speller(pred_batch, h) # [batch, k+1, num_class]

                # greedy search
                output = torch.argmax(F.softmax(output, dim = -1), dim = -1) # [batch_size, k+1]
                predictions = output[:,-1].unsqueeze(1)
                pred_batch = torch.cat([pred_batch, predictions], dim = -1)

                for i in range(args.batch_size):
                    if predictions[i] == eos_idx:
                        eos_mask[i] = 1

                # every sentence has eos
                if eos_mask.sum() == args.batch_size :
                    break
                
                t = time.time() - start
                print("[%d/%d][%d/%d] prediction done | time : %.2fs"%(j, test_loader.size // args.batch_size + 1, k+1, max_length, t))
            j += 1

            # flush the GPU cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            print("[%d/%d] prediction done | time : %.2fs"%(j, test_loader.size // args.batch_size + 1, time.time() - start_batch))
            pred += seq2sen(pred_batch.cpu().numpy().tolist(), mapping)
            ref += seq2sen(trg_batch.cpu().numpy().tolist(), mapping)

            if j % 10 == 0:
                WER = word_error_rate(ref, pred)
                print("Test [%d/%d] : WER %.2f%%"%(j, test_loader.size // args.batch_size + 1, WER))

            with open('results/pred_%d.txt'%(j), 'w') as f:
                for line in pred:
                    f.write('{}\n'.format(line))

            with open('results/ref_%d.txt'%(j), 'w') as f:
                for line in ref:
                    f.write('{}\n'.format(line))

        WER = word_error_rate(ref, pred)
        print("Test End : WER %.2f%%"%(WER))
Exemplo n.º 5
0
def main(args):
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    vsize_src = len(src_vocab)
    vsize_tar = len(tgt_vocab)
    net = Transformer(vsize_src, vsize_tar)

    if not args.test:

        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        net.to(device)
        optimizer = optim.Adam(net.parameters(), lr=args.lr)

        best_valid_loss = 10.0
        for epoch in range(args.epochs):
            print("Epoch {0}".format(epoch))
            net.train()
            train_loss = run_epoch(net, train_loader, optimizer)
            print("train loss: {0}".format(train_loss))
            net.eval()
            valid_loss = run_epoch(net, valid_loader, None)
            print("valid loss: {0}".format(valid_loss))
            torch.save(net, 'data/ckpt/last_model')
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(net, 'data/ckpt/best_model')
    else:
        # test
        net = torch.load('data/ckpt/best_model')
        net.to(device)
        net.eval()

        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        pred = []
        iter_cnt = 0
        for src_batch, tgt_batch in test_loader:
            source, src_mask = make_tensor(src_batch)
            source = source.to(device)
            src_mask = src_mask.to(device)
            res = net.decode(source, src_mask)
            pred_batch = res.tolist()
            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred += seq2sen(pred_batch, tgt_vocab)
            iter_cnt += 1
            #print(pred_batch)

        with open('data/results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh data/results/pred.txt data/multi30k/test.de.atok'
        )
Exemplo n.º 6
0
def main(args):
    # constant definition
    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    a_dim = 512
    h_dim = 512
    attn_dim = 512
    embed_dim = 512
    regularize_constant = 1.  # lambda * L => lambda = 1/L

    vocabulary = torch.load(args.voca_path)
    vocab_size = len(vocabulary)

    device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
    encoder = Encoder().to(device)
    decoder = Decoder(a_dim, h_dim, attn_dim, vocab_size, embed_dim).to(device)

    # We do not train the encoder
    encoder.eval()

    if not args.test:
        # train
        validation_term = 1
        best_bleu = 0.
        num_of_epochs_since_improvement = 0
        early_stop_criterion = 20

        train_loader = get_train_data_loader(args.path, args.token_path,
                                             args.voca_path, args.batch_size,
                                             pad_idx)
        valid_loader = get_test_data_loader(args.path,
                                            args.token_path,
                                            args.voca_path,
                                            args.batch_size,
                                            pad_idx,
                                            dataset_type='valid')

        criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
        optimizer = torch.optim.Adam(decoder.parameters(), lr=0.0001)

        print('Start training ...')
        for epoch in range(args.epochs):

            # early stopping
            if num_of_epochs_since_improvement > early_stop_criterion:
                print("There's no improvement on BLEU score while %d epochs" %
                      (num_of_epochs_since_improvement))
                print("Stop Training")
                break

            start_epoch = time.time()
            i = 0

            ############################################################################################################################################
            # training
            decoder.train()
            for src_batch, trg_batch in train_loader:
                batch_start = time.time()

                src_batch = src_batch.to(device)
                trg_batch = torch.tensor(trg_batch).to(device)

                trg_input = trg_batch[:, :-1]
                trg_output = trg_batch[:, 1:].contiguous().view(-1)

                a = encoder(src_batch)
                preds, alphas = decoder(
                    a, trg_input)  # [batch, C, vocab_size], [batch, C, L]

                optimizer.zero_grad()

                loss = criterion(preds.view(-1, preds.size(-1)),
                                 trg_output)  # NLL loss
                regularize_term = regularize_constant * (
                    (1. - torch.sum(alphas, dim=1))**2).mean()

                total_loss = loss + regularize_term
                total_loss.backward()

                optimizer.step()

                i = i + 1

                # flush the GPU cache
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                batch_time = time.time() - batch_start
                print(
                    '[%d/%d][%d/%d] train loss : %.4f (%.4f / %.4f) | time : %.2fs'
                    % (epoch + 1, args.epochs, i, train_loader.size //
                       args.batch_size + 1, total_loss.item(), loss.item(),
                       regularize_term.item(), batch_time))

            epoch_time = time.time() - start_epoch
            print('Time taken for %d epoch : %.2fs' % (epoch + 1, epoch_time))

            ############################################################################################################################################
            # validation
            if i % validation_term == 0:
                decoder.eval()
                j = 0
                pred, ref = [], []

                for src_batch, trg_batch in valid_loader:
                    start = time.time()
                    batch_size = src_batch.size(0)

                    src_batch = src_batch.to(device)  # [batch, 3, 244, 244]
                    trg_batch = torch.tensor(trg_batch).to(
                        device)  # [batch * 5, C]
                    trg_batch = torch.split(trg_batch, 5)

                    batches = []
                    for k in range(batch_size):
                        batches.append(trg_batch[k].unsqueeze(0))

                    trg_batch = torch.cat(batches, dim=0)  # [batch, 5, C]

                    max_length = trg_batch.size(-1)

                    pred_batch = torch.zeros(batch_size, 1, dtype=int).to(
                        device)  # [batch, 1] = [[0],[0],...,[0]]

                    # eos_mask[i] = 1 means i-th sentence has eos
                    eos_mask = torch.zeros(batch_size, dtype=int)

                    a = encoder(src_batch)

                    for _ in range(max_length):

                        output, _ = decoder(
                            a, pred_batch)  # [batch, _+1, vocab_size]

                        # greedy search
                        output = torch.argmax(F.softmax(output, dim=-1),
                                              dim=-1)  # [batch_size, _+1]
                        predictions = output[:, -1].unsqueeze(1)
                        pred_batch = torch.cat([pred_batch, predictions],
                                               dim=-1)

                        for l in range(batch_size):
                            if predictions[l] == eos_idx:
                                eos_mask[l] = 1

                        # every sentence has eos
                        if eos_mask.sum() == batch_size:
                            break

                    # flush the GPU cache
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                    pred += seq2sen(pred_batch.cpu().numpy().tolist(),
                                    vocabulary)
                    for m in range(batch_size):
                        ref += [
                            seq2sen(trg_batch[m].cpu().numpy().tolist(),
                                    vocabulary)
                        ]

                    t = time.time() - start
                    j += 1
                    print("[%d/%d] prediction done | time : %.2fs" %
                          (j, valid_loader.size // args.batch_size + 1, t))

                bleu_1 = corpus_bleu(ref, pred, weights=(1. / 1., )) * 100
                bleu_2 = corpus_bleu(ref, pred, weights=(
                    1. / 2.,
                    1. / 2.,
                )) * 100
                bleu_3 = corpus_bleu(
                    ref, pred, weights=(
                        1. / 3.,
                        1. / 3.,
                        1. / 3.,
                    )) * 100
                bleu_4 = corpus_bleu(
                    ref, pred, weights=(
                        1. / 4.,
                        1. / 4.,
                        1. / 4.,
                        1. / 4.,
                    )) * 100

                print(f'BLEU-1: {bleu_1:.2f}')
                print(f'BLEU-2: {bleu_2:.2f}')
                print(f'BLEU-3: {bleu_3:.2f}')
                print(f'BLEU-4: {bleu_4:.2f}')

                if bleu_1 > best_bleu:
                    num_of_epochs_since_improvement = 0

                    best_bleu = bleu_1
                    print('Best BLEU-1 has been updated : %.2f' % (best_bleu))
                    save_checkpoint(decoder, 'checkpoints/best')
                else:
                    num_of_epochs_since_improvement += validation_term
                    print(
                        "There's no improvement on BLEU score while %d epochs"
                        % (num_of_epochs_since_improvement))

            ################################################################################################################################################################
        print('End of the training')
    else:
        if os.path.exists(args.checkpoint):
            decoder_checkpoint = torch.load(args.checkpoint)
            decoder.load_state_dict(decoder_checkpoint['state_dict'])
            print("trained decoder " + args.checkpoint + " is loaded")

        decoder.eval()

        # test
        test_loader = get_test_data_loader(args.path, args.token_path,
                                           args.voca_path, args.batch_size,
                                           pad_idx)

        j = 0
        pred, ref = [], []
        for src_batch, trg_batch in test_loader:
            # predict pred_batch from src_batch with your model.
            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            start = time.time()
            batch_size = src_batch.size(0)

            src_batch = src_batch.to(device)  # [batch, 3, 244, 244]
            trg_batch = torch.tensor(trg_batch).to(device)  # [batch * 5, C]
            trg_batch = torch.split(trg_batch, 5)

            batches = []
            for k in range(batch_size):
                batches.append(trg_batch[k].unsqueeze(0))

            trg_batch = torch.cat(batches, dim=0)  # [batch, 5, C]

            max_length = trg_batch.size(-1)

            pred_batch = torch.zeros(batch_size, 1, dtype=int).to(
                device)  # [batch, 1] = [[0],[0],...,[0]]

            # eos_mask[i] = 1 means i-th sentence has eos
            eos_mask = torch.zeros(batch_size, dtype=int)

            a = encoder(src_batch)

            for _ in range(max_length):

                output, _ = decoder(a, pred_batch)  # [batch, _+1, vocab_size]

                # greedy search
                output = torch.argmax(F.softmax(output, dim=-1),
                                      dim=-1)  # [batch_size, _+1]
                predictions = output[:, -1].unsqueeze(1)
                pred_batch = torch.cat([pred_batch, predictions], dim=-1)

                for l in range(batch_size):
                    if predictions[l] == eos_idx:
                        eos_mask[l] = 1

                # every sentence has eos
                if eos_mask.sum() == batch_size:
                    break

            # flush the GPU cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            pred += seq2sen(pred_batch.cpu().numpy().tolist(), vocabulary)
            for m in range(batch_size):
                ref += [
                    seq2sen(trg_batch[m].cpu().numpy().tolist(), vocabulary)
                ]

            t = time.time() - start
            j += 1
            print("[%d/%d] prediction done | time : %.2fs" %
                  (j, test_loader.size // args.batch_size + 1, t))

        bleu_1 = corpus_bleu(ref, pred, weights=(1. / 1., )) * 100
        bleu_2 = corpus_bleu(ref, pred, weights=(
            1. / 2.,
            1. / 2.,
        )) * 100
        bleu_3 = corpus_bleu(ref, pred, weights=(
            1. / 3.,
            1. / 3.,
            1. / 3.,
        )) * 100
        bleu_4 = corpus_bleu(
            ref, pred, weights=(
                1. / 4.,
                1. / 4.,
                1. / 4.,
                1. / 4.,
            )) * 100

        print(f'BLEU-1: {bleu_1:.2f}')
        print(f'BLEU-2: {bleu_2:.2f}')
        print(f'BLEU-3: {bleu_3:.2f}')
        print(f'BLEU-4: {bleu_4:.2f}')

        with open('results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        with open('results/ref.txt', 'w') as f:
            for lines in ref:
                for line in lines:
                    f.write('{}\n'.format(line))
                f.write('_' * 50 + '\n')
Exemplo n.º 7
0
def main(args):
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    # TODO: use these information.
    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    # TODO: use these values to construct embedding layers
    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    if not args.test:
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        # TODO: train
        for epoch in range(args.epochs):
            for src_batch, tgt_batch in train_loader:
                pass

            # TODO: validation
            for src_batch, tgt_batch in valid_loader:
                pass
    else:
        # test
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        pred = []
        for src_batch, tgt_batch in test_loader:
            # TODO: predict pred_batch from src_batch with your model.
            pred_batch = tgt_batch

            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred += seq2sen(pred_batch, tgt_vocab)

        with open('results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')