コード例 #1
0
def train():
    g = Net(config)
    g.build_net()
    sv = tf.train.Supervisor(graph=g.graph, logdir=g.config.logdir)
    cfg = tf.ConfigProto()
    cfg.gpu_options.per_process_gpu_memory_fraction = 0.9
    cfg.gpu_options.allow_growth = True
    with sv.managed_session(config=cfg) as sess:
        ckpt = tf.train.latest_checkpoint(config.logdir)
        start_step = 0
        if ckpt:
            #加载checkpoint,断点保存恢复
            sv.saver.restore(sess, ckpt)
            print("restore from the checkpoint {0}".format(ckpt))
            for root, dir, files in os.walk(config.logdir):
                for file in files:
                    if file.startswith('model_step_'):
                        temp = file.split('.')
                        if int(temp[0][11:]) > start_step:
                            start_step = int(temp[0][11:])
            print('start_step=', start_step)
        best_loss = 1e8
        best_auto_loss = 1e8
        not_improve_count = 0
        MLoss = 0
        MClsLoss = 0
        MRegLoss = 0
        MAutoLoss = 0
        MAttLoss = 0
        mloss = m_cls_loss = m_reg_loss = m_auto_loss = m_att_loss = 0
        time_start = time.time()
        for st in range(start_step, g.config.total_steps):
            mloss, m_cls_loss, m_reg_loss, m_auto_loss, m_att_loss, _ = sess.run(
                [
                    g.mean_loss, g.score_mean_loss, g.offset_mean_loss,
                    g.theta_mean_loss, g.local_mean_loss, g.train_op
                ], {g.train_stage: True})
            MLoss += mloss
            MClsLoss += m_cls_loss
            MRegLoss += m_reg_loss
            MAutoLoss += m_auto_loss
            MAttLoss += m_att_loss
            # display
            if st % g.config.display == 0:
                print(
                    "step=%d, Loss=%f, score Loss=%f, offset Loss=%f, theta Loss=%f, local Loss=%f, time=%f"
                    %
                    (st, MLoss / g.config.display, MClsLoss / g.config.display,
                     MRegLoss / g.config.display, MAutoLoss / g.config.display,
                     MAttLoss / g.config.display, time.time() - time_start))
                MLoss = MClsLoss = MRegLoss = MAutoLoss = MAttLoss = 0
                time_start = time.time()
            valid_step = g.config.num_train_samples // g.config.batch_size
            #验证集验证,用于网络调参
            if st % valid_step == 0:
                VLoss = VClsLoss = VRegLoss = VAutoLoss = VAttLoss = 0
                vloss = v_cls_loss = v_reg_loss = v_auto_loss = v_att_loss = 0
                count = g.config.num_train_samples // g.config.batch_size
                for vi in range(count):
                    vloss, v_cls_loss, v_reg_loss, v_auto_loss, v_att_loss = sess.run(
                        [
                            g.mean_loss, g.score_mean_loss, g.offset_mean_loss,
                            g.theta_mean_loss, g.local_mean_loss
                        ], {g.train_stage: False})
                    VLoss += vloss
                    VClsLoss += v_cls_loss
                    VRegLoss += v_reg_loss
                    VAutoLoss += v_auto_loss
                    VAttLoss += v_att_loss
                VLoss /= count
                VClsLoss /= count
                VRegLoss /= count
                VAutoLoss /= count
                VAttLoss /= count
                print(
                    "validation --- Loss=%f, score Loss=%f, offset Loss=%f, theta Loss=%f, local Loss=%f"
                    % (VLoss, VClsLoss, VRegLoss, VAutoLoss, VAttLoss))
                # model select && early stop
                if VLoss < best_loss or VAutoLoss < best_auto_loss:
                    best_loss = VLoss
                    best_auto_loss = VAutoLoss
                    not_improve_count = 0
                    sv.saver.save(sess,
                                  g.config.logdir + '/model_step_%d' % st)
                else:
                    not_improve_count += 1
                if not_improve_count >= g.config.early_stop_count:
                    print("training stopped, best Loss=%f" % (best_loss))
                    break
                    sv.request_stop()
