def predict_tree(img, model, device, word_dict, max_child=6): tranf = transforms.Compose([WordEmbedding(word_dict), TreeToTensor()]) end_value = tranf(Tree('end')).value.to(device) model.to(device) root = Tree('root') root = tranf(root) root.value = root.value.to(device) out_size = root.value.size() queue = [root] while len(queue) != 0: sub_tree = Tree(image_caption_model(img, [root]).flatten().detach()) max_value = torch.max(sub_tree.value) sub_tree.value = torch.where(sub_tree.value >= max_value, torch.ones(out_size).to(device), torch.zeros(out_size).to(device)) queue[0].add_child(sub_tree) if len(queue[0].children) >= max_child: sub_tree = Tree(end_value.clone().detach()) queue[0].add_child(sub_tree) if torch.equal(end_value, sub_tree.value): queue.pop(0) else: queue.append(sub_tree) root.for_each_value(lambda x: x.cpu().numpy()) vec2word = Vec2Word(word_dict) root = vec2word(root) return root
def predict_tree_with_rule(img, model, device, word_dict, env): env.reset() vec_dict = dict() for k, v in word_dict.items(): vec_dict[np.sum(np.multiply(v, np.arange(v.size)))] = k tranf = transforms.Compose([WordEmbedding(word_dict), TreeToTensor()]) model.to(device) root, parent, chioce = env.state() out_size = len(word_dict) while parent != None: # get mask mask = np.sum([word_dict[c] for c in chioce], axis=0) copy_root = root.copy() root_tensor = tranf(copy_root).for_each_value(lambda x: x.to(device)) pred_node = image_caption_model(img, [root_tensor]).flatten().detach() # fliter the new node and get the predict value pred_node *= torch.from_numpy(mask).to(device).float() max_value = torch.max(pred_node) pred_node = torch.where(pred_node >= max_value, torch.ones(out_size).to(device), torch.zeros(out_size).to(device)) action = vec_dict[np.sum( np.multiply(pred_node.cpu().numpy(), np.arange(out_size)))] root, parent, chioce = env.step(action) return root
def samples_preprocessing(): def count_word_dict(dataset): word_count = {'root': 0, 'end': 0} def count_tree(tree, word_count): for child in tree.children: count_tree(child, word_count) if tree.value in word_count: word_count[tree.value] += 1 else: word_count[tree.value] = 1 for i in range(len(dataset)): count_tree(dataset[i]['tree'], word_count) word_dict = {} i = 0 for key in word_count.keys(): a = np.zeros(len(word_count)) a[i] = 1.0 word_dict[key] = a i += 1 return word_dict dataset = Pix2TreeDataset() if not os.path.exists('word_dict.npy'): word_dict = count_word_dict(dataset) np.save('word_dict.npy', word_dict) else: word_dict = np.load('word_dict.npy', allow_pickle=True).item() # prepare dataset sample_dataset = Pix2TreeDataset( partition=range(int(len(dataset) * 0.8)), tree_transform=transforms.Compose( [WordEmbedding(word_dict), TreeToTensor()]), img_transform=transforms.Compose([Rescale(224), transforms.ToTensor()])) Vec2Word_t = Vec2Word(word_dict) sample_data = [] for i in range(len(sample_dataset)): seq = Vec2Word_t(sample_dataset[i]['tree']).seq() seq = list(map(int, seq.split(' '))) seq = [x + 1 for x in seq] if len(seq) > 100: continue else: for i in range(100 - len(seq)): seq.append(0) sample_data.append(seq) return sample_data
dataset = Pix2TreeDataset(img_dir=dataset_img_dir, tree_dir=dataset_tree_dir) if not os.path.exists('word_dict.npy'): word_dict = count_word_dict(dataset) np.save('word_dict.npy', word_dict) else: word_dict = np.load('word_dict.npy', allow_pickle=True).item() # prepare dataset train_data = Pix2TreeDataset( img_dir=dataset_img_dir, tree_dir=dataset_tree_dir, partition=range(int(len(dataset) * 0.8)), tree_transform=transforms.Compose( [WordEmbedding(word_dict), TreeToTensor()]), img_transform=transforms.Compose([Rescale(224), transforms.ToTensor()])) valid_data = Pix2TreeDataset(img_dir=dataset_img_dir, tree_dir=dataset_tree_dir, partition=range(int(len(dataset) * 0.8), len(dataset)), img_transform=transforms.Compose( [Rescale(224), transforms.ToTensor()])) #import matplotlib.pyplot as plt #print(dataset[0]['tree']) #plt.imshow(dataset[0]['img']) # model