Пример #1
0
def process_weight_tf(weight, training_set, testing_set, queue):
    dnnlib.tflib.init_tf()
    print("running: " + weight)
    G, D, Gs = pickle.load(open(weight, "rb"))

    train_pred = []
    train_lab = []

    train_data = "/".join(training_set.split("/")[:-1])
    train_dir = training_set.split("/")[-1]

    test_data = "/".join(testing_set.split("/")[:-1])
    test_dir = testing_set.split("/")[-1]

    train = dataset.load_dataset(data_dir=train_data,
                                 tfrecord_dir=train_dir,
                                 max_label_size=1,
                                 repeat=False,
                                 shuffle_mb=0)
    test = dataset.load_dataset(data_dir=test_data,
                                tfrecord_dir=test_dir,
                                max_label_size=1,
                                repeat=False,
                                shuffle_mb=0)

    for x in range(train._np_labels.shape[0] - 1):
        try:
            image, label = train.get_minibatch_np(1)
            image = misc.adjust_dynamic_range(image, [0, 255], [-1, 1])
        except:
            break
        train_pred.append(D.run(image, None)[0][0])
        train_lab.append(label)
    print("done train")

    test_pred = []
    test_lab = []

    for x in range(test._np_labels.shape[0] - 1):
        try:
            image, label = test.get_minibatch_np(1)
            image = misc.adjust_dynamic_range(image, [0, 255], [-1, 1])

        except:
            break
        test_pred.append(D.run(image, None)[0][0])
        test_lab.append(label)

    hter = calculate_metrics(train_pred, train_lab, test_pred, test_lab)

    queue.put((weight.split("-")[-1].split(".")[0], hter))
Пример #2
0
def project_real_images(network_pkl, dataset_name, data_dir, num_images,
                        num_snapshots):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=0,
                                       repeat=False,
                                       shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        print("read image from database:", images.shape)
        import cv2
        cv2.imwrite("test_project_real_images_image_from_db.png", images)
        project_image(proj,
                      targets=images,
                      png_prefix=dnnlib.make_run_dir_path('image%04d-' %
                                                          image_idx),
                      num_snapshots=num_snapshots)
    def _evaluate(self, classifier, Gs_kwargs, num_gpus):
        self._set_dataset_obj(dataset.load_dataset(tfrecord_dir=self.test_dataset, data_dir=self.test_data_dir, shuffle_mb=1024))
        dataset_object = self._get_dataset_obj()
        dataset_object.configure(minibatch_size=self.minibatch_per_gpu)
        num_correct = 0
        num_total = 0

        images_placeholder = tf.placeholder(shape=classifier.input_shapes[0], dtype=tf.float32)
        label_placeholder = tf.placeholder(shape=[None, dataset_object.label_size], dtype=tf.float32)
        # vgg uses 0-255
        # images_adjust = misc.adjust_dynamic_range(images_placeholder, [0, 255], [-1, 1])
        prediction = classifier.get_output_for(images_placeholder)
        one_hot_prediction = tf.one_hot(indices=tf.argmax(prediction, axis=-1), depth=dataset_object.label_size)
        num_correct_pred = tf.reduce_sum(one_hot_prediction * label_placeholder)

        while num_total < self.num_images:
            images, labels = dataset_object.get_minibatch_np(minibatch_size=self.minibatch_per_gpu)
            num_correct_pred_out = tflib.run(
                num_correct_pred
            , feed_dict={
                images_placeholder: images,
                label_placeholder: labels
            })
            num_correct += num_correct_pred_out
            num_total += self.minibatch_per_gpu

        self._report_result(num_correct / num_total)
 def _get_dataset_obj(self):
     if self._dataset_obj is None:
         self._dataset_obj = dataset.load_dataset(data_dir=self._data_dir,
                                                  split=self.split,
                                                  repeat=False,
                                                  **self._dataset_args)
     return self._dataset_obj
Пример #5
0
def project_real_images(network_pkl, dataset_name, data_dir, num_images,
                        num_snapshots):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=0,
                                       repeat=False,
                                       shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[
        1:], "%sexpected shape %s, got %s%s" % (
            dnnlib.util.Col.RB, Gs.output_shape[1:], dataset_obj.shape,
            dnnlib.util.Col.AU)

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        project_image(proj,
                      targets=images,
                      png_prefix=dnnlib.make_run_dir_path('image%04d-' %
                                                          image_idx),
                      num_snapshots=num_snapshots)
