class Train:
    def __init__(self):
        # create network
        # 初始化,准备网络与数据载入
        if cfg.FLAGS.net == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        self.imdb, self.roidb = combined_roidb("DIY_dataset")

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')

    def train(self):

        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        tfconfig.gpu_options.allow_growth = True
        # tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.90
        sess = tf.Session(config=tfconfig)
        # 创建一个session并对其进行配置
        # allow_soft_placement允许自动分配GPU,allow_growth允许慢慢增加分配的内存

        with sess.graph.as_default():

            tf.set_random_seed(cfg.FLAGS.rng_seed)  # 固定随机数种子
            layers = self.net.create_architecture(sess,
                                                  "TRAIN",
                                                  self.imdb.num_classes,
                                                  tag='default')
            # 展开网络  create_architecture在network.py中被定义
            loss = layers['total_loss']

            lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)
            # 通过tf.Variable()申请一个常驻内存的量,作为learning_rate
            momentum = cfg.FLAGS.momentum
            optimizer = tf.train.MomentumOptimizer(lr, momentum)
            # 使用动量梯度下降优化器

            gvs = optimizer.compute_gradients(loss)  # 对loss进行优化

            # double bias
            # 通过cfg.FLAGS.double_bias进行控制
            # 加倍gradient
            if cfg.FLAGS.double_bias:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult'):
                    for grad, var in gvs:
                        scale = 1
                        if cfg.FLAGS.double_bias and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = optimizer.apply_gradients(final_gvs)
            else:
                train_op = optimizer.apply_gradients(gvs)

            # 自己处理snapshots保存模型参数
            self.saver = tf.train.Saver(max_to_keep=100000)
            # 向tensorboard中写入训练和检验信息
            writer = tf.summary.FileWriter('default/', sess.graph)
            # valwriter = tf.summary.FileWriter(self.tbvaldir)

        # 加载权重
        # 直接从ImageNet weights 中更新训练
        # 加载预训练模型vgg16,路径cfg.FLAGS.pretrained_model
        print('Loading initial model weights from {:s}'.format(
            cfg.FLAGS.pretrained_model))
        variables = tf.global_variables()
        # 获取全部学习变量,以便进行初始化

        # print('###################')
        # print(variables)
        # print('###################')

        sess.run(tf.variables_initializer(variables, name='init'))
        # 初始化变量
        var_keep_dic = self.get_variables_in_checkpoint_file(
            cfg.FLAGS.pretrained_model)
        # Get the variables to restore, ignorizing the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic, sess, cfg.FLAGS.pretrained_model)

        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, cfg.FLAGS.pretrained_model)
        print('Loaded.')
        # 需要在加载前fix variables,以便将RGB数据型的权重转变成BGR表示
        # 同样对vgg16网络中fc6和fc7的权重也进行了改变
        # 全连接层权重
        self.net.fix_variables(sess, cfg.FLAGS.pretrained_model)
        print('Fixed.')
        sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))
        # 随着训练的进行,将cfg.FLAGS.learning_rate的值赋给lr
        last_snapshot_iter = 0

        timer = Timer()
        # 添加一个计时器,计算训练时间
        iter = last_snapshot_iter + 1
        # last_summary_time = time.time()
        print('****************start training*****************')
        while iter < cfg.FLAGS.max_iters + 1:
            # learning rate
            if iter == cfg.FLAGS.step_size + 1:
                # 在更新learning rate之前,保存模型snapshot
                # self.snapshot(sess, iter)
                sess.run(
                    tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))

            timer.tic()
            # 获取训练数据,一次获取一个batch
            blobs = self.data_layer.forward()
            iter += 1
            # 在没有summary的情况下计算图
            if iter % 100 == 0:
                # 没100次保存一下
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = self.net.train_step_with_summary(
                    sess, blobs, train_op)
                timer.toc()

                run_metadata = tf.RunMetadata()
                writer.add_run_metadata(run_metadata, 'step%03d' % iter)
                writer.add_summary(summary, iter)
            else:
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(
                    sess, blobs, train_op)
                timer.toc()

            # 在每个cfg.FLAGS.display处进行print,展示训练loss等
            if iter % (cfg.FLAGS.display) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
                      (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            if iter % cfg.FLAGS.snapshot_iterations == 0:
                # 在cfg.FLAGS.snapshot_iterations处保存snapshot
                self.snapshot(sess, iter)

    def get_variables_in_checkpoint_file(self, file_name):
        # 读取与训练模型
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:
            # 如果预训练文件是压缩状态,抛出错误
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            # 检查output路径
            os.makedirs(self.output_dir)

        # 保存模型
        filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('write snapshot to: {:s}'.format(filename))

        # 保存 meta information, random state等数据
        nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # 当前np随机数current state of numpy random
        st0 = np.random.get_state()
        # 当前layer
        cur = self.data_layer._cur
        # 数据库的当前无序索引shuffled indeces of the database
        perm = self.data_layer._perm

        # 保存数据
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename
Exemplo n.º 2
0
class Train:
    def __init__(self, canting, data_riqi):

        self.canting = canting
        self.data_riqi = data_riqi

        # Create network
        if cfg.FLAGS.network == 'ghostnet':
            self.net = GhostNet(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        self.imdb, self.roidb = combined_roidb("voc_2007_trainval", canting,
                                               data_riqi)

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(
            self.imdb, 'default_{}_{}'.format(self.canting, self.data_riqi))

    def train(self):

        # Create session
        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        tfconfig.gpu_options.allow_growth = True
        sess = tf.Session(config=tfconfig)

        losser = []

        with sess.graph.as_default():

            tf.set_random_seed(cfg.FLAGS.rng_seed)
            with TowerContext('', is_training=False):
                layers = self.net.create_architecture(sess,
                                                      "TRAIN",
                                                      self.imdb.num_classes,
                                                      tag='default')
            loss = layers['total_loss']
            lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)
            momentum = cfg.FLAGS.momentum
            optimizer = tf.train.MomentumOptimizer(lr, momentum)
            # optimizer = tf.train.AdamOptimizer(lr)

            gvs = optimizer.compute_gradients(loss)

            # Double bias
            # Double the gradient of the bias if set
            if cfg.FLAGS.double_bias:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult'):
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.FLAGS.double_bias and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = optimizer.apply_gradients(final_gvs)
            else:
                train_op = optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)

            # Write the train and validation information to tensorboard

            #writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            #valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Load weights
        # Fresh train directly from ImageNet weights
        #print('Loading initial model weights from {:s}'.format(cfg.FLAGS.pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        #
        pretrained_model = r'F:\GhostNet\ghostnet\models\ghostnet_checkpoint'
        sess.run(tf.variables_initializer(variables, name='init'))
        var_keep_dic = self.get_variables_in_checkpoint_file(pretrained_model)
        #var_keep_dic = get_model_loader(pretrained_model)
        # Get the variables to restore, ignorizing the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic)

        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, pretrained_model)
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        self.net.fix_variables(sess, pretrained_model)
        print('Fixed.')
        sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))
        last_snapshot_iter = 0

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()

        fig = plt.figure()
        arx = fig.add_subplot(1, 1, 1)

        while iter < cfg.FLAGS.max_iters + 1:
            # Learning rate
            if iter == cfg.FLAGS.step_size + 1:
                # Add snapshot here before reducing the learning rate
                # self.snapshot(sess, iter)
                sess.run(
                    tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            # Compute the graph without summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(
                sess, blobs, train_op)
            timer.toc()
            iter += 1

            # Display training information
            if iter % (cfg.FLAGS.display) == 0:
                losser.append(total_loss)
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
                      (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

                #arx.cla()
                #arx.plot(losser,'bo-')
                #plt.pause(0.1)

            if iter % cfg.FLAGS.snapshot_iterations == 0:
                self.snapshot(sess, iter)

        #arx.plot(losser,'bo-')

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = 'mobilenetv1_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = 'mobilenetv1_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename
Exemplo n.º 3
0
class Train:
    def __init__(self):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")#default = voc_2007_trainval
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')
        #For Tensorboard yuqiyue 20171226
        #self.tbdir = cfg.get_output_dir(self.imdb, "tensorboard")

    def train(self):

        # Create session
        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        #tfconfig.gpu_options.allow_growth = True
        sess = tf.Session(config=tfconfig)
    #modify these anchor parameters for your target
        _anchor_scales = (8, 16, 32)#default 8, 16, 32 : The target size
        _anchor_ratios = (0.5, 1, 2)#default 1:2, 1:1, 2:1 : The target ratio

        with sess.graph.as_default():

            tf.set_random_seed(cfg.FLAGS.rng_seed)
        #create_architecture(self, sess, mode, num_classes, tag=None, anchor_scales=(8, 16, 32), anchor_ratios=(0.5, 1, 2))
            layers = self.net.create_architecture(sess, "TRAIN", self.imdb.num_classes, tag='default', anchor_scales=_anchor_scales, anchor_ratios=_anchor_ratios)
            loss = layers['total_loss']
            lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)
            momentum = cfg.FLAGS.momentum
            optimizer = tf.train.MomentumOptimizer(lr, momentum)

            gvs = optimizer.compute_gradients(loss)

            # Double bias
            # Double the gradient of the bias if set
            if cfg.FLAGS.double_bias:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult'):
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.FLAGS.double_bias and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = optimizer.apply_gradients(final_gvs)
            else:
                train_op = optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=None)
            # Write the train and validation information to tensorboard
            #self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
        # Load weights
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(cfg.FLAGS.pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))
        var_keep_dic = self.get_variables_in_checkpoint_file(cfg.FLAGS.pretrained_model)
        # Get the variables to restore, ignorizing the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, cfg.FLAGS.pretrained_model)
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        self.net.fix_variables(sess, cfg.FLAGS.pretrained_model)
        print('Fixed.')
        sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))
