Beispiel #1
0
def setup_snapshot_3d(G, training_set):
    num_fakes = 1
    reals, label = training_set.get_minibatch_np(1)
    latents = misc.random_latents(num_fakes, G)
    labels = np.zeros([num_fakes, training_set.label_size],
                      dtype=training_set.label_dtype)
    return reals, labels, latents
def setup_snapshot_image_grid(
    G,
    training_set,
    size='6by8',
    layout='random'
):  # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label.

    # Select size.
    gw = 1
    gh = 1
    if size == '6by8':
        gw = 8
        gh = 6

    # Fill in reals and labels.
    reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)
    labels = np.zeros([gw * gh, training_set.label_size],
                      dtype=training_set.label_dtype)
    for idx in range(gw * gh):
        x = idx % gw
        y = idx // gw
        while True:
            real, label, _, _ = training_set.get_minibatch_np(1)
            if layout == 'row_per_class' and training_set.label_size > 0:
                if label[0, y % training_set.label_size] == 0.0:
                    continue
            reals[idx] = real[0]
            labels[idx] = label[0]
            break

    # Generate latents.
    latents = misc.random_latents(gw * gh, G)
    return (gw, gh), reals, labels, latents
def generate(network_pkl, out_dir):
    if os.path.exists(out_dir):
        raise ValueError('{} already exists'.format(out_dir))
    misc.init_output_logging()
    np.random.seed(config.random_seed)
    tfutil.init_tf(config.tf_config)
    with tf.device('/gpu:0'):
        G, D, Gs = misc.load_pkl(network_pkl)
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)
    # grid_size, grid_reals, grid_labels, grid_latents = train.setup_snapshot_image_grid(G, training_set, **config.grid)
    number_of_images = 1000
    grid_labels = np.zeros([number_of_images, training_set.label_size],
                           dtype=training_set.label_dtype)
    grid_latents = misc.random_latents(number_of_images, G)
    total_kimg = config.train.total_kimg
    sched = train.TrainingSchedule(total_kimg * 1000, training_set,
                                   **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)
    os.makedirs(out_dir)
    # print(np.min(grid_fakes), np.mean(grid_fakes), np.max(grid_fakes))
    # misc.save_image_grid(grid_fakes, 'fakes.png', drange=[-1,1], grid_size=grid_size)
    for i, img in enumerate(grid_fakes):
        img = img.transpose((1, 2, 0))
        img = np.clip(img, -1, 1)
        img = (1 + img) / 2
        img = skimage.img_as_ubyte(img)
        imageio.imwrite(os.path.join(out_dir, '{}.png'.format(i)),
                        img[..., :3])
        if img.shape[-1] > 3:
            np.save(os.path.join(out_dir, '{}.npy'.format(i)), img)
Beispiel #4
0
def generate_fake_images(model_path,
                         out_dir,
                         num_samples,
                         random_seed=1000,
                         image_shrink=1,
                         minibatch_size=32):
    random_state = np.random.RandomState(random_seed)

    network_pkl = model_path
    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(network_pkl)

    latents = misc.random_latents(num_samples, Gs, random_state)
    labels = np.zeros([latents.shape[0], 0], np.float32)
    images = Gs.run(latents,
                    labels,
                    minibatch_size=minibatch_size,
                    num_gpus=config.num_gpus,
                    out_mul=127.5,
                    out_add=127.5,
                    out_shrink=image_shrink,
                    out_dtype=np.uint8)
    save_dir = misc.make_dir(out_dir)
    misc.save_image_grid(images[:100], os.path.join(save_dir, 'samples.png'),
                         [0, 255], [10, 10])

    img_r01 = images.astype(np.float32) / 255.
    img_r01 = img_r01.transpose(0, 2, 3, 1)  # NCHW => NHWC
    np.savez_compressed(os.path.join(save_dir, 'generated.npz'),
                        noise=latents,
                        img_r01=img_r01)