Пример #6
0
def project_real_images(network_pkl, dataset_name, data_dir, num_images,
                        num_snapshots):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=0,
                                       repeat=False,
                                       shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    imgs = read_images('/gdata2/fengrl/imgs-for-embed')

    for image_idx in range(len(imgs)):
        print('Projecting image %d/%d ...' % (image_idx, len(imgs)))
        # images, _labels = dataset_obj.get_minibatch_np(1)
        images = np.expand_dims(imgs[image_idx], 0)
        # images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        project_image(proj,
                      targets=images,
                      png_prefix=dnnlib.make_run_dir_path('image%04d-' %
                                                          image_idx),
                      num_snapshots=num_snapshots)
Пример #7
0
def project_real_images(network_pkl, dataset_name, data_dir, num_images,
                        num_steps, num_snapshots, save_every_dlatent,
                        save_final_dlatent):
    assert num_snapshots <= num_steps, "Can't have more snapshots than number of steps taken!"
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector(num_steps=num_steps)
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=0,
                                       repeat=False,
                                       shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        try:
            images, _labels = dataset_obj.get_minibatch_np(1)
            images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
            project_image(proj,
                          targets=images,
                          png_prefix=dnnlib.make_run_dir_path('image%04d-' %
                                                              image_idx),
                          num_snapshots=num_snapshots,
                          save_every_dlatent=save_every_dlatent,
                          save_final_dlatent=save_final_dlatent)
        except tf.errors.OutOfRangeError:
            print(
                f'Error! There are only {image_idx} images in {data_dir}{dataset_name}!'
            )
            sys.exit(1)
Пример #8
0
 def _iterate_reals(self, minibatch_size):
     dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args)
     while True:
         images, _labels = dataset_obj.get_minibatch_np(minibatch_size)
         if self._mirror_augment:
             images = misc.apply_mirror_augment(images)
         yield images
Пример #9
0
def project_image(proj, src_file, dst_dir, tmp_dir, video=False):

    data_dir = '%s/dataset' % tmp_dir
    if os.path.exists(data_dir):
        shutil.rmtree(data_dir)
    image_dir = '%s/images' % data_dir
    tfrecord_dir = '%s/tfrecords' % data_dir
    os.makedirs(image_dir, exist_ok=True)
    shutil.copy(src_file, image_dir + '/')
    dataset_tool.create_from_images_raw(tfrecord_dir, image_dir, shuffle=0)
    dataset_obj = dataset.load_dataset(
        data_dir=data_dir, tfrecord_dir='tfrecords',
        max_label_size=0, repeat=False, shuffle_mb=0
    )

    print('Projecting image "%s"...' % os.path.basename(src_file))
    images, _labels = dataset_obj.get_minibatch_np(1)
    images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
    proj.start(images)
    if video:
        video_dir = '%s/video' % tmp_dir
        os.makedirs(video_dir, exist_ok=True)
    while proj.get_cur_step() < proj.num_steps:
        print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True)
        proj.step()
        if video:
            filename = '%s/%08d.png' % (video_dir, proj.get_cur_step())
            misc.save_image_grid(proj.get_images(), filename, drange=[-1,1])
    print('\r%-30s\r' % '', end='', flush=True)

    os.makedirs(dst_dir, exist_ok=True)
    filename = os.path.join(dst_dir, os.path.basename(src_file)[:-4] + '.png')
    misc.save_image_grid(proj.get_images(), filename, drange=[-1,1])
    filename = os.path.join(dst_dir, os.path.basename(src_file)[:-4] + '.npy')
    np.save(filename, proj.get_dlatents()[0])
Пример #10
0
def project_real_images(submit_config, network_pkl, dataset_name, data_dir,
                        num_images, num_snapshots):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.verbose = submit_config.verbose
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=0,
                                       repeat=False,
                                       shuffle_mb=0)
    print('dso shape: ' + str(dataset_obj.shape) + ' vs gs shape: ' +
          str(Gs.output_shape[1:]))
    assert dataset_obj.shape == Gs.output_shape[1:]

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        project_image(proj,
                      targets=images,
                      png_prefix=dnnlib.make_run_dir_path('image%04d-' %
                                                          image_idx),
                      num_snapshots=num_snapshots)