#Notice: set this number if you want to continue training
        last_snapshot_iter = 0

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        record_file = open("Train_Record.csv", "w")
        record_file.write("iter,RPN_LOSS_CLS,RPN_LOSS_BOX,LOSS_CLS,LOSS_BOX,TOTAL_LOSS\n")
        while iter < cfg.FLAGS.max_iters + 1:
            #output iter for debugging
            #print("training : iter = {0}".format(iter))
            # Learning rate
            if iter == cfg.FLAGS.step_size + 1:
                # Add snapshot here before reducing the learning rate
                # self.snapshot(sess, iter)
                sess.run(tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            # Compute the graph without summary ==> Fast Training
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(sess, blobs, train_op)
            # End Compute the graph without summary
            # Compute the graphy with summary ==> Speed Slowly
            #rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = self.net.train_step_with_summary(sess, blobs, train_op)
            #self.writer.add_summary(summary, float(iter))
            # End Compute the graphy with summary
            # Compute the graph with nothing ==> Fastest Training
            #self.net.train_step_no_return(sess, blobs, train_op)
            # End Compute the graph with noting
            #write values to csv file
            record_file.write("{0},{1},{2},{3},{4},{5}\n".format(iter, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss))
            timer.toc()
            iter += 1
            # Display training information
            if iter % (cfg.FLAGS.display) == 0:
                print("{0} iter / {1} :: total_loss = {2}".format(iter, cfg.FLAGS.max_iters, total_loss))
            if iter % (cfg.FLAGS.snapshot_iterations) == 0 :
                self.snapshot(sess, iter)
                #output the values                
                #print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                #      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
                #      (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
                #print('speed: {:.3f}s / iter'.format(timer.average_time))
        record_file.close()
        #close summary writer
        #self.writer.close()

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print("It's likely that your checkpoint file has been compressed "
                      "with SNAPPY.")

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename
Exemplo n.º 4
0
class Train:
    def __init__(self):
        # 初始化网络
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)  # 1
        else:
            raise NotImplementedError

        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")
        # 输入输出
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')

    def train(self):
        # 创建session
        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        tfconfig.gpu_options.allow_growth = True
        sess = tf.Session(config=tfconfig)

        with sess.graph.as_default():
            tf.set_random_seed(cfg.FLAGS.rng_seed)  # rng_seed=3
            # 构建网络
            layers = self.net.create_architecture("TRAIN",
                                                  self.imdb.num_classes,
                                                  tag='default')
            loss = layers['total_loss']
            lr = tf.Variable(cfg.FLAGS.learning_rate,
                             trainable=False)  # learning_rate=0.001
            momentum = cfg.FLAGS.momentum  # momentum=0.9
            # 创建优化器
            optimizer = tf.train.MomentumOptimizer(lr, momentum)
            # 计算梯度
            gvs = optimizer.compute_gradients(loss)

            # 是否要使用双倍偏差
            if cfg.FLAGS.double_bias:  # True
                final_gvs = []
                with tf.variable_scope('Gradient_Mult'):
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.FLAGS.double_bias and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = optimizer.apply_gradients(final_gvs)
            else:
                train_op = optimizer.apply_gradients(gvs)

            self.saver = tf.train.Saver(max_to_keep=100000)

            print('Loading initial model weights from {:s}'.format(
                cfg.FLAGS.pretrained_model))
            # 初始化所有变量
            variables = tf.global_variables()
            sess.run(tf.variables_initializer(variables, name='init'))
            # 从下载目录取加载我们下载的参数
            var_keep_dic = self.get_variables_in_checkpoint_file(
                cfg.FLAGS.pretrained_model)
            # 除去需要修正的变量,对其他变量进行赋值存储
            variables_to_restore = self.net.get_variables_to_restore(
                variables, var_keep_dic)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, cfg.FLAGS.pretrained_model)
            print('Loaded.')
            # 对fc6,fc7,conv1进行修正
            self.net.fix_variables(sess, cfg.FLAGS.pretrained_model)
            print('Fixed.')
            sess.run(tf.assign(lr,
                               cfg.FLAGS.learning_rate))  # learning_rate=0.001
            last_snapshot_iter = 0

            timer = Timer()
            iter = last_snapshot_iter + 1
            last_summary_time = time.time()
            while iter < cfg.FLAGS.max_iters + 1:  # 40000
                # 根据论文描述,将30000步后的学习率调低
                if iter == cfg.FLAGS.step_size + 1:  # 30000
                    sess.run(
                        tf.assign(lr, cfg.FLAGS.learning_rate *
                                  cfg.FLAGS.gamma))  # gamma=0.1

                timer.tic()
                # 获取这次mini-batch训练需要用的数据data,gt_boxes和im_info
                blobs = self.data_layer.forward()

                # 将参数传入进行损失的计算
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(
                    sess, blobs, train_op)
                timer.toc()
                iter += 1

                # 10次迭代打印一次损失
                if iter % (cfg.FLAGS.display) == 0:  # 10
                    print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                          '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
                          (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
                    print('speed: {:.3f}s / iter'.format(timer.average_time))
                # 5000次进行一次存储
                if iter % cfg.FLAGS.snapshot_iterations == 0:  # 5000
                    self.snapshot(sess, iter)

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename
Exemplo n.º 5
0
class Train:
    def __init__(self):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError
        '''
        这里应用factory.py中的get_imdb(name)函数,
        然后值为 __sets[name](),是一个字典, __sets[name] = (lambda split=split, year=year: pascal_voc(split, year)),值为一个函数,也即实例化一个pascal_voc(split, year)对象
        也就确定了imbd的name,这里只是跟imbd的name有关,可以不用管
        '''
        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")     
        
        # 对原始image和gt_boxes进行平移缩放等处理,得到network input
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        # 模型保存的路径
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')


    def train(self):

        # Create session
        #allow_soft_placement=True自动将无法放到GPU上的操作放回到CPU
        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        #让GPU按需分配,不一定占用某个GPU的全部内存
        tfconfig.gpu_options.allow_growth = True
        sess = tf.Session(config=tfconfig)

        with sess.graph.as_default():

            tf.set_random_seed(cfg.FLAGS.rng_seed)
            layers = self.net.create_architecture(sess, "TRAIN", self.imdb.num_classes, tag='default')
            loss = layers['total_loss']
            lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)   # 0.001
            momentum = cfg.FLAGS.momentum                                # 0.9
            optimizer = tf.train.MomentumOptimizer(lr, momentum)

            gvs = optimizer.compute_gradients(loss) 

            # Double bias
            # Double the gradient of the bias if set
            if cfg.FLAGS.double_bias:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult'):
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.FLAGS.double_bias and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = optimizer.apply_gradients(final_gvs)
            else:
                train_op = optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            
            # tensorboard文件保存地址
            self.tbdir=os.path.join(os.path.dirname(__file__),'log','log_RSDDs_3_40000_30000_0.1')
            if not os.path.exists(self.tbdir):
                os.makedirs(self.tbdir)
            # Write the train and validation information to tensorboard
            writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            # valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Load weights

        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))

        # 从这里加一个判断,如果ckpt文件没有,则加载VGG16参数,如果由,则加载训练好模型的参数
        tfmodel = os.path.join(cfg.FLAGS2["root_dir"], 'Module', 'LJModule', 'Faster RCNN', 'vgg16', 'RSDDs_3', 'default')
        iter_add=0
        # 得到checkpointstate类
        ckpt=tf.train.get_checkpoint_state(tfmodel)
        # ckpt.model_checkpoint_path属性保存了最新模型文件的绝对文件名
        if ckpt and ckpt.model_checkpoint_path:
             saver = tf.train.Saver()
             saver.restore(sess, ckpt.model_checkpoint_path)
             tfmodel_name=ckpt.model_checkpoint_path.split('\\')[-1]
             print('Loaded  network参数 from {:s}.'.format(tfmodel_name))
             iter_add=int(ckpt.model_checkpoint_path.split('\\')[-1].split("_")[-1].split(".")[0])
             print(iter_add)
        else:
             """
             这里是将VGG16的预训练权重全部加载,原论文中是从conv3_1开始
             """
             # Fresh train directly from ImageNet weights
             print('Loading initial model weights from {:s}'.format(cfg.FLAGS.pretrained_model))

             #从VGG_16.ckpt获得权重参数名及维度,以字典的方式返回
             var_keep_dic = self.get_variables_in_checkpoint_file(cfg.FLAGS.pretrained_model)

             # Get the variables to restore, ignorizing the variables to fix(修理)
             variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

             restorer = tf.train.Saver(variables_to_restore)
             restorer.restore(sess, cfg.FLAGS.pretrained_model)  # 加载conv1到conv5参数
             print('Loaded VGG16参数.')

             # Need to fix the variables before loading, so that the RGB weights are changed to BGR
             # For VGG16 it also changes the convolutional weights fc6 and fc7 to
             # fully connected weights
             self.net.fix_variables(sess, cfg.FLAGS.pretrained_model)
             print('Fixed.')

        sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))
        last_snapshot_iter = iter_add   #如果是从训练好的模型加载参数继续训练时,这里显示迭代次数
        
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        while iter < cfg.FLAGS.max_iters + 1:
            # Learning rate
            # FLAGS.step_size=60000:Step size for reducing the learning rate, currently only support one step
            if iter == cfg.FLAGS.step_size_1 + 1:
                # Add snapshot here before reducing the learning rate
                # self.snapshot(sess, iter)
                sess.run(tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))  # 0.001*0.1
            '''
            if iter == cfg.FLAGS.step_size_2 + 1:
                # Add snapshot here before reducing the learning rate
                # self.snapshot(sess, iter)
                sess.run(tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma * cfg.FLAGS.gamma))  # 0.001*0.1*0.1
            '''
            timer.tic()
            # Get training data, one batch at a time
            '''
            返回一个字典blobs,键值为:
            data        输入的image,四维数组,第一维表示每一个minibatch中的第几张图片
            gt_boxes    将gt_box坐标对应到图片缩放之后的坐标,前四列是坐标,第五列是类别
            im_info     image等比缩放之后的尺寸和缩放比例
            '''
            blobs = self.data_layer.forward()

            # Compute the graph without summary
            try:
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss= self.net.train_step(sess, blobs, train_op)
                
            except Exception:
                # if some errors were encountered image is skipped without increasing iterations
                print('image invalid, skipping')
                continue

            timer.toc()
            



            # Display training information
            if iter % (cfg.FLAGS.display) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
                      (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            # 每迭代50次,将所有日志写入文件,tensorboard就可以拿到这次运行所对应的运行信息
            if iter % 50 == 0:
                # 每迭代10次,将所有日志写入文件,tensorboard就可以拿到这次运行所对应的运行信息
                summary=self.net.get_summary_2(sess, blobs)
                writer.add_summary(summary,iter)

            #每迭代cfg.FLAGS.snapshot_iterations次,保存一次模型。
            if iter % cfg.FLAGS.snapshot_iterations == 0:
                self.snapshot(sess, iter )

            iter += 1
            
    #从VGG_16.ckpt获得权重参数名及维度,以字典的方式返回
    #{"global_step":[],"vgg_16/conv1/conv1_1/weights":[3,3,3,64],........}
    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print("It's likely that your checkpoint file has been compressed "
                      "with SNAPPY.")
                      
    #用于保存训练好的模型
    def snapshot(self, sess, iter):
        net = self.net
        #self.output_dir表示模型保存的路径,如果不存在则创建它
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)      #到config.py中修改路径
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename
Exemplo n.º 6
0
        print('im_detect: {:d}/{:d} {:.3f}s {:.3f}s' \
              .format(i + 1, num_images, _t['im_detect'].average_time,
                      _t['misc'].average_time))

    det_file = os.path.join(output_dir, 'detections.pkl')
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    print('Evaluating detections')
    imdb.evaluate_detections(all_boxes, output_dir)


