示例#1
0
def train():
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    os.makedirs(checkpoints_dir, exist_ok=True)

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN()
        G_loss, D_Y_loss, F_loss, D_X_loss = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)

    with tf.Session(graph=graph) as sess:
        sess.run(tf.global_variables_initializer())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            step = 0
            while not coord.should_stop():
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run([
                        optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                        cycle_gan.summary
                    ]))

                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 100 == 0:
                    print('-----------Step %d:-------------' % step)
                    print('  G_loss   : {}'.format(G_loss_val))
                    print('  D_Y_loss   : {}'.format(D_Y_loss_val))
                    print('  F_loss   : {}'.format(F_loss_val))
                    print('  D_X_loss   : {}'.format(D_X_loss_val))

                if step % 1000 == 0:
                    save_path = cycle_gan.saver.save(sess,
                                                     checkpoints_dir +
                                                     "/model.ckpt",
                                                     global_step=step)
                    print("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            print('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = cycle_gan.saver.save(sess,
                                             checkpoints_dir + "/model.ckpt")
            print("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
示例#2
0
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip(
            "checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(X_train_file=FLAGS.X,
                             Y_train_file=FLAGS.Y,
                             batch_size=FLAGS.batch_size,
                             image_size=FLAGS.image_size,
                             use_lsgan=FLAGS.use_lsgan,
                             norm=FLAGS.norm,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda2,
                             learning_rate=FLAGS.learning_rate,
                             beta1=FLAGS.beta1,
                             ngf=FLAGS.ngf)
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_Y_pool = ImagePool(FLAGS.pool_size)
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # get previously generated images
                fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

                # train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 100 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 5000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model
  else:
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    cycle_gan = CycleGAN(
        X_train_file=FLAGS.X,
        Y_train_file=FLAGS.Y,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambda1=FLAGS.lambda1,
        lambda2=FLAGS.lambda1,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        ngf=FLAGS.ngf
    )
    G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
    optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                  feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
              )
        )
        if step % 100 == 0:
          train_writer.add_summary(summary, step)
          train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
          logging.info('  F_loss   : {}'.format(F_loss_val))
          logging.info('  D_X_loss : {}'.format(D_X_loss_val))

        if step % 10000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)
