Example #1
0
    def compile_graph(self, acc):
        """
        编译当前准确率下对应的计算图为pb模型,准确率仅作为模型命名的一部分
        :param acc: 准确率
        :return:
        """
        input_graph = tf.compat.v1.Graph()
        tf.compat.v1.keras.backend.clear_session()
        tf.compat.v1.reset_default_graph()
        predict_sess = tf.compat.v1.Session(graph=input_graph)
        tf.compat.v1.keras.backend.set_session(predict_sess)

        with predict_sess.graph.as_default():
            model = core.NeuralNetwork(model_conf=self.model_conf,
                                       mode=RunMode.Predict,
                                       backbone=self.model_conf.neu_cnn,
                                       recurrent=self.model_conf.neu_recurrent)
            model.build_graph()
            model.build_train_op()
            input_graph_def = predict_sess.graph.as_graph_def()
            saver = tf.compat.v1.train.Saver(
                var_list=tf.compat.v1.global_variables())
            tf.compat.v1.logging.info(
                tf.train.latest_checkpoint(self.model_conf.model_root_path))
            saver.restore(
                predict_sess,
                tf.train.latest_checkpoint(self.model_conf.model_root_path))

            output_graph_def = convert_variables_to_constants(
                predict_sess,
                input_graph_def,
                output_node_names=['dense_decoded'])

        if not os.path.exists(self.model_conf.compile_model_path):
            os.makedirs(self.model_conf.compile_model_path)

        last_compile_model_path = (os.path.join(
            self.model_conf.compile_model_path,
            "{}.pb".format(self.model_conf.model_name))).replace(
                '.pb', '_{}.pb'.format(int(acc * 10000)))

        self.model_conf.output_config(target_model_name="{}_{}".format(
            self.model_conf.model_name, int(acc * 10000)))
        with tf.io.gfile.GFile(last_compile_model_path, mode='wb') as gf:
            gf.write(output_graph_def.SerializeToString())
