sys.path.append('../') import torch import numpy as np import argparse from model import AlignmentModel from torch.utils.data import DataLoader from cc_dataset import CCDataset device = "cuda:1" model_path = "../checkpoints/2020_12_23_1/epch0_bidx6400.pt" model = AlignmentModel().to(device) model_data = torch.load(model_path)["model_state_dict"] model_data.pop("text_encoder.model.embeddings.position_ids") model.load_state_dict(model_data) model.eval() class Namespace: def __init__(self, opts): self.__dict__.update(opts) def collate_fn(batch): image_feats = [] text_feats = [] for sample in batch: image_feats.append(sample["img_feat"]) text_feats.append(sample["text_feat"]) image_feats = torch.mean(torch.stack(image_feats), dim=1).to(device) text_feats = torch.stack(text_feats).to(device)