def collate_fn(batch): text, x, bbox_num, y = [], [], [], [] n = len(batch) for (_text, _ids), _y in batch: t = [] for i in _text: t += i text.append(t) t2 = [] for _id in _ids: feature = load_rcnn_features(_id) t2.append(feature["x"]) x.append(np.vstack(t2)) bbox_num.append(len(x[-1])) # bbox num y.append(_y) padded_text, lens = padding_text_1d(text, config["max_tokens"]) max_lens = max(lens) max_bbox = max(bbox_num) padded_x = torch.zeros((n, max_bbox, config["img_input_size"]), dtype=torch.float32) img_key_padding_mask = torch.ones((n, max_bbox), dtype=torch.bool) text_key_padding_mask = torch.ones((n, max_lens), dtype=torch.bool) for i, (l, s) in enumerate(zip(lens, bbox_num)): padded_x[i, :s] = torch.from_numpy(x[i]) img_key_padding_mask[i, :s] = False text_key_padding_mask[i, :l] = False bbox_num = torch.LongTensor(bbox_num) y = torch.tensor(y) return (padded_text, lens, text_key_padding_mask, padded_x, img_key_padding_mask), y
def collate_fn(batch): text, imgs, y = [], [], [] for (_text, _id), _y in batch: text.append(_text) imgs.append(load_vgg(_id)) y.append(_y) padded_text, lens = padding_text_1d(text) imgs = torch.tensor(imgs) # 转为张量 y = torch.tensor(y) # res = ((padded_text, lens, imgs), y) return res
def collate_fn(batch): global config text, imgs, y = [], [], [] for (_text, _imgs), _y in batch: t = [] for i in _text: t += i text.append(t) for i in _imgs: feature = load_vgg_features(i) if feature is not None: imgs.append(feature) # imgs += _imgs # _imgs: List[4096, 4096, 4096] y.append(_y) padded_text, lens = padding_text_1d(text, config["max_tokens"]) imgs = torch.tensor(imgs) # 转为张量 y = torch.tensor(y) # res = ((padded_text, lens, imgs), y) return res
def collate_fn(batch): text, x, bbox_num, y = [], [], [], [] n = len(batch) for (_text, _id), _y in batch: text.append(_text) feature = load_rcnn(_id)["x"] x.append(feature) bbox_num.append(len(feature)) # bbox num y.append(_y) padded_text, lens = padding_text_1d(text) max_lens = max(lens) max_bbox = max(bbox_num) padded_x = torch.zeros((n, max_bbox, config["img_input_size"]), dtype=torch.float32) for i, (l, s) in enumerate(zip(lens, bbox_num)): padded_x[i, :s] = torch.from_numpy(x[i]) bbox_num = torch.LongTensor(bbox_num) y = torch.tensor(y) return (padded_text, lens, padded_x), y