Example #2
0
    def train_process(self):
        """
        训练任务
        :return:
        """
        # 输出重要的配置参数
        self.model_conf.println()
        # 定义网络结构
        model = core.NeuralNetwork(model_conf=self.model_conf,
                                   mode=RunMode.Trains,
                                   cnn=self.model_conf.neu_cnn,
                                   recurrent=self.model_conf.neu_recurrent)
        model.build_graph()

        tf.compat.v1.logging.info('Loading Trains DataSet...')
        train_feeder = utils.data.DataIterator(model_conf=self.model_conf,
                                               mode=RunMode.Trains)
        train_feeder.read_sample_from_tfrecords(
            self.model_conf.trains_path[DatasetType.TFRecords])

        tf.compat.v1.logging.info('Loading Test DataSet...')
        validation_feeder = utils.data.DataIterator(model_conf=self.model_conf,
                                                    mode=RunMode.Validation)
        validation_feeder.read_sample_from_tfrecords(
            self.model_conf.validation_path[DatasetType.TFRecords])

        tf.logging.info('Total {} Trains DataSets'.format(train_feeder.size))
        tf.logging.info('Total {} Test DataSets'.format(
            validation_feeder.size))
        if validation_feeder.size >= train_feeder.size:
            exception(
                "The number of training sets cannot be less than the test set.",
            )

        num_train_samples = train_feeder.size
        num_test_samples = validation_feeder.size
        if num_test_samples < self.model_conf.validation_batch_size:
            exception(
                "The number of test sets cannot be less than the test batch size.",
                ConfigException.INSUFFICIENT_SAMPLE)
        num_batches_per_epoch = int(num_train_samples /
                                    self.model_conf.batch_size)
        # 会话配置
        sess_config = tf.compat.v1.ConfigProto(
            # allow_soft_placement=True,
            log_device_placement=False,
            gpu_options=tf.compat.v1.GPUOptions(
                allocator_type='BFC',
                allow_growth=True,  # it will cause fragmentation.
                per_process_gpu_memory_fraction=self.model_conf.memory_usage))
        accuracy = 0
        epoch_count = 1
        with tf.compat.v1.Session(config=sess_config) as sess:
            tf.keras.backend.set_session(session=sess)
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            saver = tf.train.Saver(var_list=tf.global_variables(),
                                   max_to_keep=2)
            train_writer = tf.compat.v1.summary.FileWriter('logs', sess.graph)
            # try:
            checkpoint_state = tf.train.get_checkpoint_state(
                self.model_conf.model_root_path)
            if checkpoint_state and checkpoint_state.model_checkpoint_path:
                # 加载被中断的训练任务
                saver.restore(sess, checkpoint_state.model_checkpoint_path)

            tf.logging.info('Start training...')

            # 进入训练任务循环
            while 1:

                start_time = time.time()

                # 批次循环
                for cur_batch in range(num_batches_per_epoch):

                    if self.stop_flag:
                        break

                    batch_time = time.time()

                    trains_batch = train_feeder.generate_batch_by_tfrecords(
                        sess)

                    batch_inputs, batch_labels = trains_batch

                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: batch_labels,
                    }

                    summary_str, batch_cost, step, _, seq_len = sess.run(
                        [
                            model.merged_summary, model.cost,
                            model.global_step, model.train_op, model.seq_len
                        ],
                        feed_dict=feed)
                    train_writer.add_summary(summary_str, step)

                    if step % 100 == 0 and step != 0:
                        tf.logging.info(
                            'Step: {} Time: {:.3f} sec/batch, Cost = {:.8f}, BatchSize: {}, Shape[1]: {}'
                            .format(step,
                                    time.time() - batch_time, batch_cost,
                                    len(batch_inputs), seq_len[0]))

                    # 达到保存步数对模型过程进行存储
                    if step % self.model_conf.trains_save_steps == 0 and step != 0:
                        saver.save(sess,
                                   self.model_conf.save_model,
                                   global_step=step)

                    # 进入验证集验证环节
                    if step % self.model_conf.trains_validation_steps == 0 and step != 0:

                        batch_time = time.time()
                        validation_batch = validation_feeder.generate_batch_by_tfrecords(
                            sess)

                        test_inputs, test_labels = validation_batch
                        val_feed = {
                            model.inputs: test_inputs,
                            model.labels: test_labels
                        }
                        dense_decoded, lr = sess.run(
                            [model.dense_decoded, model.lrn_rate],
                            feed_dict=val_feed)
                        # 计算准确率
                        accuracy = self.validation.accuracy_calculation(
                            validation_feeder.labels,
                            dense_decoded,
                        )
                        log = "Epoch: {}, Step: {}, Accuracy = {:.4f}, Cost = {:.5f}, " \
                              "Time = {:.3f} sec/batch, LearningRate: {}"
                        tf.logging.info(
                            log.format(
                                epoch_count,
                                step,
                                accuracy,
                                batch_cost,
                                time.time() - batch_time,
                                lr / len(validation_batch),
                            ))
                        # epoch_cost = batch_cost
                        achieve_accuracy = accuracy >= self.model_conf.trains_end_acc
                        achieve_epochs = epoch_count >= self.model_conf.trains_end_epochs
                        achieve_cost = batch_cost <= self.model_conf.trains_end_cost
                        over_epochs = epoch_count > 10000

                        # 满足终止条件但尚未完成当前epoch时跳出epoch循环
                        if (achieve_accuracy and achieve_epochs
                                and achieve_cost) or over_epochs:
                            break

                achieve_accuracy = accuracy >= self.model_conf.trains_end_acc
                achieve_epochs = epoch_count >= self.model_conf.trains_end_epochs
                achieve_cost = batch_cost <= self.model_conf.trains_end_cost
                over_epochs = epoch_count > 10000

                # 满足终止条件时,跳出任务循环
                if self.stop_flag:
                    break
                if (achieve_accuracy and achieve_epochs
                        and achieve_cost) or over_epochs:
                    self.compile_graph(accuracy)
                    tf.logging.info('Total Time: {} sec.'.format(time.time() -
                                                                 start_time))
                    break
                epoch_count += 1
