예제 #1
0
    def __init__(self, hyper_params=None, params=None):
        graph_base.GraphBase.__init__(self, hyper_params, params)

        # network design
        tf.logging.info(
            "============================================================")
        tf.logging.info("BUILDING NETWORK...")

        self.hyper_params["model_type"] = "Test HREDDECODER"
        self.embedding = graph_base.get_params([
            self.hyper_params["common_vocab"] + self.hyper_params["kb_vocab"],
            self.hyper_params["emb_dim"]
        ])
        self.encoder = encoder.Encoder(self.hyper_params["encoder_layer_num"],
                                       self.hyper_params["emb_dim"],
                                       self.hyper_params["encoder_h_dim"],
                                       norm=True)
        self.kb_scorer = KBscorer.BiKBScorer(
            self.hyper_params["emb_dim"] + self.hyper_params["hred_h_dim"],
            FLAGS.candidate_num)
        self.hred = HRED.HRED(self.hyper_params["encoder_h_dim"],
                              self.hyper_params["hred_h_dim"],
                              self.hyper_params["emb_dim"],
                              norm=True)
        self.decoder = decoder.Decoder([
            self.hyper_params["decoder_gen_layer_num"],
            self.hyper_params["emb_dim"],
            self.hyper_params["decoder_gen_h_dim"],
            self.hyper_params["hred_h_dim"] + FLAGS.candidate_num,
            self.hyper_params["common_vocab"]
        ], [], [
            self.hyper_params["decoder_mlp_layer_num"],
            self.hyper_params["emb_dim"] + FLAGS.candidate_num * 2 +
            self.hyper_params["hred_h_dim"] +
            self.hyper_params["decoder_gen_h_dim"],
            self.hyper_params["decoder_mlp_h_dim"], 2
        ],
                                       d_type="MASK",
                                       norm=True,
                                       hyper_params=None,
                                       params=None)

        self.print_params()
        self.encoder.print_params()
        self.kb_scorer.print_params()
        self.hred.print_params()
        self.decoder.print_params()

        self.params = [self.embedding] + self.encoder.params + self.hred.params + \
                       self.kb_scorer.params + self.decoder.params
        params_dict = {}
        for i in range(0, len(self.params)):
            params_dict[str(i)] = self.params[i]
        self.saver = tf.train.Saver(params_dict)

        self.optimizer = self.get_optimizer()

        self.params_simple = [
            self.embedding
        ] + self.encoder.params + self.hred.params + self.decoder.params
예제 #2
0
    def __init__(self, hyper_params=None, params=None):
        graph_base.GraphBase.__init__(self, hyper_params, params)

        # network design
        tf.logging.info(
            "============================================================")
        tf.logging.info("BUILDING NETWORK...")

        self.hyper_params["model_type"] = "Test Autoencoder"
        self.embedding = graph_base.get_params([
            self.hyper_params["common_vocab"] + self.hyper_params["kb_vocab"],
            self.hyper_params["emb_dim"]
        ])
        self.encoder = encoder.Encoder(self.hyper_params["encoder_layer_num"],
                                       self.hyper_params["emb_dim"],
                                       self.hyper_params["encoder_h_dim"])
        self.aux_decoder = decoder.AuxDecoder([
            self.hyper_params["decoder_gen_layer_num"],
            self.hyper_params["emb_dim"],
            self.hyper_params["decoder_gen_h_dim"],
            self.hyper_params["encoder_h_dim"],
            FLAGS.common_vocab + FLAGS.candidate_num
        ])

        self.print_params()
        self.encoder.print_params()
        self.aux_decoder.print_params()

        self.params = [self.embedding
                       ] + self.encoder.params + self.aux_decoder.params
        params_dict = {}
        for i in range(0, len(self.params)):
            params_dict[str(i)] = self.params[i]
        self.saver = tf.train.Saver(params_dict)

        self.optimizer = self.get_optimizer()

        self.params_simple = [self.embedding
                              ] + self.encoder.params + self.aux_decoder.params