Пример #11
0
def load_dataset_for_previous_run(run_id,
                                  **kwargs):  # => dataset_obj, mirror_augment
    cfg = parse_config_for_previous_run(run_id)
    cfg['dataset'].update(kwargs)
    dataset_obj = dataset.load_dataset(data_dir=config.data_dir,
                                       **cfg['dataset'])
    mirror_augment = cfg['train'].get('mirror_augment', False)
    return dataset_obj, mirror_augment
Пример #12
0
def get_images(data_dir, dataset_name, n_images):
    tflib.init_tf()
    dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size=0, repeat=False, shuffle_mb=0)
    results = [None] * n_images
    for i in range(n_images):
        dataset_obj.get_minibatch_np(1)
        images, _ = dataset_obj.get_minibatch_np(1)
        results[i] = images[0, ...].transpose(1, 2, 0)
    return results
Пример #13
0
def test_d(submit_config,
           resume_run_id,
           dataset_args,
           tf_config={},
           resume_snapshot=None):

    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
    print('Loading networks from "%s"...' % network_pkl)
    G, D, Gs = misc.load_pkl(network_pkl)

    latents_1 = tf.placeholder(tf.float32)
    labels_1 = None

    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)
    w_1 = Gs.components.mapping.get_output_for(latents_1,
                                               labels_1,
                                               is_validation=True)
    fake_image_1_op = Gs.components.synthesis.get_output_for(
        w_1, is_validation=True, randomize_noise=False)

    reals, labels = training_set.get_minibatch_tf()

    lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])

    reals = process_reals(reals, lod_in, False, training_set.dynamic_range,
                          [-1, 1])

    d_pred_real = D.get_output_for(reals, labels_1)
    d_pred_fake = D.get_output_for(fake_image_1_op, labels_1)

    training_set.configure(1, 0)

    for i in range(15):
        latents_1_val = np.random.randn(1, *G.input_shape[1:])

        # d_pred, fake_image_1 = tflib.run([d_pred_op, fake_image_1_op], feed_dict={latents_1: latents_1_val, lod_in: 0})
        d_pred_real_, d_pred_fake_, real_image = tflib.run(
            [d_pred_real, d_pred_fake, reals],
            feed_dict={
                latents_1: latents_1_val,
                lod_in: 0
            })

        print(d_pred_real_, d_pred_fake_)
        misc.save_mri_image(real_image,
                            os.path.join(submit_config.run_dir,
                                         'real_{}.nii.gz'.format(i)),
                            drange=[-1, 1])
Пример #14
0
def project_real_images(network_pkl, dataset_name, data_dir, num_images,
    num_snapshots, num_steps, save_snapshots=False, save_latents=False,
    save_umap=False, save_tiles=False):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.set_network(Gs)
    proj.num_steps = num_steps

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size=0, repeat=False, shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    latents = np.zeros((num_images, Gs.input_shape[1]), dtype=np.float32)
    tiles = [None] * num_images
    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)
        tiles[image_idx] = images[0, ...].transpose(1, 2, 0)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        latents[image_idx, ...] = project_image(proj, targets=images,
            png_prefix=dnnlib.make_run_dir_path('image%04d-' % image_idx),
            num_snapshots=num_snapshots, save_snapshots=save_snapshots)

        if save_latents:
            filename = dnnlib.make_run_dir_path('real_image_latent_{:06d}'.format(image_idx))
            np.save(filename, latents[image_idx, ...])


    if save_latents:
        filename = dnnlib.make_run_dir_path('real_image_latents.npy')
        np.save(filename, latents)

    if save_umap:
        reducer = umap.UMAP()
        embeddings = reducer.fit_transform(latents)
        filename = dnnlib.make_run_dir_path('real_image_umap.json')
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(embeddings.tolist(), f, ensure_ascii=False)

    if save_tiles:
        tiles_prefix = dnnlib.make_run_dir_path('real_tile_solid')
        misc.save_texture_grid(tiles, tiles_prefix)

        textures_prefix = dnnlib.make_run_dir_path('real_texture_solid')
        textures = [misc.make_white_square() for _ in range(len(tiles))]
        misc.save_texture_grid(textures, textures_prefix)

        filename = dnnlib.make_run_dir_path('labels.json')
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump([0.0] * len(tiles), f, ensure_ascii=False)
Пример #15
0
def show_real_data(data_dir, dataset_name, number):
    tflib.init_tf()
    dataset_args = EasyDict(tfrecord_dir=dataset_name, max_label_size='full')
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    gw = 1
    gh = 1
    for i in range(number):
        reals, _ = training_set.get_minibatch_np(gw * gh)
        misc.save_image_grid(reals,
                             dnnlib.make_run_dir_path('reals%04d.png' % (i)),
                             drange=training_set.dynamic_range,
                             grid_size=None)
