コード例 #1
0
    def __init__(self, batch, config, is_train=True):
        self.batch = batch
        self.config = config
        self.data_cfg = config.data_cfg
        self.data_dir = config.data_dir
        self.is_train = is_train

        self.losses = {}
        self.report = {}
        self.mid_result = {}
        self.vis_image = {}

        vocab_path = os.path.join(self.data_dir, 'vocab.pkl')
        self.vocab = cPickle.load(open(vocab_path, 'rb'))

        answer_dict_path = os.path.join(self.data_dir, 'answer_dict.pkl')
        self.answer_dict = cPickle.load(open(answer_dict_path, 'rb'))
        self.num_answer = len(self.answer_dict['vocab'])

        ws_dict_path = os.path.join(self.data_dir, 'wordset_dict5.pkl')
        self.ws_dict = cPickle.load(open(ws_dict_path, 'rb'))
        self.num_ws = len(self.ws_dict['vocab'])

        self.wordset_map = modules.learn_embedding_map(self.ws_dict,
                                                       scope='wordset_map')
        self.v_word_map = modules.LearnGloVe(self.vocab, scope='V_GloVe')
        self.l_word_map = modules.LearnGloVe(self.vocab, scope='L_GloVe')
        self.l_answer_word_map = modules.LearnAnswerGloVe(self.answer_dict)

        self.build()
コード例 #2
0
    def __init__(self, batches, config, is_train=True):
        self.batches = batches
        self.config = config

        self.report = {}
        self.output = {}

        self.no_object = config.no_object
        self.no_region = config.no_region
        self.use_blank_fill = config.use_blank_fill

        self.object_batch_size = config.object_batch_size
        self.region_batch_size = config.region_batch_size

        self.object_num_k = config.object_num_k
        self.object_max_name_len = config.object_max_name_len

        self.region_max_len = config.region_max_len

        # model parameters
        self.finetune_enc_I = config.finetune_enc_I
        self.no_V_grad_enc_L = config.no_V_grad_enc_L
        self.no_V_grad_dec_L = config.no_V_grad_dec_L
        self.no_L_grad_dec_L = config.no_L_grad_dec_L
        self.use_embed_transform = config.use_embed_transform
        self.use_dense_predictor = config.use_dense_predictor
        self.no_glove = config.no_glove

        self.vocab = json.load(open(config.vocab_path, 'r'))
        self.wordset = modules.used_wordset(config.used_wordset_path)
        self.wordset_vocab = {}
        with h5py.File(config.used_wordset_path, 'r') as f:
            wordset = list(f['used_wordset'].value)
            self.wordset_vocab['vocab'] = [self.vocab['vocab'][w]
                                           for w in wordset]
            self.wordset_vocab['dict'] = {w: i for i, w in
                                          enumerate(self.wordset_vocab['vocab'])}

        if self.no_glove:
            self.glove_all = modules.learn_embedding_map(self.vocab)
        else: self.glove_all = modules.glove_embedding_map(self.vocab)
        self.glove_wordset = tf.nn.embedding_lookup(self.glove_all,
                                                    self.wordset)
        predictor_embed = self.glove_wordset
        if self.use_embed_transform:
            predictor_embed = modules.embedding_transform(
                predictor_embed, W_DIM, L_DIM, is_train=is_train)
        if self.use_dense_predictor:
            self.word_predictor = tf.layers.Dense(
                len(self.wordset_vocab['vocab']), use_bias=True, name='WordPredictor')
        else:
            self.word_predictor = modules.WordPredictor(predictor_embed,
                                                        trainable=is_train,
                                                        name='WordPredictor')

        self.build(is_train=is_train)
コード例 #3
0
    def __init__(self, batch, config, is_train=True):
        self.batch = batch
        self.config = config
        self.data_cfg = config.data_cfg
        self.data_dir = config.data_dir
        self.is_train = is_train

        self.losses = {}
        self.report = {}
        self.mid_result = {}
        self.vis_image = {}

        vocab_path = os.path.join(self.data_dir, 'vocab.pkl')
        self.vocab = cPickle.load(open(vocab_path, 'rb'))

        answer_dict_path = os.path.join(self.data_dir, 'answer_dict.pkl')
        self.answer_dict = cPickle.load(open(answer_dict_path, 'rb'))
        self.num_answer = len(self.answer_dict['vocab'])

        ws_dict_path = os.path.join(self.data_dir, 'wordset_dict5.pkl')
        self.ws_dict = cPickle.load(open(ws_dict_path, 'rb'))
        self.num_ws = len(self.ws_dict['vocab'])

        enwiki_dict_path = os.path.join(
            self.data_dir, 'enwiki_context_dict_w3_p{}_n5.pkl'.format(config.enwiki_preprocessing))

        self.enwiki_dict = cPickle.load(open(enwiki_dict_path, 'rb'))
        self.num_context_vocab = len(self.enwiki_dict['context_word_vocab'])
        self.max_context_len = self.enwiki_dict['max_context_len']
        self.enwiki_vocab_dict = {
            'vocab': self.enwiki_dict['context_word_vocab'],
            'dict': self.enwiki_dict['context_word_dict'],
        }

        self.wordset_map = modules.learn_embedding_map(
            self.ws_dict, scope='wordset_map')
        self.enwiki_map = modules.learn_embedding_map(
            self.enwiki_vocab_dict, scope='enwiki_map')
        self.v_word_map = modules.LearnGloVe(self.vocab, scope='V_GloVe')
        self.l_word_map = modules.LearnGloVe(self.vocab, scope='L_GloVe')
        self.l_answer_word_map = modules.LearnAnswerGloVe(self.answer_dict)

        self.build()
コード例 #4
0
    raise ValueError('Do not overwrite: {}'.format(config.save_dir))

vocab_path = os.path.join(config.data_dir, 'vocab.pkl')
vocab = cPickle.load(open(vocab_path, 'rb'))

answer_dict_path = os.path.join(config.data_dir, 'answer_dict.pkl')
answer_dict = cPickle.load(open(answer_dict_path, 'rb'))
num_answer = len(answer_dict['vocab'])

ws_dict_path = os.path.join(
    config.data_dir,
    'wordset_dict5_depth{}.pkl'.format(int(config.expand_depth)))
ws_dict = cPickle.load(open(ws_dict_path, 'rb'))
num_ws = len(ws_dict['vocab'])

wordset_map = modules.learn_embedding_map(ws_dict, scope='wordset_map')

L_DIM = 1024

wordset_embed = tf.tanh(wordset_map)
wordset_ft = modules.fc_layer(wordset_embed,
                              L_DIM,
                              use_bias=True,
                              use_bn=False,
                              use_ln=True,
                              activation_fn=tf.tanh,
                              is_training=False,
                              scope='wordset_ft')

session_config = tf.ConfigProto(allow_soft_placement=True,
                                gpu_options=tf.GPUOptions(allow_growth=True),