예제 #1
0
 def collate_fn(batch):
     text, x, bbox_num, y = [], [], [], []
     n = len(batch)
     for (_text, _ids), _y in batch:
         t = []
         for i in _text:
             t += i
         text.append(t)
         t2 = []
         for _id in _ids:
             feature = load_rcnn_features(_id)
             t2.append(feature["x"])
         x.append(np.vstack(t2))
         bbox_num.append(len(x[-1]))  # bbox num
         y.append(_y)
     padded_text, lens = padding_text_1d(text, config["max_tokens"])
     max_lens = max(lens)
     max_bbox = max(bbox_num)
     padded_x = torch.zeros((n, max_bbox, config["img_input_size"]),
                            dtype=torch.float32)
     img_key_padding_mask = torch.ones((n, max_bbox), dtype=torch.bool)
     text_key_padding_mask = torch.ones((n, max_lens), dtype=torch.bool)
     for i, (l, s) in enumerate(zip(lens, bbox_num)):
         padded_x[i, :s] = torch.from_numpy(x[i])
         img_key_padding_mask[i, :s] = False
         text_key_padding_mask[i, :l] = False
     bbox_num = torch.LongTensor(bbox_num)
     y = torch.tensor(y)
     return (padded_text, lens, text_key_padding_mask, padded_x,
             img_key_padding_mask), y
예제 #2
0
 def collate_fn(batch):
     text, imgs, y = [], [], []
     for (_text, _id), _y in batch:
         text.append(_text)
         imgs.append(load_vgg(_id))
         y.append(_y)
     padded_text, lens = padding_text_1d(text)
     imgs = torch.tensor(imgs)  # 转为张量
     y = torch.tensor(y)  #
     res = ((padded_text, lens, imgs), y)
     return res
예제 #3
0
파일: bigru_vgg.py 프로젝트: yang-233/mmsa
 def collate_fn(batch):
     global config
     text, imgs, y = [], [], []
     for (_text, _imgs), _y in batch:
         t = []
         for i in _text:
             t += i
         text.append(t)
         for i in _imgs:
             feature = load_vgg_features(i)
             if feature is not None:
                 imgs.append(feature)
         # imgs += _imgs # _imgs: List[4096, 4096, 4096]
         y.append(_y)
     padded_text, lens = padding_text_1d(text, config["max_tokens"])
     imgs = torch.tensor(imgs)  # 转为张量
     y = torch.tensor(y)  #
     res = ((padded_text, lens, imgs), y)
     return res
예제 #4
0
 def collate_fn(batch):
     text, x, bbox_num, y = [], [], [], []
     n = len(batch)
     for (_text, _id), _y in batch:
         text.append(_text)
         feature = load_rcnn(_id)["x"]
         x.append(feature)
         bbox_num.append(len(feature))  # bbox num
         y.append(_y)
     padded_text, lens = padding_text_1d(text)
     max_lens = max(lens)
     max_bbox = max(bbox_num)
     padded_x = torch.zeros((n, max_bbox, config["img_input_size"]),
                            dtype=torch.float32)
     for i, (l, s) in enumerate(zip(lens, bbox_num)):
         padded_x[i, :s] = torch.from_numpy(x[i])
     bbox_num = torch.LongTensor(bbox_num)
     y = torch.tensor(y)
     return (padded_text, lens, padded_x), y