Пример #16
0
def project_image(proj, src_file, dst_dir, tmp_dir, video=False):
 
    data_dir = '%s/dataset' % tmp_dir  # ./stylegan2-tmp/dataset
    if os.path.exists(data_dir):
        shutil.rmtree(data_dir)
    image_dir = '%s/images' % data_dir  # ./stylegan2-tmp/dataset/images
    tfrecord_dir = '%s/tfrecords' % data_dir  # ./stylegan2-tmp/dataset/tfrecords
    os.makedirs(image_dir, exist_ok=True)
    # 将源图片文件copy到./stylegan2-tmp/dataset/images下
    shutil.copy(src_file, image_dir + '/')
    # 在./stylegan2-tmp/dataset/tfrecords下生成tfrecord临时文件
    # tfrecord临时文件序列化存储了不同lod下的图像的shape和数据
    # 举例,如果图像是1024x1024,则tfr_file命名从10--2,如:tfrecords-r10.tfrecords...tfrecords-r05.tfrecords...
    dataset_tool.create_from_images(tfrecord_dir, image_dir, shuffle=0)
    # TFRecordDataset类在“dataset.py”中定义,从一组.tfrecords文件中加载数据集到dataset_obj
    # load_dataset是个helper函数,用于构建dataset对象(在TFRecordDataset类创建对象实例时完成)
    dataset_obj = dataset.load_dataset(
        data_dir=data_dir, tfrecord_dir='tfrecords',
        max_label_size=0, repeat=False, shuffle_mb=0
    )
 
    # 生成用于优化迭代的目标图像(组)
    print('Projecting image "%s"...' % os.path.basename(src_file))
    # 取下一个minibatch=1作为Numpy数组
    images, _labels = dataset_obj.get_minibatch_np(1)
    # 把images的取值从[0. 255]区间调整到[-1, 1]区间
    images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
    # Projector初始化:start
    proj.start(images)
    if video:
        video_dir = '%s/video' % tmp_dir
        os.makedirs(video_dir, exist_ok=True)
    while proj.get_cur_step() < proj.num_steps:
        print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True)
        # Projector优化迭代:step
        proj.step()
        # 如果配置了video选项,将优化过程图像存入./ stylegan2 - tmp / video
        if video:
            filename = '%s/%08d.png' % (video_dir, proj.get_cur_step())
            misc.save_image_grid(proj.get_images(), filename, drange=[-1,1])
    print('\r%-30s\r' % '', end='', flush=True)
 
    # 在目的地目录中保存图像,保存dlatents文件
    os.makedirs(dst_dir, exist_ok=True)
    filename = os.path.join(dst_dir, os.path.basename(src_file)[:-4] + '.png')
    misc.save_image_grid(proj.get_images(), filename, drange=[-1,1])
    filename = os.path.join(dst_dir, os.path.basename(src_file)[:-4] + '.npy')
    np.save(filename, proj.get_dlatents()[0])
Пример #17
0
def get_projected_real_images(dataset_name, data_dir, num_images, num_snapshots,num_steps, _Gs):
    proj = projector.Projector()
    proj.set_network(_Gs)
    proj.num_steps = num_steps

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size=0, repeat=False, shuffle_mb=0)
    assert dataset_obj.shape == _Gs.output_shape[1:]

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        out = run_projector.get_projected_images(proj, targets=images, num_snapshots=num_snapshots)

    return out
Пример #18
0
def project_real_dataset_images(network_pkl, dataset_name, data_dir, num_images, num_snapshots, create_new_G, new_func_name):
    print('Loading networks from "%s"...' % network_pkl)
    tflib.init_tf()
    # _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    _G, _D, I, Gs = misc.load_pkl(network_pkl)
    proj = projector_vc2.ProjectorVC2()
    proj.set_network(Gs, create_new_G, new_func_name)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size=0, repeat=False, shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        project_image(proj, targets=images, I_net=I, png_prefix=dnnlib.make_run_dir_path('image%04d-' % image_idx), num_snapshots=num_snapshots)