def train():
    max_accuracy = 0.98
    learning_loss_set = 4.0
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip(
            "checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(
            #X_train_file=FLAGS.X,
            #Y_train_file=FLAGS.Y,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda2,
            learning_rate=FLAGS.learning_rate,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf)
        #G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, x_correct, y_correct, fake_y_correct, fake_y_pre
        G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, x_correct, y_correct, fake_y_correct = cycle_gan.model(
        )
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss,
                                        teacher_loss, student_loss,
                                        learning_loss)
        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = "checkpoints/20190224-1130/model.ckpt-7792.meta"
            print('meta_graph_path', meta_graph_path)
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, "checkpoints/20190224-1130/model.ckpt-7792")

            step = 7792
            #meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            #restore = tf.train.import_meta_graph(meta_graph_path)
            #restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            #step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        last_1 = 0.0
        last_2 = 0.0
        best_1 = 0.0
        best_2 = 0.0

        try:
            while not coord.should_stop():
                #x_image, x_label = get_batch_images(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.X)
                #y_image, y_label = get_batch_images(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.Y)
                x_image, x_label = get_train_batch("X", FLAGS.batch_size,
                                                   FLAGS.image_size,
                                                   FLAGS.image_size,
                                                   "./dataset/")
                #print('x_label',x_label)
                y_image, y_label = get_train_batch("Y", FLAGS.batch_size,
                                                   FLAGS.image_size,
                                                   FLAGS.image_size,
                                                   "./dataset/")
                #print('y_label', y_label)
                # get previously generated images
                # fake_y_val, fake_x_val = sess.run([fake_y, fake_x],feed_dict={cycle_gan.x: x_image, cycle_gan.y: y_image})

                # train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, teacher_loss_eval, student_loss_eval, learning_loss_eval, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            teacher_loss, student_loss, learning_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.x: x_image,
                            cycle_gan.y: y_image,
                            cycle_gan.x_label: x_label,
                            cycle_gan.y_label: y_label
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 100 == 0:

                    print('-----------Step %d:-------------' % step)
                    print('  G_loss   : {}'.format(G_loss_val))
                    print('  D_Y_loss : {}'.format(D_Y_loss_val))
                    print('  F_loss   : {}'.format(F_loss_val))
                    print('  D_X_loss : {}'.format(D_X_loss_val))
                    print('teacher_loss: {}'.format(teacher_loss_eval))
                    print('student_loss: {}'.format(student_loss_eval))
                    print('learning_loss: {}'.format(learning_loss_eval))

                if step % 100 == 0 and step >= 10:
                    print('Now is in testing! Please wait result...')
                    test_images_y, test_labels_y = get_test_batch(
                        "Y", FLAGS.image_size, FLAGS.image_size, "./dataset/")
                    fake_y_correct_cout = 0
                    for i in range((len(test_images_y))):
                        y_imgs = []
                        y_lbs = []
                        y_imgs.append(test_images_y[i])
                        y_lbs.append(test_labels_y[i])
                        y_correct_eval, fake_y_correct_eval = (sess.run(
                            [y_correct, fake_y_correct],
                            feed_dict={
                                cycle_gan.y: y_imgs,
                                cycle_gan.y_label: y_lbs
                            }))
                        if fake_y_correct_eval:
                            fake_y_correct_cout = fake_y_correct_cout + 1

                    print('fake_y_accuracy: {}'.format(fake_y_correct_cout /
                                                       len(test_labels_y)))

                    # print('Now is in testing! Please wait result...')
                    # #save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
                    # #print("Model saved in file: %s" % save_path)
                    # test_images_y,test_labels_y= get_test_batch("Y",FLAGS.image_size,FLAGS.image_size,"./dataset/")
                    # test_images_x,test_labels_x= get_test_batch("X",FLAGS.image_size,FLAGS.image_size,"./dataset/")
                    # y_correct_cout=0
                    # fake_y_correct_cout=0
                    # for i in range(min(len(test_images_y),len(test_images_x))):
                    #     y_imgs=[]
                    #     y_lbs=[]
                    #     y_imgs.append(test_images_y[i])
                    #     y_lbs.append(test_labels_y[i])
                    #     x_imgs=[]
                    #     x_lbs=[]
                    #     x_imgs.append(test_images_x[i])
                    #     x_lbs.append(test_labels_x[i])
                    #     y_correct_eval,fake_y_correct_eval = (
                    #         sess.run(
                    #             [y_correct,fake_y_correct],
                    #             feed_dict={cycle_gan.x: x_imgs, cycle_gan.y: y_imgs,
                    #                        cycle_gan.x_label: x_lbs,cycle_gan.y_label: y_lbs}
                    #         )
                    #     )
                    #     #print('y_correct_eval', y_correct_eval)
                    #     #print('y_correct_cout',y_correct_cout)
                    #     #print('fake_y_correct_eval', fake_y_correct_eval)
                    #     #print('fake_y_correct_cout',fake_y_correct_cout)
                    #     #if y_correct_eval[0][0]:
                    #     if y_correct_eval:
                    #         y_correct_cout=y_correct_cout+1
                    #     #if fake_y_correct_eval[0][0]:
                    #     if fake_y_correct_eval:
                    #         fake_y_correct_cout=fake_y_correct_cout+1
                    #
                    #
                    # print('y_accuracy: {}'.format(y_correct_cout/(min(len(test_labels_y),len(test_labels_x)))))
                    # print('fake_y_accuracy: {}'.format(fake_y_correct_cout/(min(len(test_labels_y),len(test_labels_x)))))
                    # y_accuracy_1 = format(y_correct_cout / (min(len(test_labels_y), len(test_labels_x))))
                    # fake_y_accuracy_1 = format(fake_y_correct_cout / (min(len(test_labels_y), len(test_labels_x))))
                    #
                    # #print('test_images_len:',len(test_images_y))
                    # #print('test_labels_len:', len(test_labels_y))

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            print("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda2,
            learning_rate=FLAGS.learning_rate,
            learning_rate2=FLAGS.learning_rate2,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf
        )
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, Disperse_loss, Fuzzy_loss,feature_x,feature_y,_,_ = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss, Disperse_loss)
        optimizers2 = cycle_gan.optimize2(Fuzzy_loss)

        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 1

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            x_path = FLAGS.X + FLAGS.UC_name
            print('now is in FCM initializing!')
            if FLAGS.load_model is None:
                
                x_images, x_id_list, x_len, x_labels,_ ,_= get_source_batch(0, 256, 256, source_dir=x_path)
                y_images, y_id_list, y_len, y_labels,_,_ = get_target_batch(0, 256, 256, target_dir=FLAGS.Y)
                print('x_len',len(x_images))
                print('y_len',len(y_images))
                x_data=[]
                y_data=[]
                for x in x_images:
                    feature_x_eval = ( sess.run(
                        feature_x, feed_dict={cycle_gan.x: [x]}
                    ))
                    x_data.append(feature_x_eval[0])
                for y in y_images:
                    feature_y_eval = (sess.run(
                        feature_y, feed_dict={cycle_gan.y: [y]}
                    ))
                    y_data.append(feature_y_eval[0])
                Ux, Uy, Cx, Cy= fuzzy.initialize_UC_test(x_len,x_data,y_len,y_data, FLAGS.UC_name,checkpoints_dir)
                np.savetxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",")
                np.savetxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",")
                np.savetxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",")
                np.savetxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",")

            else:
                Ux = np.loadtxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', delimiter=",")
                Ux = [[x] for x in Ux]
                Uy = np.loadtxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', delimiter=",")
                Cx = np.loadtxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', delimiter=",")
                Cx = [Cx]
                Cy = np.loadtxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', delimiter=",")
            print('FCM initialization is ended! Go to train')
            max_accuracy = 0
            while not coord.should_stop():
                images_x, idx_list, len_x, labels_x,_ ,_= get_source_batch(FLAGS.batch_size, FLAGS.image_size,
                                                                       FLAGS.image_size, source_dir=x_path)
                subUx = fuzzy.getSubU(Ux, idx_list)
                label_x = [x[0] for x in subUx]
                images_y, idy_list, len_y, labels_y,_,_ = get_target_batch(FLAGS.batch_size, FLAGS.image_size,
                                                                       FLAGS.image_size, target_dir=FLAGS.Y)
                subUy = fuzzy.getSubU(Uy, idy_list)
                label_y = [x[0] for x in subUy]
                _,_, Fuzzy_loss_val, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary, Disperse_loss_val,feature_x_eval,feature_y_eval = (
                    sess.run(
                        [optimizers,optimizers2,Fuzzy_loss, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op,
                         Disperse_loss, feature_x, feature_y],
                        feed_dict={cycle_gan.x: images_x, cycle_gan.y: images_y,
                                   cycle_gan.Uy2x: subUy, cycle_gan.Ux2y: subUx,
                                   cycle_gan.x_label: label_x, cycle_gan.y_label: label_y,
                                   cycle_gan.ClusterX: Cx, cycle_gan.ClusterY: Cy}
                    )
                )
                train_writer.add_summary(summary, step)
                train_writer.flush()
                '''
                Optimize Networks
                
                
                if step % 10 == 0:
                    print('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))
                    logging.info('  Disperse_loss : {}'.format(Disperse_loss_val))
                    logging.info('  Fuzzy_loss : {}'.format(Fuzzy_loss_val))
                
                Optimize FCM algorithm
                '''
                if step % 100== 0:
                    print('Now is in FCM training!')
                    y_images, y_id_list, y_len, y_labels,_ ,_= get_target_batch(0, 256, 256, target_dir=FLAGS.Y)
                    print('y_len', len(y_images))
                    #x_data = []
                    y_data = []
                    for y in y_images:
                        feature_y_eval = (sess.run(
                            feature_y, feed_dict={cycle_gan.y: [y]}
                        ))
                        y_data.append(feature_y_eval[0])

                    #print('y_data:',np.sum(y_data,1))
                    Uy, Cy = fuzzy.updata_U(checkpoints_dir, y_data, Uy, FLAGS.UC_name)
                    accuracy, tp, tn, fp, fn, f1_score, recall, precision, specificity=computeAccuracy(Uy, y_labels)

                    print("accuracy:%.4f\ttp:%.4f\ttn:%.4f\tfp %d\tfn:%d" %
                          (accuracy, tp, tn, fp, fn))
                    if accuracy==1:
                        break
                    if accuracy >= max_accuracy:
                        max_accuracy = accuracy
                        if not os.path.exists(checkpoints_dir + "/max"):
                            os.makedirs(checkpoints_dir + "/max")
                        f = open(checkpoints_dir + "/max/step.txt", 'w')
                        f.seek(0)
                        f.truncate()
                        f.write(str(step) + '\n')
                        f.write(str(accuracy) + '\taccuracy\n')
                        f.close()
                        np.save(checkpoints_dir + "/max/feature_fcgan.npy",y_data)
                        np.savetxt(checkpoints_dir + "/max/"+ "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",")
                        np.savetxt(checkpoints_dir + "/max/"+ "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",")
                        np.savetxt(checkpoints_dir + "/max/"+ "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",")
                        np.savetxt(checkpoints_dir + "/max/"+ "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",")
                        save_path = saver.save(sess, checkpoints_dir + "/max/model.ckpt",global_step=step)
                        print("Max model saved in file: %s" % save_path)
                    print('max_accuracy:', max_accuracy)
                    print('mean_U',np.min(Uy,0))
                step += 1
                if step>10000:
                    logging.info('train stop!')
                    break

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
            np.savetxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",")
            np.savetxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",")
            np.savetxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",")
            np.savetxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",")
            logging.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
