def process_babi(path):
    dset_dict, vocab_dict = {}, {}
    for task_id in range(1, 21):
        dset_dict[task_id] = BabiDataset(task_id)
        vocab_dict[task_id] = len(dset_dict[task_id].QA.VOCAB)

    dset_path = os.path.join(path, 'dset.pkl')
    with open(dset_path, 'wb') as f:
        pkl.dump(dset_dict, f)
    f.close()
    vocab_path = os.path.join(path, 'vocab.pkl')
    with open(vocab_path, 'wb') as f:
        pkl.dump(vocab_dict, f)
    f.close()
예제 #2
0
    def __init__(self, task_id=1, seed=123, mode='train'):

        self.seed = np.random.seed(seed)
        self.i = 0
        self.num_correct = 0
        self.task_id = task_id
        self.last_action = None

        self.data = BabiDataset(task_id, mode)
        self.current_qa = self.data[self.i]
        self.vocab_size = len(self.data.QA.VOCAB)

        self.action_space = list(self.data.QA.VOCAB.values())  #  0 = no answer
        self.state_space = list(
            self.data.QA.VOCAB.values())  #  0 = start of QA

        self.state = self.current_qa[0][0][0]
        self.pos_idx = {'q': 0, 'sen': 0, 'word': 0}
예제 #3
0
        if len(var.size()) == 3:
            for n, sentences in enumerate(var):
                s= ' '.join([self.qa.IVOCAB[elem.data[0]] for elem in sentence])
                print (str(n)+'th batch, '+str(i)+'th sentence, '+str(s))

        elif len(var.size()) == 2:
            for n, sentence in enumerate(var):
                s= ' '.join([self.qa.IVOCAB[elem.data[0]] for elem in sentence])
                print (str(n)+'th batch, '+str(s))

        elif len(var.size()) == 1:
            for n, token in enumerate(var):
                s= self.qa.IVOCAB[token.data[0]]
                print (str(n)+'th of batch, '+str(s))


if __name__ == "__main__":
    from babi_loader import BabiDataset, pad_collate
    from torch.utils.data import DataLoader
    
    dset = BabiDataset(dataset_dir="../bAbI/tasks_1-20_v1-2/en-10k/", task_id=2)
    vocab_size = len(dset.QA.VOCAB)
    dset.set_mode('train')
    loader = DataLoader(dset, batch_size=2, shuffle=False, collate_fn=pad_collate)

    s, q, a = next(iter(loader))
    s, q, a = s.long().cuda(), q.long().cuda(), a.long().cuda()

    dmn = DMNPlus(30, vocab_size).cuda()
    ans = dmn(s, q)
예제 #4
0
parser.add_argument("--batch_size", default=100)
parser.add_argument("--max_hops", default=3)
parser.add_argument("--num_head", default=4)
parser.add_argument("--embed_size", default=30)
parser.add_argument("--hidden_size", default=60)
parser.add_argument("--dataset_dir", default="../bAbI/tasks_1-20_v1-2/en-10k/")
parser.add_argument("--task", type=int, default=1)
parser.add_argument("--random_state", default=2033)
parser.add_argument("--epochs", default=100)
parser.add_argument("--weight_decay", default=0.001)
args = parser.parse_args()

device = torch.device(
    "cuda:0" if args.use_cuda and torch.cuda.is_available() else "cpu")

dset = BabiDataset(dataset_dir=args.dataset_dir, task_id=args.task)
vocab_size = len(dset.QA.VOCAB)

torch.manual_seed(args.random_state)
model = WMN(vocab_size=vocab_size,
            embed_size=args.embed_size,
            hidden_size=args.hidden_size,
            max_hops=args.max_hops,
            num_head=args.num_head,
            seqend_idx=1,
            pad_idx=0).cuda()
criterion = nn.CrossEntropyLoss(reduction='sum')
optims = [
    torch.optim.AdamW(
        [p for name, p in model.named_parameters() if 'embedding' not in name],
        weight_decay=args.weight_decay),
예제 #5
0
                print(f'{n}th of batch, {i}th sentence, {s}')
    elif len(var.size()) == 2:
        # var -> n x #token
        for n, sentence in enumerate(var):
            s = ' '.join([qa.IVOCAB[elem.data[0]] for elem in sentence ] )
            print(f'{n}th of batch, {s}')
    elif len(var.size()) == 1:
        # var -> n (one token per batch)
        for n, token in enumerate(var):
            s = qa.IVOCAB[token.data[0]]
            print(f'{n}th of batch, {s}')

if __name__ == '__main__':
    for run in range(10):
        for task_id in range(1, 2):
            dset = BabiDataset(task_id)
            vocab_size = len(dset.QA.VOCAB)
            hidden_size = 80

            model = DMNPlus(hidden_size, vocab_size, num_hop=3, qa=dset.QA)
            model
            early_stopping_cnt = 0
            early_stopping_flag = False
            best_acc = 0
            optim = torch.optim.Adam(model.parameters())


            for epoch in range(256):
                dset.set_mode('train')
                train_loader = DataLoader(
                    dset, batch_size=1, shuffle=True, collate_fn=pad_collate
예제 #6
0
from babi_loader import BabiDataset, pad_collate
import os
import torch
from torch.utils.data import DataLoader
from models import DMNTrans

if __name__ == '__main__':
    dset_dict, vocab_dict = {}, {}
    for task_id in range(1, 21):
        dset_dict[task_id] = BabiDataset(task_id)
        vocab_dict[task_id] = len(dset_dict[task_id].QA.VOCAB)
    for run in range(10):
        for task_id in range(1, 21):
            dset = dset_dict[task_id]
            vocab_size = vocab_dict[task_id]
            # dset = BabiDataset(task_id)
            # vocab_size = len(dset.QA.VOCAB)
            hidden_size = 80

            model = DMNTrans(hidden_size, vocab_size, num_hop=3, qa=dset.QA)
            model.cuda()
            early_stopping_cnt = 0
            early_stopping_flag = False
            best_acc = 0
            optim = torch.optim.Adam(model.parameters())

            for epoch in range(256):
                dset.set_mode('train')
                train_loader = DataLoader(dset,
                                          batch_size=100,
                                          shuffle=True,