Пример #19
0
def my_project_real_images(num_images, data_dir): 
    network_pkl = 'gdrive:networks/stylegan2-ffhq-config-f.pkl'
    dataset_name = 'dataset'
    #data_dir = 'my' 
    num_snapshots = 5

    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size=0, repeat=False, shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    os.makedirs(data_dir+'/real_images', exist_ok=True)  
    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        
        targets=images
        png_prefix=data_dir+'/real_images/image'+str(image_idx)
        num_snapshots=num_snapshots
                
        snapshot_steps = set(proj.num_steps - np.linspace(0, proj.num_steps, num_snapshots, endpoint=False, dtype=int))
        misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1,1])
        proj.start(targets)
        while proj.get_cur_step() < proj.num_steps:
            print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True)
            proj.step()
            if proj.get_cur_step() in snapshot_steps:
                misc.save_image_grid(proj.get_images(), png_prefix + 'step%04d.png' % proj.get_cur_step(), drange=[-1,1])
            
            if proj.get_cur_step() == proj.num_steps:  
                vec = proj.get_dlatents() 
                if image_idx == 0:
                   vec_syn = vec
                else:
                   vec_syn = np.concatenate([vec_syn, vec])  
                print(vec_syn.shape)  
        print('\r%-30s\r' % '', end='', flush=True)

    return vec_syn
Пример #20
0
def project_real_images(network_pkl,
                        dataset_name,
                        data_dir,
                        num_images,
                        num_snapshots,
                        D_size=0,
                        minibatch_size=1,
                        use_VGG=True):
    tflib.init_tf()
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, I, Gs = misc.load_pkl(network_pkl)
    # _G, _D, Gs = misc.load_pkl(network_pkl)
    # _G, _D, Gs = pretrained_networks.load_networks(network_pkl)

    proj = projector_vc.ProjectorVC()
    proj.set_network(Gs,
                     minibatch_size=minibatch_size,
                     D_size=D_size,
                     use_VGG=use_VGG,
                     num_steps=num_steps)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size='full',
                                       repeat=False,
                                       shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        images, _labels = dataset_obj.get_minibatch_np(minibatch_size)
        print('images.shape:', images.shape)
        print('_labels.shape:', _labels.shape)
        print('_labels:', _labels)
        print('argmax of _labels:', np.argmax(_labels, axis=1))
        # pdb.set_trace()
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        project_image(proj,
                      targets=images,
                      png_prefix=dnnlib.make_run_dir_path('image%04d-' %
                                                          image_idx),
                      num_snapshots=num_snapshots)
Пример #21
0
def project_real_images(network_pkl, dataset_name, data_dir, num_images,
                        start_index, num_snapshots, save_vector):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    print('Num images: %d, Starting Index: %d' % (num_images, start_index))
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       verbose=True,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=0,
                                       repeat=False,
                                       shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    img_filenames = None
    if dataset_obj._np_filenames is not None:
        assert num_images <= dataset_obj.filenames_size
        img_filenames = dataset_obj._np_filenames

    for image_idx in range(start_index, start_index + num_images):
        filename = img_filenames[
            image_idx] if img_filenames is not None else 'unknown'
        print('Projecting image %d/%d... (index: %d, filename: %s)' %
              (image_idx - start_index, num_images, image_idx, filename))

        images, labels = dataset_obj.get_minibatch_np(1)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])

        project_image(proj,
                      targets=images,
                      labels=labels,
                      png_prefix=dnnlib.make_run_dir_path('image%04d-' %
                                                          image_idx),
                      num_snapshots=num_snapshots,
                      save_npy=save_vector,
                      npy_file_prefix=dnnlib.make_run_dir_path(filename))
        print(
            '✅ Finished projecting image %d/%d... (index: %d, filename: %s)' %
            (image_idx - start_index + 1, num_images, image_idx, filename))