コード例 #2
0
def test(image_dir):
    image_list, raw_image_list, rate_list = get_images(image_dir)
    config.batch_size = 1
    g = Net(config)
    g.build_net(is_training=False)
    print("Graph loaded.")
    with g.graph.as_default():
        sv = tf.train.Supervisor()
        with sv.managed_session() as sess:
            sv.saver.restore(sess, tf.train.latest_checkpoint(config.logdir))
            print(tf.train.latest_checkpoint(config.logdir))
            print("Restored!")

            for n in range(len(image_list)):
                #for n in range(1):
                x = image_list[n]
                raw_x = raw_image_list[n]
                two_pi = 2.0 * math.acos(-1)
                L_, G_ = sess.run([g.score_prob, g.geo_map], {g.x: x})
                #to_pb
                out_pb_path = "./pb/frozen_model.pb"
                output_node_names = "geo,prob"
                #print(output_node_names.split(","))
                #print(isinstance(output_node_names.split(","), list))
                constant_graph = graph_util.convert_variables_to_constants(
                    sess, sess.graph_def, output_node_names.split(","))
                with tf.gfile.FastGFile(out_pb_path, mode='wb') as f:
                    f.write(constant_graph.SerializeToString())
                ####
                for i in range(len(x)):
                    img = x[i]
                    L = L_[i]
                    G = G_[i]
                    L = np.reshape(L, (L.shape[0], L.shape[1]))
                    img = np.reshape(img,
                                     (img.shape[0], img.shape[1], 3)) * 255
                    img = Image.fromarray(255 - np.uint8(img)).convert('RGBA')
                    #draw = ImageDraw.Draw(img)
                    draw = ImageDraw.Draw(raw_x)
                    max_width = config.max_width * rate_list[n]
                    max_height = config.max_height * rate_list[n]
                    #'''
                    dets = []
                    for r in range(L.shape[0]):
                        for c in range(L.shape[1]):
                            if L[r, c] > 0.618:
                                tr = float(r) / float(L.shape[0])
                                tc = float(c) / float(L.shape[1])
                                x1 = int((tc + G[r, c, 0] *
                                          math.cos(G[r, c, 1] * two_pi)) *
                                         max_width)
                                y1 = int((tr + G[r, c, 0] *
                                          math.sin(G[r, c, 1] * two_pi)) *
                                         max_height)
                                x2 = int((tc + G[r, c, 2] *
                                          math.cos(G[r, c, 3] * two_pi)) *
                                         max_width)
                                y2 = int((tr + G[r, c, 2] *
                                          math.sin(G[r, c, 3] * two_pi)) *
                                         max_height)
                                x3 = int((tc + G[r, c, 4] *
                                          math.cos(G[r, c, 5] * two_pi)) *
                                         max_width)
                                y3 = int((tr + G[r, c, 4] *
                                          math.sin(G[r, c, 5] * two_pi)) *
                                         max_height)
                                x4 = int((tc + G[r, c, 6] *
                                          math.cos(G[r, c, 7] * two_pi)) *
                                         max_width)
                                y4 = int((tr + G[r, c, 6] *
                                          math.sin(G[r, c, 7] * two_pi)) *
                                         max_height)
                                # using triangle to filter out invalid box
                                test1 = Polygon([(x1, y1), (x2, y2), (x4, y4)])
                                test2 = Polygon([(x2, y2), (x3, y3), (x4, y4)])
                                test3 = Polygon([(x1, y1), (x2, y2), (x3, y3)])
                                test4 = Polygon([(x1, y1), (x3, y3), (x4, y4)])
                                if test1.is_valid and test2.is_valid and test3.is_valid and test4.is_valid:
                                    edge1 = distance(x1, y1, x2, y2)
                                    edge2 = distance(x3, y3, x4, y4)
                                    edge3 = distance(x1, y1, x4, y4)
                                    edge4 = distance(x2, y2, x3, y3)
                                    if edge1 > 2 * edge2 or edge2 > 2 * edge1 or edge3 > 2 * edge4 or edge4 > 2 * edge3:
                                        continue
                                    if test1.intersection(
                                            test2
                                    ).area == 0 and test3.intersection(
                                            test4).area == 0:
                                        score = L[r, c]
                                        panalty = (abs(edge1 - edge2) /
                                                   (edge1 + edge2) +
                                                   abs(edge3 - edge4) /
                                                   (edge3 + edge4)) / 4
                                        score -= panalty
                                        dets.append([
                                            x1, y1, x2, y2, x3, y3, x4, y4,
                                            score
                                        ])
                                #draw.point((int(tc*config.max_width), int(tr*config.max_height)), fill=(0, 255, 0, 255))
                                #draw.ppoint((int(tc*raw_x.size[0]), int(tr*raw_x.size[1])), fill=(0, 255, 0, 255))
                    if len(dets) > 0:
                        dets = np.array(dets)
                        print("\n{}_{}".format(n, i))
                        print("{} boxes before nms".format(dets.shape[0]))
                        keeps = standard_nms(dets, 0.146)
                        #keeps = standard_nms(dets, 0.4)
                        print("{} boxes after nms".format(keeps.shape[0]))
                        for k in range(keeps.shape[0]):
                            draw.polygon(list(keeps[k][:8]),
                                         outline=(0, 255, 0, 255))
                    raw_x.save("tmp/{}_{}_check.png".format(n, i))
