Ejemplo n.º 1
0
    def train_model(self, sess, max_iters):
        """Network training loop."""

        # data layer
        data_layer = GtDataLayer(self.roidb, self.imdb.num_classes)

        # classification loss
        cls_score = self.net.upscore32
        label = tf.placeholder(tf.int32, shape=[None, None, None])
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(cls_score, label))

        # add summary
        tf.summary.scalar('loss', loss)
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(self.output_dir, sess.graph)

        # optimizer
        lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
        momentum = cfg.TRAIN.MOMENTUM
        train_op = tf.train.MomentumOptimizer(lr, momentum).minimize(loss)

        # intialize variables
        sess.run(tf.initialize_all_variables())

        last_snapshot_iter = -1
        timer = Timer()
        for iter in range(max_iters):
            # learning rate
            if iter >= cfg.TRAIN.STEPSIZE:
                sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
            else:
                sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

            # get one batch
            blobs = data_layer.forward()

            # Make one SGD update
            feed_dict={self.net.data: blobs['data_depth'], label: blobs['labels']}
            
            timer.tic()
            summary, loss_cls_value, _ = sess.run([merged, loss, train_op], feed_dict=feed_dict)
            train_writer.add_summary(summary, iter)
            timer.toc()

            print 'iter: %d / %d, loss_cls: %.4f, lr: %.8f, time: %.2f' %\
                    (iter+1, max_iters, loss_cls_value, lr.eval(), timer.diff)

            if (iter+1) % (10 * cfg.TRAIN.DISPLAY) == 0:
                print 'speed: {:.3f}s / iter'.format(timer.average_time)

            if (iter+1) % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                self.snapshot(sess, iter)

        if last_snapshot_iter != iter:
            self.snapshot(sess, iter)
Ejemplo n.º 2
0
    def __init__(self):

        args = args_setter()
        imdb = get_imdb(args.imdb_name)
        print 'Loaded dataset `{:s}` for training'.format(imdb.name)

        roidb = get_training_roidb(imdb)

        data_layer = GtDataLayer(roidb, imdb.num_classes)
Ejemplo n.º 3
0
def load_and_enqueue(sess, net, roidb, num_classes, coord):
    if cfg.TRAIN.SINGLE_FRAME:
        # data layer
        data_layer = GtSingleDataLayer(roidb, num_classes)
    else:
        # data layer
        data_layer = GtDataLayer(roidb, num_classes)

    while not coord.should_stop():
        blobs = data_layer.forward()

        if cfg.INPUT == 'RGBD':
            data_blob = blobs['data_image_color']
            data_p_blob = blobs['data_image_depth']
        elif cfg.INPUT == 'COLOR':
            data_blob = blobs['data_image_color']
        elif cfg.INPUT == 'DEPTH':
            data_blob = blobs['data_image_depth']
        elif cfg.INPUT == 'NORMAL':
            data_blob = blobs['data_image_normal']

        if cfg.TRAIN.SINGLE_FRAME:
            if cfg.INPUT == 'RGBD':
                if cfg.TRAIN.VERTEX_REG:
                    feed_dict={net.data: data_blob, net.data_p: data_p_blob, net.gt_label_2d: blobs['data_label'], net.keep_prob: 0.5, \
                               net.vertex_targets: blobs['data_vertex_targets'], net.vertex_weights: blobs['data_vertex_weights']}

                else:
                    feed_dict = {
                        net.data: data_blob,
                        net.data_p: data_p_blob,
                        net.gt_label_2d: blobs['data_label'],
                        net.keep_prob: 0.5
                    }

            else:
                if cfg.TRAIN.VERTEX_REG:
                    feed_dict={net.data: data_blob, net.gt_label_2d: blobs['data_label'], net.keep_prob: 0.5, \
                               net.vertex_targets: blobs['data_vertex_targets'], net.vertex_weights: blobs['data_vertex_weights']}

                else:
                    feed_dict = {
                        net.data: data_blob,
                        net.gt_label_2d: blobs['data_label'],
                        net.keep_prob: 0.5
                    }
        else:
            if cfg.INPUT == 'RGBD':
                feed_dict={net.data: data_blob, net.data_p: data_p_blob, net.gt_label_2d: blobs['data_label'], \
                           net.depth: blobs['data_depth'], net.meta_data: blobs['data_meta_data'], \
                           net.state: blobs['data_state'], net.weights: blobs['data_weights'], net.points: blobs['data_points'], net.keep_prob: 0.5}
            else:
                feed_dict={net.data: data_blob, net.gt_label_2d: blobs['data_label'], \
                           net.depth: blobs['data_depth'], net.meta_data: blobs['data_meta_data'], \
                           net.state: blobs['data_state'], net.weights: blobs['data_weights'], net.points: blobs['data_points'], net.keep_prob: 0.5}

        sess.run(net.enqueue_op, feed_dict=feed_dict)
