Пример #1
0
    def _evaluate(self, sess, model, batcher):
        # evaluate the model in a set
        batcher.reset()
        type_count = np.zeros(self.params['n_types'], dtype=float)
        bleu1_count = np.zeros(self.params['n_types'], dtype=float)
        wups_count = np.zeros(self.params['n_types'], dtype=float)
        wups_count2 = np.zeros(self.params['n_types'], dtype=float)
        i_batch = 0
        all_batch_time = 0

        for img_frame_vecs, img_frame_n, ques_vecs, ques_n, ques_word, ans_vecs, ans_n, ans_word, type_vec, batch_size in batcher.generate(
        ):
            if ans_vecs is None:
                break

            batch_data = {
                model.input_q: ques_vecs,
                model.y: ans_word,
                model.input_x: img_frame_vecs,
                model.input_x_len: img_frame_n,
                model.input_q_len: ques_n,
                model.is_training: False,
                model.batch_size: batch_size,
                model.ans_vec: ans_vecs
            }

            mask_matrix = np.zeros(
                [np.shape(ans_n)[0], self.params['max_n_a_words']], np.int32)
            for ind, row in enumerate(mask_matrix):
                row[:ans_n[ind]] = 1
            batch_data[model.y_mask] = mask_matrix

            batch_t1 = time.time()
            test_ans = sess.run(self.model.answer_word_test,
                                feed_dict=batch_data)
            batch_t2 = time.time()
            batch_time = batch_t2 - batch_t1
            all_batch_time += batch_time
            test_ans = np.transpose(np.array(test_ans), (1, 0))

            for i in range(len(type_vec)):
                type_count[type_vec[i]] += 1
                ground_a = list()
                for l in range(self.params['max_n_a_words']):
                    word = ans_word[i][l]
                    if self.index2word[word] == 'EOS':
                        break
                    ground_a.append(self.index2word[word])

                generate_a = list()
                for l in range(self.params['max_n_a_words']):
                    word = test_ans[i][l]
                    if self.index2word[word] == 'EOS':
                        break
                    generate_a.append(self.index2word[word])

                question = list()
                for l in range(self.params['max_n_q_words']):
                    word = ques_word[i][l]
                    if self.index2word[word] == '<PAD>':
                        break
                    question.append(self.index2word[word])

                wups_value = wups.compute_wups(ground_a, generate_a, 0.0)
                wups_value2 = wups.compute_wups(ground_a, generate_a, 0.9)
                bleu1_value = wups.compute_wups(ground_a, generate_a, -1)
                # bleu1_value = bleu.calculate_bleu(' '.join(ground_a), ' '.join(generate_a))
                wups_count[type_vec[i]] += wups_value
                wups_count2[type_vec[i]] += wups_value2
                bleu1_count[type_vec[i]] += bleu1_value

            i_batch += 1
            if i_batch % 100 == 0:
                print('batch index:', i_batch)
                print('question:    ', question)
                print('ground_a:    ', ground_a)
                print('generated:    ', generate_a)

        wup_acc = wups_count.sum() / type_count.sum()
        wup_acc2 = wups_count2.sum() / type_count.sum()
        bleu1_acc = bleu1_count.sum() / type_count.sum()
        print('Overall Wup (@0):', wup_acc, '[', wups_count.sum(), '/',
              type_count.sum(), ']')
        print('Overall Wup (@0.9):', wup_acc2, '[', wups_count2.sum(), '/',
              type_count.sum(), ']')
        print('Overall Bleu1:', bleu1_acc, '[', bleu1_count.sum(), '/',
              type_count.sum(), ']')
        type_wup_acc = [
            wups_count[i] / type_count[i]
            for i in range(self.params['n_types'])
        ]
        type_wup_acc2 = [
            wups_count2[i] / type_count[i]
            for i in range(self.params['n_types'])
        ]
        type_bleu1_acc = [
            bleu1_count[i] / type_count[i]
            for i in range(self.params['n_types'])
        ]
        print('Wup@0 for each type:', type_wup_acc)
        print('[email protected] for each type:', type_wup_acc2)
        print('Bleu1 for each type:', type_bleu1_acc)
        print('type count:        ', type_count)
        print('all test time:  ', all_batch_time)
        return bleu1_acc
