def main(args, _=None): global IMG_SIZE IMG_SIZE = (args.img_size, args.img_size) model = ResnetEncoder(arch=args.arch, pooling=args.pooling) model = model.eval() model, _, _, _, device = utils.process_components(model=model) images_df = pd.read_csv(args.in_csv) images_df = images_df.reset_index().drop("index", axis=1) images_df = list(images_df.to_dict("index").values()) open_fn = ImageReader(input_key=args.img_col, output_key="image", datapath=args.datapath) dataloader = utils.get_loader(images_df, open_fn, batch_size=args.batch_size, num_workers=args.num_workers, dict_transform=dict_transformer) features = [] dataloader = tqdm(dataloader) if args.verbose else dataloader with torch.no_grad(): for batch in dataloader: features_ = model(batch["image"].to(device)) features_ = features_.cpu().detach().numpy() features.append(features_) features = np.concatenate(features, axis=0) np.save(args.out_npy, features)
def get_from_params( cls, image_size: int = None, encoder_params: Dict = None, embedding_net_params: Dict = None, heads_params: Dict = None, ) -> "MultiHeadNet": encoder_params_ = deepcopy(encoder_params) embedding_net_params_ = deepcopy(embedding_net_params) heads_params_ = deepcopy(heads_params) encoder_net = ResnetEncoder(**encoder_params_) encoder_input_shape = (3, image_size, image_size) encoder_output = \ utils.get_network_output(encoder_net, encoder_input_shape) enc_size = encoder_output.nelement() embedding_net_params_["hiddens"].insert(0, enc_size) embedding_net = SequentialNet(**embedding_net_params_) emb_size = embedding_net_params_["hiddens"][-1] head_kwargs_ = {} for key, value in heads_params_.items(): head_kwargs_[key] = nn.Linear(emb_size, value, bias=True) head_nets = nn.ModuleDict(head_kwargs_) net = cls(encoder_net=encoder_net, embedding_net=embedding_net, head_nets=head_nets) return net
def prepare_tsn_base_model(partial_bn=None, **kwargs): """ :param partial_bn: 2 if partial_bn else 1 :param kwargs: :return: """ base_model = ResnetEncoder(**kwargs) if partial_bn is not None: count = 0 for m in base_model.modules(): if isinstance(m, nn.BatchNorm2d): count += 1 if count >= partial_bn: m.eval() # shutdown update in frozen mode m.weight.requires_grad = False m.bias.requires_grad = False return base_model
def main(args, _=None): global IMG_SIZE utils.set_global_seed(args.seed) utils.prepare_cudnn(args.deterministic, args.benchmark) IMG_SIZE = (args.img_size, args.img_size) if args.traced_model is not None: device = utils.get_device() model = torch.jit.load(str(args.traced_model), map_location=device) else: model = ResnetEncoder(arch=args.arch, pooling=args.pooling) model = model.eval() model, _, _, _, device = utils.process_components(model=model) df = pd.read_csv(args.in_csv) df = df.reset_index().drop("index", axis=1) df = list(df.to_dict("index").values()) open_fn = ImageReader(input_key=args.img_col, output_key="image", datapath=args.datapath) dataloader = utils.get_loader(df, open_fn, batch_size=args.batch_size, num_workers=args.num_workers, dict_transform=dict_transformer) features = [] dataloader = tqdm(dataloader) if args.verbose else dataloader with torch.no_grad(): for batch in dataloader: features_ = model(batch["image"].to(device)) features_ = features_.cpu().detach().numpy() features.append(features_) features = np.concatenate(features, axis=0) np.save(args.out_npy, features)