def setup_snapshot_image_grid(G, training_set,
    size    = '1080p',      # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display.
    layout  = 'random'):    # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label.

    # Select size.
    gw = 1; gh = 1
    if size == '1080p':
        gw = np.clip(1920 // G.output_shape[3], 3, 32)
        gh = np.clip(1080 // G.output_shape[2], 2, 32)
    if size == '4k':
        gw = np.clip(3840 // G.output_shape[3], 7, 32)
        gh = np.clip(2160 // G.output_shape[2], 4, 32)

    # Fill in reals and labels.
    reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)
    labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype)
    for idx in range(gw * gh):
        x = idx % gw; y = idx // gw
        while True:
            real, label = training_set.get_minibatch_np(1)
            if layout == 'row_per_class' and training_set.label_size > 0:
                if label[0, y % training_set.label_size] == 0.0:
                    continue
            reals[idx] = real[0]
            labels[idx] = label[0]
            break

    # Generate latents.
    latents = misc.random_latents(gw * gh, G)
    return (gw, gh), reals, labels, latents
Beispiel #6
0
def setup_snapshot_image_grid(G, training_set,
    size    = '1080p',      # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display.
    layout  = 'random'):    # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label.

    # Select size.
    gw = 1; gh = 1
    if size == '1080p':
        gw = np.clip(1920 // G.output_shape[3], 3, 32)
        gh = np.clip(1080 // G.output_shape[2], 2, 32)
    if size == '4k':
        gw = np.clip(3840 // G.output_shape[3], 7, 32)
        gh = np.clip(2160 // G.output_shape[2], 4, 32)

    # Fill in reals and labels.
    reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)
    labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype)
    masks = np.zeros([gw * gh] + [1, training_set.shape[-1], training_set.shape[-1]], dtype=training_set.dtype)
    for idx in range(gw * gh):
        x = idx % gw; y = idx // gw
        while True:
            real, label, mask = training_set.get_minibatch_np(1)
            if layout == 'row_per_class' and training_set.label_size > 0:
                if label[0, y % training_set.label_size] == 0.0:
                    continue
            reals[idx] = real[0]
            labels[idx] = label[0]
            masks[idx] = mask[0]
            break

    # Generate latents.
    latents = misc.random_latents(gw * gh, G)
    return (gw, gh), reals, labels, latents, masks
def generate_fake_images(pkl_path,
                         out_dir,
                         num_pngs,
                         image_shrink=1,
                         random_seed=1000,
                         minibatch_size=1):
    random_state = np.random.RandomState(random_seed)
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)

    print('Loading network...')
    G, D, Gs = misc.load_network_pkl(pkl_path)

    latents = misc.random_latents(num_pngs, Gs, random_state=random_state)
    labels = np.zeros([latents.shape[0], 0], np.float32)
    images = Gs.run(latents,
                    labels,
                    minibatch_size=config.num_gpus * 256,
                    num_gpus=config.num_gpus,
                    out_mul=127.5,
                    out_add=127.5,
                    out_shrink=image_shrink,
                    out_dtype=np.uint8)
    for png_idx in range(num_pngs):
        print('Generating png to %s: %d / %d...' %
              (out_dir, png_idx, num_pngs),
              end='\r')
        if not os.path.exists(
                os.path.join(out_dir, 'ProGAN_%08d.png' % png_idx)):
            misc.save_image_grid(
                images[png_idx:png_idx + 1],
                os.path.join(out_dir, 'ProGAN_%08d.png' % png_idx), [0, 255],
                [1, 1])
    print()
Beispiel #8
0
def hello():
    tfutil.init_tf(config.tf_config)
    with tf.device('/gpu:0'):
        G, D, Gs = misc.load_pkl(resume_network_pkl)

    imsize = Gs.output_shape[-1]
    selected_textures = misc.random_latents(1, Gs)
    selected_shapes = get_random_mask(1)
    selected_colors = get_random_color(1)
    fake_images = Gs.run(selected_textures, selected_colors, selected_shapes)

    return "DCGAN endpoint -> /predict "
def generate_fake_images_all(run_id,
                             out_dir,
                             num_pngs,
                             image_shrink=1,
                             random_seed=1000,
                             minibatch_size=1,
                             num_pkls=50):
    random_state = np.random.RandomState(random_seed)
    out_dir = os.path.join(out_dir, str(run_id))

    result_subdir = misc.locate_result_subdir(run_id)
    snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False)
    assert len(snapshot_pkls) >= 1

    for snapshot_idx, snapshot_pkl in enumerate(snapshot_pkls[:num_pkls]):
        prefix = 'network-snapshot-'
        postfix = '.pkl'
        snapshot_name = os.path.basename(snapshot_pkl)
        tmp_dir = os.path.join(out_dir, snapshot_name.split('.')[0])
        if not os.path.isdir(tmp_dir):
            os.makedirs(tmp_dir)
        assert snapshot_name.startswith(prefix) and snapshot_name.endswith(
            postfix)
        snapshot_kimg = int(snapshot_name[len(prefix):-len(postfix)])

        print('Loading network...')
        G, D, Gs = misc.load_network_pkl(snapshot_pkl)

        latents = misc.random_latents(num_pngs, Gs, random_state=random_state)
        labels = np.zeros([latents.shape[0], 0], np.float32)
        images = Gs.run(latents,
                        labels,
                        minibatch_size=config.num_gpus * 32,
                        num_gpus=config.num_gpus,
                        out_mul=127.5,
                        out_add=127.5,
                        out_shrink=image_shrink,
                        out_dtype=np.uint8)
        for png_idx in range(num_pngs):
            print('Generating png to %s: %d / %d...' %
                  (tmp_dir, png_idx, num_pngs),
                  end='\r')
            if not os.path.exists(
                    os.path.join(out_dir, 'ProGAN_%08d.png' % png_idx)):
                misc.save_image_grid(
                    images[png_idx:png_idx + 1],
                    os.path.join(tmp_dir, 'ProGAN_%08d.png' % png_idx),
                    [0, 255], [1, 1])
    print()
Beispiel #10
0
def predict():

    tfutil.init_tf(config.tf_config)
    with tf.device('/gpu:0'):
        G, D, Gs = misc.load_pkl(resume_network_pkl)
    imsize = Gs.output_shape[-1]

    random_masks = []
    temp = Image.open(request.files['image']).convert('L')
    temp = temp.resize((imsize, imsize))
    temp = (np.float32(temp) - 127.5) / 127.5
    temp = temp.reshape((1, 1, imsize, imsize))
    random_masks.append(temp)
    masks = np.vstack(random_masks)
    #masks = get_random_mask(1)

    ctemp = []
    ctemp.append(float(request.form['R']))
    ctemp.append(float(request.form['G']))
    ctemp.append(float(request.form['B']))
    colors = np.array([ctemp], dtype=object)
    #colors = get_random_color(1)

    texid = -1
    selected_textures = None
    if request.form['texflag'] == "true":
        selected_textures = misc.random_latents(1, Gs)
        texture_list.append(selected_textures[0])
        texid = len(texture_list) - 1
    else:
        selected_textures = np.array(
            [texture_list[int(request.form['texid'])]], dtype=object)
        texid = int(request.form['texid'])
    #selected_textures = misc.random_latents(1, Gs)

    fake_images = Gs.run(selected_textures, colors, masks)
    fake_images = convert_to_image(fake_images)
    matplotlib.image.imsave('localtemp.png', fake_images[0])

    conv_image = Image.open('localtemp.png')
    buffered = io.BytesIO()
    conv_image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue())
    #jsonify({"image": str(img_str), "id": texid})
    return jsonify({
        "image": str(img_str)[2:-1],
        "id": texid
    })
def setup_snapshot_image_grid(
    G,
    training_set,
    size='6by8',  # '6by8'=6row and 8 column, '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display.
    layout='random'
):  # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label.

    # Select size.
    gw = 1
    gh = 1
    if size == '1080p':
        gw = np.clip(1920 // G.output_shape[3], 3, 32)
        gh = np.clip(1080 // G.output_shape[2], 2, 32)
    if size == '4k':
        gw = np.clip(3840 // G.output_shape[3], 7, 32)
        gh = np.clip(2160 // G.output_shape[2], 4, 32)
    if size == '6by8':
        gw = 8
        gh = 6

    # Fill in reals and labels.
    reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)
    labels = np.zeros([gw * gh, training_set.label_size],
                      dtype=training_set.label_dtype)
    wellfacies = np.zeros([gw * gh] + training_set.shape, dtype=np.float32)
    probimages = np.zeros([gw * gh] + training_set.shape, dtype=np.float32)
    for idx in range(gw * gh):
        x = idx % gw
        y = idx // gw
        while True:
            real, label, probimage, wellface = training_set.get_minibatch_np(
                1)  #
            if layout == 'row_per_class' and training_set.label_size > 0:
                if label[0, y % training_set.label_size] == 0.0:
                    continue
            reals[idx] = real[0]
            labels[idx] = label[0]
            wellfacies[idx] = wellface[0]
            probimages[idx] = probimage[0]
            break
    # Generate latents.
    latents = misc.random_latents(gw * gh, G)
    return (gw, gh), reals, labels, wellfacies, latents, probimages
Beispiel #12
0
def f_z_generator(CNN_INPUT_SIZE):
    """
    Generates an infinite stream of (z, G(z), F(G(z))) pairs, meaning latent points in the
    GAN's generator's latent space, and the features of the image generated by the GAN.
    """
    # Import official CelebA-HQ networks.
    with open('models/pg_gan/karras2018iclr-celebahq-1024x1024.pkl',
              'rb') as file:
        G, D, Gs = pickle.load(file)
    mark_detector = MarkDetector()

    age_net = cv2.dnn.readNetFromCaffe('models/race_age/deploy_age.prototxt',
                                       'models/race_age/age_net.caffemodel')
    gender_net = cv2.dnn.readNetFromCaffe(
        'models/race_age/deploy_gender.prototxt',
        'models/race_age/gender_net.caffemodel')
    #pose_estimator = PoseEstimator(img_size=(height, width))
    i = 0
    while True:
        i += 1
        z = misc.random_latents(1, G)
        labels = np.zeros([z.shape[0], 0], np.float32)
        x = G.run(z, labels, out_mul=127.5, out_add=127.5, out_dtype=np.uint8)
        #f = [calculate_facial_features(single_img, CNN_INPUT_SIZE, mark_detector) for single_img in xs]
        #print(x.shape)
        x = np.squeeze(x)
        x = np.transpose(x, (1, 2, 0))
        x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
        #print("Trying face...")
        facebox = mark_detector.extract_cnn_facebox(x)
        if facebox is None:
            continue
        f = calculate_facial_features(x, facebox, CNN_INPUT_SIZE,
                                      mark_detector, age_net, gender_net)
        #print(x.shape)
        #cv2.imwrite("inverter_images/output_{}.png".format(i), x)
        #print("Found face...")
        #f = tf.map_fn(lambda single_img: , x)
        yield (z, x, f.reshape(1, -1))
def setup_snapshot_image_grid(G,
                              training_set,
                              size='6by8'):  # '6by8'=6row and 8 column.

    gw = 8
    gh = 6

    # Fill in reals and labels.
    reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)
    labels = np.zeros([gw * gh, training_set.label_size],
                      dtype=training_set.label_dtype)
    for idx in range(gw * gh):
        x = idx % gw
        y = idx // gw
        while True:
            real, label = training_set.get_minibatch_np(1)
            reals[idx] = real[0]
            labels[idx] = label[0]
            break

    # Generate latents.
    latents = misc.random_latents(gw * gh, G)
    return (gw, gh), reals, labels, latents
os.makedirs(args.data_dir, exist_ok=True)
os.makedirs(args.mask_dir, exist_ok=True)
os.makedirs(args.generated_images_dir, exist_ok=True)
os.makedirs(args.dlatent_dir, exist_ok=True)
os.makedirs(args.dlabel_dir, exist_ok=True)

# Initialize generator and perceptual model

# load network
network_pkl = misc.locate_network_pkl(args.results_dir)
print('Loading network from "%s"...' % network_pkl)
G, D, Gs = misc.load_network_pkl(args.results_dir, None)

# initiate random input
latents = misc.random_latents(1, Gs, random_state=np.random.RandomState(800))
labels = np.random.rand(1, args.labels_size)

generator = Generator(Gs,
                      labels_size=572,
                      batch_size=1,
                      clipping_threshold=args.clipping_threshold,
                      model_res=args.resolution)

perc_model = None
if (args.use_lpips_loss > 0.00000001):
    with open(args.load_perc_model, "rb") as f:
        perc_model = pickle.load(f)

ff_model = None
beautyrater_model = beautyrater.BeautyRater(args.load_vgg_beauty_rater_model)
Beispiel #15
0
def evaluate_metrics(run_id,
                     log,
                     metrics,
                     num_images,
                     real_passes,
                     minibatch_size=None):
    metric_class_names = {
        'swd': 'metrics.sliced_wasserstein.API',
        'fid': 'metrics.frechet_inception_distance.API',
        'is': 'metrics.inception_score.API',
        'msssim': 'metrics.ms_ssim.API',
    }

    # Locate training run and initialize logging.
    result_subdir = misc.locate_result_subdir(run_id)
    snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False)
    assert len(snapshot_pkls) >= 1
    log_file = os.path.join(result_subdir, log)
    print('Logging output to', log_file)
    misc.set_output_log_file(log_file)

    # Initialize dataset and select minibatch size.
    dataset_obj, mirror_augment = misc.load_dataset_for_previous_run(
        result_subdir, verbose=True, shuffle_mb=0)
    if minibatch_size is None:
        minibatch_size = np.clip(8192 // dataset_obj.shape[1], 4, 256)

    # Initialize metrics.
    metric_objs = []
    for name in metrics:
        class_name = metric_class_names.get(name, name)
        print('Initializing %s...' % class_name)
        class_def = tfutil.import_obj(class_name)
        image_shape = [3] + dataset_obj.shape[1:]
        obj = class_def(num_images=num_images,
                        image_shape=image_shape,
                        image_dtype=np.uint8,
                        minibatch_size=minibatch_size)
        tfutil.init_uninited_vars()
        mode = 'warmup'
        obj.begin(mode)
        for idx in range(10):
            obj.feed(
                mode,
                np.random.randint(0,
                                  256,
                                  size=[minibatch_size] + image_shape,
                                  dtype=np.uint8))
        obj.end(mode)
        metric_objs.append(obj)

    # Print table header.
    print()
    print('%-10s%-12s' % ('Snapshot', 'Time_eval'), end='')
    for obj in metric_objs:
        for name, fmt in zip(obj.get_metric_names(),
                             obj.get_metric_formatting()):
            print('%-*s' % (len(fmt % 0), name), end='')
    print()
    print('%-10s%-12s' % ('---', '---'), end='')
    for obj in metric_objs:
        for fmt in obj.get_metric_formatting():
            print('%-*s' % (len(fmt % 0), '---'), end='')
    print()

    # Feed in reals.
    for title, mode in [('Reals', 'reals'), ('Reals2', 'fakes')][:real_passes]:
        print('%-10s' % title, end='')
        time_begin = time.time()
        labels = np.zeros([num_images, dataset_obj.label_size],
                          dtype=np.float32)
        [obj.begin(mode) for obj in metric_objs]
        for begin in range(0, num_images, minibatch_size):
            end = min(begin + minibatch_size, num_images)
            images, labels[begin:end] = dataset_obj.get_minibatch_np(end -
                                                                     begin)
            if mirror_augment:
                images = misc.apply_mirror_augment(images)
            if images.shape[1] == 1:
                images = np.tile(images, [1, 3, 1, 1])  # grayscale => RGB
            [obj.feed(mode, images) for obj in metric_objs]
        results = [obj.end(mode) for obj in metric_objs]
        print('%-12s' % misc.format_time(time.time() - time_begin), end='')
        for obj, vals in zip(metric_objs, results):
            for val, fmt in zip(vals, obj.get_metric_formatting()):
                print(fmt % val, end='')
        print()

    # Evaluate each network snapshot.
    for snapshot_idx, snapshot_pkl in enumerate(reversed(snapshot_pkls)):
        prefix = 'network-snapshot-'
        postfix = '.pkl'
        snapshot_name = os.path.basename(snapshot_pkl)
        assert snapshot_name.startswith(prefix) and snapshot_name.endswith(
            postfix)
        snapshot_kimg = int(snapshot_name[len(prefix):-len(postfix)])

        print('%-10d' % snapshot_kimg, end='')
        mode = 'fakes'
        [obj.begin(mode) for obj in metric_objs]
        time_begin = time.time()
        with tf.Graph().as_default(), tfutil.create_session(
                config.tf_config).as_default():
            G, D, Gs = misc.load_pkl(snapshot_pkl)
            for begin in range(0, num_images, minibatch_size):
                end = min(begin + minibatch_size, num_images)
                latents = misc.random_latents(end - begin, Gs)
                images = Gs.run(latents,
                                labels[begin:end],
                                num_gpus=config.num_gpus,
                                out_mul=127.5,
                                out_add=127.5,
                                out_dtype=np.uint8)
                if images.shape[1] == 1:
                    images = np.tile(images, [1, 3, 1, 1])  # grayscale => RGB
                [obj.feed(mode, images) for obj in metric_objs]
        results = [obj.end(mode) for obj in metric_objs]
        print('%-12s' % misc.format_time(time.time() - time_begin), end='')
        for obj, vals in zip(metric_objs, results):
            for val, fmt in zip(vals, obj.get_metric_formatting()):
                print(fmt % val, end='')
        print()
    print()
    def __init__(self,
                 model,
                 labels_size=572,
                 batch_size=1,
                 clipping_threshold=1,
                 model_res=128):
        self.batch_size = batch_size
        self.clipping_threshold = clipping_threshold
        self.initial_dlatents = misc.random_latents(
            1, model, random_state=np.random.RandomState(
                800))  #np.zeros((self.batch_size, 512))
        self.initial_dlabels = np.random.rand(self.batch_size, labels_size)
        self.sess = tf.get_default_session()
        self.graph = tf.get_default_graph()

        def get_tensor(name):
            try:
                return self.graph.get_tensor_by_name(name)
            except KeyError:
                return None

        self.dlatent_variable = tf.get_variable(
            'learnable_dlatents',
            shape=(batch_size, 512),
            dtype='float32',
            initializer=tf.initializers.random_normal())

        self.dlabel_variable = tf.get_variable(
            'learnable_dlabels',
            shape=(batch_size, labels_size),
            dtype='float32',
            initializer=tf.initializers.random_normal())

        self.generator_output = model.get_output_for(self.dlatent_variable,
                                                     self.dlabel_variable)

        self.latents_name_tensor = get_tensor(model.input_templates[0].name)
        self.labels_name_tensor = get_tensor(model.input_templates[1].name)
        self.output_name_tensor = get_tensor(model.output_templates[0].name)

        self.output_name_image = tflib.convert_images_to_uint8(
            self.output_name_tensor, nchw_to_nhwc=True, uint8_cast=False)
        self.output_name_image_uint8 = tf.saturate_cast(
            self.output_name_image, tf.uint8)

        self.set_dlatents(self.initial_dlatents)
        self.set_dlabels(self.initial_dlabels)

        self.generator_output_shape = model.output_shape

        if self.generator_output is None:
            for op in self.graph.get_operations():
                print(op)
            raise Exception("Couldn't find generator_output")

        self.generated_image = tflib.convert_images_to_uint8(
            self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
        self.generated_image_uint8 = tf.saturate_cast(self.generated_image,
                                                      tf.uint8)

        # Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782
        # (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper,
        # so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.)
        clipping_mask1 = tf.math.logical_or(
            self.dlatent_variable > self.clipping_threshold,
            self.dlatent_variable < -self.clipping_threshold)
        clipped_values1 = tf.where(
            clipping_mask1, tf.random_normal(shape=(self.batch_size, 512)),
            self.dlatent_variable)
        self.stochastic_clip_op1 = tf.assign(self.dlatent_variable,
                                             clipped_values1)

        clipping_mask2_1 = tf.math.logical_or(
            self.dlabel_variable[:, 0:60] > self.clipping_threshold,
            self.dlabel_variable[:, 0:60] < 0)
        clipping_mask2_2 = tf.math.logical_or(
            self.dlabel_variable[:, 60:] > self.clipping_threshold,
            self.dlabel_variable[:, 60:] < -self.clipping_threshold)
        clipping_mask2 = tf.concat([clipping_mask2_1, clipping_mask2_2],
                                   axis=1)
        clipped_values2 = tf.where(
            clipping_mask2,
            tf.random_normal(shape=(self.batch_size, labels_size)),
            self.dlabel_variable)
        self.stochastic_clip_op2 = tf.assign(self.dlabel_variable,
                                             clipped_values2)
Beispiel #17
0
def recovery(name,
             pkl_path1,
             pkl_path2,
             out_dir,
             target_latents_dir,
             num_init=20,
             num_total_sample=100,
             image_shrink=1,
             random_seed=2020,
             minibatch_size=1,
             noise_sigma=0):
    #     misc.init_output_logging()
    #     np.random.seed(random_seed)
    #     print('Initializing TensorFlow...')
    #     os.environ.update(config.env)
    #     tfutil.init_tf(config.tf_config)

    print('num_init:' + str(num_init))

    # load sorce model
    print('Loading network1...' + pkl_path1)
    _, _, G_sorce = misc.load_network_pkl(pkl_path1)

    # load target model
    print('Loading  network2...' + pkl_path2)
    _, _, G_target = misc.load_network_pkl(pkl_path2)

    # load Gt
    Gt = tfutil.Network('Gt',
                        num_samples=num_init,
                        num_channels=3,
                        resolution=128,
                        func='networks.G_recovery')
    latents = misc.random_latents(num_init, Gt, random_state=None)
    labels = np.zeros([latents.shape[0], 0], np.float32)
    Gt.copy_vars_from_with_input(G_target, latents)

    # load Gs
    Gs = tfutil.Network('Gs',
                        num_samples=num_init,
                        num_channels=3,
                        resolution=128,
                        func='networks.G_recovery')
    Gs.copy_vars_from_with_input(G_sorce, latents)

    out_dir = os.path.join(out_dir, name)
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    def G_loss(G, target_images):
        tmp_latents = tfutil.run(G.trainables['Input/weight'])
        G_out = G.get_output_for(tmp_latents, labels, is_training=True)
        G_out = rescale_output(G_out)
        return tf.losses.mean_squared_error(target_images, G_out)

    z_init = []
    z_recovered = []

    #load target z
    if target_latents_dir is not None:
        print('using latents:' + target_latents_dir)
        pre_latents = np.load(target_latents_dir)

    for k in range(num_total_sample):
        result_dir = os.path.join(out_dir, str(k) + '.png')

        #============sample target image
        if target_latents_dir is not None:
            latent = pre_latents[k]
        else:
            latents = misc.random_latents(1, Gs, random_state=None)
            latent = latents[0]
        z_init.append(latent)

        latents = np.zeros((num_init, 512))
        for i in range(num_init):
            latents[i] = latent
        Gt.change_input(inputs=latents)

        #================add_noise
        target_images = Gt.get_output_for(latents, labels, is_training=False)
        target_images_tf = rescale_output(target_images)
        target_images = tfutil.run(target_images_tf)

        target_images_noise = addGaussianNoise(target_images,
                                               sigma=noise_sigma)
        target_images_noise = tf.cast(target_images_noise, dtype='float32')
        target_images = target_images_noise

        #=============select random start point
        latents_2 = misc.random_latents(num_init, Gs, random_state=None)
        Gs.change_input(inputs=latents_2)

        #==============define loss&optimizer
        regularizer = tf.abs(tf.norm(latents_2) - np.sqrt(512))
        loss = G_loss(G=Gs, target_images=target_images)  # + regularizer
        # init_var = OrderedDict([('Input/weight',Gs.trainables['Input/weight'])])
        # decayed_lr = tf.train.exponential_decay(0.1,500, 50, 0.5, staircase=True)
        G_opt = tfutil.Optimizer(name='latent_recovery', learning_rate=0.01)
        G_opt.register_gradients(loss, Gs.trainables)
        G_train_op = G_opt.apply_updates()

        #===========recovery==========
        EPOCH = 500
        losses = []
        losses.append(tfutil.run(loss))
        for i in range(EPOCH):
            G_opt.reset_optimizer_state()
            tfutil.run([G_train_op])

        ########
        learned_latent = tfutil.run(Gs.trainables['Input/weight'])
        result_images = Gs.run(learned_latent,
                               labels,
                               minibatch_size=config.num_gpus * 256,
                               num_gpus=config.num_gpus,
                               out_mul=127.5,
                               out_add=127.5,
                               out_shrink=image_shrink,
                               out_dtype=np.float32)

        sample_losses = []
        tmp_latents = tfutil.run(Gs.trainables['Input/weight'])
        G_out = Gs.get_output_for(tmp_latents, labels, is_training=True)
        G_out = rescale_output(G_out)
        for i in range(num_init):
            loss = tf.losses.mean_squared_error(target_images[i], G_out[i])
            sample_losses.append(tfutil.run(loss))

        #========save best optimized image
        plt.subplot(1, 2, 1)
        plt.imshow(tfutil.run(target_images)[0].transpose(1, 2, 0) / 255.0)
        plt.subplot(1, 2, 2)
        plt.imshow(result_images[np.argmin(sample_losses)].transpose(1, 2, 0) /
                   255.0)
        plt.savefig(result_dir)

        #========store optimized z
        z_recovered.append(tmp_latents)

        #=========save losses
        #         loss=min(sample_losses)

        with open(out_dir + "/losses.txt", "a") as f:
            for loss in sample_losses:
                f.write(str(loss) + ' ')
            f.write('\n')
        np.save(out_dir + '/z_init', np.array(z_init))
        np.save(out_dir + '/z_re', np.array(z_recovered))