Ejemplo n.º 1
0
    def _init_model_and_optimizer(self):
        if self.config['hyperbolic']:
            model = HResNetSimCLR(**self.config["model"]).to(self.device)
        else:
            model = ResNetSimCLR(**self.config["model"]).to(self.device)
        model, loaded_iter = self._load_pre_trained_weights(model)

        # if self.config['hyperbolic']:
        #     optimizer = geoopt.optim.RiemannianAdam(
        #         [p for p in self.model.parameters() if p.requires_grad],
        #         1e-4, weight_decay=eval(self.config['weight_decay']))
        # else:
        lr = float(self.config['lr'])
        optimizer = torch.optim.Adam(
            [p for p in model.parameters() if p.requires_grad],
            lr,
            weight_decay=eval(self.config['weight_decay']))

        self.model = model
        self.optimizer = optimizer
        return loaded_iter
Ejemplo n.º 2
0
 def _save_dt_features(self):
     from eval.feature_saver import LvisSaver
     import torch
     config = self.config
     if self.config['hyperbolic']:
         from models.hyperbolic_resnet import HResNetSimCLR
         model = HResNetSimCLR(config['model']['base_model'], config['model']['out_dim'])
     else:
         from models.resnet_simclr import ResNetSimCLR
         model = ResNetSimCLR(config['model']['base_model'], config['model']['out_dim'])
     state_dict = torch.load(os.path.join(self.run_path, self.model_ckpt))  # , map_location=device)i
     model.load_state_dict(state_dict)
     model.eval()
     saver = LvisSaver(model, self.lvis_dt, GT_FEATS)
     saver.save()
Ejemplo n.º 3
0
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.manifold import TSNE
import os
import multiprocessing
from itertools import product
from models.hyperbolic_resnet import HResNetSimCLR
"""
cmodel = HResNetSimCLR('resnet18', 256)
state = torch.load(r'/scratch/users/zzweng/runs/checkpoints/freeze_coco_pretrain_hyp=True_zdim=64_loss=triplet/model_50000.pth')
cmodel.load_state_dict(state)
print(cmodel)
"""
PATH = 'features_lvis_dt_hyperbolic'
os.makedirs(PATH, exist_ok=True)
checkpoint_dir = r'/scratch/users/zzweng/runs/checkpoints/all_hyp=True_zdim=2_loss=triplet_maskloss=False/'
cmodel = HResNetSimCLR('resnet101', 2)
state_dict = torch.load(os.path.join(checkpoint_dir, 'model_11500.pth')) #, map_location=device)i
cmodel.load_state_dict(state_dict)
cmodel.eval()

#cmodel_ = torchvision.models.resnet101(pretrained=True)
#cmodel = nn.Sequential(*list(cmodel_.children())[:-1])
#cmodel.eval()

lvis = LVIS('/scratch/users/zzweng/datasets/lvis/lvis_v0.5_val.json')
lvis_dt = LVISResults(lvis, r'output/inference/lvis_instances_results.json')

def collect_features_from_dt(start, end, folder=PATH):
    print('Collecting {} to {}'.format(start, end))
    img_ids = lvis_dt.get_img_ids()[start:end]
    feats = []