Пример #1
0
    def load_codec(self, weights_prefix):
        encoder_weight_filename = weights_prefix + "_encoder.h5"
        decoder_weight_filename = weights_prefix + "_decoder.h5"

        if not os.path.isfile(encoder_weight_filename):
            raise Exception("The file for encoder weights does not exist:{}".format(encoder_weight_filename))
        self.encoder.load_weights(encoder_weight_filename)

        if not os.path.isfile(decoder_weight_filename):
            raise Exception("The file for decoder weights does not exist:{}".format(decoder_weight_filename))
        self.decoder.load_weights(decoder_weight_filename)

        print("Encoder summaries")
        self.encoder.summary()

        _, encode_H, encode_W, numChannels = self.encoder.output_shape
        config = self.decoder.get_config()
        config2 = config[1::]
        config2[0]['config']['batch_input_shape'] = (None, encode_H, encode_W, numChannels)
        decoder_temp = Sequential.from_config(config2, custom_objects={"tf": tf})

        # set weights
        cnt = -1
        for l in self.decoder.layers:
            cnt += 1
            if cnt == 0:
                continue
            weights = l.get_weights()
            decoder_temp.layers[cnt - 1].set_weights(weights)

        self.decoder = decoder_temp
        print("Decoder summaries")
        self.decoder.summary()
Пример #2
0
def load_codec(codec_prefix, print_summary=False):

    # load data
    saveFilePrefix = codec_prefix + '_'
    # load models
    encoder_model_filename = saveFilePrefix + "encoder.json"
    decoder_model_filename = saveFilePrefix + "decoder.json"
    encoder_weight_filename = saveFilePrefix + "encoder.h5"
    decoder_weight_filename = saveFilePrefix + "decoder.h5"

    if not os.path.isfile(encoder_model_filename):
        raise Exception("The file for encoder model does not exist:{}".format(
            encoder_model_filename))

    json_file = open(encoder_model_filename, 'r')
    encoder = model_from_json(json_file.read(), custom_objects={"tf": tf})
    json_file.close()

    if not os.path.isfile(encoder_weight_filename):
        raise Exception(
            "The file for encoder weights does not exist:{}".format(
                encoder_weight_filename))
    encoder.load_weights(encoder_weight_filename)

    if not os.path.isfile(decoder_model_filename):
        raise Exception("The file for decoder model does not exist:{}".format(
            decoder_model_filename))
    json_file = open(decoder_model_filename, 'r')
    decoder_temp = model_from_json(json_file.read(), custom_objects={"tf": tf})
    json_file.close()

    if not os.path.isfile(decoder_weight_filename):
        raise Exception(
            "The file for decoder weights does not exist:{}".format(
                decoder_weight_filename))
    decoder_temp.load_weights(decoder_weight_filename)

    if print_summary:
        print("Encoder summaries")
        encoder.summary()

    _, encode_H, encode_W, numChannels = encoder.output_shape

    # the workaround
    # use config to construct the decoder model
    # and then load the weights for each layer
    # Note that the information in config[0::] is the sequential model for encoder
    # thus, we need to exclude the first element

    config = decoder_temp.get_config()
    config2 = config[1::]
    config2[0]['config']['batch_input_shape'] = (None, encode_H, encode_W,
                                                 numChannels)
    decoder = Sequential.from_config(config2, custom_objects={"tf": tf})

    # set weights
    cnt = -1
    for l in decoder_temp.layers:
        cnt += 1
        if cnt == 0:
            continue
        weights = l.get_weights()
        decoder.layers[cnt - 1].set_weights(weights)
    if print_summary:
        print("Decoder summaries")
        decoder.summary()

    return encoder, decoder