-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_saver.py
executable file
·104 lines (76 loc) · 4.41 KB
/
model_saver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""This module is used to save and restore models using json and checkpoints files"""
import datetime
import json
import os
from model_builder import ModelBuilder
from model_trainer import ModelTrainer
from model_predictor import ModelPredictor
def build_save_filename(prefix, n_videos, n_captions_per_video, batch_size):
"""This tool builds a name from a prefix, the model parameters and the current date"""
filename = (prefix + '_' + str(n_videos) + '_' + str(n_captions_per_video)
+ '_' + str(batch_size) + '_')
date = str(datetime.datetime.now()).split('.')[0].replace(' ', '_')
filename = filename + date
return filename
def restore_model(checkpoint_filename, video_retriever_generator,
selector, extractor):
"""This function restores a model from a tf checkpoint
Json filename is asserted to be the same as checpoint filename, but with json file extension
-Recovers the model's parameters via the json file
-Builds the model using this parameters
-Prepare the tensorflow graph and get neural network operations
"""
json_filename = checkpoint_filename + '.json'
with open(json_filename, 'r', encoding='utf-8') as json_file:
params = json.load(json_file)
builder = ModelBuilder(params["training_videos_names"], params["testing_videos_names"],
params["n_captions_per_video"], params["feature_n_frames"],
video_retriever_generator, selector, extractor)
model = params["model"]
builder.create_model(model["enc_units"], model["dec_units"], model["rnn_layers"],
model["embedding_dims"], model["learning_rate"],
model["dropout_rate"], model["bi_encoder"])
builder.prepare_training(params["batch_size"])
model_saver = ModelSaver(os.path.dirname(checkpoint_filename),
os.path.basename(checkpoint_filename))
return builder, model_saver, params
def model_trainer_from_checkpoint(checkpoint_filename, video_retriever_generator,
selector, extractor):
"""This function restores a model from a tf checkpoint and creates a ModelTrainer"""
builder, model_saver, params = restore_model(checkpoint_filename, video_retriever_generator,
selector, extractor)
model_trainer = ModelTrainer(model_saver, builder, params["epoch"], float(params["best_loss"]))
model_trainer.load_last_checkpoint()
return model_trainer
def model_predictor_from_checkpoint(checkpoint_filename, videos_retriever_generator,
selector, extractor):
"""This function restores a model from a tf checkpoint and creates a ModelPredictor"""
builder, model_saver, _ = restore_model(checkpoint_filename, videos_retriever_generator,
selector, extractor)
model_predictor = ModelPredictor(model_saver, builder)
model_predictor.load_last_checkpoint()
return model_predictor
class ModelSaver(object):
"""This classe manages the outputs (saving) of the video captioning model"""
def __init__(self, checkpoint_folder, checkpoint_filename):
self._checkpoint_folder = checkpoint_folder
self._savefile_name = os.path.join(self._checkpoint_folder, checkpoint_filename)
self._savefile_json = self._savefile_name + '.json'
@staticmethod
def from_generated_filename(model_builder, checkpoint_folder, prefix=''):
"""Creates a ModelSaver using the build_save_filename function"""
savefile_name = build_save_filename(prefix,
model_builder.get_n_videos(),
model_builder.get_n_captions_per_video(),
model_builder.get_batch_size())
return ModelSaver(checkpoint_folder, savefile_name)
def save_model_json(self, model_builder, epoch, best_loss):
"""Save all the model parameters into a json file"""
json_dict = model_builder.model_to_dict()
json_dict['epoch'] = epoch
json_dict['best_loss'] = str(best_loss)
with open(self._savefile_json, 'w', encoding='utf-8') as json_file:
json.dump(json_dict, json_file, ensure_ascii=False, indent=4)
def get_savefile_name(self):
"""Returns the name of the savefile"""
return self._savefile_name