def load_codec(self, weights_prefix): encoder_weight_filename = weights_prefix + "_encoder.h5" decoder_weight_filename = weights_prefix + "_decoder.h5" if not os.path.isfile(encoder_weight_filename): raise Exception("The file for encoder weights does not exist:{}".format(encoder_weight_filename)) self.encoder.load_weights(encoder_weight_filename) if not os.path.isfile(decoder_weight_filename): raise Exception("The file for decoder weights does not exist:{}".format(decoder_weight_filename)) self.decoder.load_weights(decoder_weight_filename) print("Encoder summaries") self.encoder.summary() _, encode_H, encode_W, numChannels = self.encoder.output_shape config = self.decoder.get_config() config2 = config[1::] config2[0]['config']['batch_input_shape'] = (None, encode_H, encode_W, numChannels) decoder_temp = Sequential.from_config(config2, custom_objects={"tf": tf}) # set weights cnt = -1 for l in self.decoder.layers: cnt += 1 if cnt == 0: continue weights = l.get_weights() decoder_temp.layers[cnt - 1].set_weights(weights) self.decoder = decoder_temp print("Decoder summaries") self.decoder.summary()
def load_codec(codec_prefix, print_summary=False): # load data saveFilePrefix = codec_prefix + '_' # load models encoder_model_filename = saveFilePrefix + "encoder.json" decoder_model_filename = saveFilePrefix + "decoder.json" encoder_weight_filename = saveFilePrefix + "encoder.h5" decoder_weight_filename = saveFilePrefix + "decoder.h5" if not os.path.isfile(encoder_model_filename): raise Exception("The file for encoder model does not exist:{}".format( encoder_model_filename)) json_file = open(encoder_model_filename, 'r') encoder = model_from_json(json_file.read(), custom_objects={"tf": tf}) json_file.close() if not os.path.isfile(encoder_weight_filename): raise Exception( "The file for encoder weights does not exist:{}".format( encoder_weight_filename)) encoder.load_weights(encoder_weight_filename) if not os.path.isfile(decoder_model_filename): raise Exception("The file for decoder model does not exist:{}".format( decoder_model_filename)) json_file = open(decoder_model_filename, 'r') decoder_temp = model_from_json(json_file.read(), custom_objects={"tf": tf}) json_file.close() if not os.path.isfile(decoder_weight_filename): raise Exception( "The file for decoder weights does not exist:{}".format( decoder_weight_filename)) decoder_temp.load_weights(decoder_weight_filename) if print_summary: print("Encoder summaries") encoder.summary() _, encode_H, encode_W, numChannels = encoder.output_shape # the workaround # use config to construct the decoder model # and then load the weights for each layer # Note that the information in config[0::] is the sequential model for encoder # thus, we need to exclude the first element config = decoder_temp.get_config() config2 = config[1::] config2[0]['config']['batch_input_shape'] = (None, encode_H, encode_W, numChannels) decoder = Sequential.from_config(config2, custom_objects={"tf": tf}) # set weights cnt = -1 for l in decoder_temp.layers: cnt += 1 if cnt == 0: continue weights = l.get_weights() decoder.layers[cnt - 1].set_weights(weights) if print_summary: print("Decoder summaries") decoder.summary() return encoder, decoder