示例#6
0
def train():
    if FLAGS.load_model is not None:  #如果该命令行参数不为空,则据此给出checkpoint_dir
        checkpoints_dir = "checkpoints/" + FLAGS.load_model
    else:  #否则,根据当前时间,创建一个checkpoint_dir
        current_time = datetime.now().strftime("%Y%m%d - %H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()  #创建计算图
    with graph.as_default():
        cycle_gan = CycleGAN(X_train_file=FLAGS.X,
                             Y_train_file=FLAGS.Y,
                             batch_size=FLAGS.batch_size,
                             image_size=FLAGS.image_size,
                             use_lsgan=FLAGS.use_lsgan,
                             norm=FLAGS.norm,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda1,
                             learning_rate=FLAGS.learning_rate,
                             beta1=FLAGS.beta1,
                             ngf=FLAGS.ngf)  #引入CycleGAN网络
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model(
        )  #返回值分别是:反向生成网络损失,正向判别函数损失,生成网络损失,逆向判别函数损失,正向生成的y,反向生成的x
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss,
                                        D_X_loss)  #四个损失的优化器

        summary_op = tf.summary.merge_all()  #将一些信息显示在stdoutput中
        train_writer = tf.summary.FileWriter(checkpoints_dir,
                                             graph)  #将图保存在checkpoints_dir中
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:  #如果已存在训练模型,则加载继续训练
            checkpoint = tf.train.get_checkpoint_state(
                checkpoints_dir)  #将最新的model加载进来
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)  #加载model结构
            restore.restore(
                sess,
                tf.train.latest_checkpoint(checkpoints_dir))  #加载最新的model模型参数
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())  #初始化全局变量
            step = 0

        coord = tf.train.Coordinator()  #进行线程管理
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_Y_pool = ImagePool(FLASG.pool_size)  #设定image缓冲大小
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                fake_y_val, fake_x_val = sess.run(
                    [fake_y, fake_x])  #先得出generated image x,y???

                #train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(
                                fake_y_val
                            ),  #将上述得到的fake_x,fake_y输入到optimizers,G_loss,...,中,优化; 假设,初始化F,D_y,然后根据x得到fake_y,然后根据G,D_x,y,得到fake_x,根据这些value:x,y,fake_x,fake_y,求上述的几个loss,利用优化器对其进行优化
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)
                        }  #还是没太弄明白 为什么一会儿fake_y,一会儿self.fake_y;是要缓冲若干个fake_y???
                    ))  #进行训练
                if step % 100 == 0:  #到100步时,将信息输出到stdout
                    train_writer.add_summary(summary, step)
                    train_writer.flush()

                if step % 100 == 0:
                    logging.info('----------step %d:--------------' % step)
                    logging.info(' G_loss : {}'.format(G_loss_val))
                    logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info(' F_loss : {}'.format(F_loss_val))
                    logging.info(' D_X_loss : {}'.format(D_X_loss_val))

                if step % 10000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(
                sess, checkpoints_dir + "/model.ckpt",
                global_step=step)  #训练完成后,将训练好的model保存起来.ckpt;
            logging.info("Model saved in file: %s" % save_path)
            coord.request_stop()
            coord.join(threads)
示例#7
0
def main():

    num_epoch = 40000
    pool_size = 20
    batch_size = 1
    oldpath = FLAGS.buckets
    picFpath = 'picF'
    picGpath = 'picG'
    useCopyfile = True

    if useCopyfile:
        trainfiles = ['picf1.zip', 'picf2.zip', 'picg1.zip']
        # trainfiles.extend(['picf3.zip','picf4.zip','picg2.zip'])

        print(trainfiles)

        for f in trainfiles:
            fn = utils.pai_copy(f, oldpath)
            utils.Unzip(fn)

        picFpath = os.path.join('temp', picFpath)
        picGpath = os.path.join('temp', picGpath)

    print(picFpath)
    print(picGpath)

    sess = tf.InteractiveSession(config=tf.ConfigProto(
        allow_soft_placement=True))

    cycle_gan = CycleGAN(X_train_file=picGpath,
                         Y_train_file=picFpath,
                         batch_size=batch_size,
                         image_size=(270, 480),
                         use_lsgan=True,
                         lossfunc='wgan',
                         norm='instance',
                         learning_rate=3e-3,
                         start_decay_step=5000,
                         decay_steps=350000
                         #optimizer = 'RMSProp'
                         )

    G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.build()
    optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.checkpointDir)
    saver = tf.train.Saver(max_to_keep=0)

    sess.run(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # save_path = saver.save(sess,os.path.join(FLAGS.checkpointDir,"model_pre.ckpt"))
    # print("Model saved in file: %s" % save_path)

    fake_Y_pool = ImagePool(pool_size)
    fake_X_pool = ImagePool(pool_size)
    print('start train')
    start_time = time.time()

    for step in range(1, num_epoch + 1):
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
            sess.run(
                [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                feed_dict={
                    cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                    cycle_gan.fake_x: fake_X_pool.query(fake_x_val)
                }))

        elapsed_time = time.time() - start_time
        start_time = time.time()

        if step % 25 == 0:
            print('G_loss : %s--D_Y_loss : %s--F_loss : %s--D_X_loss : %s--' %
                  (G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val))

            print('step : %s --elapsed_time : %s' % (step, elapsed_time))
            print('adding summary...')
            train_writer.add_summary(summary, step)
            train_writer.flush()

        # if step % 100 == 0:
        #     print('-----------Step %d:-------------' % step)
        #     print('  G_loss   : {}'.format(G_loss_val))
        #     print('  D_Y_loss : {}'.format(D_Y_loss_val))
        #     print('  F_loss   : {}'.format(F_loss_val))
        #     print('  D_X_loss : {}'.format(D_X_loss_val))

        if step % 1000 == 0:
            save_path = saver.save(sess,
                                   os.path.join(FLAGS.checkpointDir,
                                                "model.ckpt"),
                                   global_step=step,
                                   write_meta_graph=False)
            print("Model saved in file: %s" % save_path)

    coord.request_stop()
    coord.join(threads)
示例#8
0
def main(unused_argv):
    total_step = 0
    checkpoints_dir = './models/real2cartoon'
    summary_dir = './summary'

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(batch_size=FLAGS.batch_size,
                             image_size=256,
                             use_mse=FLAGS.use_mse,
                             lambda1=FLAGS.lambda1,
                             lambda2=FLAGS.lambda2,
                             learning_rate=FLAGS.learning_rate,
                             filters=FLAGS.filters,
                             beta1=FLAGS.beta1,
                             mse_label=FLAGS.mse_label,
                             file_x=FLAGS.file_x,
                             file_y=FLAGS.file_y)

        G_loss, F_loss, D_X_loss, D_Y_loss, fake_y, fake_x = cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, F_loss, D_X_loss, D_Y_loss)

        summarys = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(summary_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        ckpt = tf.train.get_checkpoint_state(checkpoints_dir)

        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
            total_step = int(
                next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
            logger.info('load model success' + ckpt.model_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())
            logger.info('start new model')

        # img_x = utils.get_img(FLAGS.file_x, FLAGS.output_height, FLAGS.output_width, FLAGS.batch_size)
        # img_y = utils.get_img(FLAGS.file_y, FLAGS.output_height, FLAGS.output_width, FLAGS.batch_size)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            fake_X_pool = utils.ImagePool(FLAGS.pool_size)
            fake_Y_pool = utils.ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # img_x, img_y = read_file()
                fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summarys
                        ],
                        feed_dict={
                            cycle_gan.x: fake_X_pool.query(fake_x_val),
                            cycle_gan.y: fake_Y_pool.query(fake_y_val)
                        }))

                train_writer.add_summary(summary, total_step)
                train_writer.flush()

                logger.info('step: {}'.format(total_step))
                if total_step > 1e5:
                    sess.run(cycle_gan.learning_rate_decay_op())

                if total_step % 100 == 0:
                    logger.info('-----------Step %d:-------------' %
                                total_step)
                    logger.info('  G_loss   : {}'.format(G_loss_val))
                    logger.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logger.info('  F_loss   : {}'.format(F_loss_val))
                    logger.info('  D_X_loss : {}'.format(D_X_loss_val))
                    logger.info('  learning_rate : {}'.format(
                        cycle_gan.learning_rate))

                if total_step % 10000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=total_step)
                    logger.info("Model saved in file: %s" % save_path)

                total_step += 1
        except KeyboardInterrupt:
            logger.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=total_step)
            logger.info("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
def train():

    # 如果存储中间训练结果的路径设置不为None 就从路径中读取数据继续训练,如果为None则建立一个新的,以时间命名的文件夹存储训练结果
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
            os.makedirs(FLAGS.res_im_path)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        # 初始化 cyclegan 类
        cycle_gan = CycleGAN(FLAGS)

        # 构建图
        G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, real_y, real_x = cycle_gan.model(
        )
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

        # 初始化summary
        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver(max_to_keep=10)

    with tf.Session(graph=graph) as sess:
        # 如果存储中间训练结果的路径设置不为None 就从路径中读取数据继续训练
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        # 初始化样本队列
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            # 初始化在线样本池
            fake_Y_pool = ImagePool(FLAGS.pool_size)
            fake_X_pool = ImagePool(FLAGS.pool_size)

            while not coord.should_stop():
                # get previously generated images
                fake_y_val, fake_x_val, real_y_in, real_x_in = sess.run(
                    [fake_y, fake_x, real_y, real_x])

                # train
                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()
                # 输出当前状态
                if step % 1 == 0:
                    logging.info('-----------Step %d:-------------' % step)
                    logging.info('  G_loss   : {}'.format(G_loss_val))
                    logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
                    logging.info('  F_loss   : {}'.format(F_loss_val))
                    logging.info('  D_X_loss : {}'.format(D_X_loss_val))

                if step % 1000 == 0:
                    ops.save_img_result(fake_y_val, fake_x_val, real_y_in,
                                        real_x_in, FLAGS.res_im_path, step)

                if step % 1000 == 0:
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    logging.info("Model saved in file: %s" % save_path)

                step += 1
                if step == FLAGS.epho:
                    coord.request_stop()  # 发出停止训练信号

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            ops.save_img_result(fake_y_val, fake_x_val, real_y_in, real_x_in,
                                FLAGS.res_im_path, step)
            logging.info("Model saved in file: %s" % save_path)

            coord.request_stop()  # 停止训练
            coord.join(threads)
def train():
    max_accuracy = 0.90
    learning_loss_set = 4.0
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip(
            "checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(
            #X_train_file=FLAGS.X,
            #Y_train_file=FLAGS.Y,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda2,
            learning_rate=FLAGS.learning_rate,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf)
        G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, \
        x_correct, y_correct, fake_x_correct, softmax3, fake_x_pre, f_fakeX, fake_x, fake_y_= cycle_gan.model()
        optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss,
                                        teacher_loss, student_loss,
                                        learning_loss)
        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = "checkpoints/20190611-1650/model.ckpt-90000.meta"
            print('meta_graph_path', meta_graph_path)
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, "checkpoints/20190611-1650/model.ckpt-90000")

            step = 90000
            #meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            #restore = tf.train.import_meta_graph(meta_graph_path)
            #restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            #step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        last_1 = 0.0
        last_2 = 0.0
        best_1 = 0.0
        best_2 = 0.0

        try:
            while not coord.should_stop():

                x_image, x_label = get_train_batch("X", FLAGS.batch_size,
                                                   FLAGS.image_size,
                                                   FLAGS.image_size,
                                                   "./dataset/")

                y_image, y_label = get_train_batch("Y", FLAGS.batch_size,
                                                   FLAGS.image_size,
                                                   FLAGS.image_size,
                                                   "./dataset/")

                _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, teacher_loss_eval, student_loss_eval, learning_loss_eval, summary = (
                    sess.run(
                        [
                            optimizers, G_loss, D_Y_loss, F_loss, D_X_loss,
                            teacher_loss, student_loss, learning_loss,
                            summary_op
                        ],
                        feed_dict={
                            cycle_gan.x: x_image,
                            cycle_gan.y: y_image,
                            cycle_gan.x_label: x_label,
                            cycle_gan.y_label: y_label
                        }))

                train_writer.add_summary(summary, step)
                train_writer.flush()

                if step % 500 == 0:

                    print('-----------Step %d:-------------' % step)
                    print('  G_loss   : {}'.format(G_loss_val))
                    print('  D_Y_loss : {}'.format(D_Y_loss_val))
                    print('  F_loss   : {}'.format(F_loss_val))
                    print('  D_X_loss vb: {}'.format(D_X_loss_val))
                    print('teacher_loss: {}'.format(teacher_loss_eval))
                    print('student_loss: {}'.format(student_loss_eval))
                    print('learning_loss: {}'.format(learning_loss_eval))

                if step % 2000 == 0 and step > 0:
                    print('Now is in testing! Please wait result...')

                    test_images_x, test_labels_x, _ = get_test_batch1(
                        'X', 1000, FLAGS.image_size, FLAGS.image_size,
                        "./dataset/")
                    test_images_y, test_labels_y = get_roc_batch(
                        FLAGS.image_size, FLAGS.image_size, "./dataset/Y")
                    y_correct_cout = 0
                    fake_x_correct_cout = 0
                    print(len(test_images_y))
                    print(len(test_images_x))
                    for i in range(min(len(test_images_y),
                                       len(test_images_x))):
                        y_imgs = []
                        y_lbs = []
                        y_imgs.append(test_images_y[i])
                        y_lbs.append(test_labels_y[i])
                        y_correct_eval, fake_x_correct_eval = (sess.run(
                            [y_correct, fake_x_correct],
                            feed_dict={
                                cycle_gan.y: y_imgs,
                                cycle_gan.y_label: y_lbs
                            }))

                        # for i in range(min(len(test_images_y),len(test_images_x))):
                        #     y_imgs=[]
                        #     y_lbs=[]
                        #     y_imgs.append(test_images_y[i])
                        #     y_lbs.append(test_labels_y[i])
                        #     x_imgs=[]
                        #     x_lbs=[]
                        #     x_imgs.append(test_images_x[i])
                        #     x_lbs.append(test_labels_x[i])
                        #     y_correct_eval,fake_x_correct_eval = (
                        #         sess.run(
                        #             [y_correct,fake_x_correct],
                        #             feed_dict={cycle_gan.x: x_imgs, cycle_gan.y: y_imgs,
                        #                        cycle_gan.x_label: x_lbs,cycle_gan.y_label: y_lbs}
                        #         )
                        #     )

                        if y_correct_eval:
                            y_correct_cout = y_correct_cout + 1
                        if fake_x_correct_eval:
                            fake_x_correct_cout = fake_x_correct_cout + 1

                    print('fake_x_correct_cout', fake_x_correct_cout)
                    print('x_accuracy: {}'.format(
                        y_correct_cout /
                        (min(len(test_labels_y), len(test_labels_x)))))
                    print('fake_x_accuracy: {}'.format(
                        fake_x_correct_cout /
                        (min(len(test_labels_y), len(test_labels_x)))))
                    y_accuracy_1 = format(
                        y_correct_cout /
                        (min(len(test_labels_y), len(test_labels_x))))
                    fake_y_accuracy_1 = format(
                        fake_x_correct_cout /
                        (min(len(test_labels_y), len(test_labels_x))))
                    save_path = saver.save(sess,
                                           checkpoints_dir + "/model.ckpt",
                                           global_step=step)
                    print("Model saved in file: %s" % save_path)

                    if float(fake_y_accuracy_1) > max_accuracy:
                        max_accuracy = float(fake_y_accuracy_1)
                        if not os.path.exists(checkpoints_dir):
                            os.makedirs(checkpoints_dir)
                        f = open(checkpoints_dir + "/step.txt", 'w')
                        f.seek(0)
                        f.truncate()
                        f.write(str(step) + '\n')
                        f.write((fake_y_accuracy_1 + '\n'))
                        f.close()
                        save_path = saver.save(sess,
                                               checkpoints_dir +
                                               "/bestmodel/model.ckpt",
                                               global_step=step)
                        print("Model saved in file: %s" % save_path)

                step += 1

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            save_path = saver.save(sess,
                                   checkpoints_dir + "/model.ckpt",
                                   global_step=step)
            print("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)