예제 #1
0
파일: fuzzy.py 프로젝트: senli2018/FCGAN
def initialize_UC(checkpoints_dir, C_initial, UC_name, source_dir, target_dir):
    x_images, x_id_list, x_len, x_labels = dr.get_source_batch(
        0, 256, 256, source_dir=source_dir)
    y_images, y_id_list, y_len, y_labels = dr.get_target_batch(
        0, 256, 256, target_dir=target_dir)
    #print('x_len',len(x_images))
    #print('y_len',len(y_images))
    x = tf.placeholder(tf.float32, [None, 256, 256, 3])
    y = tf.placeholder(tf.float32, [None, 256, 256, 3])

    # C_initial = Classifier('C', True, reuse=True)
    fx = C_initial(x)
    fy = C_initial(y)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()

        x_data = []
        y_data = []

        for img in x_images:
            data = sess.run(fx, feed_dict={x: [img]})
            x_data.append(data[0])
        for img in y_images:
            data = sess.run(fy, feed_dict={y: [img]})
            y_data.append(data[0])
    print('y_data:', np.sum(y_data, axis=1))
    Ux = initialize_U(x_len, 1)
    Uy = initialize_U(y_len, 2)
    Cx = initialize_C(x_data, Ux)
    Cy = initialize_C(y_data, Uy)

    np.savetxt(checkpoints_dir + "/Ux" + UC_name + '.txt',
               Ux,
               fmt="%.20f",
               delimiter=",")
    np.savetxt(checkpoints_dir + "/Uy" + UC_name + '.txt',
               Uy,
               fmt="%.20f",
               delimiter=",")
    np.savetxt(checkpoints_dir + "/Cx" + UC_name + '.txt',
               Cx,
               fmt="%.20f",
               delimiter=",")
    np.savetxt(checkpoints_dir + "/Cy" + UC_name + '.txt',
               Cy,
               fmt="%.20f",
               delimiter=",")

    return Ux, Uy, Cx, Cy, x_data, y_data, y_labels
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda2,
            learning_rate=FLAGS.learning_rate,
            learning_rate2=FLAGS.learning_rate2,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf
        )
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, Disperse_loss, Fuzzy_loss,feature_x,feature_y,_,_ = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss, Disperse_loss)
        optimizers2 = cycle_gan.optimize2(Fuzzy_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 1

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            x_path = FLAGS.X + FLAGS.UC_name
            print('now is in FCM initializing!')
            if FLAGS.load_model is None:
                
                x_images, x_id_list, x_len, x_labels,_ ,_= get_source_batch(0, 256, 256, source_dir=x_path)
                y_images, y_id_list, y_len, y_labels,_,_ = get_target_batch(0, 256, 256, target_dir=FLAGS.Y)
                print('x_len',len(x_images))
                print('y_len',len(y_images))
                x_data=[]
                y_data=[]
                for x in x_images:
                    feature_x_eval = ( sess.run(
                        feature_x, feed_dict={cycle_gan.x: [x]}
                    ))
                    x_data.append(feature_x_eval[0])
                for y in y_images:
                    feature_y_eval = (sess.run(
                        feature_y, feed_dict={cycle_gan.y: [y]}
                    ))
                    y_data.append(feature_y_eval[0])
                Ux, Uy, Cx, Cy= fuzzy.initialize_UC_test(x_len,x_data,y_len,y_data, FLAGS.UC_name,checkpoints_dir)
                np.savetxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",")
                np.savetxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",")
                np.savetxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",")
                np.savetxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",")

            else:
                Ux = np.loadtxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', delimiter=",")
                Ux = [[x] for x in Ux]
                Uy = np.loadtxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', delimiter=",")
                Cx = np.loadtxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', delimiter=",")
                Cx = [Cx]
                Cy = np.loadtxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', delimiter=",")
            print('FCM initialization is ended! Go to train')
            max_accuracy = 0
            while not coord.should_stop():
                images_x, idx_list, len_x, labels_x,_ ,_= get_source_batch(FLAGS.batch_size, FLAGS.image_size,
                                                                       FLAGS.image_size, source_dir=x_path)
                subUx = fuzzy.getSubU(Ux, idx_list)
                label_x = [x[0] for x in subUx]
                images_y, idy_list, len_y, labels_y,_,_ = get_target_batch(FLAGS.batch_size, FLAGS.image_size,
                                                                       FLAGS.image_size, target_dir=FLAGS.Y)
                subUy = fuzzy.getSubU(Uy, idy_list)
                label_y = [x[0] for x in subUy]
                _,_, Fuzzy_loss_val, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary, Disperse_loss_val,feature_x_eval,feature_y_eval = (
                    sess.run(
                        [optimizers,optimizers2,Fuzzy_loss, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op,
                         Disperse_loss, feature_x, feature_y],
                        feed_dict={cycle_gan.x: images_x, cycle_gan.y: images_y,
                                   cycle_gan.Uy2x: subUy, cycle_gan.Ux2y: subUx,
                                   cycle_gan.x_label: label_x, cycle_gan.y_label: label_y,
                                   cycle_gan.ClusterX: Cx, cycle_gan.ClusterY: Cy}
                    )
                )
                train_writer.add_summary(summary, step)
                train_writer.flush()
                '''
                Optimize Networks
                
                
                if step % 10 == 0:
                    print('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))
                    logging.info('  Disperse_loss : {}'.format(Disperse_loss_val))
                    logging.info('  Fuzzy_loss : {}'.format(Fuzzy_loss_val))
                
                Optimize FCM algorithm
                '''
                if step % 100== 0:
                    print('Now is in FCM training!')
                    y_images, y_id_list, y_len, y_labels,_ ,_= get_target_batch(0, 256, 256, target_dir=FLAGS.Y)
                    print('y_len', len(y_images))
                    #x_data = []
                    y_data = []
                    for y in y_images:
                        feature_y_eval = (sess.run(
                            feature_y, feed_dict={cycle_gan.y: [y]}
                        ))
                        y_data.append(feature_y_eval[0])

                    #print('y_data:',np.sum(y_data,1))
                    Uy, Cy = fuzzy.updata_U(checkpoints_dir, y_data, Uy, FLAGS.UC_name)
                    accuracy, tp, tn, fp, fn, f1_score, recall, precision, specificity=computeAccuracy(Uy, y_labels)

                    print("accuracy:%.4f\ttp:%.4f\ttn:%.4f\tfp %d\tfn:%d" %
                          (accuracy, tp, tn, fp, fn))
                    if accuracy==1:
                        break
                    if accuracy >= max_accuracy:
                        max_accuracy = accuracy
                        if not os.path.exists(checkpoints_dir + "/max"):
                            os.makedirs(checkpoints_dir + "/max")
                        f = open(checkpoints_dir + "/max/step.txt", 'w')
                        f.seek(0)
                        f.truncate()
                        f.write(str(step) + '\n')
                        f.write(str(accuracy) + '\taccuracy\n')
                        f.close()
                        np.save(checkpoints_dir + "/max/feature_fcgan.npy",y_data)
                        np.savetxt(checkpoints_dir + "/max/"+ "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",")
                        np.savetxt(checkpoints_dir + "/max/"+ "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",")
                        np.savetxt(checkpoints_dir + "/max/"+ "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",")
                        np.savetxt(checkpoints_dir + "/max/"+ "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",")
                        save_path = saver.save(sess, checkpoints_dir + "/max/model.ckpt",global_step=step)
                        print("Max model saved in file: %s" % save_path)
                    print('max_accuracy:', max_accuracy)
                    print('mean_U',np.min(Uy,0))
                step += 1
                if step>10000:
                    logging.info('train stop!')
                    break

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
            np.savetxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",")
            np.savetxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",")
            np.savetxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",")
            np.savetxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",")
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
예제 #3
0
파일: test.py 프로젝트: senli2018/FCGAN
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip(
            "checkpoints/")
        print(checkpoints_dir)
    else:
        logging.info('No model to test, stopped!')
        return

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(batch_size=FLAGS.batch_size,
                             image_size=FLAGS.image_size,
                             use_lsgan=FLAGS.use_lsgan,
                             norm=FLAGS.norm,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda2,
                             learning_rate=FLAGS.learning_rate,
                             learning_rate2=FLAGS.learning_rate2,
                             beta1=FLAGS.beta1,
                             ngf=FLAGS.ngf)
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, Disperse_loss, Fuzzy_loss, feature_x, feature_y = cycle_gan.model(
        )

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()
    logging.info('Network Built!')

    with tf.Session(graph=graph) as sess:
        checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
        meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
        restore = tf.train.import_meta_graph(meta_graph_path)
        restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
        step = int(meta_graph_path.split("-")[2].split(".")[0])
        Ux = np.loadtxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt',
                        delimiter=",")
        Ux = [[x] for x in Ux]
        Uy = np.loadtxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt',
                        delimiter=",")
        Cx = np.loadtxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt',
                        delimiter=",")
        Cx = [Cx]
        Cy = np.loadtxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt',
                        delimiter=",")
        logging.info('Parameter Initialized!')

        #print('Ux',Ux)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            plots_count = 10
            tsne_plot_count = 1000
            result_dir = './result'
            fake_dir = os.path.join(result_dir, 'fake_xy')
            roc_dir = os.path.join(result_dir, 'roc_curves')
            plot_dir = os.path.join(result_dir, 'tsne_pca')
            conv_dir = os.path.join(result_dir, 'convs')
            #occ_dir=os.path.join(result_dir,'occ_test')
            #utils.prepare_dir(occ_dir)

            x_path = FLAGS.X + FLAGS.UC_name
            x_images, x_id_list, x_len, x_labels, oimg_xs, x_files = get_source_batch(
                0, 256, 256, source_dir=x_path)
            y_images, y_id_list, y_len, y_labels, oimg_ys, y_files = get_target_batch(
                0, 256, 256, target_dir=FLAGS.Y)
            #Compute Accuracy, tp, tn, fp, fn, f1_score, recall, precision, specificity#
            accuracy, tp, tn, fp, fn, f1_score, recall, precision, specificity = computeAccuracy(
                Uy, y_labels)
            print(
                "accuracy:%.4f\ttp:%d\ttn:%d\tfp %d\tfn:%d\tf1_score:%.3f\trecall:%.3f\tprecision:%.3f\tspecicity:%.3f\t"
                % (accuracy, tp, tn, fp, fn, f1_score, recall, precision,
                   specificity))
            #cv2.imshow('201',oimg_ys[201])
            #cv 2.waitKey()
            #draw ROC curves
            '''
            print('y_labels:',np.shape(y_labels))
            print('y_scores:',np.shape(Uy[:,0]))
            fpr,tpr,thresholds=roc_curve(y_labels,Uy[:,1])
            roc_auc=auc(fpr,tpr)
            plt.plot(fpr, tpr)
            plt.xticks(np.arange(0, 1, 0.1))
            plt.yticks(np.arange(0, 1, 0.1))
            plt.xlabel("False Positive Rate")
            plt.ylabel("True Positive Rate")
            # plt.title("A simple plot")
            plt.show()
            print('fpr:',np.shape(fpr))
            print('tpr:', np.shape(tpr))
            print('thresholds:', np.shape(thresholds))
            '''
            # t-SNE and PCA plots#
            for j in range(plots_count):
                feature_path = os.path.join(checkpoints_dir,
                                            'feature_fcgan.npy')
                feature = np.load(feature_path)
                print('feature:', len(feature))
                randIdx = random.sample(range(0, len(y_labels)),
                                        tsne_plot_count)
                t_features = []
                t_labels = []
                for i in range(len(randIdx)):
                    t_features.append(feature[randIdx[i]])
                    t_labels.append(y_labels[randIdx[i]])
                # 使用TSNE进行降维处理。从100维降至2维。
                tsne = TSNE(n_components=2,
                            learning_rate=100).fit_transform(t_features)
                #pca = PCA().fit_transform(t_features)
                #设置画布大小
                plt.figure(figsize=(6, 6))
                #plt.subplot(121)
                plt.scatter(tsne[:, 0], tsne[:, 1], c=t_labels)
                #plt.subplot(122)
                #plt.scatter(pca[:, 0], pca[:, 1], c=t_labels)
                plt.colorbar()  # 使用这一句就可以分辨出,颜色对应的类了!神奇啊。
                utils.prepare_dir(plot_dir)
                plt.savefig(os.path.join(plot_dir, 'plot{}.pdf'.format(j)))

            for i in range(10):
                #if True:
                #Cross Domain Image Generation#

                x_img, _, x_oimg = get_single_img(x_img_path)
                y_path = os.path.join(y_img_path, str(i + 1) + '.jpg')
                y_img, _, y_oimg = get_single_img(y_path)
                fake_y_eval, fake_x_eval, conv_y_eval = sess.run(
                    [fake_y, fake_x,
                     tf.get_collection('conv_output')],
                    feed_dict={
                        cycle_gan.x: x_img,
                        cycle_gan.y: y_img
                    })
                #print(np.shape(fake_y_eval))
                #print(np.shape(fake_x_eval))
                #print(np.shape(conv_y_eval))
                plot_fake_xy(fake_y_eval[0], fake_x_eval[0], x_name,
                             str(i + 1), x_oimg, y_oimg, fake_dir)
                print('processing:', i)

                #Feature Map Visualization#

                print('conv_len:', len(conv_y_eval))
                print('conv_shape:', np.shape(conv_y_eval[0]))
                id_y_dir = os.path.join(conv_dir, str(y_name))
                #utils.prepare_dir()
                for i, c in enumerate(conv_y_eval):
                    #conv_i_dir=os.path.join(id_y_dir,'_layer_'+str(i))
                    plot_conv_output(c, i, id_y_dir)
                #print(os.path.join(id_y_dir, 'y.png'))
                cv2.imwrite(os.path.join(id_y_dir, 'y.png'), y_oimg)

            #Occlusion Test#
            '''
            if True:
                t_imgs, t_lbs, t_img = get_single_img(t_img_path)
                #s_imgs, s_lbs, t_img = get_single_img(s_img_path)
                width=np.shape(t_imgs[0])[0]
                height=np.shape(t_imgs[0])[1]
                #print('width:',width)
                #print('height:', height)
                data=generate_occluded_imageset(t_imgs[0],width=width,height=height,occluded_size=16)
                #print(data.shape[0])
                #print('Cy:',Cy)
                u_ys=np.empty([data.shape[0]],dtype='float64')
                occ_map=np.empty((width,height),dtype='float64')
                print(occ_map.shape)
                cnt=0
                feature_y_eval = sess.run(
                    feature_y,
                    feed_dict={cycle_gan.y: [data[0]]})
                # print(feature_y_eval)
                idx_u = 0
                u_y0 = cal_U(feature_y_eval[0], Cy, 2, 2)[idx_u]
                occ_value=0
                print('u_y0:',u_y0)
                for i in range(width):
                    print('collum idx:',i)
                    print(str(cnt) + ':' + str(occ_value))
                    for j in range(height):
                        feature_y_eval = sess.run(
                            feature_y,
                            feed_dict={cycle_gan.y: [data[cnt+1]]})
                        # print(feature_y_eval)
                        u_y = cal_U(feature_y_eval[0], Cy, 2, 2)[idx_u]
                        #print('u_y0:', u_y0)
                        #print('u_y:',u_y)
                        occ_value=u_y0-u_y
                        occ_map[i,j]=occ_value
                        #print(str(cnt)+':'+str(occ_value))
                        cnt+=1

                occ_map_path=os.path.join(occ_dir,'occlusion_map.txt')
                np.savetxt(occ_map_path, occ_map, fmt='%0.20f')
                cv2.imwrite(os.path.join(occ_dir, 'occ_test.png'), oimg_ys[id_y])
                draw_heatmap(occ_map_path=occ_map_path,ori_img=t_img,save_dir=os.path.join(occ_dir,'heatmap.png'))
            '''

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            #save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
            #np.savetxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",")
            #np.savetxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",")
            #np.savetxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",")
            #np.savetxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",")
            logging.info("stopped")
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)