def make_model( model_status: str = "ucf101_trained", weights_path: Optional[str] = None ) -> Tuple[torch.nn.DataParallel, optim.SGD]: statuses = ("ucf101_trained", "kinetics_pretrained") if model_status not in statuses: raise ValueError(f"model_status {model_status} not in {statuses}") trained = model_status == "ucf101_trained" if not trained and weights_path is None: raise ValueError( "weights_path cannot be None for 'kinetics_pretrained'") opt = parse_opts(arguments=[]) opt.dataset = "UCF101" opt.only_RGB = True opt.log = 0 opt.batch_size = 1 opt.arch = f"{opt.model}-{opt.model_depth}" if trained: opt.n_classes = 101 else: opt.n_classes = 400 opt.n_finetune_classes = 101 opt.batch_size = 32 opt.ft_begin_index = 4 opt.pretrain_path = weights_path logger.info(f"Loading model... {opt.model} {opt.model_depth}") model, parameters = generate_model(opt) if trained and weights_path is not None: checkpoint = torch.load(weights_path, map_location=DEVICE) model.load_state_dict(checkpoint["state_dict"]) # Initializing the optimizer if opt.pretrain_path: opt.weight_decay = 1e-5 opt.learning_rate = 0.001 if opt.nesterov: dampening = 0 else: dampening = opt.dampening optimizer = optim.SGD( parameters, lr=opt.learning_rate, momentum=opt.momentum, dampening=dampening, weight_decay=opt.weight_decay, nesterov=opt.nesterov, ) return model, optimizer
def make_model(model_status="ucf101_trained", weights_file=None): statuses = ("ucf101_trained", "kinetics_pretrained") if model_status not in statuses: raise ValueError(f"model_status {model_status} not in {statuses}") trained = model_status == "ucf101_trained" if not trained and weights_file is None: raise ValueError( "weights_file cannot be None for 'kinetics_pretrained'") if weights_file: filepath = maybe_download_weights_from_s3(weights_file) opt = parse_opts(arguments=[]) opt.dataset = "UCF101" opt.only_RGB = True opt.log = 0 opt.batch_size = 1 opt.arch = f"{opt.model}-{opt.model_depth}" if trained: opt.n_classes = 101 else: opt.n_classes = 400 opt.n_finetune_classes = 101 opt.batch_size = 32 opt.ft_begin_index = 4 opt.pretrain_path = filepath logger.info(f"Loading model... {opt.model} {opt.model_depth}") model, parameters = generate_model(opt) if trained and weights_file is not None: checkpoint = torch.load(filepath, map_location=DEVICE) # Fit the robust model into the original resnext model state_dict_path = 'model' if not ('model' in checkpoint): state_dict_path = 'state_dict' sd = checkpoint[state_dict_path] sd = {k[len('module.'):]: v for k, v in sd.items()} items = list(sd.items()) for key, val in items: if key.startswith('attacker.'): sd.pop(key) if key.startswith('model.'): new_key = 'module.' + key[len('model.'):] sd[new_key] = val sd.pop(key) if key == 'normalizer.new_mean' or key == 'normalizer.new_std': sd.pop(key) model.load_state_dict(sd) # Initializing the optimizer if opt.pretrain_path: opt.weight_decay = 1e-5 opt.learning_rate = 0.001 if opt.nesterov: dampening = 0 else: dampening = opt.dampening optimizer = optim.SGD( parameters, lr=opt.learning_rate, momentum=opt.momentum, dampening=dampening, weight_decay=opt.weight_decay, nesterov=opt.nesterov, ) return model, optimizer