コード例 #3
0
def eval(save_path, total_batch, auto_vis=False):
    g = Net(config)
    g.build_net(is_training=False)
    print("Graph loaded.")

    with g.graph.as_default():
        image, Labels, GeoMaps = g.read_and_decode(save_path,
                                                   is_training=False)

        sv = tf.train.Supervisor()
        with sv.managed_session() as sess:
            sv.saver.restore(sess, tf.train.latest_checkpoint(config.logdir))
            print(tf.train.latest_checkpoint(config.logdir))
            print("Restored!")

            for n in range(
                    random.randint(
                        0, config.num_valid_samples // config.batch_size)):
                x, _L, _G = sess.run([image, Labels, GeoMaps])

            for n in range(total_batch):
                two_pi = 2.0 * math.acos(-1)
                x, _L, _G = sess.run([image, Labels, GeoMaps])
                L_, G_ = sess.run([g.score_prob, g.geo_map], {g.x: x})
                for i in range(len(x)):
                    img = x[i]
                    L = L_[i]
                    G = G_[i]
                    #L = _L[i]
                    #G = _G[i]
                    L = np.reshape(L, (L.shape[0], L.shape[1]))
                    img = np.reshape(img,
                                     (img.shape[0], img.shape[1], 3)) * 255
                    img = Image.fromarray(255 - np.uint8(img)).convert('RGBA')
                    draw = ImageDraw.Draw(img)
                    #'''
                    dets = []
                    for r in range(L.shape[0]):
                        for c in range(L.shape[1]):
                            if L[r, c] > 0.618:
                                #if L[r, c] > 0.4:
                                tr = float(r) / float(L.shape[0])
                                tc = float(c) / float(L.shape[1])
                                x1 = int((tc + G[r, c, 0] *
                                          math.cos(G[r, c, 1] * two_pi)) *
                                         config.max_width)
                                y1 = int((tr + G[r, c, 0] *
                                          math.sin(G[r, c, 1] * two_pi)) *
                                         config.max_height)
                                x2 = int((tc + G[r, c, 2] *
                                          math.cos(G[r, c, 3] * two_pi)) *
                                         config.max_width)
                                y2 = int((tr + G[r, c, 2] *
                                          math.sin(G[r, c, 3] * two_pi)) *
                                         config.max_height)
                                x3 = int((tc + G[r, c, 4] *
                                          math.cos(G[r, c, 5] * two_pi)) *
                                         config.max_width)
                                y3 = int((tr + G[r, c, 4] *
                                          math.sin(G[r, c, 5] * two_pi)) *
                                         config.max_height)
                                x4 = int((tc + G[r, c, 6] *
                                          math.cos(G[r, c, 7] * two_pi)) *
                                         config.max_width)
                                y4 = int((tr + G[r, c, 6] *
                                          math.sin(G[r, c, 7] * two_pi)) *
                                         config.max_height)
                                # using triangle to filter out invalid box
                                test1 = Polygon([(x1, y1), (x2, y2), (x4, y4)])
                                test2 = Polygon([(x2, y2), (x3, y3), (x4, y4)])
                                test3 = Polygon([(x1, y1), (x2, y2), (x3, y3)])
                                test4 = Polygon([(x1, y1), (x3, y3), (x4, y4)])
                                if test1.is_valid and test2.is_valid and test3.is_valid and test4.is_valid:
                                    edge1 = distance(x1, y1, x2, y2)
                                    edge2 = distance(x3, y3, x4, y4)
                                    edge3 = distance(x1, y1, x4, y4)
                                    edge4 = distance(x2, y2, x3, y3)
                                    if edge1 > 2 * edge2 or edge2 > 2 * edge1 or edge3 > 2 * edge4 or edge4 > 2 * edge3:
                                        continue
                                    if test1.intersection(
                                            test2
                                    ).area == 0 and test3.intersection(
                                            test4).area == 0:
                                        score = L[r, c]
                                        panalty = (abs(edge1 - edge2) /
                                                   (edge1 + edge2) +
                                                   abs(edge3 - edge4) /
                                                   (edge3 + edge4)) / 4
                                        score -= panalty
                                        dets.append([
                                            x1, y1, x2, y2, x3, y3, x4, y4,
                                            score
                                        ])
                                draw.point((int(tc * config.max_width),
                                            int(tr * config.max_height)),
                                           fill=(0, 255, 0, 255))
                    if len(dets) > 0:
                        dets = np.array(dets)
                        print("\n{}_{}".format(n, i))
                        print("{} boxes before nms".format(dets.shape[0]))
                        keeps = standard_nms(dets, 0.1)
                        print("{} boxes after nms".format(keeps.shape[0]))
                        for k in range(keeps.shape[0]):
                            draw.polygon(list(keeps[k][:8]),
                                         outline=(0, 255, 0, 255))
                    #'''
                    img.save("tmp/{}_{}_check.png".format(n, i))
                    #'''
                    '''