Example #3
0
    def train_process(self):
        """
        训练任务
        :return:
        """
        # 输出重要的配置参数
        self.model_conf.println()
        # 定义网络结构
        model = core.NeuralNetwork(mode=RunMode.Trains,
                                   model_conf=self.model_conf,
                                   backbone=self.model_conf.neu_cnn,
                                   recurrent=self.model_conf.neu_recurrent)
        model.build_graph()

        ran_captcha = RandomCaptcha()

        if self.model_conf.da_random_captcha['Enable']:
            self.init_captcha_gennerator(ran_captcha=ran_captcha)

        tf.compat.v1.logging.info('Loading Trains DataSet...')
        train_feeder = utils.data.DataIterator(model_conf=self.model_conf,
                                               mode=RunMode.Trains,
                                               ran_captcha=ran_captcha)
        train_feeder.read_sample_from_tfrecords(
            self.model_conf.trains_path[DatasetType.TFRecords])

        tf.compat.v1.logging.info('Loading Validation DataSet...')
        validation_feeder = utils.data.DataIterator(model_conf=self.model_conf,
                                                    mode=RunMode.Validation,
                                                    ran_captcha=ran_captcha)
        validation_feeder.read_sample_from_tfrecords(
            self.model_conf.validation_path[DatasetType.TFRecords])

        tf.compat.v1.logging.info('Total {} Trains DataSets'.format(
            train_feeder.size))
        tf.compat.v1.logging.info('Total {} Validation DataSets'.format(
            validation_feeder.size))
        if validation_feeder.size >= train_feeder.size:
            exception(
                "The number of training sets cannot be less than the validation set.",
            )
        if validation_feeder.size < self.model_conf.validation_batch_size:
            exception(
                "The number of validation sets cannot be less than the validation batch size.",
            )

        num_train_samples = train_feeder.size
        num_validation_samples = validation_feeder.size

        if num_validation_samples < self.model_conf.validation_batch_size:
            self.model_conf.validation_batch_size = num_validation_samples
            tf.compat.v1.logging.warn(
                'The number of validation sets is less than the validation batch size, '
                'will use validation set size as validation batch size.'.
                format(validation_feeder.size))

        num_batches_per_epoch = int(num_train_samples /
                                    self.model_conf.batch_size)

        model.build_train_op(num_train_samples)

        # 会话配置
        # sess_config = tf.compat.v1.ConfigProto(
        #     # allow_soft_placement=True,
        #     # log_device_placement=False,
        #     gpu_options=tf.compat.v1.GPUOptions(
        #         # allocator_type='BFC',
        #         # allow_growth=True,  # it will cause fragmentation.
        #         per_process_gpu_memory_fraction=0.3
        #     )
        # )
        accuracy = 0
        epoch_count = 1

        if num_train_samples < 500:
            save_step = 10
            trains_validation_steps = 50

        else:
            save_step = 100
            trains_validation_steps = self.model_conf.trains_validation_steps

        sess = tf.compat.v1.Session()

        init_op = tf.compat.v1.global_variables_initializer()
        sess.run(init_op)
        saver = tf.compat.v1.train.Saver(
            var_list=tf.compat.v1.global_variables(), max_to_keep=2)
        train_writer = tf.compat.v1.summary.FileWriter('logs', sess.graph)
        # try:
        checkpoint_state = tf.train.get_checkpoint_state(
            self.model_conf.model_root_path)
        if checkpoint_state and checkpoint_state.model_checkpoint_path:
            # 加载被中断的训练任务
            saver.restore(sess, checkpoint_state.model_checkpoint_path)

        tf.compat.v1.logging.info('Start training...')

        # 进入训练任务循环
        while 1:

            start_time = time.time()
            batch_cost = 65535
            # 批次循环
            for cur_batch in range(num_batches_per_epoch):

                if self.stop_flag:
                    break

                batch_time = time.time()

                trains_batch = train_feeder.generate_batch_by_tfrecords(sess)

                batch_inputs, batch_labels = trains_batch

                feed = {
                    model.inputs: batch_inputs,
                    model.labels: batch_labels,
                    model.utils.is_training: True
                }

                summary_str, batch_cost, step, _, seq_len = sess.run(
                    [
                        model.merged_summary, model.cost, model.global_step,
                        model.train_op, model.seq_len
                    ],
                    feed_dict=feed)
                train_writer.add_summary(summary_str, step)

                if step % save_step == 0 and step != 0:
                    tf.compat.v1.logging.info(
                        'Step: {} Time: {:.3f} sec/batch, Cost = {:.8f}, BatchSize: {}, Shape[1]: {}'
                        .format(step,
                                time.time() - batch_time, batch_cost,
                                len(batch_inputs), seq_len[0]))

                # 达到保存步数对模型过程进行存储
                if step % save_step == 0 and step != 0:
                    saver.save(sess,
                               self.model_conf.save_model,
                               global_step=step)

                # 进入验证集验证环节
                if step % trains_validation_steps == 0 and step != 0:

                    batch_time = time.time()
                    validation_batch = validation_feeder.generate_batch_by_tfrecords(
                        sess)

                    test_inputs, test_labels = validation_batch
                    val_feed = {
                        model.inputs: test_inputs,
                        model.labels: test_labels,
                        model.utils.is_training: False
                    }
                    dense_decoded, lr = sess.run(
                        [model.dense_decoded, model.lrn_rate],
                        feed_dict=val_feed)
                    # 计算准确率
                    accuracy = self.validation.accuracy_calculation(
                        validation_feeder.labels,
                        dense_decoded,
                    )
                    log = "Epoch: {}, Step: {}, Accuracy = {:.4f}, Cost = {:.5f}, " \
                          "Time = {:.3f} sec/batch, LearningRate: {}"
                    tf.compat.v1.logging.info(
                        log.format(
                            epoch_count,
                            step,
                            accuracy,
                            batch_cost,
                            time.time() - batch_time,
                            lr / len(validation_batch),
                        ))

                    # 满足终止条件但尚未完成当前epoch时跳出epoch循环
                    if self.achieve_cond(acc=accuracy,
                                         cost=batch_cost,
                                         epoch=epoch_count):
                        break

            # 满足终止条件时,跳出任务循环
            if self.stop_flag:
                break
            if self.achieve_cond(acc=accuracy,
                                 cost=batch_cost,
                                 epoch=epoch_count):
                # sess.close()
                tf.compat.v1.keras.backend.clear_session()
                sess.close()
                self.compile_graph(accuracy)
                tf.compat.v1.logging.info(
                    'Total Time: {} sec.'.format(time.time() - start_time))

                break
            epoch_count += 1
        tf.compat.v1.logging.info('Total Time: {} sec.'.format(time.time() -
                                                               start_time))