コード例 #1
0
def generate_fake_images(run_id=None,
                         snapshot=None,
                         grid_size=[1, 1],
                         num_pngs=1,
                         image_shrink=1,
                         png_prefix=None,
                         random_seed=1000,
                         minibatch_size=8,
                         path=None,
                         latent=None,
                         Gs=None,
                         labels=None):
    if path and latent is not None and Gs is not None:
        # latents = misc.random_latents(np.prod([1, 1]), Gs, random_state=random_state)
        if labels is None:
            labels = np.zeros([latent.shape[0], 0], np.float32)
        images = Gs.run(latent,
                        labels,
                        minibatch_size=minibatch_size,
                        num_gpus=config.num_gpus,
                        out_mul=127.5,
                        out_add=127.5,
                        out_shrink=image_shrink,
                        out_dtype=np.uint8)
        misc.save_image_grid(images, os.path.join(path), [0, 255], [1, 1])
        return

    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    if png_prefix is None:
        png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-'
    random_state = np.random.RandomState(random_seed)

    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(run_id, snapshot)

    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    for png_idx in range(num_pngs):
        print('Generating png %d / %d...' % (png_idx, num_pngs))
        latents = misc.random_latents(np.prod(grid_size),
                                      Gs,
                                      random_state=random_state)
        labels = np.zeros([latents.shape[0], 0], np.float32)
        images = Gs.run(latents,
                        labels,
                        minibatch_size=minibatch_size,
                        num_gpus=config.num_gpus,
                        out_mul=127.5,
                        out_add=127.5,
                        out_shrink=image_shrink,
                        out_dtype=np.uint8)
        misc.save_image_grid(
            images,
            os.path.join(result_subdir, '%s%06d.png' % (png_prefix, png_idx)),
            [0, 255], grid_size)
    open(os.path.join(result_subdir, '_done.txt'), 'wt').close()