if __name__ == '__main__':
    imdb, roidb = combined_roidb("voc_2007_trainval")
    data_layer = RoIDataLayer(roidb, imdb.num_classes)
    output_dir = cfg.get_output_dir(imdb, 'default')

    args = parse_args()

    # model path
    demonet = args.demo_net
    dataset = args.dataset
    # tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0])
    tfmodel = os.path.join('default', DATASETS[dataset][0], 'default',
                           NETS[demonet][0])

    if not os.path.isfile(tfmodel + '.meta'):
        print(tfmodel)
        raise IOError(
            ('{:s} not found.\nDid you download the proper networks from '
Exemplo n.º 7
0
class Train:
    def __init__(self):

        # Create network
        if cfg.FLAGS.net == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError
        ######################
        self.imdb, self.roidb = combined_roidb("DIY_dataset")

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')

    def train(self):

        # Create session
        tfconfig = tf.ConfigProto(
            allow_soft_placement=True
        )  # allow_soft_placement = true : select GPU automatically
        tfconfig.gpu_options.allow_growth = True
        # tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.90
        sess = tf.Session(config=tfconfig)

        with sess.graph.as_default():

            tf.set_random_seed(cfg.FLAGS.rng_seed)
            layers = self.net.create_architecture(sess,
                                                  "TRAIN",
                                                  self.imdb.num_classes,
                                                  tag='default')
            loss = layers['total_loss']
            lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)
            momentum = cfg.FLAGS.momentum
            optimizer = tf.train.MomentumOptimizer(lr, momentum)

            gvs = optimizer.compute_gradients(loss)

            # Double bias
            # Double the gradient of the bias if set
            if cfg.FLAGS.double_bias:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult'):
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.FLAGS.double_bias and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = optimizer.apply_gradients(final_gvs)
            else:
                train_op = optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            writer = tf.summary.FileWriter('default/', sess.graph)
            # valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Load weights
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(
            cfg.FLAGS.pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))
        print('cfg.FLAGS.pretrained_model', cfg.FLAGS.pretrained_model)
        var_keep_dic = self.get_variables_in_checkpoint_file(
            cfg.FLAGS.pretrained_model)
        print('var to keep: ', var_keep_dic)
        # Get the variables to restore, ignorizing the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic, sess, cfg.FLAGS.pretrained_model)

        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, cfg.FLAGS.pretrained_model)
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        self.net.fix_variables(sess, cfg.FLAGS.pretrained_model)
        print('Fixed.')
        sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))
        last_snapshot_iter = 0

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        print('START TRAINING: ...')
        while iter < cfg.FLAGS.max_iters + 1:
            print('iteration number: ', iter)
            # Learning rate
            if iter == cfg.FLAGS.step_size + 1:
                # Add snapshot here before reducing the learning rate
                # self.snapshot(sess, iter)
                sess.run(
                    tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()
            iter += 1
            # Compute the graph without summary
            if iter % 2 == 0:
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = self.net.train_step_with_summary(
                    sess, blobs, train_op)
                timer.toc()

                # run_metadata = tf.RunMetadata()
                # writer.add_run_metadata(run_metadata, 'step%03d' % iter)
                # writer.add_summary(summary, iter)
                tf_summary = tf.Summary()
                tf_summary.value.add(tag='total_loss',
                                     simple_value=float(total_loss))
                tf_summary.value.add(tag='rpn_loss_cls',
                                     simple_value=float(rpn_loss_cls))
                tf_summary.value.add(tag='rpn_loss_box',
                                     simple_value=float(rpn_loss_box))
                tf_summary.value.add(tag='loss_cls',
                                     simple_value=float(loss_cls))
                tf_summary.value.add(tag='loss_box',
                                     simple_value=float(loss_box))
                writer.add_summary(tf_summary, iter)
                writer.flush()
            else:
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(
                    sess, blobs, train_op)
                timer.toc()

            # Display training information
            if iter % (cfg.FLAGS.display) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
                      (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            if iter % cfg.FLAGS.snapshot_iterations == 0:
                self.snapshot(sess, iter)

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            print('reader: ', reader)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename
class Train:
    # 数据初始化 构建pascal_voc类和imdb类加载数据
    def __init__(self):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            # 创建vgg(16)网络
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        # 加载数据
        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")

        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        # 模型保存的位置
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')

    def train(self):
        '''运行程序的主方法 网络构建, 迭代训练'''

        # Create session
        # 配置session()运行方式 CPU或者GPU
        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        # tfconfig.gpu_options.allow_growth = True
        sess = tf.Session(config=tfconfig)

        with sess.graph.as_default():

            tf.set_random_seed(cfg.FLAGS.rng_seed)
            # 创建网络结构
            layers = self.net.create_architecture(sess,
                                                  "TRAIN",
                                                  self.imdb.num_classes,
                                                  tag='default')
            loss = layers['total_loss']
            lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)
            momentum = cfg.FLAGS.momentum
            optimizer = tf.train.MomentumOptimizer(lr, momentum)

            # 优化
            gvs = optimizer.compute_gradients(loss)

            # Double bias
            # Double the gradient of the bias if set
            if cfg.FLAGS.double_bias:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult'):
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.FLAGS.double_bias and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = optimizer.apply_gradients(final_gvs)
            else:
                train_op = optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            # writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            # valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Load weights
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(
            cfg.FLAGS.pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))
        var_keep_dic = self.get_variables_in_checkpoint_file(
            cfg.FLAGS.pretrained_model)
        # Get the variables to restore, ignorizing the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic)

        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, cfg.FLAGS.pretrained_model)
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        self.net.fix_variables(sess, cfg.FLAGS.pretrained_model)
        print('Fixed.')
        sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))
        last_snapshot_iter = 0

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        while iter < cfg.FLAGS.max_iters + 1:
            # Learning rate
            if iter == cfg.FLAGS.step_size + 1:
                # Add snapshot here before reducing the learning rate
                # self.snapshot(sess, iter)
                sess.run(
                    tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            # Compute the graph without summary
            try:
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(
                    sess, blobs, train_op)
            except Exception:
                # if some errors were encountered image is skipped without increasing iterations
                print('image invalid, skipping')
                continue

            timer.toc()
            iter += 1

            # Display training information
            if iter % (cfg.FLAGS.display) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
                      (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            if iter % cfg.FLAGS.snapshot_iterations == 0:
                self.snapshot(sess, iter)

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            # reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            reader = NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def snapshot(self, sess, iter):
        '''
        作用:保存训练结果集
        参数:
        sess:会话
        iter:网络
        return:filename, nfilename
        '''
        net = self.net

        # 检查文件是否存在, 不存在就创建一个
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
        # 拼接 将两个字符串按路径的方式拼接
        filename = os.path.join(self.output_dir, filename)
        # 保存
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        # 随机产生的数相同 与np.random.set_state()配合使用
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename
class Train:
    def __init__(self, dataset):

        # Create network
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        elif cfg.FLAGS.network == 'RESNET_v1_50':
            self.net = resnetv1(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError

        #The database
        #self.imdb, self.roidb = combined_roidb("voc_2007_trainval+test+Isabel")

        self.imdb, self.roidb = combined_roidb(dataset)
        #self.imdb, self.roidb = combined_roidb("Isabel")

        print(self.imdb.name)
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')

        print(self.output_dir)

    def train(self):

        # Create session
        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        tfconfig.gpu_options.allow_growth = True
        sess = tf.Session(config=tfconfig)

        with sess.graph.as_default():

            tf.set_random_seed(cfg.FLAGS.rng_seed)
            layers = self.net.create_architecture(sess,
                                                  "TRAIN",
                                                  self.imdb.num_classes,
                                                  tag='default')
            loss = layers['total_loss']
            lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)
            momentum = cfg.FLAGS.momentum
            optimizer = tf.train.MomentumOptimizer(lr, momentum)

            gvs = optimizer.compute_gradients(loss)

            # Double bias
            # Double the gradient of the bias if set
            if cfg.FLAGS.double_bias:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult'):
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.FLAGS.double_bias and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = optimizer.apply_gradients(final_gvs)
            else:
                train_op = optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            # writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            # valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Load weights
        # Fresh train directly from ImageNet weights
        if cfg.FLAGS.network == 'vgg16':
            pretrained_model = cfg.FLAGS.pretrained_model_vgg
        elif cfg.FLAGS.network == 'RESNET_v1_50':
            pretrained_model = cfg.FLAGS.pretrained_model_resnet_50

        print(
            'Loading initial model weights from {:s}'.format(pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))
        var_keep_dic = self.get_variables_in_checkpoint_file(pretrained_model)
        # Get the variables to restore, ignorizing the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic)

        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, pretrained_model)
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        self.net.fix_variables(sess, pretrained_model)
        print('Fixed.')
        for op in tf.get_default_graph().get_operations():
            print(str(op.name))
        sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))
        last_snapshot_iter = 0

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        while iter < cfg.FLAGS.max_iters + 1:
            # Learning rate
            if iter == cfg.FLAGS.step_size + 1:
                # Add snapshot here before reducing the learning rate
                # self.snapshot(sess, iter)
                sess.run(
                    tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            # Compute the graph without summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(
                sess, blobs, train_op)
            timer.toc()
            iter += 1

            # Display training information
            if iter % (cfg.FLAGS.display) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                   '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
                   (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            for itert in cfg.FLAGS2["iterations_to_save"]:
                if iter == itert:
                    self.snapshot(sess, iter)

        #in_image = tf.get_default_graph().get_tensor_by_name('vgg_16/conv1/conv1_1/biases:0')
        #inputs = {'image_bytes': tf.saved_model.utils.build_tensor_info(in_image)}

        #out_classes = tf.get_default_graph().get_tensor_by_name('vgg_16/fc7/biases:0')
        #outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
        # Use a saver_def to get the "magic" strings to restore
        #saver_def = self.saver.as_saver_def()
        #print (saver_def.filename_tensor_name)
        #print (saver_def.restore_op_name)
        # write out 3 files
        #self.saver.save(sess, '/Users/dwaithe/Documents/collaborators/WaitheD/Faster-RCNN-TensorFlow-Python3.5/tmp/trained_model.sd')
        #tf.train.write_graph(tf.get_default_graph(), '.', 'trained_model.proto', as_text=False)
        #tf.train.write_graph(tf.get_default_graph(), '.', 'trained_model.txt', as_text=True)
        #self.save_model(sess)

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def save_model(self, session):
        signature = tf.saved_model.signature_def_utils.build_signature_def()
        #inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
        #outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},)
        b = tf.saved_model.builder.SavedModelBuilder('tmp/model/')
        b.add_meta_graph_and_variables(
            session, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                signature
            })
        b.save()

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename