Esempio n. 1
0
    def load_model(self):
        checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints')
        checkpoint = Path(checkpoint_root) / f'stylegan/stylegan_{self.outclass}_{self.resolution}.pt'
        
        self.model = stylegan.StyleGAN_G(self.resolution).to(self.device)

        urls_tf = {
            'vases': 'https://thisvesseldoesnotexist.s3-us-west-2.amazonaws.com/public/network-snapshot-008980.pkl',
            'fireworks': 'https://mega.nz/#!7uBHnACY!quIW-pjdDa7NqnZOYh1z5UemWwPOW6HkYSoJ4usCg9U',
            'abstract': 'https://mega.nz/#!vCQyHQZT!zdeOg3VvT4922Z2UfxO51xgAfJD-NAK2nW7H_jMlilU',
            'anime': 'https://mega.nz/#!vawjXISI!F7s13yRicxDA3QYqYDL2kjnc2K7Zk3DwCIYETREmBP4',
            'ukiyo-e': 'https://drive.google.com/uc?id=1CHbJlci9NhVFifNQb3vCGu6zw4eqzvTd',
        }

        urls_torch = {
            'celebahq': 'https://drive.google.com/uc?export=download&id=1lGcRwNoXy_uwXkD6sy43aAa-rMHRR7Ad',
            'bedrooms': 'https://drive.google.com/uc?export=download&id=1r0_s83-XK2dKlyY3WjNYsfZ5-fnH8QgI',
            'ffhq': 'https://drive.google.com/uc?export=download&id=1GcxTcLDPYxQqcQjeHpLUutGzwOlXXcks',
            'cars': 'https://drive.google.com/uc?export=download&id=1aaUXHRHjQ9ww91x4mtPZD0w50fsIkXWt',
            'cats': 'https://drive.google.com/uc?export=download&id=1JzA5iiS3qPrztVofQAjbb0N4xKdjOOyV',
            'wikiart': 'https://drive.google.com/uc?export=download&id=1fN3noa7Rsl9slrDXsgZVDsYFxV0O08Vx',
        }

        if not checkpoint.is_file():
            os.makedirs(checkpoint.parent, exist_ok=True)
            if self.outclass in urls_torch:
                download_ckpt(urls_torch[self.outclass], checkpoint)
            else:
                checkpoint_tf = checkpoint.with_suffix('.pkl')
                if not checkpoint_tf.is_file():
                    download_ckpt(urls_tf[self.outclass], checkpoint_tf)
                print('Converting TensorFlow checkpoint to PyTorch')
                self.model.export_from_tf(checkpoint_tf)
        
        self.model.load_weights(checkpoint)
Esempio n. 2
0
    def load_model(self):
        checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints')
        checkpoint = Path(checkpoint_root) / f'progan/{self.outclass}_lsun.pth'
        
        if not checkpoint.is_file():
            os.makedirs(checkpoint.parent, exist_ok=True)
            url = f'http://netdissect.csail.mit.edu/data/ganmodel/karras/{self.outclass}_lsun.pth'
            download_ckpt(url, checkpoint)

        self.model = proggan.from_pth_file(str(checkpoint.resolve())).to(self.device)
Esempio n. 3
0
def main():
    query_folder = ''
    cfg = Config(query_folder)

    # download checkpoint and check it
    utils.download_ckpt()
    utils.check_ckpt()

    ckpt_path = utils.default_ckpt_path()
    batch_size = 4
    gpu_id = '1'
    gpu_fraction = 0.25
    test_obj = ArtiRetrieval(ckpt_path, batch_size, gpu_id, gpu_fraction)

    # generate the base of descriptors of a given directory
    des_base, base_img_paths = test_obj.generate_directory_descriptors(
        cfg.base_img_dir)

    # save the descriptors to the disk
    test_obj.save_descriptors(des_base, cfg.des_base_path)

    # load base descriptors
    des_base = test_obj.load_descriptors(cfg.des_base_path)

    # inference
    query_descriptors, query_img_paths = test_obj.generate_directory_descriptors(
        cfg.query_img_dir)

    fig_save_dir = os.path.join(os.path.dirname(cfg.query_img_dir), 'res')
    if not os.path.exists(fig_save_dir):
        os.mkdir(fig_save_dir)
    print('the plot will be saved in          ', fig_save_dir)

    # query and visualization
    for idx in range(query_descriptors.shape[0]):
        query_des = query_descriptors[idx]
        query_img_path = query_img_paths[idx]
        retrieved_idxes, retrieved_values = test_obj.get_topk_matches(
            query_des, des_base, topk=1)
        retrieved_img_path = base_img_paths[retrieved_idxes[0]]
        retrieved_value = retrieved_values[0]
        utils.vis_query_best_res(query_img_path,
                                 retrieved_value,
                                 retrieved_img_path,
                                 fig_save_dir,
                                 show_plot=False,
                                 vertical=True)

    print('visualization results are saved in ', fig_save_dir)

    # make a video demo
    utils.img2video(fig_save_dir, fps=10.0, output=fig_save_dir + '.mp4')
Esempio n. 4
0
    def download_checkpoint(self, outfile):
        checkpoints = {
            'horse': 'https://drive.google.com/uc?export=download&id=18SkqWAkgt0fIwDEf2pqeaenNi4OoCo-0',
            'ffhq': 'https://drive.google.com/uc?export=download&id=1FJRwzAkV-XWbxgTwxEmEACvuqF5DsBiV',
            'church': 'https://drive.google.com/uc?export=download&id=1HFM694112b_im01JT7wop0faftw9ty5g',
            'car': 'https://drive.google.com/uc?export=download&id=1iRoWclWVbDBAy5iXYZrQnKYSbZUqXI6y',
            'cat': 'https://drive.google.com/uc?export=download&id=15vJP8GDr0FlRYpE8gD7CdeEz2mXrQMgN',
            'places': 'https://drive.google.com/uc?export=download&id=1X8-wIH3aYKjgDZt4KMOtQzN1m4AlCVhm',
            'bedrooms': 'https://drive.google.com/uc?export=download&id=1nZTW7mjazs-qPhkmbsOLLA_6qws-eNQu',
            'kitchen': 'https://drive.google.com/uc?export=download&id=15dCpnZ1YLAnETAPB0FGmXwdBclbwMEkZ'
        }

        url = checkpoints[self.outclass]
        download_ckpt(url, outfile)
Esempio n. 5
0
def generate_des_for_24dataset():
    query_folder = ''
    cfg = Config(query_folder)

    # download checkpoint and check it
    utils.download_ckpt()
    utils.check_ckpt()

    ckpt_path = utils.default_ckpt_path()
    batch_size = 8
    gpu_id = '1'
    gpu_fraction = 0.75
    arti_obj = ArtiRetrieval(ckpt_path, batch_size, gpu_id, gpu_fraction)

    base_dir = '/data/qing/datasets/24h_Dataset/'
    descriptor_base_dir = '/data/qing/datasets/24h_Dataset/ref_des_base/'

    seqs = [
        '2019-08-05_15-09-41', '2019-08-05_15-21-20', '2019-08-06_11-36-14',
        '2019-08-06_11-50-46'
    ]
    cams = ['cam0', 'cam1']
    img_types = ['kf', 'all']

    for seq in seqs:
        for cam in cams:
            img_dir = '{}/{}/undistorted_images/{}'.format(base_dir, seq, cam)
            des_base_path = '{}/desp_{}_all_{}.npy'.format(
                descriptor_base_dir, seq, cam)

            # generate the base of descriptors of a given directory
            des_base, base_img_paths = arti_obj.generate_directory_descriptors(
                img_dir)

            arti_obj.save_descriptors(des_base, des_base_path)

            kf_img_dir = '{}/{}/undistorted_images/kf/{}'.format(
                base_dir, seq, cam)
            des_base_path = '{}/desp_{}_kf_{}.npy'.format(
                descriptor_base_dir, seq, cam)

            # generate the base of descriptors of a given directory
            des_base, base_img_paths = arti_obj.generate_directory_descriptors(
                kf_img_dir)

            arti_obj.save_descriptors(des_base, des_base_path)

            print('finished ', img_dir)
Esempio n. 6
0
    val_summary_merged = tf.summary.merge([loss_summary, acc_summary])
    val_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "summaries", "val"), graph=sess.graph)

    # checkPoint saver
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "ckpt"))
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

    sess.run(tf.global_variables_initializer())

    # Load the pre_trained weights into the non-trainable layer
    if "resnet_v1_50.ckpt" not in os.listdir("./pre_trained_models/"):
        print(" ")
        download_ckpt(url="http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz")

    resnetv1_50.load_initial_weights(sess)
    print("run the tensorboard in terminal: \ntensorboard --logdir={} --port=6006 \n".format(out_dir))

    while True:
        step = 0
        # train loop
        x_batch_train, y_batch_train = sess.run(train_next_batch)
        _, step, train_summaries, loss, accuracy = sess.run([resnetv1_50.train_op, resnetv1_50.global_step, train_summary_merged, resnetv1_50.loss, resnetv1_50.accuracy],
                                                            feed_dict={
                                                                resnetv1_50.x_input: x_batch_train,
                                                                resnetv1_50.y_input: y_batch_train,
                                                                resnetv1_50.learning_rate: FLAGS.learning_rate
                                                            })
        train_summary_writer.add_summary(train_summaries, step)
Esempio n. 7
0
    val_summary_merged = tf.summary.merge([loss_summary, acc_summary])
    val_summary_writer = tf.summary.FileWriter(os.path.join(out_dir, "summaries", "val"), graph=sess.graph)

    # checkPoint saver
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "ckpt"))
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

    sess.run(tf.global_variables_initializer())

    # Load the pre_trained weights into the non-trainable layer
    if "resnet_v2_152.ckpt" not in os.listdir("./pre_trained_models/"):
        print(" ")
        download_ckpt(url="http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz")

    resnetv2_152.load_initial_weights(sess)
    print("run the tensorboard in terminal: \ntensorboard --logdir={} --port=6006 \n".format(out_dir))

    while True:
        step = 0
        # train loop
        x_batch_train, y_batch_train = sess.run(train_next_batch)
        _, step, train_summaries, loss, accuracy = sess.run([resnetv2_152.train_op, resnetv2_152.global_step, train_summary_merged, resnetv2_152.loss, resnetv2_152.accuracy],
                                                            feed_dict={
                                                                resnetv2_152.x_input: x_batch_train,
                                                                resnetv2_152.y_input: y_batch_train,
                                                                resnetv2_152.learning_rate: FLAGS.learning_rate
                                                            })
        train_summary_writer.add_summary(train_summaries, step)
    # checkPoint saver
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "ckpt"))
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=FLAGS.num_checkpoints)

    sess.run(tf.global_variables_initializer())

    # Load the pre_trained weights into the non-trainable layer
    if "vgg_16.ckpt" not in os.listdir("./pre_trained_models/"):
        print(" ")
        download_ckpt(
            url="http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz"
        )

    vgg16.load_initial_weights(sess)
    print(
        "run the tensorboard in terminal: \ntensorboard --logdir={} --port=6006 \n"
        .format(out_dir))

    while True:
        step = 0
        # train loop
        x_batch_train, y_batch_train = sess.run(train_next_batch)
        _, step, train_summaries, loss, accuracy = sess.run(
            [
                vgg16.train_op, vgg16.global_step, train_summary_merged,
                vgg16.loss, vgg16.accuracy
    # checkPoint saver
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "ckpt"))
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=FLAGS.num_checkpoints)

    sess.run(tf.global_variables_initializer())

    # Load the pre_trained weights into the non-trainable layer
    if "inception_v3.ckpt" not in os.listdir("./pre_trained_models/"):
        print(" ")
        download_ckpt(
            url=
            "http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz"
        )

    inceptionv3.load_initial_weights(sess)
    print(
        "run the tensorboard in terminal: \ntensorboard --logdir={} --port=6006 \n"
        .format(out_dir))

    while True:
        step = 0
        # train loop
        x_batch_train, y_batch_train = sess.run(train_next_batch)
        _, step, train_summaries, loss, accuracy = sess.run(
            [
                inceptionv3.train_op, inceptionv3.global_step,
                train_summary_merged, inceptionv3.loss, inceptionv3.accuracy
Esempio n. 10
0
    def download_checkpoint(self, outfile):
        checkpoints = conf["checkpoints"]

        url = checkpoints[self.outclass]
        download_ckpt(url, outfile)