Пример #22
0
def project_real_images(Gs,
                        data_dir,
                        dataset_name,
                        snapshot_name,
                        seq_no,
                        num_snapshots=5):
    proj = projector.Projector()
    proj.set_network(Gs)

    print('Loading images from "%s/%s"...' % (data_dir, dataset_name))
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=0,
                                       repeat=False,
                                       shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]

    print('Projecting image ...')
    images, _labels = dataset_obj.get_minibatch_np(1)
    images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
    project_image(proj,
                  targets=images,
                  png_prefix=dnnlib.make_run_dir_path('%s/image%04d-' %
                                                      (snapshot_name, seq_no)),
                  num_snapshots=num_snapshots)
    ####################
    # print dlatents
    ####################
    dlatents = proj.get_dlatents()
    # for dlatents1 in dlatents:
    #     for dlatents2 in dlatents1:
    #         str = ''
    #         for e in dlatents2:
    #             str = '{} {}'.format(str, e)
    #         print('###', str)
    # img_name = f'100-100_01'
    # dir = 'results/dst'
    # img_name = '100-100_01.npy'
    # dir = 'results/src'
    # img_name = 'me_01.npy'
    # np.save(os.path.join(dir, img_name), dlatents[0])
    return dlatents[0]
Пример #23
0
def draw_style_mixing_figure_transition(png):

    n = 8
    w = 512
    h = 512
    canvas = PIL.Image.new('RGB', (w * n, h), 'white')
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)

    image = training_set.get_minibatch_np(1, 0)[0][0].transpose(1, 2, 0)
    image = training_set.get_minibatch_np(1, 0)[0][0].transpose(1, 2, 0)
    res = 8
    for i in range(n):

        pil_image = PIL.Image.fromarray(image, 'RGB')
        pil_image.thumbnail((res, res))
        pil_image.thumbnail((512, 512))
        canvas.paste(pil_image, (i * w, 0))
        res = res * 2
    canvas.save(png)
    print(png)
Пример #24
0
    def __init__(self):
        super().__init__()
        self.state = -1
        self.canvas = tk.Canvas(self, bg='gray', height=args.window_size, width=args.window_size)
        self.canvas.bind("<Button-1>", self.L_press)
        self.canvas.bind("<ButtonRelease-1>", self.L_release)
        self.canvas.bind("<B1-Motion>", self.L_move)
        self.canvas.bind("<Button-3>", self.R_press)
        self.canvas.bind("<ButtonRelease-3>", self.R_release)
        self.canvas.bind("<B3-Motion>", self.R_move)
        self.canvas.bind("<Key>", self.key_down)
        self.canvas.bind("<KeyRelease>", self.key_up)
        self.canvas.pack()

        self.canvas.focus_set()
        self.canvas_image = self.canvas.create_image(0, 0, anchor='nw')

        dnnlib.tflib.init_tf()
        self.dataset = dataset.load_dataset(tfrecord_dir=args.data_dir, verbose=True, shuffle_mb=0)
        
        self.networks = []
        self.truncations = []
        self.model_names = []
        for ckpt in args.checkpoint.split(','):
            if ':' in ckpt:
                ckpt, truncation = ckpt.split(':')
                truncation = float(truncation)
            else:
                truncation = None
           
            _, _, Gs = misc.load_pkl(ckpt)
            
            self.networks.append(Gs)
            self.truncations.append(truncation)
            self.model_names.append(os.path.basename(os.path.splitext(ckpt)[0]))
        
        self.key_list = ['q', 'w', 'e', 'r', 't', 'y', 'u', 'i', 'o', 'p'][:len(self.networks)]
        self.image_id = -1
        
        self.new_image()
        self.display()
Пример #25
0
def project_real_images(network_pkl, dataset_name, data_dir, num_images,
                        num_snapshots):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.set_network(Gs)

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=0,
                                       repeat=False,
                                       shuffle_mb=0)
    assert dataset_obj.shape == Gs.output_shape[1:]
    try:
        os.remove('latents.txt')
    except OSError:
        pass
    all_latents = list()
    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx + 1, num_images))
        images, _labels = dataset_obj.get_minibatch_np(1)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        all_latents.append(
            np.mean(project_image(proj,
                                  targets=images,
                                  png_prefix=dnnlib.make_run_dir_path(
                                      'image%04d-' % image_idx),
                                  num_snapshots=num_snapshots),
                    axis=0))
    for j in range(len(all_latents) - 1):
        for i in range(j + 1, len(all_latents)):
            #print(f"Euclid dist between {j} and {i}: {np.linalg.norm(all_latents[j]-all_latents[i])}")
            print(
                f"Dot product between {j} and {i}: {np.dot(all_latents[j],all_latents[i])}"
            )
            print(
                f"Cosine product between {j} and {i}: {spatial.distance.cosine(all_latents[j],all_latents[i])}"
            )
