Beispiel #1
0
def predict_tree(img, model, device, word_dict, max_child=6):
    tranf = transforms.Compose([WordEmbedding(word_dict), TreeToTensor()])
    end_value = tranf(Tree('end')).value.to(device)
    model.to(device)

    root = Tree('root')
    root = tranf(root)
    root.value = root.value.to(device)
    out_size = root.value.size()

    queue = [root]
    while len(queue) != 0:
        sub_tree = Tree(image_caption_model(img, [root]).flatten().detach())
        max_value = torch.max(sub_tree.value)
        sub_tree.value = torch.where(sub_tree.value >= max_value,
                                     torch.ones(out_size).to(device),
                                     torch.zeros(out_size).to(device))
        queue[0].add_child(sub_tree)

        if len(queue[0].children) >= max_child:
            sub_tree = Tree(end_value.clone().detach())
            queue[0].add_child(sub_tree)
        if torch.equal(end_value, sub_tree.value):
            queue.pop(0)
        else:
            queue.append(sub_tree)

    root.for_each_value(lambda x: x.cpu().numpy())
    vec2word = Vec2Word(word_dict)
    root = vec2word(root)
    return root
Beispiel #2
0
def predict_tree_with_rule(img, model, device, word_dict, env):
    env.reset()
    vec_dict = dict()
    for k, v in word_dict.items():
        vec_dict[np.sum(np.multiply(v, np.arange(v.size)))] = k
    tranf = transforms.Compose([WordEmbedding(word_dict), TreeToTensor()])
    model.to(device)
    root, parent, chioce = env.state()
    out_size = len(word_dict)
    while parent != None:
        # get mask
        mask = np.sum([word_dict[c] for c in chioce], axis=0)
        copy_root = root.copy()
        root_tensor = tranf(copy_root).for_each_value(lambda x: x.to(device))
        pred_node = image_caption_model(img, [root_tensor]).flatten().detach()

        # fliter the new node and get the predict value
        pred_node *= torch.from_numpy(mask).to(device).float()
        max_value = torch.max(pred_node)
        pred_node = torch.where(pred_node >= max_value,
                                torch.ones(out_size).to(device),
                                torch.zeros(out_size).to(device))

        action = vec_dict[np.sum(
            np.multiply(pred_node.cpu().numpy(), np.arange(out_size)))]

        root, parent, chioce = env.step(action)
    return root
Beispiel #3
0
def samples_preprocessing():
    def count_word_dict(dataset):
        word_count = {'root': 0, 'end': 0}

        def count_tree(tree, word_count):
            for child in tree.children:
                count_tree(child, word_count)
            if tree.value in word_count:
                word_count[tree.value] += 1
            else:
                word_count[tree.value] = 1

        for i in range(len(dataset)):
            count_tree(dataset[i]['tree'], word_count)

        word_dict = {}
        i = 0
        for key in word_count.keys():
            a = np.zeros(len(word_count))
            a[i] = 1.0
            word_dict[key] = a
            i += 1
        return word_dict

    dataset = Pix2TreeDataset()
    if not os.path.exists('word_dict.npy'):
        word_dict = count_word_dict(dataset)
        np.save('word_dict.npy', word_dict)
    else:
        word_dict = np.load('word_dict.npy', allow_pickle=True).item()

    # prepare dataset
    sample_dataset = Pix2TreeDataset(
        partition=range(int(len(dataset) * 0.8)),
        tree_transform=transforms.Compose(
            [WordEmbedding(word_dict),
             TreeToTensor()]),
        img_transform=transforms.Compose([Rescale(224),
                                          transforms.ToTensor()]))

    Vec2Word_t = Vec2Word(word_dict)

    sample_data = []

    for i in range(len(sample_dataset)):
        seq = Vec2Word_t(sample_dataset[i]['tree']).seq()
        seq = list(map(int, seq.split(' ')))
        seq = [x + 1 for x in seq]
        if len(seq) > 100:
            continue
        else:
            for i in range(100 - len(seq)):
                seq.append(0)
            sample_data.append(seq)

    return sample_data
Beispiel #4
0
    dataset = Pix2TreeDataset(img_dir=dataset_img_dir,
                              tree_dir=dataset_tree_dir)
    if not os.path.exists('word_dict.npy'):
        word_dict = count_word_dict(dataset)
        np.save('word_dict.npy', word_dict)
    else:
        word_dict = np.load('word_dict.npy', allow_pickle=True).item()

    # prepare dataset
    train_data = Pix2TreeDataset(
        img_dir=dataset_img_dir,
        tree_dir=dataset_tree_dir,
        partition=range(int(len(dataset) * 0.8)),
        tree_transform=transforms.Compose(
            [WordEmbedding(word_dict),
             TreeToTensor()]),
        img_transform=transforms.Compose([Rescale(224),
                                          transforms.ToTensor()]))

    valid_data = Pix2TreeDataset(img_dir=dataset_img_dir,
                                 tree_dir=dataset_tree_dir,
                                 partition=range(int(len(dataset) * 0.8),
                                                 len(dataset)),
                                 img_transform=transforms.Compose(
                                     [Rescale(224),
                                      transforms.ToTensor()]))
    #import matplotlib.pyplot as plt
    #print(dataset[0]['tree'])
    #plt.imshow(dataset[0]['img'])

    # model