Beispiel #1
0
    def __init__(self, train_config):

        self.learing_rate = float(train_config['learning_rate'])
        self.train_iterior = int(train_config['train_iteration'])
        self._train_logger_init()

        net_config = parser_cfg_file('./config/rnn_net.cfg')
        self.rnn_net = AFD_RNN(net_config)
        self.predict = self.rnn_net.build_net_graph()
        self.label = tf.placeholder(
            tf.float32, [None, self.rnn_net.time_step, self.rnn_net.class_num])
    def __init__(self, mode_dir, time_step=5, batch_size=1):
        self.batch_size = batch_size
        self.time_step = time_step

        ckpt = tf.train.get_checkpoint_state(mode_dir)
        if ckpt is None:
            raise FileExistsError(str(mode_dir, '没有模型可以加载'))

        net_config = parser_cfg_file('./config/rnn_net.cfg')
        self.rnn_net = AFD_RNN(net_config, batch_size, time_step)
        predict = self.rnn_net.build_net_graph()
        self._predict_tensor = tf.argmax(predict, axis=2)
        saver = tf.train.Saver()
        self._sess = tf.Session()
        # 加载参数
        saver.restore(self._sess, ckpt.model_checkpoint_path)
Beispiel #3
0
class AFD_RNN_Train(object):
    def __init__(self, train_config):

        self.learing_rate = float(train_config['learning_rate'])
        self.train_iterior = int(train_config['train_iteration'])
        self._train_logger_init()

        net_config = parser_cfg_file('./config/rnn_net.cfg')
        self.rnn_net = AFD_RNN(net_config)
        self.predict = self.rnn_net.build_net_graph()
        self.label = tf.placeholder(
            tf.float32, [None, self.rnn_net.time_step, self.rnn_net.class_num])

    def _compute_loss(self):
        with tf.name_scope('loss'):
            # [batchszie, time_step, class_num] ==> [time_step][batchsize, class_num]
            predict = tf.unstack(self.predict, axis=0)
            label = tf.unstack(self.label, axis=1)

            loss = [
                tf.nn.softmax_cross_entropy_with_logits(labels=label[i],
                                                        logits=predict[i])
                for i in range(self.rnn_net.time_step)
            ]
            loss = tf.reduce_mean(loss)
            train_op = tf.train.AdamOptimizer(self.learing_rate).minimize(loss)
        return loss, train_op

    def train_rnn(self):

        loss, train_op = self._compute_loss()

        with tf.name_scope('accuracy'):
            predict = tf.transpose(self.predict, [1, 0, 2])
            correct_pred = tf.equal(tf.argmax(self.label, 2),
                                    tf.argmax(predict, axis=2))
            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        dataset = DataLoad('./dataset/train/',
                           time_step=self.rnn_net.time_step,
                           class_num=self.rnn_net.class_num)
        saver = tf.train.Saver()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            for step in range(1, self.train_iterior + 1):
                x, y = dataset.get_batch(self.rnn_net.batch_size)
                if step == 1:
                    feed_dict = {self.rnn_net.input_tensor: x, self.label: y}
                else:
                    feed_dict = {
                        self.rnn_net.input_tensor: x,
                        self.label: y,
                        self.rnn_net.cell_state: state
                    }
                _, compute_loss, state = sess.run(
                    [train_op, loss, self.rnn_net.cell_state],
                    feed_dict=feed_dict)

                if step % 10 == 0:
                    compute_accuracy = sess.run(accuracy, feed_dict=feed_dict)
                    self.train_logger.info(
                        'train step = %d,loss = %f,accuracy = %f' %
                        (step, compute_loss, compute_accuracy))
                if step % 1000 == 0:
                    save_path = saver.save(sess, './model/model.ckpt')
                    self.train_logger.info(
                        "train step = %d ,model save to =%s" %
                        (step, save_path))

    def _train_logger_init(self):
        """
        初始化log日志
        :return:
        """
        self.train_logger = logging.getLogger('train')
        self.train_logger.setLevel(logging.DEBUG)

        # 添加文件输出
        log_file = './train_logs/' + time.strftime(
            '%Y%m%d%H%M', time.localtime(time.time())) + '.logs'
        file_handler = logging.FileHandler(log_file, mode='w')
        file_handler.setLevel(logging.DEBUG)
        file_formatter = logging.Formatter(
            '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
        )
        file_handler.setFormatter(file_formatter)
        self.train_logger.addHandler(file_handler)

        # 添加控制台输出
        consol_handler = logging.StreamHandler()
        consol_handler.setLevel(logging.DEBUG)
        consol_formatter = logging.Formatter('%(message)s')
        consol_handler.setFormatter(consol_formatter)
        self.train_logger.addHandler(consol_handler)
Beispiel #4
0
class Run_AFD_RNN(object):
    def __init__(self, mode_dir, time_step=5, batch_size=1):
        self.batch_size = batch_size
        self.time_step = time_step

        ckpt = tf.train.get_checkpoint_state(mode_dir)
        if ckpt is None:
            raise FileExistsError(str(mode_dir, '没有模型可以加载'))

        net_config = parser_cfg_file('./config/rnn_net.cfg')

        self.rnn_net = AFD_RNN(net_config, batch_size, time_step)
        predict = self.rnn_net.build_net_graph()
        self._predict_tensor = tf.argmax(predict, axis=2)
        saver = tf.train.Saver()
        self._sess = tf.Session()
        # 加载参数
        saver.restore(self._sess, ckpt.model_checkpoint_path)

    def run(self, data):
        data = np.reshape(
            data,
            [self.batch_size, self.time_step, self.rnn_net.senor_data_num])
        predict = self._sess.run(self._predict_tensor,
                                 feed_dict={self.rnn_net.input_tensor: data})
        return predict

    def run_stop(self):
        self._sess.close()

    def _update_show_data(self, data, step, update_data):
        for i in range(step):
            data.pop(0)
            data.append(update_data[i])

    def draw_flow(self, test_data, test_label):
        data_size = test_data.shape[0]

        x = [_ for _ in range(150)]
        ax = [0 for _ in range(150)]
        ay = [0 for _ in range(150)]
        az = [0 for _ in range(150)]

        sum = 0
        run_step = 10
        num = int(data_size / run_step)
        print(num)

        start_time = time.time()

        #plt.axis([0, 151, -20, 20])
        #plt.ion()
        for i in range(num):

            if i > int(time_step / run_step):
                predict = run.run(test_data[i * run_step - time_step:i *
                                            run_step, :])
                title = 'correct:' + Label[test_label[
                    i * run_step]] + '     predict:' + Label[predict[int(
                        time_step - 1)][0]]
                print(Label[test_label[i * run_step]])
                print(Label[predict[int(time_step - 1)][0]])
                if Label[test_label[i *
                                    run_step]] is Label[predict[int(time_step -
                                                                    1)][0]]:
                    sum = sum + 1
            else:
                title = 'correct:' + Label[test_label[
                    i * run_step]] + '     predict:' + 'unknow'

            self._update_show_data(
                ax, run_step, test_data[i * run_step:i * run_step + run_step,
                                        0])
            self._update_show_data(
                ay, run_step, test_data[i * run_step:i * run_step + run_step,
                                        1])
            self._update_show_data(
                az, run_step, test_data[i * run_step:i * run_step + run_step,
                                        2])

            # plt.cla()
            # plt.plot(x, ax)
            # plt.plot(x, ay)
            # plt.plot(x, az)
            #
            # plt.title(title)
            # plt.draw()
            # plt.pause(0.001)

        during = str(time.time() - start_time)
        print(sum / num)
        print('检测耗时=', during)