Пример #26
0
def project_images_dataset(proj, dataset_name, data_dir, num_snapshots=2):

    print('Loading images from "%s"...' % dataset_name)
    dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                       tfrecord_dir=dataset_name,
                                       max_label_size=1,
                                       repeat=False,
                                       shuffle_mb=0)

    all_ssim = []
    all_mse = []
    all_times = []
    labels = []
    image_idx = 0
    while (True):
        #print('Projecting image %d ...' % (image_idx), flush=True)
        try:
            images, label = dataset_obj.get_minibatch_np(1)
            labels.append(label)
        except:
            break
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        start = time.time()
        img, temp_ssim, temp_mse = project_image(proj,
                                                 targets=images,
                                                 img_num=image_idx,
                                                 num_snapshots=num_snapshots)
        end = time.time()
        all_ssim.append(temp_ssim)
        all_mse.append(temp_mse)
        all_times.append(end - start)

        # print("Time to process image: ", end-start , flush=True)
        avg_time = sum(all_times) / len(all_times)
        image_idx += 1
        break
    return all_ssim, all_mse, labels, avg_time
Пример #27
0
def project_real_images(network_pkl, dataset_name, data_dir, num_images,
                        num_snapshots, input_images):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    proj = projector.Projector()
    proj.set_network(Gs)

    if input_images is None:
        print('Loading images from "%s"...' % dataset_name)
        dataset_obj = dataset.load_dataset(data_dir=data_dir,
                                           tfrecord_dir=dataset_name,
                                           max_label_size=0,
                                           repeat=False,
                                           shuffle_mb=0)
        assert dataset_obj.shape == Gs.output_shape[1:]
    else:
        num_images = min(num_images, len(input_images))

    for image_idx in range(num_images):
        print('Projecting image %d/%d ...' % (image_idx, num_images))
        if input_images is None:
            images, _labels = dataset_obj.get_minibatch_np(1)
        else:
            # images = None
            image_path = input_images[image_idx]
            print(image_path)
            images = cv2.imread(image_path, cv2.IMREAD_COLOR)
            images = images.transpose((2, 0, 1))  # HWC -> CHW
            images = images[::-1, :, :]
            images = np.expand_dims(images, axis=0)
        images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
        project_image(proj,
                      targets=images,
                      png_prefix=dnnlib.make_run_dir_path('image%04d-' %
                                                          image_idx),
                      num_snapshots=num_snapshots)
