def load_pretrained_model(serialization_dir: str) -> models.Model: """ Given serialization directory, returns: model loaded with the pretrained weights. """ # Load Config config_path = os.path.join(serialization_dir, "config.json") model_path = os.path.join(serialization_dir, "model.ckpt.index") model_files_present = all( [os.path.exists(path) for path in [config_path, model_path]]) if not model_files_present: raise Exception( f"Model files in serialization_dir ({serialization_dir}) " f" are missing. Cannot load_the_model.") model_path = model_path.replace(".index", "") with open(config_path, "r") as file: config = json.load(file) # Load Model model_name = config.pop("type") if model_name == "basic": from model import MyBasicAttentiveBiGRU # To prevent circular imports model = MyBasicAttentiveBiGRU(**config) elif model_name == "advanced": from model import MyAdvancedModel # To prevent circular imports model = MyAdvancedModel(**config) else: raise Exception(f"model_name: {model_name} is not supported.") model.load_weights(model_path) return model
train_instances = read_instances(args.data_file, MAX_TOKENS) print(f"\nReading Val Instances") val_instances = read_instances(args.val_file, MAX_TOKENS) with open(GLOVE_COMMON_WORDS_PATH) as file: glove_common_words = [line.strip() for line in file.readlines() if line.strip()] vocab_token_to_id, vocab_id_to_token = build_vocabulary(train_instances, VOCAB_SIZE, glove_common_words) train_instances = index_instances(train_instances, vocab_token_to_id) val_instances = index_instances(val_instances, vocab_token_to_id) vocab_size = len(vocab_token_to_id) config = {'vocab_size': vocab_size, 'embed_dim': args.embed_dim, 'training': True, 'hidden_size': args.hidden_size} model = MyBasicAttentiveBiGRU(**config) config['type'] = 'basic' optimizer = optimizers.Adam() embeddings = load_glove_embeddings(args.embed_file, args.embed_dim, vocab_id_to_token) model.embeddings.assign(tf.convert_to_tensor(embeddings)) save_serialization_dir = os.path.join('serialization_dirs', 'basic') if not os.path.exists(save_serialization_dir): os.makedirs(save_serialization_dir) train_output = train(model, optimizer, train_instances, val_instances, args.epochs, args.batch_size, save_serialization_dir) config_path = os.path.join(save_serialization_dir, "config.json")