Example #1
0
def plot_conv_output(conv_img, i,id_y_dir):
    """
    Makes plots of results of performing convolution
    :param conv_img: numpy array of rank 4
    :param name: string, name of convolutional layer
    :return: nothing, plots are saved on the disk
    """
    # make path to output folder
    #ori_dir = os.path.join(id_y_dir, 'original_map')

    #print(conv_img)
    if i in [0,1,2,10,11,12,13,14,15,23,24,25]:
        for j in range(np.shape(conv_img)[3]):
            conv_dir=os.path.join(id_y_dir,'layer_'+str(i))
            utils.prepare_dir(conv_dir)
            file_name=os.path.join(conv_dir,'_conv_'+str(j)+'.png')
            gray_img=np.array(conv_img[0,:,:,j])
            v_min=np.min(gray_img)
            v_max=np.max(gray_img)
            img=(gray_img-v_min)/(v_max-v_min)*255
            cv2.imwrite(file_name,img)
            #cv2.imwrite(file_name,gray_img*255)
            #cv2.imshow('gray',gray_img)
            #cv2.waitKey()
            #gray_img=cv2.cvtColor(conv_img[0,:,:,j],cv2.COLOR_RGB2GRAY)
            #cv2.imwrite(file_name,gray_img)
    # create directory if does not exist, otherwise empty it
    '''
def plot_conv_weights(weights, name, channels_all=True):
    """
    Plots convolutional filters
    :param weights: numpy array of rank 4
    :param name: string, name of convolutional layer
    :param channels_all: boolean, optional
    :return: nothing, plots are saved on the disk
    """
    # make path to output folder
    plot_dir = os.path.join(PLOT_DIR, 'conv_weights')
    plot_dir = os.path.join(plot_dir, name)

    # create directory if does not exist, otherwise empty it
    utils.prepare_dir(plot_dir, empty=True)

    w_min = np.min(weights)
    w_max = np.max(weights)

    channels = [0]
    # make a list of channels if all are plotted
    if channels_all:
        channels = range(weights.shape[2])

    # get number of convolutional filters
    num_filters = weights.shape[3]

    # get number of grid rows and columns
    grid_r, grid_c = utils.get_grid_dim(num_filters)

    # create figure and axes
    fig, axes = plt.subplots(min([grid_r, grid_c]), max([grid_r, grid_c]))

    # iterate channels
    for channel in channels:
        # iterate filters inside every channel
        for l, ax in enumerate(axes.flat):
            # get a single filter
            img = weights[:, :, channel, l]
            # put it on the grid
            ax.imshow(img,
                      vmin=w_min,
                      vmax=w_max,
                      interpolation='nearest',
                      cmap='seismic')
            # remove any labels from the axes
            ax.set_xticks([])
            ax.set_yticks([])
        # save figure
        plt.savefig(os.path.join(plot_dir, '{}-{}.png'.format(name, channel)),
                    bbox_inches='tight')
Example #3
0
def plot_masked(masked_img, step,idx,U_value,i):
    """
    Makes plots of results of performing convolution
    :param conv_img: numpy array of rank 4
    :param name: string, name of convolutional layer
    :return: nothing, plots are saved on the disk
    """
    # make path to output folder
    plot_dir = os.path.join(MASKed_dir, step)
    plot_dir=os.path.join(plot_dir,str(idx))
    #plot_dir = os.path.join(plot_dir, name)

    # create directory if does not exist, otherwise empty it
    utils.prepare_dir(plot_dir, empty=False)
    file_name=os.path.join(plot_dir,str(idx)+'_'+str(int(U_value*1000))+'_'+str(i)+'.png')
    cv2.imwrite(file_name,masked_img)
    return plot_dir
Example #4
0
def plot_fake_xy(fake_y, fake_x, id_x, id_y, ori_x,ori_y,fake_dir):
    #x_dir = os.path.join(fake_dir, str(id_x))
    #y_dir=os.path.join(fake_dir,str(id))
    #plot_dir=os.path.join(plot_dir,str(idx))
    fake_x_img = (np.array(fake_x) + 1.0) * 127.5
    fake_x_img = cv2.cvtColor(fake_x_img, cv2.COLOR_RGB2BGR)
    fake_y_img = (np.array(fake_y) + 1.0) * 127.5
    fake_y_img = cv2.cvtColor(fake_y_img, cv2.COLOR_RGB2BGR)
    #ori_x=cv2.cvtColor(ori_x,cv2.COLOR_RGB2BGR)
    #ori_y = cv2.cvtColor(ori_y, cv2.COLOR_RGB2BGR)

    utils.prepare_dir(fake_dir, empty=False)
    file_nameOX = os.path.join(fake_dir, str(id_x) + '_oriX.png')
    cv2.imwrite(file_nameOX, ori_x)
    file_nameX=os.path.join(fake_dir,str(id_x)+'_fakeX.png')
    cv2.imwrite(file_nameX,fake_y_img)
    file_nameOY = os.path.join(fake_dir, str(id_y) + '_oriY.png')
    cv2.imwrite(file_nameOY, ori_y)
    file_nameY = os.path.join(fake_dir, str(id_y) + '_fakeY.png')
    cv2.imwrite(file_nameY, fake_x_img)
Example #5
0
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip(
            "checkpoints/")
        print(checkpoints_dir)
    else:
        logging.info('No model to test, stopped!')
        return

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(batch_size=FLAGS.batch_size,
                             image_size=FLAGS.image_size,
                             use_lsgan=FLAGS.use_lsgan,
                             norm=FLAGS.norm,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda2,
                             learning_rate=FLAGS.learning_rate,
                             learning_rate2=FLAGS.learning_rate2,
                             beta1=FLAGS.beta1,
                             ngf=FLAGS.ngf)
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, Disperse_loss, Fuzzy_loss, feature_x, feature_y = cycle_gan.model(
        )

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()
    logging.info('Network Built!')

    with tf.Session(graph=graph) as sess:
        checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
        meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
        restore = tf.train.import_meta_graph(meta_graph_path)
        restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
        step = int(meta_graph_path.split("-")[2].split(".")[0])
        Ux = np.loadtxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt',
                        delimiter=",")
        Ux = [[x] for x in Ux]
        Uy = np.loadtxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt',
                        delimiter=",")
        Cx = np.loadtxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt',
                        delimiter=",")
        Cx = [Cx]
        Cy = np.loadtxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt',
                        delimiter=",")
        logging.info('Parameter Initialized!')

        #print('Ux',Ux)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            plots_count = 10
            tsne_plot_count = 1000
            result_dir = './result'
            fake_dir = os.path.join(result_dir, 'fake_xy')
            roc_dir = os.path.join(result_dir, 'roc_curves')
            plot_dir = os.path.join(result_dir, 'tsne_pca')
            conv_dir = os.path.join(result_dir, 'convs')
            #occ_dir=os.path.join(result_dir,'occ_test')
            #utils.prepare_dir(occ_dir)

            x_path = FLAGS.X + FLAGS.UC_name
            x_images, x_id_list, x_len, x_labels, oimg_xs, x_files = get_source_batch(
                0, 256, 256, source_dir=x_path)
            y_images, y_id_list, y_len, y_labels, oimg_ys, y_files = get_target_batch(
                0, 256, 256, target_dir=FLAGS.Y)
            #Compute Accuracy, tp, tn, fp, fn, f1_score, recall, precision, specificity#
            accuracy, tp, tn, fp, fn, f1_score, recall, precision, specificity = computeAccuracy(
                Uy, y_labels)
            print(
                "accuracy:%.4f\ttp:%d\ttn:%d\tfp %d\tfn:%d\tf1_score:%.3f\trecall:%.3f\tprecision:%.3f\tspecicity:%.3f\t"
                % (accuracy, tp, tn, fp, fn, f1_score, recall, precision,
                   specificity))
            #cv2.imshow('201',oimg_ys[201])
            #cv 2.waitKey()
            #draw ROC curves
            '''
            print('y_labels:',np.shape(y_labels))
            print('y_scores:',np.shape(Uy[:,0]))
            fpr,tpr,thresholds=roc_curve(y_labels,Uy[:,1])
            roc_auc=auc(fpr,tpr)
            plt.plot(fpr, tpr)
            plt.xticks(np.arange(0, 1, 0.1))
            plt.yticks(np.arange(0, 1, 0.1))
            plt.xlabel("False Positive Rate")
            plt.ylabel("True Positive Rate")
            # plt.title("A simple plot")
            plt.show()
            print('fpr:',np.shape(fpr))
            print('tpr:', np.shape(tpr))
            print('thresholds:', np.shape(thresholds))
            '''
            # t-SNE and PCA plots#
            for j in range(plots_count):
                feature_path = os.path.join(checkpoints_dir,
                                            'feature_fcgan.npy')
                feature = np.load(feature_path)
                print('feature:', len(feature))
                randIdx = random.sample(range(0, len(y_labels)),
                                        tsne_plot_count)
                t_features = []
                t_labels = []
                for i in range(len(randIdx)):
                    t_features.append(feature[randIdx[i]])
                    t_labels.append(y_labels[randIdx[i]])
                # 使用TSNE进行降维处理。从100维降至2维。
                tsne = TSNE(n_components=2,
                            learning_rate=100).fit_transform(t_features)
                #pca = PCA().fit_transform(t_features)
                #设置画布大小
                plt.figure(figsize=(6, 6))
                #plt.subplot(121)
                plt.scatter(tsne[:, 0], tsne[:, 1], c=t_labels)
                #plt.subplot(122)
                #plt.scatter(pca[:, 0], pca[:, 1], c=t_labels)
                plt.colorbar()  # 使用这一句就可以分辨出,颜色对应的类了!神奇啊。
                utils.prepare_dir(plot_dir)
                plt.savefig(os.path.join(plot_dir, 'plot{}.pdf'.format(j)))

            for i in range(10):
                #if True:
                #Cross Domain Image Generation#

                x_img, _, x_oimg = get_single_img(x_img_path)
                y_path = os.path.join(y_img_path, str(i + 1) + '.jpg')
                y_img, _, y_oimg = get_single_img(y_path)
                fake_y_eval, fake_x_eval, conv_y_eval = sess.run(
                    [fake_y, fake_x,
                     tf.get_collection('conv_output')],
                    feed_dict={
                        cycle_gan.x: x_img,
                        cycle_gan.y: y_img
                    })
                #print(np.shape(fake_y_eval))
                #print(np.shape(fake_x_eval))
                #print(np.shape(conv_y_eval))
                plot_fake_xy(fake_y_eval[0], fake_x_eval[0], x_name,
                             str(i + 1), x_oimg, y_oimg, fake_dir)
                print('processing:', i)

                #Feature Map Visualization#

                print('conv_len:', len(conv_y_eval))
                print('conv_shape:', np.shape(conv_y_eval[0]))
                id_y_dir = os.path.join(conv_dir, str(y_name))
                #utils.prepare_dir()
                for i, c in enumerate(conv_y_eval):
                    #conv_i_dir=os.path.join(id_y_dir,'_layer_'+str(i))
                    plot_conv_output(c, i, id_y_dir)
                #print(os.path.join(id_y_dir, 'y.png'))
                cv2.imwrite(os.path.join(id_y_dir, 'y.png'), y_oimg)

            #Occlusion Test#
            '''
            if True:
                t_imgs, t_lbs, t_img = get_single_img(t_img_path)
                #s_imgs, s_lbs, t_img = get_single_img(s_img_path)
                width=np.shape(t_imgs[0])[0]
                height=np.shape(t_imgs[0])[1]
                #print('width:',width)
                #print('height:', height)
                data=generate_occluded_imageset(t_imgs[0],width=width,height=height,occluded_size=16)
                #print(data.shape[0])
                #print('Cy:',Cy)
                u_ys=np.empty([data.shape[0]],dtype='float64')
                occ_map=np.empty((width,height),dtype='float64')
                print(occ_map.shape)
                cnt=0
                feature_y_eval = sess.run(
                    feature_y,
                    feed_dict={cycle_gan.y: [data[0]]})
                # print(feature_y_eval)
                idx_u = 0
                u_y0 = cal_U(feature_y_eval[0], Cy, 2, 2)[idx_u]
                occ_value=0
                print('u_y0:',u_y0)
                for i in range(width):
                    print('collum idx:',i)
                    print(str(cnt) + ':' + str(occ_value))
                    for j in range(height):
                        feature_y_eval = sess.run(
                            feature_y,
                            feed_dict={cycle_gan.y: [data[cnt+1]]})
                        # print(feature_y_eval)
                        u_y = cal_U(feature_y_eval[0], Cy, 2, 2)[idx_u]
                        #print('u_y0:', u_y0)
                        #print('u_y:',u_y)
                        occ_value=u_y0-u_y
                        occ_map[i,j]=occ_value
                        #print(str(cnt)+':'+str(occ_value))
                        cnt+=1

                occ_map_path=os.path.join(occ_dir,'occlusion_map.txt')
                np.savetxt(occ_map_path, occ_map, fmt='%0.20f')
                cv2.imwrite(os.path.join(occ_dir, 'occ_test.png'), oimg_ys[id_y])
                draw_heatmap(occ_map_path=occ_map_path,ori_img=t_img,save_dir=os.path.join(occ_dir,'heatmap.png'))
            '''

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            #save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
            #np.savetxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",")
            #np.savetxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",")
            #np.savetxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",")
            #np.savetxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",")
            logging.info("stopped")
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)