Пример #28
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=10000.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            #print('Constructing networks...')
            #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            #Gs = G.clone('Gs')
            url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ'
            with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
                G, D, Gs = pickle.load(f)
            print('Loading pretrained FFHQ network')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)

    cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' %
                                      resume_kimg) + "  gs://stylegan_out"
    response = subprocess.run(cmd, shell=True)

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                cmd = "gsutil cp " + os.path.join(
                    submit_config.run_dir, 'fakes%06d.png' %
                    (cur_nimg // 1000)) + "  gs://stylegan_out"
                response = subprocess.run(cmd, shell=True)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
Пример #29
0
def training_loop(
    G_args                  = {},       # Options for generator network.
    D_args                  = {},       # Options for discriminator network.
    G_opt_args              = {},       # Options for generator optimizer.
    D_opt_args              = {},       # Options for discriminator optimizer.
    G_loss_args             = {},       # Options for generator loss.
    D_loss_args             = {},       # Options for discriminator loss.
    dataset_args            = {},       # Options for dataset.load_dataset().
    sched_args              = {},       # Options for train.TrainingSchedule.
    grid_args               = {},       # Options for train.setup_snapshot_image_grid().
    setname                 = None,   # Model name 
    tf_config               = {},       # Options for tflib.init_tf().
    G_smoothing_kimg        = 10.0,     # Half-life of the running average of generator weights.
    minibatch_repeats       = 4,        # Number of minibatches to run before adjusting training parameters.
    lazy_regularization     = True,     # Perform regularization as a separate training step?
    G_reg_interval          = 4,        # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval          = 16,       # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod   = True,     # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,    # Enable mirror augment?
    mirror_augment_v        = False,  # Enable mirror augment vertically?
    drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 50,       # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph           = False,    # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,    # Include weight histograms in the tfevents file?
    resume_pkl              = 'latest',     # Network pickle to resume training from, None = train from scratch.
    resume_kimg             = 0.0,      # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0,      # Assumed wallclock time at the beginning. Affects reporting.
    restore_partial_fn      = None,   # Filename of network for partial restore
    resume_with_new_nets    = False):   # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(verbose=True, **dataset_args)
    # custom resolution - for saved model name below
    resolution = training_set.resolution
    if training_set.init_res != [4,4]:
        init_res_str = '-%dx%d' % (training_set.init_res[0], training_set.init_res[1])
    else:
        init_res_str = ''
    ext = 'png' if training_set.shape[0] == 4 else 'jpg'
    print(' model base resolution', resolution)
    
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args)
    misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('_reals.%s'%ext), drange=training_set.dynamic_range, grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print(' Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            if resume_pkl == 'latest':
                resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root)
            elif resume_pkl == 'restore_partial':
                print(' Restore partially...')
                # Initialize networks
                G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args)
                D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args)
                Gs = G.clone('Gs')
                # Load pre-trained networks
                assert restore_partial_fn != None
                G_partial, D_partial, Gs_partial = pickle.load(open(restore_partial_fn, 'rb'))
                # Restore (subset of) pre-trained weights (only parameters that match both name and shape)
                G.copy_compatible_trainables_from(G_partial)
                D.copy_compatible_trainables_from(D_partial)
                Gs.copy_compatible_trainables_from(Gs_partial)
            else:
                if resume_pkl is not None and resume_kimg == 0:
                    resume_pkl, resume_kimg = misc.locate_latest_pkl(resume_pkl)
                print(' Loading networks from "%s", kimg %.3g' % (resume_pkl, resume_kimg))
                rG, rD, rGs = misc.load_pkl(resume_pkl)
                if resume_with_new_nets:
                    G.copy_vars_from(rG)
                    D.copy_vars_from(rD)
                    Gs.copy_vars_from(rGs)
                else:
                    G, D, Gs = rG, rD, rGs
                
    # Print layers if needed and generate initial image snapshot
    # G.print_layers(); D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.%s'%ext), drange=drange_net, grid_size=grid_size)

    # Setup training inputs.
    print(' Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in               = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in             = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in    = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[])
        minibatch_gpu_in     = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus)
        Gs_beta              = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args)
                reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape))
                labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, training_set.dynamic_range, drange_net)
                reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables)
                if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    # print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()

    print(' Training for %d kimg (%d left) \n' % (total_kimg, total_kimg-resume_kimg))
    dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu) # , sched.lod
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu}
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time

            if sched.lod == 0:
                left_kimg = total_kimg - cur_nimg / 1000
                left_sec = left_kimg * tick_time / tick_kimg
                finaltime = time.asctime(time.localtime(cur_time + left_sec))
                msg_final = '%ss left till %s ' % (shortime(left_sec), finaltime[11:16])
            else:
                msg_final = ''

            # Report progress.
            # print('tick %-4d kimg %-6.1f lod %-5.2f minibch %-3d:%d time %-8s min/tick %-6.3g %s sec/kimg %-7.3g gpumem %-4.1f %d lr %.2g ' % (
            print('tick %-4d kimg %-6.1f time %-8s  %s min/tick %-6.3g sec/kimg %-7.3g gpumem %-4.1f lr %.2g ' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                # autosummary('Progress/lod', sched.lod),
                # autosummary('Progress/minibatch', sched.minibatch_size),
                # autosummary('Progress/minibatch_gpu', sched.minibatch_gpu),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                msg_final,
                autosummary('Timing/min_per_tick', tick_time / 60),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                # autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30),
                sched.G_lrate))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=drange_net, grid_size=grid_size)
            if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('snapshot-%d-%s%s-%04d.pkl' % (resolution, setname[-1], init_res_str, cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-%04d.pkl' % (setname[:-1], resolution, setname[-1], init_res_str, cur_nimg // 1000)))

            # Update summaries and RunContext.
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('snapshot-%d-%s%s-final.pkl' % (resolution, setname[-1], init_res_str)))
    misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-final.pkl' % (setname[:-1], resolution, setname[-1], init_res_str)))

    # All done.
    summary_log.close()
    training_set.close()
Пример #30
0
 def _get_dataset_obj(self):
     if self._dataset_obj is None:
         self._dataset_obj = dataset.load_dataset(data_dir=self._data_dir, **self._dataset_args)
     return self._dataset_obj