Example #1
0
    def train(self):
        # training
        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        sess = tf.Session(config=sess_config)
        # module initialization
        self.model.build_model()
        self.model_path = os.path.join(self.params['cache_dir'])
        if not os.path.exists(self.model_path):
            print('create path: ', self.model_path)
            os.makedirs(self.model_path)

        self.last_checkpoint = None

        # main procedure including training & testing
        self.train_batcher = Batcher(self.params, self.params['train_json'],
                                     'train', self.params['epoch_reshuffle'])
        self.valid_batcher = Batcher(self.params, self.params['val_json'],
                                     'val')
        self.test_batcher = Batcher(self.params, self.params['test_json'],
                                    'test')
        self.model_saver = tf.train.Saver()

        print('Trainnning begins......')
        self._train(sess)
        # testing
        print('Evaluating best model in file', self.last_checkpoint, '...')
        if self.last_checkpoint is not None:
            self.model_saver.restore(sess, self.last_checkpoint)
            self._test(sess)
        else:
            print('ERROR: No checkpoint available!')
        sess.close()
Example #2
0
class Trainer(object):
    def __init__(self, params, model):
        self.params = params
        self.model = model
        self.index2word = load_file(self.params['index2word'])

    def train(self):
        # training
        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        sess = tf.Session(config=sess_config)
        # module initialization
        self.model.build_model()
        self.model_path = os.path.join(self.params['cache_dir'])
        if not os.path.exists(self.model_path):
            print('create path: ', self.model_path)
            os.makedirs(self.model_path)

        self.last_checkpoint = None

        # main procedure including training & testing
        self.train_batcher = Batcher(self.params, self.params['train_json'],
                                     'train', self.params['epoch_reshuffle'])
        self.valid_batcher = Batcher(self.params, self.params['val_json'],
                                     'val')
        self.test_batcher = Batcher(self.params, self.params['test_json'],
                                    'test')
        self.model_saver = tf.train.Saver()

        print('Trainnning begins......')
        self._train(sess)
        # testing
        print('Evaluating best model in file', self.last_checkpoint, '...')
        if self.last_checkpoint is not None:
            self.model_saver.restore(sess, self.last_checkpoint)
            self._test(sess)
        else:
            print('ERROR: No checkpoint available!')
        sess.close()

    def _train(self, sess):
        # tensorflow initialization
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        learning_rates = tf.train.exponential_decay(
            self.params['learning_rate'],
            global_step,
            decay_steps=self.params['lr_decay_n_iters'],
            decay_rate=self.params['lr_decay_rate'],
            staircase=True)
        optimizer = tf.train.AdamOptimizer(learning_rates)
        train_proc = optimizer.minimize(self.model.train_loss,
                                        global_step=global_step)
        # train_proc_rl = optimizer.minimize(self.model.loss_rl, global_step=global_step)

        # training
        init_proc = tf.global_variables_initializer()
        sess.run(init_proc)
        # self.model_saver.restore(sess, '../results/0817221611-2670')
        best_epoch_acc = 0
        best_epoch_id = 0

        print('****************************')
        print('Trainning datetime:',
              time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
        print('Trainning params')
        print(self.params)
        utils.count_total_variables()
        print('****************************')

        all_epoch_time = 0
        epoch_time_list = list()
        for i_epoch in range(self.params['max_epoches']):
            # train an epoch

            all_batch_time = 0
            self.train_batcher.reset()
            i_batch = 0
            loss_sum = 0

            type_count = np.zeros(self.params['n_types'], dtype=float)
            wups_count = np.zeros(self.params['n_types'], dtype=float)
            wups_count2 = np.zeros(self.params['n_types'], dtype=float)
            bleu1_count = np.zeros(self.params['n_types'], dtype=float)

            for img_frame_vecs, img_frame_n, ques_vecs, ques_n, ques_word, ans_vecs, ans_n, ans_word, type_vec, batch_size in self.train_batcher.generate(
            ):
                if ans_vecs is None:
                    break

                batch_data = dict()
                batch_data[self.model.y] = ans_word
                batch_data[self.model.input_x] = img_frame_vecs
                batch_data[self.model.input_x_len] = img_frame_n
                batch_data[self.model.input_q] = ques_vecs
                batch_data[self.model.input_q_len] = ques_n
                batch_data[self.model.ans_vec] = ans_vecs
                batch_data[self.model.is_training] = True
                batch_data[self.model.batch_size] = batch_size

                mask_matrix = np.zeros(
                    [np.shape(ans_n)[0], self.params['max_n_a_words']],
                    np.int32)
                for ind, row in enumerate(mask_matrix):
                    row[:ans_n[ind]] = 1
                batch_data[self.model.y_mask] = mask_matrix

                batch_t1 = time.time()
                _, train_loss, train_ans, learning_rate = sess.run(
                    [
                        train_proc, self.model.train_loss,
                        self.model.answer_word_train, learning_rates
                    ],
                    feed_dict=batch_data)
                # train_ans, learning_rate = sess.run([self.model.answer_word_train, learning_rates], feed_dict=batch_data)
                batch_t2 = time.time()
                batch_time = batch_t2 - batch_t1
                all_batch_time += batch_time

                # get word and calculate WUPS
                train_ans = np.transpose(np.array(train_ans), (1, 0))
                reward = np.ones(len(ans_vecs), dtype=float)
                for i in range(len(type_vec)):
                    type_count[type_vec[i]] += 1
                    ground_a = list()
                    for l in range(self.params['max_n_a_words']):
                        word = ans_word[i][l]
                        if self.index2word[word] == 'EOS':
                            break
                        ground_a.append(self.index2word[word])

                    generate_a = list()
                    for l in range(self.params['max_n_a_words']):
                        word = train_ans[i][l]
                        if self.index2word[word] == 'EOS':
                            break
                        generate_a.append(self.index2word[word])

                    question = list()
                    for l in range(self.params['max_n_q_words']):
                        word = ques_word[i][l]
                        if self.index2word[word] == '<PAD>':
                            break
                        question.append(self.index2word[word])

                    wups_value = wups.compute_wups(ground_a, generate_a, 0.0)
                    wups_value2 = wups.compute_wups(ground_a, generate_a, 0.9)
                    bleu1_value = wups.compute_wups(ground_a, generate_a, -1)
                    wups_count[type_vec[i]] += wups_value
                    wups_count2[type_vec[i]] += wups_value2
                    bleu1_count[type_vec[i]] += bleu1_value

                    reward[i] = bleu1_value

                # batch_data[self.model.reward] = reward
                # _, train_loss = sess.run([train_proc, self.model.train_loss], feed_dict=batch_data)

                # display batch info
                i_batch += 1
                loss_sum += train_loss
                if i_batch % self.params['display_batch_interval'] == 0:
                    print(
                        'Epoch %d, Batch %d, loss = %.4f, %.3f seconds/batch' %
                        (i_epoch, i_batch, train_loss,
                         all_batch_time / i_batch))
                    # print('question:    ', question)
                    # print('ground_a:    ', ground_a)
                    # print('generated:    ', generate_a)
                    # print('wups_value:  ', wups_value)
                    # print('wups_value2: ', wups_value2)
                    # print('Bleu1 value: ', bleu1_value)

            print('****************************')
            wup_acc = wups_count.sum() / type_count.sum()
            wup_acc2 = wups_count2.sum() / type_count.sum()
            bleu1_acc = bleu1_count.sum() / type_count.sum()
            print('Overall Wup (@0):', wup_acc, '[', wups_count.sum(), '/',
                  type_count.sum(), ']')
            print('Overall Wup (@0.9):', wup_acc2, '[', wups_count2.sum(), '/',
                  type_count.sum(), ']')
            print('Overall Bleu1:', bleu1_acc, '[', bleu1_count.sum(), '/',
                  type_count.sum(), ']')
            type_wup_acc = [
                wups_count[i] / type_count[i]
                for i in range(self.params['n_types'])
            ]
            type_wup_acc2 = [
                wups_count2[i] / type_count[i]
                for i in range(self.params['n_types'])
            ]
            type_bleu1_acc = [
                bleu1_count[i] / type_count[i]
                for i in range(self.params['n_types'])
            ]
            print('Wup@0 for each type:', type_wup_acc)
            print('[email protected] for each type:', type_wup_acc2)
            print('Bleu1 for each type:', type_bleu1_acc)
            print(type_count)

            # print info

            avg_batch_loss = loss_sum / i_batch
            all_epoch_time += all_batch_time
            epoch_time_list.append(all_epoch_time)
            print('Epoch %d ends. Average loss %.3f. %.3f seconds/epoch' %
                  (i_epoch, avg_batch_loss, all_batch_time))
            print('learning_rate: ', learning_rate)

            if i_epoch % self.params['evaluate_interval'] == 0:
                print('****************************')
                print('Overall evaluation')
                print('****************************')
                _, valid_acc, _ = self._test(sess)
                print('****************************')
            else:
                print('****************************')
                print('Valid evaluation')
                print('****************************')
                valid_acc = self._evaluate(sess, self.model,
                                           self.valid_batcher)
                print('****************************')

            # save model and early stop
            if valid_acc > best_epoch_acc:
                best_epoch_acc = valid_acc
                best_epoch_id = i_epoch
                print('Saving new best model...')
                timestamp = time.strftime("%m%d%H%M%S", time.localtime())
                self.last_checkpoint = self.model_saver.save(
                    sess, self.model_path + timestamp, global_step=global_step)
                print('Saved at', self.last_checkpoint)
            else:
                if i_epoch - best_epoch_id >= self.params['early_stopping']:
                    print('Early stopped. Best loss %.3f at epoch %d' %
                          (best_epoch_acc, best_epoch_id))
                    break

    def _evaluate(self, sess, model, batcher):
        # evaluate the model in a set
        batcher.reset()
        type_count = np.zeros(self.params['n_types'], dtype=float)
        bleu1_count = np.zeros(self.params['n_types'], dtype=float)
        wups_count = np.zeros(self.params['n_types'], dtype=float)
        wups_count2 = np.zeros(self.params['n_types'], dtype=float)
        i_batch = 0
        all_batch_time = 0

        for img_frame_vecs, img_frame_n, ques_vecs, ques_n, ques_word, ans_vecs, ans_n, ans_word, type_vec, batch_size in batcher.generate(
        ):
            if ans_vecs is None:
                break

            batch_data = {
                model.input_q: ques_vecs,
                model.y: ans_word,
                model.input_x: img_frame_vecs,
                model.input_x_len: img_frame_n,
                model.input_q_len: ques_n,
                model.is_training: False,
                model.batch_size: batch_size,
                model.ans_vec: ans_vecs
            }

            mask_matrix = np.zeros(
                [np.shape(ans_n)[0], self.params['max_n_a_words']], np.int32)
            for ind, row in enumerate(mask_matrix):
                row[:ans_n[ind]] = 1
            batch_data[model.y_mask] = mask_matrix

            batch_t1 = time.time()
            test_ans = sess.run(self.model.answer_word_test,
                                feed_dict=batch_data)
            batch_t2 = time.time()
            batch_time = batch_t2 - batch_t1
            all_batch_time += batch_time
            test_ans = np.transpose(np.array(test_ans), (1, 0))

            for i in range(len(type_vec)):
                type_count[type_vec[i]] += 1
                ground_a = list()
                for l in range(self.params['max_n_a_words']):
                    word = ans_word[i][l]
                    if self.index2word[word] == 'EOS':
                        break
                    ground_a.append(self.index2word[word])

                generate_a = list()
                for l in range(self.params['max_n_a_words']):
                    word = test_ans[i][l]
                    if self.index2word[word] == 'EOS':
                        break
                    generate_a.append(self.index2word[word])

                question = list()
                for l in range(self.params['max_n_q_words']):
                    word = ques_word[i][l]
                    if self.index2word[word] == '<PAD>':
                        break
                    question.append(self.index2word[word])

                wups_value = wups.compute_wups(ground_a, generate_a, 0.0)
                wups_value2 = wups.compute_wups(ground_a, generate_a, 0.9)
                bleu1_value = wups.compute_wups(ground_a, generate_a, -1)
                # bleu1_value = bleu.calculate_bleu(' '.join(ground_a), ' '.join(generate_a))
                wups_count[type_vec[i]] += wups_value
                wups_count2[type_vec[i]] += wups_value2
                bleu1_count[type_vec[i]] += bleu1_value

            i_batch += 1
            if i_batch % 100 == 0:
                print('batch index:', i_batch)
                print('question:    ', question)
                print('ground_a:    ', ground_a)
                print('generated:    ', generate_a)

        wup_acc = wups_count.sum() / type_count.sum()
        wup_acc2 = wups_count2.sum() / type_count.sum()
        bleu1_acc = bleu1_count.sum() / type_count.sum()
        print('Overall Wup (@0):', wup_acc, '[', wups_count.sum(), '/',
              type_count.sum(), ']')
        print('Overall Wup (@0.9):', wup_acc2, '[', wups_count2.sum(), '/',
              type_count.sum(), ']')
        print('Overall Bleu1:', bleu1_acc, '[', bleu1_count.sum(), '/',
              type_count.sum(), ']')
        type_wup_acc = [
            wups_count[i] / type_count[i]
            for i in range(self.params['n_types'])
        ]
        type_wup_acc2 = [
            wups_count2[i] / type_count[i]
            for i in range(self.params['n_types'])
        ]
        type_bleu1_acc = [
            bleu1_count[i] / type_count[i]
            for i in range(self.params['n_types'])
        ]
        print('Wup@0 for each type:', type_wup_acc)
        print('[email protected] for each type:', type_wup_acc2)
        print('Bleu1 for each type:', type_bleu1_acc)
        print('type count:        ', type_count)
        print('all test time:  ', all_batch_time)
        return bleu1_acc

    def _test(self, sess):
        print('Validation set:')
        valid_acc = self._evaluate(sess, self.model, self.valid_batcher)
        print('Test set:')
        test_acc = self._evaluate(sess, self.model, self.test_batcher)
        return 0.0, valid_acc, test_acc