Example #1
0
    def __init__(self,
                 pretrained,
                 encoder_dir=None,
                 decoder_dir=None,
                 temp=1,
                 Resnet="r18",
                 color_switch=True,
                 coord_switch=True):
        super(track_match_comb, self).__init__()

        if Resnet in "r18":
            self.gray_encoder = encoder_res18(pretrained=pretrained,
                                              uselayer=4)
        elif Resnet in "r50":
            self.gray_encoder = encoder_res50(pretrained=pretrained,
                                              uselayer=4)
        self.rgb_encoder = encoder3(reduce=True)
        self.decoder = decoder3(reduce=True)

        self.rgb_encoder.load_state_dict(torch.load(encoder_dir))
        self.decoder.load_state_dict(torch.load(decoder_dir))
        for param in self.decoder.parameters():
            param.requires_grad = False
        for param in self.rgb_encoder.parameters():
            param.requires_grad = False

        self.nlm = NLM_woSoft()
        self.normalize = normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
        self.softmax = nn.Softmax(dim=1)
        self.temp = temp
        self.grid_flat = None
        self.grid_flat_crop = None
        self.color_switch = color_switch
        self.coord_switch = coord_switch
Example #2
0
	def __init__(self, encoder_dir = None, decoder_dir = None, fix_dec = True,
					   temp = None, pretrainRes = False, uselayer=3, model='resnet18'):
		'''
		For switchable concenration loss
		Using Resnet18
		'''
		super(Model_switchGTfixdot_swCC_Res, self).__init__()
		if(model == 'resnet18'):
			print('Use ResNet18.')
			self.gray_encoder = encoder_res18(pretrained = pretrainRes, uselayer=uselayer)
		else:
			print('Use ResNet50.')
			self.gray_encoder = encoder_res50(pretrained = pretrainRes, uselayer=uselayer)
		self.rgb_encoder = encoder3(reduce = True)
		self.nlm = NLM_woSoft()
		self.decoder = decoder3(reduce = True)
		self.temp = temp
		self.softmax = nn.Softmax(dim=1)
		self.cos_window = torch.Tensor(np.outer(np.hanning(40), np.hanning(40))).cuda()
		self.normalize = normalize(mean=[0.485, 0.456, 0.406],
								   std=[0.229, 0.224, 0.225])

		self.rgb_encoder.load_state_dict(torch.load(encoder_dir))
		self.decoder.load_state_dict(torch.load(decoder_dir))

		for param in self.decoder.parameters():
			param.requires_grad = False
		for param in self.rgb_encoder.parameters():
			param.requires_grad = False