def _model(model_name): if model_name == 'motion': from STM.models.model_fusai import STM elif model_name == 'aspp': from STM.models.model_fusai_aspp import STM model = STM() # model.eval() # model.Decoder.train() elif model_name == 'enhanced': from STM.models.model_enhanced import STM model = STM() model.eval() model.KV_Q.train() elif model_name == 'standard': from STM.models.model import STM model = STM() elif model_name == 'enhanced_motion': from STM.models.model_enhanced_motion import STM model = STM() elif model_name == 'varysize': from STM.models.model_enhanced_varysize import STM model = STM() elif model_name == 'sp': from STM.models.model_fusai_spatial_prior import STM model = STM() # model.eval() # model.Decoder.Aspp.train() elif model_name == 'hkf': from STM.model_hkf import STM model = STM() return model
def init_stm_model(model_name, model_path): if model_name == 'motion': from STM.models.model_fusai import STM model = STM() elif model_name == 'aspp': from STM.models.model_fusai_aspp import STM model = STM() elif model_name == 'enhanced': from STM.models.model_enhanced import STM model = STM() elif model_name == 'enhanced_motion': from STM.models.model_enhanced_motion import STM model = STM() elif model_name == 'standard': from STM.models.model import STM model = STM() elif model_name == 'varysize': from STM.models.model_enhanced_varysize import STM model = STM() elif model_name == 'sp': from STM.models.model_fusai_spatial_prior import STM model = STM() else: raise ValueError # turn-off BN print('Loading weights:', model_path) model_ = torch.load(model_path, map_location=torch.device('cpu')) if 'state_dict' in model_.keys(): state_dict = model_['state_dict'] else: state_dict = model_ d = {} for k, v in state_dict.items(): d.setdefault(k.replace('module.', ''), v) state_dict = d model.load_state_dict(state_dict) model.eval() model.to(ipex.DEVICE) return model