コード例 #2
0
def generate_fake_interpolate_midle_images(run_id,
                                           snapshot=None,
                                           grid_size=[1, 1],
                                           num_pngs=1,
                                           image_shrink=1,
                                           png_prefix=None,
                                           random_seed=1000,
                                           minibatch_size=8,
                                           middle_img=10):
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    if png_prefix is None:
        png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-'
    random_state = np.random.RandomState(random_seed)

    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(run_id, snapshot)

    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    for png_idx in range(num_pngs):
        latents = misc.random_latents(middle_img + 2,
                                      Gs,
                                      random_state=random_state)
        from_to_tensor = latents[middle_img + 1] - latents[0]
        from_z = latents[0]
        #between_x_list = [from_x]
        counter = 0
        for alpha in np.linspace(-0.5, 0.5, middle_img +
                                 2):  #np.linspace(0, 1, middle_img + 1):
            print('alpha: ', alpha, 'counter= ', counter)
            between_z = from_z + alpha * from_to_tensor
            latents[counter] = between_z
            counter += 1
        labels = np.zeros([latents.shape[0], 0], np.float32)
        images = Gs.run(latents,
                        labels,
                        minibatch_size=minibatch_size,
                        num_gpus=config.num_gpus,
                        out_mul=127.5,
                        out_add=127.5,
                        out_shrink=image_shrink,
                        out_dtype=np.uint8)
        #grid_size_1=[middle_img+2,1]
        grid_size_1 = [middle_img + 1, 1]
        #png_prefix=0

        misc.save_image_grid(
            images[1:, :, :, :],
            os.path.join(result_subdir, '%s%06d.png' % (png_prefix, png_idx)),
            [0, 255], grid_size_1)
    '''
コード例 #3
0
def find_latent_with_query_image(run_id, snapshot=None, grid_size=[1,1], num_pngs=1, image_shrink=1, png_prefix=None, random_seed=4123, minibatch_size=8):
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    if png_prefix is None:
        png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-'
    random_state = np.random.RandomState(random_seed)

    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(run_id, snapshot)

    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    
    # Create query image - tensorflow constant
    query_image = cv2.imread('../../data/ACDC/training/patient001/cardiac_cycles/0/0.png')
    query_image = cv2.resize(query_image, (256, 256))
    print('Saving query image to "%s"...' % result_subdir)
    cv2.imwrite(result_subdir+'/query_image.png', query_image)
    query_image = query_image.transpose(2,0,1)
    query_image = query_image[np.newaxis]
    x = tf.constant(query_image, dtype=tf.float32, name='query_image')
    # Create G(z) - tensorflow variable and label
    latent = misc.random_latents(np.prod(grid_size), Gs, random_state=random_state)
    initial = tf.constant(latent, dtype=tf.float32)
    z = tf.Variable(initial_value=initial, dtype=tf.float32, name='latent_space')
    label = np.zeros([latent.shape[0], 5], np.float32)
    label[:,4] = 1 # | 0 -> NOR | 1 -> DCM | 2 -> HCM | 3 -> MINF | 4 -> RV | 
    gz = Gs.run(latent, label, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.float32)
    gz = tf.Variable(gz, dtype=tf.float32)
    # Define a loss function
    residual_loss = tf.losses.absolute_difference(x, gz)
    # Define an optimizer
    train_op = tf.train.AdamOptimizer(learning_rate=0.01).minimize(residual_loss)
    
    zs, gzs, step = [], [], 1
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        _, loss_value = sess.run([train_op, residual_loss])
        while (loss_value > 2e-04 and step <= 50000):
            _, loss_value = sess.run([train_op, residual_loss])
            step += 1
            if step % 10000 == 0:
                print('Step {}, Loss value: {}'.format(step, loss_value))
                gzs.append(sess.run(gz))
                zs.append(sess.run(z))
                
    for png_idx, image in enumerate(gzs):
        misc.save_image_grid(image, os.path.join(result_subdir, '%s%06d.png' % (png_prefix, png_idx)), [0,255], grid_size)
        
    np.save(result_subdir+'/zs.npy', np.asarray(zs))
コード例 #4
0
def generate_fake_images(run_id, snapshot=None, grid_size=[1,1], num_pngs=1, image_shrink=1, png_prefix=None, random_seed=1000, minibatch_size=8):
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    if png_prefix is None:
        png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-'
    random_state = np.random.RandomState(random_seed)

    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(run_id, snapshot)

    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    for png_idx in range(num_pngs):
        print('Generating png %d / %d...' % (png_idx, num_pngs))
        latents = misc.random_latents(np.prod(grid_size), Gs, random_state=random_state)
        labels = np.zeros([latents.shape[0], 0], np.float32)
        images = Gs.run(latents, labels, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8)
        misc.save_image_grid(images, os.path.join(result_subdir, '%s%06d.png' % (png_prefix, png_idx)), [0,255], grid_size)
    open(os.path.join(result_subdir, '_done.txt'), 'wt').close()
コード例 #5
0
def generate_fake_images_glob(run_id,
                              snapshot=None,
                              grid_size=[1, 1],
                              num_pngs=1,
                              image_shrink=1,
                              png_prefix=None,
                              random_seed=1000,
                              minibatch_size=8):
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    if png_prefix is None:
        png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-'
    random_state = np.random.RandomState(random_seed)

    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(run_id, snapshot)
    latents = random_state.randn(num_pngs,
                                 *G.input_shape[1:]).astype(np.float32)
    dist = cdist(latents, latents)
    np.fill_diagonal(dist, 100)
    result_subdir = misc.create_result_subdir(config_test.result_dir,
                                              config_test.desc)
    for png_idx in range(num_pngs):
        print('Generating png %d / %d...' % (png_idx, num_pngs))
        latents = misc.random_latents(np.prod(grid_size),
                                      Gs,
                                      random_state=random_state)
        labels = np.zeros([latents.shape[0], 0], np.float32)
        images = Gs.run(latents,
                        labels,
                        minibatch_size=minibatch_size,
                        num_gpus=config_test.num_gpus,
                        out_mul=127.5,
                        out_add=127.5,
                        out_shrink=image_shrink,
                        out_dtype=np.uint8)
        misc.save_image_grid(
            images,
            os.path.join(result_subdir, '%s%06d.png' % (png_prefix, png_idx)),
            [0, 255], grid_size)
    open(os.path.join(result_subdir, '_done.txt'), 'wt').close()
コード例 #6
0
def generate_fake_images(Gs,
                         D,
                         random_state,
                         race,
                         gender,
                         num_pngs=1,
                         grid_size=[1, 1],
                         image_shrink=1,
                         png_prefix=None,
                         random_seed=1000,
                         minibatch_size=8):

    for png_idx in range(num_pngs):
        print('Generating png %d / %d...' % (png_idx, num_pngs))
        latents = misc.random_latents(np.prod(grid_size),
                                      Gs,
                                      random_state=random_state)

        trans_label, thr = parse_label(race, gender)
        labels = np.zeros([latents.shape[0], 8], np.float32)
        labels[:, trans_label] = 1.0
        images = Gs.run(latents,
                        labels,
                        minibatch_size=minibatch_size,
                        num_gpus=config.num_gpus,
                        out_mul=127.5,
                        out_add=127.5,
                        out_shrink=image_shrink,
                        out_dtype=np.uint8)

        score, _ = D.run(images)
        if score >= thr:
            # output image
            img = images[0].transpose(1, 2, 0)
            scipy.misc.imsave('test%d.jpg' % png_idx, img)
            print('save image %d' % png_idx)
コード例 #7
0
def evaluate_metrics(run_id, log, metrics, num_images, real_passes, minibatch_size=None):
    metric_class_names = {
        'swd':      'metrics.sliced_wasserstein.API',
        'fid':      'metrics.frechet_inception_distance.API',
        'is':       'metrics.inception_score.API',
        'msssim':   'metrics.ms_ssim.API',
    }

    # Locate training run and initialize logging.
    result_subdir = misc.locate_result_subdir(run_id)
    snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False)
    assert len(snapshot_pkls) >= 1
    log_file = os.path.join(result_subdir, log)
    print('Logging output to', log_file)
    misc.set_output_log_file(log_file)

    # Initialize dataset and select minibatch size.
    dataset_obj, mirror_augment = misc.load_dataset_for_previous_run(result_subdir, verbose=True, shuffle_mb=0)
    if minibatch_size is None:
        minibatch_size = np.clip(8192 // dataset_obj.shape[1], 4, 256)

    # Initialize metrics.
    metric_objs = []
    for name in metrics:
        class_name = metric_class_names.get(name, name)
        print('Initializing %s...' % class_name)
        class_def = tfutil.import_obj(class_name)
        image_shape = [3] + dataset_obj.shape[1:]
        obj = class_def(num_images=num_images, image_shape=image_shape, image_dtype=np.uint8, minibatch_size=minibatch_size)
        tfutil.init_uninited_vars()
        mode = 'warmup'
        obj.begin(mode)
        for idx in range(10):
            obj.feed(mode, np.random.randint(0, 256, size=[minibatch_size]+image_shape, dtype=np.uint8))
        obj.end(mode)
        metric_objs.append(obj)

    # Print table header.
    print()
    print('%-10s%-12s' % ('Snapshot', 'Time_eval'), end='')
    for obj in metric_objs:
        for name, fmt in zip(obj.get_metric_names(), obj.get_metric_formatting()):
            print('%-*s' % (len(fmt % 0), name), end='')
    print()
    print('%-10s%-12s' % ('---', '---'), end='')
    for obj in metric_objs:
        for fmt in obj.get_metric_formatting():
            print('%-*s' % (len(fmt % 0), '---'), end='')
    print()

    # Feed in reals.
    for title, mode in [('Reals', 'reals'), ('Reals2', 'fakes')][:real_passes]:
        print('%-10s' % title, end='')
        time_begin = time.time()
        labels = np.zeros([num_images, dataset_obj.label_size], dtype=np.float32)
        [obj.begin(mode) for obj in metric_objs]
        for begin in range(0, num_images, minibatch_size):
            end = min(begin + minibatch_size, num_images)
            images, labels[begin:end] = dataset_obj.get_minibatch_np(end - begin)
            if mirror_augment:
                images = misc.apply_mirror_augment(images)
            if images.shape[1] == 1:
                images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB
            [obj.feed(mode, images) for obj in metric_objs]
        results = [obj.end(mode) for obj in metric_objs]
        print('%-12s' % misc.format_time(time.time() - time_begin), end='')
        for obj, vals in zip(metric_objs, results):
            for val, fmt in zip(vals, obj.get_metric_formatting()):
                print(fmt % val, end='')
        print()

    # Evaluate each network snapshot.
    for snapshot_idx, snapshot_pkl in enumerate(reversed(snapshot_pkls)):
        prefix = 'network-snapshot-'; postfix = '.pkl'
        snapshot_name = os.path.basename(snapshot_pkl)
        assert snapshot_name.startswith(prefix) and snapshot_name.endswith(postfix)
        snapshot_kimg = int(snapshot_name[len(prefix) : -len(postfix)])

        print('%-10d' % snapshot_kimg, end='')
        mode ='fakes'
        [obj.begin(mode) for obj in metric_objs]
        time_begin = time.time()
        with tf.Graph().as_default(), tfutil.create_session(config.tf_config).as_default():
            G, D, Gs = misc.load_pkl(snapshot_pkl)
            for begin in range(0, num_images, minibatch_size):
                end = min(begin + minibatch_size, num_images)
                latents = misc.random_latents(end - begin, Gs)
                images = Gs.run(latents, labels[begin:end], num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_dtype=np.uint8)
                if images.shape[1] == 1:
                    images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB
                [obj.feed(mode, images) for obj in metric_objs]
        results = [obj.end(mode) for obj in metric_objs]
        print('%-12s' % misc.format_time(time.time() - time_begin), end='')
        for obj, vals in zip(metric_objs, results):
            for val, fmt in zip(vals, obj.get_metric_formatting()):
                print(fmt % val, end='')
        print()
    print()
コード例 #8
0
ファイル: util_scripts.py プロジェクト: zeta1999/TBGAN
def generate_fake_images(run_id,
                         snapshot=None,
                         grid_size=[1, 1],
                         batch_size=8,
                         num_pngs=1,
                         image_shrink=1,
                         png_prefix=None,
                         random_seed=1000,
                         minibatch_size=8):
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    if png_prefix is None:
        png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-'
    random_state = np.random.RandomState(random_seed)

    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(run_id, snapshot)

    lsfm_model = m3io.import_lsfm_model(
        '/home/baris/Projects/faceganhd/models/all_all_all.mat')
    lsfm_tcoords = \
    mio.import_pickle('/home/baris/Projects/team members/stelios/UV_spaces_V2/UV_dicts/full_face/512_UV_dict.pkl')[
        'tcoords']
    lsfm_params = []

    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    for png_idx in range(int(num_pngs / batch_size)):
        start = time.time()
        print('Generating png %d-%d / %d... in ' %
              (png_idx * batch_size, (png_idx + 1) * batch_size, num_pngs),
              end='')
        latents = misc.random_latents(np.prod(grid_size) * batch_size,
                                      Gs,
                                      random_state=random_state)
        labels = np.zeros([latents.shape[0], 0], np.float32)
        images = Gs.run(latents,
                        labels,
                        minibatch_size=minibatch_size,
                        num_gpus=config.num_gpus,
                        out_shrink=image_shrink)
        for i in range(batch_size):
            if images.shape[1] == 3:
                mio.export_pickle(
                    images[i],
                    os.path.join(
                        result_subdir,
                        '%s%06d.pkl' % (png_prefix, png_idx * batch_size + i)))
                # misc.save_image(images[i], os.path.join(result_subdir, '%s%06d.png' % (png_prefix, png_idx*batch_size+i)), [0,255], grid_size)
            elif images.shape[1] == 6:
                mio.export_pickle(images[i][3:6],
                                  os.path.join(
                                      result_subdir, '%s%06d.pkl' %
                                      (png_prefix, png_idx * batch_size + i)),
                                  overwrite=True)
                misc.save_image(
                    images[i][0:3],
                    os.path.join(
                        result_subdir,
                        '%s%06d.png' % (png_prefix, png_idx * batch_size + i)),
                    [-1, 1], grid_size)
            elif images.shape[1] == 9:
                texture = Image(np.clip(images[i, 0:3] / 2 + 0.5, 0, 1))
                mesh_raw = from_UV_2_3D(Image(images[i, 3:6]))
                normals = images[i, 6:9]
                normals_norm = (normals - normals.min()) / (normals.max() -
                                                            normals.min())
                mesh = lsfm_model.reconstruct(mesh_raw)
                lsfm_params.append(lsfm_model.project(mesh_raw))
                t_mesh = TexturedTriMesh(mesh.points, lsfm_tcoords.points,
                                         texture, mesh.trilist)
                m3io.export_textured_mesh(
                    t_mesh,
                    os.path.join(result_subdir,
                                 '%06d.obj' % (png_idx * minibatch_size + i)),
                    texture_extension='.png')
                mio.export_image(
                    Image(normals_norm),
                    os.path.join(
                        result_subdir,
                        '%06d_nor.png' % (png_idx * minibatch_size + i)))
                shape = images[i, 3:6]
                shape_norm = (shape - shape.min()) / (shape.max() -
                                                      shape.min())
                mio.export_image(
                    Image(shape_norm),
                    os.path.join(
                        result_subdir,
                        '%06d_shp.png' % (png_idx * minibatch_size + i)))
                mio.export_pickle(
                    t_mesh,
                    os.path.join(result_subdir,
                                 '%06d.pkl' % (png_idx * minibatch_size + i)))

        print('%0.2f seconds' % (time.time() - start))

    open(os.path.join(result_subdir, '_done.txt'), 'wt').close()
コード例 #9
0
def compute_fid(Gs,
                minibatch_size,
                dataset_obj,
                iter_number,
                lod=0,
                num_images=10000,
                printing=True):

    # Initialize metrics.
    from metrics.frechet_inception_distance import API as class_def
    image_shape = [3] + dataset_obj.shape[1:]
    obj = class_def(num_images=num_images,
                    image_shape=image_shape,
                    image_dtype=np.uint8,
                    minibatch_size=minibatch_size)
    mode = 'warmup'
    obj.begin(mode)
    for idx in range(10):
        obj.feed(
            mode,
            np.random.randint(0,
                              256,
                              size=[minibatch_size] + image_shape,
                              dtype=np.uint8))
    obj.end(mode)

    # Print table header
    if printing:
        print(flush=True)
        print('%-10s%-12s' % ('KIMG', 'Time_eval'), end='', flush=True)
        print('%-12s' % ('FID'), end='', flush=True)
        print(flush=True)
        print('%-10s%-12s%-12s' % ('---', '---', '---'), end='', flush=True)
        print(flush=True)

    # Feed in reals.
    print('%-10s' % "Reals", end='', flush=True)
    time_begin = time.time()
    labels = np.zeros([num_images, dataset_obj.label_size], dtype=np.float32)
    obj.begin(mode)
    for begin in range(0, num_images, minibatch_size):
        end = min(begin + minibatch_size, num_images)
        images, labels[begin:end] = dataset_obj.get_minibatch_np(end - begin,
                                                                 lod=lod)
        if images.shape[1] == 1:
            images = np.tile(images, [1, 3, 1, 1])  # grayscale => RGB
            obj.feed(mode, images)
    results = obj.end(mode)
    if printing:
        print('%-12s' % misc.format_time(time.time() - time_begin),
              end='',
              flush=True)
        print(results[0], end='', flush=True)
        print(flush=True)

    # Evaluate each network snapshot.
    if printing:
        print('%-10d' % iter_number, end='', flush=True)
    mode = 'fakes'
    obj.begin(mode)
    time_begin = time.time()
    for begin in range(0, num_images, minibatch_size):
        end = min(begin + minibatch_size, num_images)
        latents = misc.random_latents(end - begin, Gs)
        images = Gs.run(latents,
                        labels[begin:end],
                        num_gpus=config.num_gpus,
                        out_mul=127.5,
                        out_add=127.5,
                        out_dtype=np.uint8)
        if images.shape[1] == 1:
            images = np.tile(images, [1, 3, 1, 1])  # grayscale => RGB
        obj.feed(mode, images)
    results = obj.end(mode)
    if printing:
        print('%-12s' % misc.format_time(time.time() - time_begin),
              end='',
              flush=True)
        print(results[0], end='', flush=True)
        print(flush=True)
    return results[0]
コード例 #10
0
def generate_fake_images(run_id,
                         snapshot=None,
                         grid_size=[1, 1],
                         num_pngs=1,
                         image_shrink=1,
                         png_prefix=None,
                         random_seed=1000,
                         minibatch_size=8):

    embeddings_contant = False
    labels_constant = False
    latents_constant = False

    idx = random.randint(0, 56880)
    df = pandas.read_csv('datasets/50k_sorted_tf/50k_index_sorted.csv')
    print('embeddings_contant : ' + str(embeddings_contant))
    print('labels_constant : ' + str(labels_constant))
    print('latents_constant : ' + str(latents_constant))

    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    if png_prefix is None:
        png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-'
    random_state = np.random.RandomState(random_seed)

    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(run_id, snapshot)

    result_subdir = misc.create_result_subdir(config.result_dir + '/' + run_id,
                                              config.desc)

    if latents_constant:
        latents = misc.random_latents(np.prod(grid_size),
                                      Gs,
                                      random_state=None)
    #embeddings = np.zeros([1, 300], dtype=np.float32)
    #labels = np.zeros([1, 32], dtype=np.float32)
    embeddings = np.load(
        'datasets/50k_sorted_tf/sum_embedding_title.embeddings')
    embeddings = embeddings.astype('float32')

    labels = np.load(
        'datasets/50k_sorted_tf/sum_embedding_category_average.labels')
    labels = labels.astype('float32')
    name1 = ''
    if labels_constant:
        label = labels[idx]
        name1 = name1 + ' ' + df.at[idx, 'category1']
        label = label.reshape(1, label.shape[0])

    if embeddings_contant:
        embedding = embeddings[idx]
        title = df.at[idx, 'title']
        name1 = name1 + ' ' + title[:10]
        embedding = embedding.reshape(1, embedding.shape[0])

    #print(latents.shape)
    for png_idx in range(num_pngs):
        name = ''
        name = name + name1
        print('Generating png %d / %d...' % (png_idx, num_pngs))
        rand = random.randint(0, 56880)
        #rand = png_idx * 1810
        #labels = sess.run(classes[0])
        if not latents_constant:
            latents = misc.random_latents(np.prod(grid_size),
                                          Gs,
                                          random_state=random_state)
        if not labels_constant:
            label = labels[rand]
            label = label.reshape(1, label.shape[0])
            name = name + ' ' + df.at[rand, 'category1']
        if not embeddings_contant:
            embedding = embeddings[rand]
            title = df.at[rand, 'title']
            name = name + ' ' + title[:10]
            embedding = embedding.reshape(1, embedding.shape[0])

        #print(labels.shape)
        images = Gs.run(latents,
                        label,
                        embedding,
                        minibatch_size=minibatch_size,
                        num_gpus=config.num_gpus,
                        out_mul=127.5,
                        out_add=127.5,
                        out_shrink=image_shrink,
                        out_dtype=np.uint8)
        misc.save_image_grid(
            images, os.path.join(result_subdir,
                                 '%s%06d.png' % (name, png_idx)), [0, 255],
            grid_size)
    open(os.path.join(result_subdir, '_done.txt'), 'wt').close()
コード例 #11
0
def generate_fake_images(run_id,
                         snapshot=None,
                         grid_size=[1, 1],
                         batch_size=8,
                         num_pngs=1,
                         image_shrink=1,
                         png_prefix=None,
                         random_seed=1000,
                         minibatch_size=8):
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    if png_prefix is None:
        png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-'
    random_state = np.random.RandomState(random_seed)

    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(run_id, snapshot)

    result_subdir = misc.create_result_subdir(config_test.result_dir,
                                              config_test.desc)
    for png_idx in range(int(num_pngs / batch_size)):
        start = time.time()
        print('Generating png %d-%d / %d... in ' %
              (png_idx * batch_size, (png_idx + 1) * batch_size, num_pngs),
              end='')
        latents = misc.random_latents(np.prod(grid_size) * batch_size,
                                      Gs,
                                      random_state=random_state)
        labels = np.zeros([latents.shape[0], 7], np.float32)
        images = Gs.run(latents,
                        labels,
                        minibatch_size=minibatch_size,
                        num_gpus=config_test.num_gpus,
                        out_shrink=image_shrink)
        for i in range(batch_size):
            if images.shape[1] == 3:
                mio.export_pickle(
                    images[i],
                    os.path.join(
                        result_subdir,
                        '%s%06d.pkl' % (png_prefix, png_idx * batch_size + i)))
                # misc.save_image(images[i], os.path.join(result_subdir, '%s%06d.png' % (png_prefix, png_idx*batch_size+i)), [0,255], grid_size)
            elif images.shape[1] == 6:
                mio.export_pickle(images[i][3:6],
                                  os.path.join(
                                      result_subdir, '%s%06d.pkl' %
                                      (png_prefix, png_idx * batch_size + i)),
                                  overwrite=True)
                misc.save_image(
                    images[i][0:3],
                    os.path.join(
                        result_subdir,
                        '%s%06d.png' % (png_prefix, png_idx * batch_size + i)),
                    [-1, 1], grid_size)
            elif images.shape[1] == 9:
                mio.export_pickle(images[i][3:6],
                                  os.path.join(
                                      result_subdir, '%s%06d_shp.pkl' %
                                      (png_prefix, png_idx * batch_size + i)),
                                  overwrite=True)
                mio.export_pickle(images[i][6:9],
                                  os.path.join(
                                      result_subdir, '%s%06d_nor.pkl' %
                                      (png_prefix, png_idx * batch_size + i)),
                                  overwrite=True)
                misc.save_image(
                    images[i][0:3],
                    os.path.join(
                        result_subdir,
                        '%s%06d.png' % (png_prefix, png_idx * batch_size + i)),
                    [-1, 1], grid_size)
        print('%0.2f seconds' % (time.time() - start))

    open(os.path.join(result_subdir, '_done.txt'), 'wt').close()
コード例 #12
0
def evaluate_metrics_swd_distributions_training_trad_prog(
        run_id,
        network_dir_conv,
        network_dir_prog,
        log,
        metrics,
        num_images_per_group,
        num_groups,
        real_passes,
        minibatch_size=None):
    metric_class_names = {
        'swd_distri_training_trad_prog':
        'metrics.swd_distributions_training_trad_prog.API',
    }
    # Locate training run and initialize logging.
    result_subdir = misc.locate_result_subdir(run_id)
    log_file = os.path.join(result_subdir, log)
    print('Logging output to', log_file)
    misc.set_output_log_file(log_file)

    # Initialize dataset and select minibatch size.
    dataset_obj, mirror_augment = misc.load_dataset_for_previous_run(
        result_subdir, verbose=True, shuffle_mb=0)

    # Initialize metrics.
    metric_objs = []
    for name in metrics:
        class_name = metric_class_names.get(name, name)
        print('Initializing %s...' % class_name)
        class_def = tfutil.import_obj(class_name)
        image_shape = [3] + dataset_obj.shape[1:]
        obj = class_def(image_shape=image_shape, image_dtype=np.uint8)
        tfutil.init_uninited_vars()
        metric_objs.append(obj)

        mode = 'fakes'
        [obj.begin(mode) for obj in metric_objs]
        images_real, labels = dataset_obj.get_minibatch_np(
            num_groups * num_images_per_group)

        with tf.Graph().as_default(), tfutil.create_session(
                config.tf_config).as_default():
            G, D, Gs = misc.load_pkl(network_dir_conv)
            #G, D, Gs = pickle.load(file)
            latents = misc.random_latents(num_groups * num_images_per_group,
                                          Gs)
            images = images_real
            for k in range(
                    10
            ):  # because Gs can not generate lots of (>3000 around) images at one time. Make sure /10 = int
                nn = int(num_groups * num_images_per_group / 10)
                images_fake = Gs.run(latents[k * nn:(k + 1) * nn],
                                     labels[k * nn:(k + 1) * nn],
                                     num_gpus=config.num_gpus,
                                     out_mul=127.5,
                                     out_add=127.5,
                                     out_dtype=np.uint8)
                images = np.concatenate((images, images_fake), axis=0)

        with tf.Graph().as_default(), tfutil.create_session(
                config.tf_config).as_default():
            G, D, Gs = misc.load_pkl(network_dir_prog)
            #  G, D, Gs = pickle.load(file)
            latents = misc.random_latents(num_groups * num_images_per_group,
                                          Gs)
            for k in range(
                    10
            ):  # because Gs can not generate lots of (>3000 around) images at one time. Make sure /10 = int
                nn = int(num_groups * num_images_per_group / 10)
                images_fake = Gs.run(latents[k * nn:(k + 1) * nn],
                                     labels[k * nn:(k + 1) * nn],
                                     num_gpus=config.num_gpus,
                                     out_mul=127.5,
                                     out_add=127.5,
                                     out_dtype=np.uint8)
                images = np.concatenate((images, images_fake), axis=0)

        if images.shape[1] == 1:
            images = np.tile(images, [1, 3, 1, 1])  # grayscale => RGB
        [
            obj.feed(mode, images, num_images_per_group, num_groups,
                     result_subdir) for obj in metric_objs
        ]
コード例 #13
0
def evaluate_metrics_swd_distributions(run_id,
                                       log,
                                       metrics,
                                       num_images_per_group,
                                       num_groups,
                                       real_passes,
                                       minibatch_size=None):
    metric_class_names = {
        'swd_distri': 'metrics.swd_distributions.API',
    }
    # Locate training run and initialize logging.
    result_subdir = misc.locate_result_subdir(run_id)
    snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False)
    assert len(snapshot_pkls) >= 1
    log_file = os.path.join(result_subdir, log)
    print('Logging output to', log_file)
    misc.set_output_log_file(log_file)

    # Initialize dataset and select minibatch size.
    dataset_obj, mirror_augment = misc.load_dataset_for_previous_run(
        result_subdir, verbose=True, shuffle_mb=0)

    # Initialize metrics.
    metric_objs = []
    for name in metrics:
        class_name = metric_class_names.get(name, name)
        print('Initializing %s...' % class_name)
        class_def = tfutil.import_obj(class_name)
        image_shape = [3] + dataset_obj.shape[1:]
        obj = class_def(image_shape=image_shape, image_dtype=np.uint8)
        tfutil.init_uninited_vars()
        metric_objs.append(obj)

    # Evaluate each network snapshot.
    for snapshot_idx, snapshot_pkl in enumerate(reversed(snapshot_pkls)):
        prefix = 'network-snapshot-'
        postfix = '.pkl'
        snapshot_name = os.path.basename(snapshot_pkl)
        assert snapshot_name.startswith(prefix) and snapshot_name.endswith(
            postfix)
        snapshot_kimg = int(snapshot_name[len(prefix):-len(postfix)])

        print('%-10d' % snapshot_kimg, end='')
        mode = 'fakes'
        [obj.begin(mode) for obj in metric_objs]

        images_real, labels = dataset_obj.get_minibatch_np(
            num_groups * num_images_per_group)

        with tf.Graph().as_default(), tfutil.create_session(
                config.tf_config).as_default():
            G, D, Gs = misc.load_pkl(snapshot_pkl)

            latents = misc.random_latents(num_groups * num_images_per_group,
                                          Gs)
            images = images_real
            for k in range(
                    10
            ):  # because Gs can not generate lots of (>3000 around) images at one time. Make sure /10 = int
                nn = int(num_groups * num_images_per_group / 10)
                images_fake = Gs.run(latents[k * nn:(k + 1) * nn],
                                     labels[k * nn:(k + 1) * nn],
                                     num_gpus=config.num_gpus,
                                     out_mul=127.5,
                                     out_add=127.5,
                                     out_dtype=np.uint8)
                images = np.concatenate((images, images_fake), axis=0)

            if images.shape[1] == 1:
                images = np.tile(images, [1, 3, 1, 1])  # grayscale => RGB
            [
                obj.feed(mode, images, num_images_per_group, num_groups,
                         snapshot_kimg, result_subdir) for obj in metric_objs
            ]
コード例 #14
0
def find_dir_latent_with_query_image(run_id, snapshot=None, grid_size=[1,1], num_pngs=1, image_shrink=1, png_prefix=None, random_seed=4123, minibatch_size=8, dir_path='../../data/ACDC/latents/cleaned_testing/'):
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    if png_prefix is None:
        png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-'
    random_state = np.random.RandomState(random_seed)

    print('Loading network from "%s"...' % network_pkl)
    G, D, Gs = misc.load_network_pkl(run_id, snapshot)

    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    replicate_folder_structure(dir_path, result_subdir+'/')

    train_patients = sorted_nicely(glob.glob(dir_path+'*'))

    for patient in train_patients:
        cardiac_cycles = sorted_nicely(glob.glob(patient+'/*/*/*.png'))
        cfg = open(patient+'/Info.cfg')
        label = condition_to_onehot(cfg.readlines()[2][7:])
        cont = 0
        for cycle in cardiac_cycles:
            # Get folder containing the image
            supfolder = sup_folder(cycle)
            latent_subir = result_subdir + '/' + supfolder

            # Create query image - tensorflow constant
            query_image = cv2.imread(cycle) # read frame
            query_image = cv2.resize(query_image, (256, 256))
            query_image = query_image.transpose(2,0,1)
            query_image = query_image[np.newaxis]
            x = tf.constant(query_image, dtype=tf.float32, name='query_image')

            # Create G(z) - tensorflow variable and label
            latent = misc.random_latents(np.prod(grid_size), Gs, random_state=random_state)
            initial = tf.constant(latent, dtype=tf.float32)
            z = tf.Variable(initial_value=initial, dtype=tf.float32, name='latent_space')
            gz = Gs.run(latent, label, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.float32)
            gz = tf.Variable(gz, dtype=tf.float32)

            # Define a loss function
            residual_loss = tf.losses.absolute_difference(x, gz)
            # Define an optimizer
            train_op = tf.train.AdamOptimizer(learning_rate=0.1).minimize(residual_loss)

            zs, gzs, step = [], [], 1
    
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                _, loss_value = sess.run([train_op, residual_loss])
                while (loss_value > 2e-04 and step <= 5000):
                    _, loss_value = sess.run([train_op, residual_loss])
                    step += 1
                    if step % 1000 == 0:
                        print('Step {}, Loss value: {}'.format(step, loss_value))
                        gzs.append(sess.run(gz))
                        zs.append(sess.run(z))
            
            # save last image
            print('Image saved at {}'.format(os.path.join(latent_subir, '%s.png' % (cont))))
            misc.save_image_grid(gzs[-1], os.path.join(latent_subir, '%02d.png' % (cont)), [0,255], grid_size)
            print('Latent vectors saved at {}'.format(os.path.join(latent_subir, 'latent_%02d.npy' % (cont))))
            np.save(os.path.join(latent_subir, 'latent_%02d.npy' % (cont)), zs[-1])
            print('Labels saved at {}'.format(os.path.join(latent_subir, 'label_%02d.npy' % (cont))))
            np.save(os.path.join(latent_subir, 'label_%02d.npy' % (cont)), label)
            cont+=1

        cfg.close()
        cont = 0
コード例 #15
0
def evaluate_metrics(run_id, log, metrics, num_images, real_passes, minibatch_size=None):
    metric_class_names = {
        'swd':      'metrics.sliced_wasserstein.API',
        'fid':      'metrics.frechet_inception_distance.API',
        'is':       'metrics.inception_score.API',
        'msssim':   'metrics.ms_ssim.API',
    }

    # Locate training run and initialize logging.
    result_subdir = misc.locate_result_subdir(run_id)
    snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False)
    assert len(snapshot_pkls) >= 1
    log_file = os.path.join(result_subdir, log)
    print('Logging output to', log_file)
    misc.set_output_log_file(log_file)

    # Initialize dataset and select minibatch size.
    dataset_obj, mirror_augment = misc.load_dataset_for_previous_run(result_subdir, verbose=True, shuffle_mb=0)
    if minibatch_size is None:
        minibatch_size = np.clip(8192 // dataset_obj.shape[1], 4, 256)

    # Initialize metrics.
    metric_objs = []
    for name in metrics:
        class_name = metric_class_names.get(name, name)
        print('Initializing %s...' % class_name)
        class_def = tfutil.import_obj(class_name)
        image_shape = [3] + dataset_obj.shape[1:]
        obj = class_def(num_images=num_images, image_shape=image_shape, image_dtype=np.uint8, minibatch_size=minibatch_size)
        tfutil.init_uninited_vars()
        mode = 'warmup'
        obj.begin(mode)
        for idx in range(10):
            obj.feed(mode, np.random.randint(0, 256, size=[minibatch_size]+image_shape, dtype=np.uint8))
        obj.end(mode)
        metric_objs.append(obj)

    # Print table header.
    print()
    print('%-10s%-12s' % ('Snapshot', 'Time_eval'), end='')
    for obj in metric_objs:
        for name, fmt in zip(obj.get_metric_names(), obj.get_metric_formatting()):
            print('%-*s' % (len(fmt % 0), name), end='')
    print()
    print('%-10s%-12s' % ('---', '---'), end='')
    for obj in metric_objs:
        for fmt in obj.get_metric_formatting():
            print('%-*s' % (len(fmt % 0), '---'), end='')
    print()

    # Feed in reals.
    for title, mode in [('Reals', 'reals'), ('Reals2', 'fakes')][:real_passes]:
        print('%-10s' % title, end='')
        time_begin = time.time()
        labels = np.zeros([num_images, dataset_obj.label_size], dtype=np.float32)
        [obj.begin(mode) for obj in metric_objs]
        for begin in range(0, num_images, minibatch_size):
            end = min(begin + minibatch_size, num_images)
            images, labels[begin:end] = dataset_obj.get_minibatch_np(end - begin)
            if mirror_augment:
                images = misc.apply_mirror_augment(images)
            if images.shape[1] == 1:
                images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB
            [obj.feed(mode, images) for obj in metric_objs]
        results = [obj.end(mode) for obj in metric_objs]
        print('%-12s' % misc.format_time(time.time() - time_begin), end='')
        for obj, vals in zip(metric_objs, results):
            for val, fmt in zip(vals, obj.get_metric_formatting()):
                print(fmt % val, end='')
        print()

    # Evaluate each network snapshot.
    for snapshot_idx, snapshot_pkl in enumerate(reversed(snapshot_pkls)):
        prefix = 'network-snapshot-'; postfix = '.pkl'
        snapshot_name = os.path.basename(snapshot_pkl)
        assert snapshot_name.startswith(prefix) and snapshot_name.endswith(postfix)
        snapshot_kimg = int(snapshot_name[len(prefix) : -len(postfix)])

        print('%-10d' % snapshot_kimg, end='')
        mode ='fakes'
        [obj.begin(mode) for obj in metric_objs]
        time_begin = time.time()
        with tf.Graph().as_default(), tfutil.create_session(config.tf_config).as_default():
            G, D, Gs = misc.load_pkl(snapshot_pkl)
            for begin in range(0, num_images, minibatch_size):
                end = min(begin + minibatch_size, num_images)
                latents = misc.random_latents(end - begin, Gs)
                images = Gs.run(latents, labels[begin:end], num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_dtype=np.uint8)
                if images.shape[1] == 1:
                    images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB
                [obj.feed(mode, images) for obj in metric_objs]
        results = [obj.end(mode) for obj in metric_objs]
        print('%-12s' % misc.format_time(time.time() - time_begin), end='')
        for obj, vals in zip(metric_objs, results):
            for val, fmt in zip(vals, obj.get_metric_formatting()):
                print(fmt % val, end='')
        print()
    print()