class Flickr8k(VisionDataset):
    def __init__(
        self,
        args,
        root: str,
        ann_file: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super(Flickr8k, self).__init__(root,
                                       transform=transform,
                                       target_transform=target_transform)
        self.annotations = torch.load(ann_file, map_location='cpu')
        self.ids = list(sorted(self.annotations.keys()))
        self.vocab = Vocab(init_token='<sos>',
                           eos_token='<eos>',
                           pad_token='<pad>',
                           unk_token='<unk>')
        sentences = []
        for key in self.annotations.keys():
            sentences += self.annotations[key]
        self.vocab.build_vocab(sentences)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img_id = self.ids[index]

        # Image
        img = Image.open(img_id).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        # Captions
        target = self.annotations[img_id]
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.ids)

    def get_vocab(self):
        return self.vocab
 def __init__(
     self,
     args,
     root: str,
     ann_file: str,
     transform: Optional[Callable] = None,
     target_transform: Optional[Callable] = None,
 ) -> None:
     super(Flickr8k, self).__init__(root,
                                    transform=transform,
                                    target_transform=target_transform)
     self.annotations = torch.load(ann_file, map_location='cpu')
     self.ids = list(sorted(self.annotations.keys()))
     self.vocab = Vocab(init_token='<sos>',
                        eos_token='<eos>',
                        pad_token='<pad>',
                        unk_token='<unk>')
     sentences = []
     for key in self.annotations.keys():
         sentences += self.annotations[key]
     self.vocab.build_vocab(sentences)
Exemple #3
0
def main(args):

    # 0. initial setting

    # set environmet
    cudnn.benchmark = True

    if not os.path.isdir('./ckpt'):
        os.mkdir('./ckpt')
    if not os.path.isdir('./results'):
        os.mkdir('./results')
    if not os.path.isdir(os.path.join('./ckpt', args.name)):
        os.mkdir(os.path.join('./ckpt', args.name))
    if not os.path.isdir(os.path.join('./results', args.name)):
        os.mkdir(os.path.join('./results', args.name))
    if not os.path.isdir(os.path.join('./results', args.name, "log")):
        os.mkdir(os.path.join('./results', args.name, "log"))

    # set logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    handler = logging.FileHandler("results/{}/log/{}.log".format(
        args.name, time.strftime('%c', time.localtime(time.time()))))
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.addHandler(logging.StreamHandler())
    args.logger = logger

    # set cuda
    if torch.cuda.is_available():
        args.logger.info("running on cuda")
        args.device = torch.device("cuda")
        args.use_cuda = True
    else:
        args.logger.info("running on cpu")
        args.device = torch.device("cpu")
        args.use_cuda = False

    args.logger.info("[{}] starts".format(args.name))

    # 1. load data

    args.logger.info("loading data...")
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    # 2. setup

    args.logger.info("setting up...")

    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    # transformer config
    d_e = 512  # embedding size
    d_q = 64  # query size (= key, value size)
    d_h = 2048  # hidden layer size in feed forward network
    num_heads = 8
    num_layers = 6  # number of encoder/decoder layers in encoder/decoder

    args.sos_idx = sos_idx
    args.eos_idx = eos_idx
    args.pad_idx = pad_idx
    args.max_length = max_length
    args.src_vocab_size = src_vocab_size
    args.tgt_vocab_size = tgt_vocab_size
    args.d_e = d_e
    args.d_q = d_q
    args.d_h = d_h
    args.num_heads = num_heads
    args.num_layers = num_layers

    model = Transformer(args)
    model.to(args.device)
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

    if args.load:
        model.load_state_dict(load(args, args.ckpt))

    # 3. train / test

    if not args.test:
        # train
        args.logger.info("starting training")
        acc_val_meter = AverageMeter(name="Acc-Val (%)",
                                     save_all=True,
                                     save_dir=os.path.join(
                                         'results', args.name))
        train_loss_meter = AverageMeter(name="Loss",
                                        save_all=True,
                                        save_dir=os.path.join(
                                            'results', args.name))
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        for epoch in range(1, 1 + args.epochs):
            spent_time = time.time()
            model.train()
            train_loss_tmp_meter = AverageMeter()
            for src_batch, tgt_batch in tqdm(train_loader):
                # src_batch: (batch x source_length), tgt_batch: (batch x target_length)
                optimizer.zero_grad()
                src_batch, tgt_batch = torch.LongTensor(src_batch).to(
                    args.device), torch.LongTensor(tgt_batch).to(args.device)
                batch = src_batch.shape[0]
                # split target batch into input and output
                tgt_batch_i = tgt_batch[:, :-1]
                tgt_batch_o = tgt_batch[:, 1:]

                pred = model(src_batch.to(args.device),
                             tgt_batch_i.to(args.device))
                loss = loss_fn(pred.contiguous().view(-1, tgt_vocab_size),
                               tgt_batch_o.contiguous().view(-1))
                loss.backward()
                optimizer.step()

                train_loss_tmp_meter.update(loss / batch, weight=batch)

            train_loss_meter.update(train_loss_tmp_meter.avg)
            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] train loss: {:.3f} took {:.1f} seconds".format(
                    epoch, train_loss_tmp_meter.avg, spent_time))

            # validation
            model.eval()
            acc_val_tmp_meter = AverageMeter()
            spent_time = time.time()

            for src_batch, tgt_batch in tqdm(valid_loader):
                src_batch, tgt_batch = torch.LongTensor(
                    src_batch), torch.LongTensor(tgt_batch)
                tgt_batch_i = tgt_batch[:, :-1]
                tgt_batch_o = tgt_batch[:, 1:]

                with torch.no_grad():
                    pred = model(src_batch.to(args.device),
                                 tgt_batch_i.to(args.device))

                corrects, total = val_check(
                    pred.max(dim=-1)[1].cpu(), tgt_batch_o)
                acc_val_tmp_meter.update(100 * corrects / total, total)

            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] validation accuracy: {:.1f} %, took {} seconds".format(
                    epoch, acc_val_tmp_meter.avg, spent_time))
            acc_val_meter.update(acc_val_tmp_meter.avg)

            if epoch % args.save_period == 0:
                save(args, "epoch_{}".format(epoch), model.state_dict())
                acc_val_meter.save()
                train_loss_meter.save()
    else:
        # test
        args.logger.info("starting test")
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)
        pred_list = []
        model.eval()

        for src_batch, tgt_batch in test_loader:
            #src_batch: (batch x source_length)
            src_batch = torch.Tensor(src_batch).long().to(args.device)
            batch = src_batch.shape[0]
            pred_batch = torch.zeros(batch, 1).long().to(args.device)
            pred_mask = torch.zeros(batch, 1).bool().to(
                args.device)  # mask whether each sentece ended up

            with torch.no_grad():
                for _ in range(args.max_length):
                    pred = model(
                        src_batch,
                        pred_batch)  # (batch x length x tgt_vocab_size)
                    pred[:, :, pad_idx] = -1  # ignore <pad>
                    pred = pred.max(dim=-1)[1][:, -1].unsqueeze(
                        -1)  # next word prediction: (batch x 1)
                    pred = pred.masked_fill(
                        pred_mask,
                        2).long()  # fill out <pad> for ended sentences
                    pred_mask = torch.gt(pred.eq(1) + pred.eq(2), 0)
                    pred_batch = torch.cat([pred_batch, pred], dim=1)
                    if torch.prod(pred_mask) == 1:
                        break

            pred_batch = torch.cat([
                pred_batch,
                torch.ones(batch, 1).long().to(args.device) + pred_mask.long()
            ],
                                   dim=1)  # close all sentences
            pred_list += seq2sen(pred_batch.cpu().numpy().tolist(), tgt_vocab)

        with open('results/pred.txt', 'w', encoding='utf-8') as f:
            for line in pred_list:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