Пример #2
0
    def _train(self, sess):
        # tensorflow initialization
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        learning_rates = tf.train.exponential_decay(
            self.params['learning_rate'],
            global_step,
            decay_steps=self.params['lr_decay_n_iters'],
            decay_rate=self.params['lr_decay_rate'],
            staircase=True)
        optimizer = tf.train.AdamOptimizer(learning_rates)
        train_proc = optimizer.minimize(self.model.train_loss,
                                        global_step=global_step)
        # train_proc_rl = optimizer.minimize(self.model.loss_rl, global_step=global_step)

        # training
        init_proc = tf.global_variables_initializer()
        sess.run(init_proc)
        # self.model_saver.restore(sess, '../results/0817221611-2670')
        best_epoch_acc = 0
        best_epoch_id = 0

        print('****************************')
        print('Trainning datetime:',
              time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
        print('Trainning params')
        print(self.params)
        utils.count_total_variables()
        print('****************************')

        all_epoch_time = 0
        epoch_time_list = list()
        for i_epoch in range(self.params['max_epoches']):
            # train an epoch

            all_batch_time = 0
            self.train_batcher.reset()
            i_batch = 0
            loss_sum = 0

            type_count = np.zeros(self.params['n_types'], dtype=float)
            wups_count = np.zeros(self.params['n_types'], dtype=float)
            wups_count2 = np.zeros(self.params['n_types'], dtype=float)
            bleu1_count = np.zeros(self.params['n_types'], dtype=float)

            for img_frame_vecs, img_frame_n, ques_vecs, ques_n, ques_word, ans_vecs, ans_n, ans_word, type_vec, batch_size in self.train_batcher.generate(
            ):
                if ans_vecs is None:
                    break

                batch_data = dict()
                batch_data[self.model.y] = ans_word
                batch_data[self.model.input_x] = img_frame_vecs
                batch_data[self.model.input_x_len] = img_frame_n
                batch_data[self.model.input_q] = ques_vecs
                batch_data[self.model.input_q_len] = ques_n
                batch_data[self.model.ans_vec] = ans_vecs
                batch_data[self.model.is_training] = True
                batch_data[self.model.batch_size] = batch_size

                mask_matrix = np.zeros(
                    [np.shape(ans_n)[0], self.params['max_n_a_words']],
                    np.int32)
                for ind, row in enumerate(mask_matrix):
                    row[:ans_n[ind]] = 1
                batch_data[self.model.y_mask] = mask_matrix

                batch_t1 = time.time()
                _, train_loss, train_ans, learning_rate = sess.run(
                    [
                        train_proc, self.model.train_loss,
                        self.model.answer_word_train, learning_rates
                    ],
                    feed_dict=batch_data)
                # train_ans, learning_rate = sess.run([self.model.answer_word_train, learning_rates], feed_dict=batch_data)
                batch_t2 = time.time()
                batch_time = batch_t2 - batch_t1
                all_batch_time += batch_time

                # get word and calculate WUPS
                train_ans = np.transpose(np.array(train_ans), (1, 0))
                reward = np.ones(len(ans_vecs), dtype=float)
                for i in range(len(type_vec)):
                    type_count[type_vec[i]] += 1
                    ground_a = list()
                    for l in range(self.params['max_n_a_words']):
                        word = ans_word[i][l]
                        if self.index2word[word] == 'EOS':
                            break
                        ground_a.append(self.index2word[word])

                    generate_a = list()
                    for l in range(self.params['max_n_a_words']):
                        word = train_ans[i][l]
                        if self.index2word[word] == 'EOS':
                            break
                        generate_a.append(self.index2word[word])

                    question = list()
                    for l in range(self.params['max_n_q_words']):
                        word = ques_word[i][l]
                        if self.index2word[word] == '<PAD>':
                            break
                        question.append(self.index2word[word])

                    wups_value = wups.compute_wups(ground_a, generate_a, 0.0)
                    wups_value2 = wups.compute_wups(ground_a, generate_a, 0.9)
                    bleu1_value = wups.compute_wups(ground_a, generate_a, -1)
                    wups_count[type_vec[i]] += wups_value
                    wups_count2[type_vec[i]] += wups_value2
                    bleu1_count[type_vec[i]] += bleu1_value

                    reward[i] = bleu1_value

                # batch_data[self.model.reward] = reward
                # _, train_loss = sess.run([train_proc, self.model.train_loss], feed_dict=batch_data)

                # display batch info
                i_batch += 1
                loss_sum += train_loss
                if i_batch % self.params['display_batch_interval'] == 0:
                    print(
                        'Epoch %d, Batch %d, loss = %.4f, %.3f seconds/batch' %
                        (i_epoch, i_batch, train_loss,
                         all_batch_time / i_batch))
                    # print('question:    ', question)
                    # print('ground_a:    ', ground_a)
                    # print('generated:    ', generate_a)
                    # print('wups_value:  ', wups_value)
                    # print('wups_value2: ', wups_value2)
                    # print('Bleu1 value: ', bleu1_value)

            print('****************************')
            wup_acc = wups_count.sum() / type_count.sum()
            wup_acc2 = wups_count2.sum() / type_count.sum()
            bleu1_acc = bleu1_count.sum() / type_count.sum()
            print('Overall Wup (@0):', wup_acc, '[', wups_count.sum(), '/',
                  type_count.sum(), ']')
            print('Overall Wup (@0.9):', wup_acc2, '[', wups_count2.sum(), '/',
                  type_count.sum(), ']')
            print('Overall Bleu1:', bleu1_acc, '[', bleu1_count.sum(), '/',
                  type_count.sum(), ']')
            type_wup_acc = [
                wups_count[i] / type_count[i]
                for i in range(self.params['n_types'])
            ]
            type_wup_acc2 = [
                wups_count2[i] / type_count[i]
                for i in range(self.params['n_types'])
            ]
            type_bleu1_acc = [
                bleu1_count[i] / type_count[i]
                for i in range(self.params['n_types'])
            ]
            print('Wup@0 for each type:', type_wup_acc)
            print('[email protected] for each type:', type_wup_acc2)
            print('Bleu1 for each type:', type_bleu1_acc)
            print(type_count)

            # print info

            avg_batch_loss = loss_sum / i_batch
            all_epoch_time += all_batch_time
            epoch_time_list.append(all_epoch_time)
            print('Epoch %d ends. Average loss %.3f. %.3f seconds/epoch' %
                  (i_epoch, avg_batch_loss, all_batch_time))
            print('learning_rate: ', learning_rate)

            if i_epoch % self.params['evaluate_interval'] == 0:
                print('****************************')
                print('Overall evaluation')
                print('****************************')
                _, valid_acc, _ = self._test(sess)
                print('****************************')
            else:
                print('****************************')
                print('Valid evaluation')
                print('****************************')
                valid_acc = self._evaluate(sess, self.model,
                                           self.valid_batcher)
                print('****************************')

            # save model and early stop
            if valid_acc > best_epoch_acc:
                best_epoch_acc = valid_acc
                best_epoch_id = i_epoch
                print('Saving new best model...')
                timestamp = time.strftime("%m%d%H%M%S", time.localtime())
                self.last_checkpoint = self.model_saver.save(
                    sess, self.model_path + timestamp, global_step=global_step)
                print('Saved at', self.last_checkpoint)
            else:
                if i_epoch - best_epoch_id >= self.params['early_stopping']:
                    print('Early stopped. Best loss %.3f at epoch %d' %
                          (best_epoch_acc, best_epoch_id))
                    break
    def _evaluate(self, sess, model, batcher):
        # evaluate the model in a set
        batcher.reset()
        type_count = np.zeros(self.params['n_types'], dtype=float)
        bleu1_count = np.zeros(self.params['n_types'], dtype=float)
        wups_count = np.zeros(self.params['n_types'], dtype=float)
        wups_count2 = np.zeros(self.params['n_types'], dtype=float)
        i_batch = 0
        all_batch_time = 0

        for img_frame_vecs, img_frame_n, ques_vecs, ques_n, ques_word, ans_vecs, ans_n, ans_word, type_vec, batch_size in batcher.generate():
            if ans_vecs is None:
                break


            ques_vecs = np.tile(ques_vecs,[self.beam_width, 1, 1])
            ans_word = np.tile(ans_word,[self.beam_width, 1])
            img_frame_vecs = np.tile(img_frame_vecs,[self.beam_width, 1, 1])
            img_frame_n = np.tile(img_frame_n,[self.beam_width])
            ques_n = np.tile(ques_n,[self.beam_width])
            batch_size_beam_search = batch_size*self.beam_width


            batch_data = {
                model.ques_vecs: ques_vecs,
                model.target: ans_word,
                model.frame_vecs: img_frame_vecs,
                model.frame_len: img_frame_n,
                model.ques_len: ques_n,
                model.is_training: False,
                model.batch_size: batch_size_beam_search
            }

            mask_matrix = np.zeros([np.shape(ans_n)[0], self.params['max_n_a_words']], np.int32)
            for ind, row in enumerate(mask_matrix):
                row[:ans_n[ind]] = 1
            batch_data[model.target_mask] = mask_matrix


            sequences = list()
            for batch_step in range(batch_size):
                global_seqs_and_scores = list()
                for j in range(self.beam_width):
                    global_seqs_and_scores.append([list(),1.0])
                sequences.append(global_seqs_and_scores)
            cur_ans_vecs = np.zeros(shape=[batch_size_beam_search, self.params['max_n_a_words'], self.params['input_ques_dim']])

            for i in range(self.params['max_n_a_words']):
                batch_data[model.answer_vecs] = cur_ans_vecs

                batch_t1 = time.time()
                test_ans = sess.run(self.model.answer_word_test, feed_dict=batch_data)
                batch_t2 = time.time()
                batch_time = batch_t2 - batch_t1
                all_batch_time += batch_time

                ans_prods, ans_idxs = test_ans[0], test_ans[1]
                for batch_step in range(batch_size):
                    all_condidates = list()
                    global_seqs_and_scores = sequences[batch_step]
                    for j in range(self.beam_width):
                        for k in range(self.beam_width):
                            cur_score = ans_prods[batch_step*self.beam_width+j][i][k]
                            cur_idx = ans_idxs[batch_step*self.beam_width+j][i][k]
                            condidate = [global_seqs_and_scores[j][0] + [cur_idx], global_seqs_and_scores[j][1]* (-np.log(cur_score))]
                            all_condidates.append(condidate)
                    ordered = sorted(all_condidates, key=lambda x: x[1])
                    sequences[batch_step] = ordered[:self.beam_width]

                for batch_step in range(batch_size):
                    global_seqs_and_scores = sequences[batch_step]
                    for j in range(len(global_seqs_and_scores)):
                        global_seq = global_seqs_and_scores[j][0]
                        cur_ans_vecs[batch_step * self.beam_width + j, :i+1, :] = self.train_batcher.all_word_vec[global_seq]

                print(sequences[0])





            ans_prods, ans_idxs = test_ans[0], test_ans[1]
            # print(ans_prods.shape, ans_idxs.shape)
            # test_ans = np.transpose(np.array(test_ans), (1, 0))

            for i in range(len(type_vec)):
                type_count[type_vec[i]] += 1
                ground_a = list()
                for l in range(self.params['max_n_a_words']):
                    word = ans_word[i][l]
                    if self.index2word[word] == 'EOS':
                        break
                    ground_a.append(self.index2word[word])


                generate_a = list()
                for l in range(self.params['max_n_a_words']):
                    word = ans_idxs[i*self.beam_width][l][0]
                    if self.index2word[word] == 'EOS':
                        break
                    generate_a.append(self.index2word[word])


                question = list()
                for l in range(self.params['max_n_q_words']):
                    word = ques_word[i][l]
                    if self.index2word[word] == '<PAD>':
                        break
                    question.append(self.index2word[word])


                wups_value = wups.compute_wups(ground_a, generate_a, 0.0)
                wups_value2 = wups.compute_wups(ground_a, generate_a, 0.9)
                bleu1_value = wups.compute_wups(ground_a, generate_a, -1)
                wups_count[type_vec[i]] += wups_value
                wups_count2[type_vec[i]] += wups_value2
                bleu1_count[type_vec[i]] += bleu1_value

            i_batch += 1
            if i_batch % 100 == 0:
                print('Batch:', i_batch)
                print('question:    ', question)
                print('ground_a:    ', ground_a)
                print('generated:    ', generate_a)

        wup_acc = wups_count.sum() / type_count.sum()
        wup_acc2 = wups_count2.sum() / type_count.sum()
        bleu1_acc = bleu1_count.sum() / type_count.sum()
        print('Overall Wup (@0):', wup_acc, '[', wups_count.sum(), '/', type_count.sum(), ']')
        print('Overall Wup (@0.9):', wup_acc2, '[', wups_count2.sum(), '/', type_count.sum(), ']')
        print('Overall Bleu1:', bleu1_acc, '[', bleu1_count.sum(), '/', type_count.sum(), ']')
        type_wup_acc = [wups_count[i] / type_count[i] for i in range(self.params['n_types'])]
        type_wup_acc2 = [wups_count2[i] / type_count[i] for i in range(self.params['n_types'])]
        type_bleu1_acc = [bleu1_count[i] / type_count[i] for i in range(self.params['n_types'])]
        print('Wup@0 for each type:', type_wup_acc)
        print('[email protected] for each type:', type_wup_acc2)
        print('Bleu1 for each type:', type_bleu1_acc)
        print('type count:        ', type_count)
        print('all test time:  ',all_batch_time)

        return bleu1_acc