Ejemplo n.º 4
0
def get_data_layer(roidb, num_classes):
    """return a data layer."""
    if cfg.TRAIN.HAS_RPN:
        if cfg.IS_MULTISCALE:
            layer = GtDataLayer(roidb)
        else:
            layer = RoIDataLayer(roidb, num_classes)
    else:
        layer = RoIDataLayer(roidb, num_classes)

    return layer
Ejemplo n.º 5
0
def train_net(network,
              imdb,
              roidb,
              output_dir,
              pretrained_model=None,
              pretrained_ckpt=None,
              max_iters=40000):
    """Train a Fast R-CNN network."""

    loss_regu = tf.add_n(tf.losses.get_regularization_losses(), 'regu')
    if cfg.TRAIN.SINGLE_FRAME:
        # classification loss
        if cfg.NETWORK == 'FCN8VGG':
            scores = network.prob
            labels = network.gt_label_2d_queue
            loss = loss_cross_entropy_single_frame(scores, labels) + loss_regu
        else:
            if cfg.TRAIN.VERTEX_REG_2D or cfg.TRAIN.VERTEX_REG_3D:
                scores = network.get_output('prob')
                labels = network.get_output('gt_label_weight')
                loss_cls = loss_cross_entropy_single_frame(scores, labels)

                vertex_pred = network.get_output('vertex_pred')
                vertex_targets = network.get_output('vertex_targets')
                vertex_weights = network.get_output('vertex_weights')
                # loss_vertex = tf.div( tf.reduce_sum(tf.multiply(vertex_weights, tf.abs(tf.subtract(vertex_pred, vertex_targets)))), tf.reduce_sum(vertex_weights) + 1e-10 )
                loss_vertex = cfg.TRAIN.VERTEX_W * smooth_l1_loss_vertex(
                    vertex_pred, vertex_targets, vertex_weights)

                if cfg.TRAIN.POSE_REG:
                    # pose_pred = network.get_output('poses_pred')
                    # pose_targets = network.get_output('poses_target')
                    # pose_weights = network.get_output('poses_weight')
                    # loss_pose = cfg.TRAIN.POSE_W * tf.div( tf.reduce_sum(tf.multiply(pose_weights, tf.abs(tf.subtract(pose_pred, pose_targets)))), tf.reduce_sum(pose_weights) )
                    # loss_pose = cfg.TRAIN.POSE_W * loss_quaternion(pose_pred, pose_targets, pose_weights)
                    loss_pose = cfg.TRAIN.POSE_W * network.get_output(
                        'loss_pose')[0]

                    if cfg.TRAIN.ADAPT:
                        domain_score = network.get_output("domain_score")
                        domain_label = network.get_output("domain_label")
                        label_domain = network.get_output("label_domain")
                        loss_domain = cfg.TRAIN.ADAPT_WEIGHT * tf.reduce_mean(
                            tf.nn.sparse_softmax_cross_entropy_with_logits(
                                logits=domain_score, labels=label_domain))
                        loss = loss_cls + loss_vertex + loss_pose + loss_domain + loss_regu
                    else:
                        loss = loss_cls + loss_vertex + loss_pose + loss_regu
                else:
                    loss = loss_cls + loss_vertex + loss_regu
            else:
                scores = network.get_output('prob')
                labels = network.get_output('gt_label_weight')
                loss = loss_cross_entropy_single_frame(scores,
                                                       labels) + loss_regu
    else:
        # classification loss
        scores = network.get_output('outputs')
        labels = network.get_output('labels_gt_2d')
        loss = loss_cross_entropy(scores, labels) + loss_regu

    # optimizer
    global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = cfg.TRAIN.LEARNING_RATE
    learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                               global_step,
                                               cfg.TRAIN.STEPSIZE,
                                               0.1,
                                               staircase=True)
    momentum = cfg.TRAIN.MOMENTUM
    train_op = tf.train.MomentumOptimizer(learning_rate, momentum).minimize(
        loss, global_step=global_step)

    #config = tf.ConfigProto()
    #config.gpu_options.per_process_gpu_memory_fraction = 0.85
    #config.gpu_options.allow_growth = True
    #with tf.Session(config=config) as sess:
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        # data layer
        if cfg.TRAIN.SINGLE_FRAME:
            data_layer = GtSynthesizeLayer(roidb, imdb.num_classes,
                                           imdb._extents, imdb._points_all,
                                           imdb._symmetry, imdb.cache_path,
                                           imdb.name, imdb.data_queue, cfg.CAD,
                                           cfg.POSE)
        else:
            data_layer = GtDataLayer(roidb, imdb.num_classes)

        sw = SolverWrapper(sess,
                           network,
                           imdb,
                           roidb,
                           output_dir,
                           pretrained_model=pretrained_model,
                           pretrained_ckpt=pretrained_ckpt)

        print 'Solving...'
        if cfg.TRAIN.VERTEX_REG_2D or cfg.TRAIN.VERTEX_REG_3D:
            if cfg.TRAIN.POSE_REG:
                if cfg.TRAIN.ADAPT:
                    sw.train_model_vertex_pose_adapt(sess, train_op, loss, loss_cls, loss_vertex, loss_pose, \
                        loss_domain, label_domain, domain_label, learning_rate, max_iters, data_layer)
                else:
                    sw.train_model_vertex_pose(sess, train_op, loss, loss_cls,
                                               loss_vertex, loss_pose,
                                               learning_rate, max_iters,
                                               data_layer)
            else:
                sw.train_model_vertex(sess, train_op, loss, loss_cls,
                                      loss_vertex, loss_regu, learning_rate,
                                      max_iters, data_layer)
        else:
            sw.train_model(sess, train_op, loss, learning_rate, max_iters,
                           data_layer)
        print 'done solving'
assert cfg.TEST.SEGMENTATION == True
assert cfg.TEST.SINGLE_FRAME == True

# **** get network
network = get_network(args.network_name)
print 'Use network `{:s}` in training'.format(args.network_name)

# **** define losses
losses = loss_definition(network, cfg)

# **** load data
sub_fact = sub_factory()
imdb = sub_fact.get_imdb(args.imdb_name)
roidb = get_training_roidb(imdb)
# data_layer = GtSynthesizeLayer(roidb, imdb.num_classes, imdb._extents, imdb._points_all, imdb._symmetry, imdb.cache_path, imdb.name, imdb.data_queue, cfg.CAD, cfg.POSE)
data_layer = GtDataLayer(roidb, imdb.num_classes)

# **** load the trained model
print '** try to restore trained weights'
# tf.reset_default_graph()
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, args.trained_model)
    print("model restored.")

    coord = tf.train.Coordinator()
    load_and_enqueue(sess, network, data_layer, coord, 0)

    sess.run([loss])