Exemple #4
0
def main(args):
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    # Set hyper parameter
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = make_model(src_vocab_size, tgt_vocab_size).to(device)
    optimizer = get_std_opt(model)
    criterion = LabelSmoothing(size=tgt_vocab_size,
                               padding_idx=pad_idx,
                               smoothing=0.1)
    train_criterion = SimpleLossCompute(model.generator, criterion, optimizer)
    valid_criterion = SimpleLossCompute(model.generator, criterion, None)
    print('Using device:', device)

    if not args.test:
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        best_loss = 987654321
        for epoch in range(args.epochs):
            train_total_loss, valid_total_loss = 0.0, 0.0
            start = time.time()
            total_tokens = 0
            tokens = 0

            model.train()
            # Train
            for src_batch, tgt_batch in train_loader:
                src_batch = torch.tensor(src_batch).to(device)
                tgt_batch = torch.tensor(tgt_batch).to(device)
                batch = Batch(src_batch, tgt_batch, pad_idx)

                prediction = model(batch.src, batch.trg, batch.src_mask,
                                   batch.trg_mask)
                loss = train_criterion(prediction, batch.trg_y, batch.ntokens)

                train_total_loss += loss
                total_tokens += batch.ntokens
                tokens += batch.ntokens

            # Valid
            model.eval()
            for src_batch, tgt_batch in valid_loader:
                src_batch = torch.tensor(src_batch).to(device)
                tgt_batch = torch.tensor(tgt_batch).to(device)
                batch = Batch(src_batch, tgt_batch, pad_idx)

                prediction = model(batch.src, batch.trg, batch.src_mask,
                                   batch.trg_mask)
                loss = valid_criterion(prediction, batch.trg_y, batch.ntokens)
                valid_total_loss += loss
                total_tokens += batch.ntokens
                tokens += batch.ntokens

            if valid_total_loss.item() < best_loss:
                best_loss = valid_total_loss
                best_model_state = model.state_dict()
                best_optimizer_state = optimizer.optimizer.state_dict()

            elpsed = time.time() - start
            print(
                time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "|| [" +
                str(epoch) + "/" + str(args.epochs) + "], train_loss = " +
                str(train_total_loss.item()) + ", valid_loss = " +
                str(valid_total_loss.item()) + ", Tokens per Sec = " +
                str(tokens.item() / elpsed))
            tokens = 0
            start = time.time()

            if epoch % 100 == 0:
                # Save model
                torch.save(
                    {
                        'epoch': args.epochs,
                        'model_state_dict': best_model_state,
                        'optimizer_state': best_optimizer_state,
                        'loss': best_loss
                    }, args.model_dir + "/intermediate.pt")
                print("Model saved")

        # Save model
        torch.save(
            {
                'epoch': args.epochs,
                'model_state_dict': best_model_state,
                'optimizer_state': best_optimizer_state,
                'loss': best_loss
            }, args.model_dir + "/best.pt")
        print("Model saved")
    else:
        # Load the model
        checkpoint = torch.load(args.model_dir + "/" + args.model_name,
                                map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.optimizer.load_state_dict(checkpoint['optimizer_state'])
        model.eval()
        print("Model loaded")

        # Test
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        pred = []

        for src_batch, tgt_batch in test_loader:
            src_batch = torch.tensor(src_batch).to(device)
            tgt_batch = torch.tensor(tgt_batch).to(device)
            batch = Batch(src_batch, tgt_batch, pad_idx)

            # Get pred_batch
            memory = model.encode(batch.src, batch.src_mask)
            pred_batch = torch.ones(src_batch.size(0), 1)\
                            .fill_(sos_idx).type_as(batch.src.data).to(device)
            for i in range(max_length - 1):
                out = model.decode(
                    memory, batch.src_mask, Variable(pred_batch),
                    Variable(
                        Batch.make_std_mask(pred_batch,
                                            pad_idx).type_as(batch.src.data)))
                prob = model.generator(out[:, -1])
                prob.index_fill_(1,
                                 torch.tensor([sos_idx, pad_idx]).to(device),
                                 -float('inf'))
                _, next_word = torch.max(prob, dim=1)

                pred_batch = torch.cat(
                    [pred_batch, next_word.unsqueeze(-1)], dim=1)
            pred_batch = torch.cat([pred_batch, torch.ones(src_batch.size(0), 1)\
                                                    .fill_(eos_idx).type_as(batch.src.data).to(device)], dim=1)

            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred += seq2sen(pred_batch.tolist(), tgt_vocab)

        with open('results/pred.txt', 'w', encoding='utf-8') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
Exemple #5
0
def main(args):
    src, tgt = load_data(args.path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    N = 6
    dim = 512

    # MODEL Construction
    encoder = Encoder(N, dim, pad_idx, src_vocab_size, device).to(device)
    decoder = Decoder(N, dim, pad_idx, tgt_vocab_size, device).to(device)

    if args.model_load:
        ckpt = torch.load("drive/My Drive/checkpoint/best.ckpt")
        encoder.load_state_dict(ckpt["encoder"])
        decoder.load_state_dict(ckpt["decoder"])

    params = list(encoder.parameters()) + list(decoder.parameters())

    if not args.test:
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        warmup = 4000
        steps = 1
        lr = 1. * (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5))
        optimizer = torch.optim.Adam(params,
                                     lr=lr,
                                     betas=(0.9, 0.98),
                                     eps=1e-09)

        train_losses = []
        val_losses = []
        latest = 1e08  # to store latest checkpoint

        start_epoch = 0

        if (args.model_load):
            start_epoch = ckpt["epoch"]
            optimizer.load_state_dict(ckpt["optim"])
            steps = start_epoch * 30

        for epoch in range(start_epoch, args.epochs):

            for src_batch, tgt_batch in train_loader:
                encoder.train()
                decoder.train()
                optimizer.zero_grad()
                tgt_batch = torch.LongTensor(tgt_batch)

                src_batch = Variable(torch.LongTensor(src_batch)).to(device)
                gt = Variable(tgt_batch[:, 1:]).to(device)
                tgt_batch = Variable(tgt_batch[:, :-1]).to(device)

                enc_output, seq_mask = encoder(src_batch)
                dec_output = decoder(tgt_batch, enc_output, seq_mask)

                gt = gt.view(-1)
                dec_output = dec_output.view(gt.size()[0], -1)

                loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx)
                loss.backward()
                train_losses.append(loss.item())
                optimizer.step()

                steps += 1
                lr = (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5))
                update_lr(optimizer, lr)

                if (steps % 10 == 0):
                    print("loss : %f" % loss.item())

            for src_batch, tgt_batch in valid_loader:
                encoder.eval()
                decoder.eval()

                src_batch = Variable(torch.LongTensor(src_batch)).to(device)
                tgt_batch = torch.LongTensor(tgt_batch)
                gt = Variable(tgt_batch[:, 1:]).to(device)
                tgt_batch = Variable(tgt_batch[:, :-1]).to(device)

                enc_output, seq_mask = encoder(src_batch)
                dec_output = decoder(tgt_batch, enc_output, seq_mask)

                gt = gt.view(-1)
                dec_output = dec_output.view(gt.size()[0], -1)

                loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx)

                val_losses.append(loss.item())
            print("[EPOCH %d] Loss %f" % (epoch, loss.item()))

            if (val_losses[-1] <= latest):
                checkpoint = {'encoder':encoder.state_dict(), 'decoder':decoder.state_dict(), \
                    'optim':optimizer.state_dict(), 'epoch':epoch}
                torch.save(checkpoint, "drive/My Drive/checkpoint/best.ckpt")
                latest = val_losses[-1]

            if (epoch % 20 == 0):
                plt.figure()
                plt.plot(val_losses)
                plt.xlabel("epoch")
                plt.ylabel("model loss")
                plt.show()

    else:
        # test
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        # LOAD CHECKPOINT

        pred = []
        for src_batch, tgt_batch in test_loader:
            encoder.eval()
            decoder.eval()

            b_s = min(args.batch_size, len(src_batch))
            tgt_batch = torch.zeros(b_s, 1).to(device).long()
            src_batch = Variable(torch.LongTensor(src_batch)).to(device)

            enc_output, seq_mask = encoder(src_batch)
            pred_batch = decoder(tgt_batch, enc_output, seq_mask)
            _, pred_batch = torch.max(pred_batch, 2)

            while (not is_finished(pred_batch, max_length, eos_idx)):
                # do something
                next_input = torch.cat((tgt_batch, pred_batch.long()), 1)
                pred_batch = decoder(next_input, enc_output, seq_mask)
                _, pred_batch = torch.max(pred_batch, 2)
            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred_batch = pred_batch.tolist()
            for line in pred_batch:
                line[-1] = 1
            pred += seq2sen(pred_batch, tgt_vocab)
            # print(pred)

        with open('results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
Exemple #6
0
def main(args):
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    vsize_src = len(src_vocab)
    vsize_tar = len(tgt_vocab)
    net = Transformer(vsize_src, vsize_tar)

    if not args.test:

        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        net.to(device)
        optimizer = optim.Adam(net.parameters(), lr=args.lr)

        best_valid_loss = 10.0
        for epoch in range(args.epochs):
            print("Epoch {0}".format(epoch))
            net.train()
            train_loss = run_epoch(net, train_loader, optimizer)
            print("train loss: {0}".format(train_loss))
            net.eval()
            valid_loss = run_epoch(net, valid_loader, None)
            print("valid loss: {0}".format(valid_loss))
            torch.save(net, 'data/ckpt/last_model')
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(net, 'data/ckpt/best_model')
    else:
        # test
        net = torch.load('data/ckpt/best_model')
        net.to(device)
        net.eval()

        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        pred = []
        iter_cnt = 0
        for src_batch, tgt_batch in test_loader:
            source, src_mask = make_tensor(src_batch)
            source = source.to(device)
            src_mask = src_mask.to(device)
            res = net.decode(source, src_mask)
            pred_batch = res.tolist()
            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred += seq2sen(pred_batch, tgt_vocab)
            iter_cnt += 1
            #print(pred_batch)

        with open('data/results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh data/results/pred.txt data/multi30k/test.de.atok'
        )
Exemple #7
0
def main(args):
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    # TODO: use these information.
    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    # TODO: use these values to construct embedding layers
    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    if not args.test:
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        # TODO: train
        for epoch in range(args.epochs):
            for src_batch, tgt_batch in train_loader:
                pass

            # TODO: validation
            for src_batch, tgt_batch in valid_loader:
                pass
    else:
        # test
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        pred = []
        for src_batch, tgt_batch in test_loader:
            # TODO: predict pred_batch from src_batch with your model.
            pred_batch = tgt_batch

            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred += seq2sen(pred_batch, tgt_vocab)

        with open('results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')