Пример #1
0
def main(args, _=None):
    global IMG_SIZE

    IMG_SIZE = (args.img_size, args.img_size)

    model = ResnetEncoder(arch=args.arch, pooling=args.pooling)
    model = model.eval()
    model, _, _, _, device = utils.process_components(model=model)

    images_df = pd.read_csv(args.in_csv)
    images_df = images_df.reset_index().drop("index", axis=1)
    images_df = list(images_df.to_dict("index").values())

    open_fn = ImageReader(input_key=args.img_col,
                          output_key="image",
                          datapath=args.datapath)

    dataloader = utils.get_loader(images_df,
                                  open_fn,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  dict_transform=dict_transformer)

    features = []
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for batch in dataloader:
            features_ = model(batch["image"].to(device))
            features_ = features_.cpu().detach().numpy()
            features.append(features_)

    features = np.concatenate(features, axis=0)
    np.save(args.out_npy, features)
Пример #2
0
    def get_from_params(
        cls,
        image_size: int = None,
        encoder_params: Dict = None,
        embedding_net_params: Dict = None,
        heads_params: Dict = None,
    ) -> "MultiHeadNet":

        encoder_params_ = deepcopy(encoder_params)
        embedding_net_params_ = deepcopy(embedding_net_params)
        heads_params_ = deepcopy(heads_params)

        encoder_net = ResnetEncoder(**encoder_params_)
        encoder_input_shape = (3, image_size, image_size)
        encoder_output = \
            utils.get_network_output(encoder_net, encoder_input_shape)
        enc_size = encoder_output.nelement()
        embedding_net_params_["hiddens"].insert(0, enc_size)
        embedding_net = SequentialNet(**embedding_net_params_)
        emb_size = embedding_net_params_["hiddens"][-1]

        head_kwargs_ = {}
        for key, value in heads_params_.items():
            head_kwargs_[key] = nn.Linear(emb_size, value, bias=True)
        head_nets = nn.ModuleDict(head_kwargs_)

        net = cls(encoder_net=encoder_net,
                  embedding_net=embedding_net,
                  head_nets=head_nets)

        return net
Пример #3
0
def prepare_tsn_base_model(partial_bn=None, **kwargs):
    """
    :param partial_bn: 2 if partial_bn else 1
    :param kwargs:
    :return:
    """
    base_model = ResnetEncoder(**kwargs)
    if partial_bn is not None:
        count = 0
        for m in base_model.modules():
            if isinstance(m, nn.BatchNorm2d):
                count += 1
                if count >= partial_bn:
                    m.eval()

                    # shutdown update in frozen mode
                    m.weight.requires_grad = False
                    m.bias.requires_grad = False

    return base_model
Пример #4
0
def main(args, _=None):
    global IMG_SIZE

    utils.set_global_seed(args.seed)
    utils.prepare_cudnn(args.deterministic, args.benchmark)

    IMG_SIZE = (args.img_size, args.img_size)

    if args.traced_model is not None:
        device = utils.get_device()
        model = torch.jit.load(str(args.traced_model), map_location=device)
    else:
        model = ResnetEncoder(arch=args.arch, pooling=args.pooling)
        model = model.eval()
        model, _, _, _, device = utils.process_components(model=model)

    df = pd.read_csv(args.in_csv)
    df = df.reset_index().drop("index", axis=1)
    df = list(df.to_dict("index").values())

    open_fn = ImageReader(input_key=args.img_col,
                          output_key="image",
                          datapath=args.datapath)

    dataloader = utils.get_loader(df,
                                  open_fn,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  dict_transform=dict_transformer)

    features = []
    dataloader = tqdm(dataloader) if args.verbose else dataloader
    with torch.no_grad():
        for batch in dataloader:
            features_ = model(batch["image"].to(device))
            features_ = features_.cpu().detach().numpy()
            features.append(features_)

    features = np.concatenate(features, axis=0)
    np.save(args.out_npy, features)