def get_model_and_data_loaders( config: ConfigParser, logger: logging.Logger, ckpt_path: Path, ) -> Tuple[torch.nn.Module, module_data.ExpertDataLoader]: expert_dims, raw_input_dims, text_dim = compute_dims(config) data_loaders = config.init( name='data_loader', module=module_data, logger=logger, raw_input_dims=raw_input_dims, challenge_mode=config.get("challenge_mode", False), text_dim=text_dim, text_feat=config["experts"]["text_feat"], text_agg=config["experts"]["text_agg"], use_zeros_for_missing=config["experts"].get("use_zeros_for_missing", False), task=config.get("task", "retrieval"), eval_only=True, distil_params=config.get("distil_params", None), training_file=config.get("training_file", None), caption_masks=config.get("caption_masks", None), ce_shared_dim=config["experts"].get("ce_shared_dim", None), ) trn_config = compute_trn_config(config) model = config.init( name='arch', module=module_arch, trn_config=trn_config, expert_dims=expert_dims, text_dim=text_dim, disable_nan_checks=config["disable_nan_checks"], task=config.get("task", "retrieval"), ce_shared_dim=config["experts"].get("ce_shared_dim", None), feat_aggregation=config["data_loader"]["args"]["feat_aggregation"], trn_cat=config["data_loader"]["args"].get("trn_cat", 0), ) ckpt_path = config._args.resume logger.info(f"Loading checkpoint: {ckpt_path} ...") checkpoint = torch.load(ckpt_path) state_dict = checkpoint['state_dict'] if config['n_gpu'] > 1: model = torch.nn.DataParallel(model) # support backwards compatibility deprecated = ["ce.moe_fc_bottleneck1", "ce.moe_cg", "ce.moe_fc_proj"] for mod in deprecated: for suffix in ("weight", "bias"): key = f"{mod}.{suffix}" if key in state_dict: print(f"WARNING: Removing deprecated key {key} from model") state_dict.pop(key) model.load_state_dict(state_dict) return model, data_loaders
def get_model_and_data_loaders( config: ConfigParser, logger: logging.Logger, ckpt_path: Path, ) -> Tuple[torch.nn.Module, module_data.ExpertDataLoader]: expert_dims, raw_input_dims = compute_dims(config) trn_config = compute_trn_config(config) data_loaders = config.init( name='data_loader', module=module_data, logger=logger, raw_input_dims=raw_input_dims, challenge_mode=config.get("challenge_mode", False), text_feat=config["experts"]["text_feat"], text_dim=config["experts"]["text_dim"], text_agg=config["experts"]["text_agg"], use_zeros_for_missing=config["experts"].get("use_zeros_for_missing", False), task=config.get("task", "retrieval"), eval_only=True, ) model = config.init( name='arch', module=module_arch, trn_config=trn_config, expert_dims=expert_dims, text_dim=config["experts"]["text_dim"], disable_nan_checks=config["disable_nan_checks"], task=config.get("task", "retrieval"), ce_shared_dim=config["experts"].get("ce_shared_dim", None), feat_aggregation=config["data_loader"]["args"]["feat_aggregation"], trn_cat=config["data_loader"]["args"].get("trn_cat", 0), ) ckpt_path = config._args.resume logger.info(f"Loading checkpoint: {ckpt_path} ...") checkpoint = torch.load(ckpt_path) state_dict = checkpoint['state_dict'] if config['n_gpu'] > 1: model = torch.nn.DataParallel(model) model.load_state_dict(state_dict) return model, data_loaders
from tensorflow.keras.utils import Progbar from datas.list_generator import ListGenerator from language_model.char_rnn_lm import CharRnnLmWrapperSingleton from lip_model.training_graph import TransformerTrainGraph from lip_model.inference_graph import TransformerInferenceGraph import json import shutil import threading import copy import queue app = Flask(__name__) args = argparse.ArgumentParser() config = ConfigParser(args) model = config.init('arch', module_arch) logger = config.get_logger('test') tic = time.time() with open(os.path.join('./misc/pretrained_models', 'KWS_Net.pth'), 'rb') as f: checkpoint = torch.load(f) state_dict = canonical_state_dict_keys(checkpoint['state_dict']) model.load_state_dict(state_dict) logger.info(f"Finished loading ckpt in {time.time() - tic:.3f}s") logger.info(f"CUDA device count: {torch.cuda.device_count()}") device_count = torch.cuda.device_count() models = [] for device_ind in range(device_count): device = f"cuda:{device_ind}" models.append(copy.deepcopy(model).to(device))