Exemple #1
0
 def _init(self):
     """
     Inner method for children's classes for model specific initialization.
     As baseline, checks device support and puts model on it.
     :return:
     """
     self.model, self.device = UtilsFactory.prepare_model(self.model)
def main(args):
    global IMG_SIZE

    IMG_SIZE = (args.img_size, args.img_size)

    model = ResnetEncoder(arch=args.arch, pooling=args.pooling)
    model = model.eval()
    model, device = UtilsFactory.prepare_model(model)

    images_df = pd.read_csv(args.in_csv)
    images_df = images_df.reset_index().drop("index", axis=1)
    images_df = list(images_df.to_dict("index").values())

    open_fn = ImageReader(row_key=args.img_col,
                          dict_key="image",
                          datapath=args.datapath)

    dataloader = UtilsFactory.create_loader(images_df,
                                            open_fn,
                                            batch_size=args.batch_size,
                                            workers=args.n_workers,
                                            dict_transform=dict_transformer)

    features = []
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for batch in dataloader:
            features_ = model(batch["image"].to(device))
            features_ = features_.cpu().detach().numpy()
            features.append(features_)

    features = np.concatenate(features, axis=0)
    np.save(args.out_npy, features)
def main(args):
    model = models.__dict__[args.arch](pretrained=True)
    model = model.eval()
    model, device = UtilsFactory.prepare_model(model)

    labels = json.loads(open(args.labels).read())

    i2k = Images2Keywords(model, args.n_keywords, labels)

    images_df = pd.read_csv(args.in_csv)
    images_df = images_df.reset_index().drop("index", axis=1)
    images_df = list(images_df.to_dict("index").values())

    open_fn = ImageReader(row_key=args.img_col,
                          dict_key="image",
                          datapath=args.datapath)

    dataloader = UtilsFactory.create_loader(images_df,
                                            open_fn,
                                            batch_size=args.batch_size,
                                            workers=args.n_workers,
                                            dict_transform=dict_transformer)

    keywords = []
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for batch in dataloader:
            keywords_batch = i2k(batch["image"].to(device))
            keywords += keywords_batch

    input_csv = pd.read_csv(args.in_csv)
    input_csv[args.keywords_col] = keywords
    input_csv.to_csv(args.out_csv, index=False)