Пример #1
0
    def __init__(self, filename=None, model=None):
        if model is not None:
            # error
            return

        _, _, G = misc.load_pkl(filename)
        self.net = Net()
        self.net.G = G
        self.have_compiled = False
        self.net.labels_var = T.TensorType('float32',
                                           [False] * 512)('labels_var')

        # experiment
        num_example_latents = 10
        self.net.example_latents = train.random_latents(
            num_example_latents, self.net.G.input_shape)
        self.net.example_labels = self.net.example_latents
        self.net.latents_var = T.TensorType(
            'float32',
            [False] * len(self.net.example_latents.shape))('latents_var')
        self.net.labels_var = T.TensorType(
            'float32',
            [False] * len(self.net.example_latents.shape))('labels_var')

        self.net.images_expr = self.net.G.eval(self.net.latents_var,
                                               self.net.labels_var,
                                               ignore_unused_inputs=True)
        self.net.images_expr = misc.adjust_dynamic_range(
            self.net.images_expr, [-1, 1], [0, 1])
        train.imgapi_compile_gen_fn(self.net)

        self.invert_models = def_invert_models(self.net,
                                               layer='conv4',
                                               alpha=0.002)
Пример #2
0
def imgapi_load_net(run_id, snapshot=None, random_seed=1000, num_example_latents=1000, load_dataset=True, compile_gen_fn=True):
    class Net: pass
    net = Net()
    net.result_subdir = misc.locate_result_subdir(run_id)
    net.network_pkl = misc.locate_network_pkl(net.result_subdir, snapshot)
    _, _, net.G = misc.load_pkl(net.network_pkl)

    # Generate example latents and labels.
    np.random.seed(random_seed)
    net.example_latents = random_latents(num_example_latents, net.G.input_shape)
    net.example_labels = np.zeros((num_example_latents, 0), dtype=np.float32)
    net.dynamic_range = [0, 255]
    if load_dataset:
        imgapi_load_dataset(net)

    # Compile Theano func.
    net.latents_var = T.TensorType('float32', [False] * len(net.example_latents.shape))('latents_var')
    net.labels_var  = T.TensorType('float32', [False] * len(net.example_labels.shape)) ('labels_var')

    if hasattr(net.G, 'cur_lod'):
        net.lod = net.G.cur_lod.get_value()
        net.images_expr = net.G.eval(net.latents_var, net.labels_var, min_lod=net.lod, max_lod=net.lod, ignore_unused_inputs=True)
    else:
        net.lod = 0.0
        net.images_expr = net.G.eval(net.latents_var, net.labels_var, ignore_unused_inputs=True)

    net.images_expr = misc.adjust_dynamic_range(net.images_expr, [-1,1], net.dynamic_range)
    if compile_gen_fn:
        imgapi_compile_gen_fn(net)
    return net
def generate(network_pkl, out_dir):
    if os.path.exists(out_dir):
        raise ValueError('{} already exists'.format(out_dir))
    misc.init_output_logging()
    np.random.seed(config.random_seed)
    tfutil.init_tf(config.tf_config)
    with tf.device('/gpu:0'):
        G, D, Gs = misc.load_pkl(network_pkl)
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)
    # grid_size, grid_reals, grid_labels, grid_latents = train.setup_snapshot_image_grid(G, training_set, **config.grid)
    number_of_images = 1000
    grid_labels = np.zeros([number_of_images, training_set.label_size],
                           dtype=training_set.label_dtype)
    grid_latents = misc.random_latents(number_of_images, G)
    total_kimg = config.train.total_kimg
    sched = train.TrainingSchedule(total_kimg * 1000, training_set,
                                   **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)
    os.makedirs(out_dir)
    # print(np.min(grid_fakes), np.mean(grid_fakes), np.max(grid_fakes))
    # misc.save_image_grid(grid_fakes, 'fakes.png', drange=[-1,1], grid_size=grid_size)
    for i, img in enumerate(grid_fakes):
        img = img.transpose((1, 2, 0))
        img = np.clip(img, -1, 1)
        img = (1 + img) / 2
        img = skimage.img_as_ubyte(img)
        imageio.imwrite(os.path.join(out_dir, '{}.png'.format(i)),
                        img[..., :3])
        if img.shape[-1] > 3:
            np.save(os.path.join(out_dir, '{}.npy'.format(i)), img)
Пример #4
0
def read_epoch(fname):
    from epoch import Epoched
    if '.pkl' in fname:
        return load_pkl(fname)
    elif '.mat' in fname:
        matfile = spio.loadmat(fname, squeeze_me=True, struct_as_record=False)
        mat_epoched = matfile['epoched']
        epoched = Epoched(0, (0, 0), 0)
        epoched.__dict__.update(mat_epoched.__dict__)
        return epoched
Пример #5
0
def hello():
    tfutil.init_tf(config.tf_config)
    with tf.device('/gpu:0'):
        G, D, Gs = misc.load_pkl(resume_network_pkl)

    imsize = Gs.output_shape[-1]
    selected_textures = misc.random_latents(1, Gs)
    selected_shapes = get_random_mask(1)
    selected_colors = get_random_color(1)
    fake_images = Gs.run(selected_textures, selected_colors, selected_shapes)

    return "DCGAN endpoint -> /predict "
Пример #6
0
    def __init__(self):
        dataset_name = type(self).__name__
        base = "/mnt/projects/counting/Saves/main/"

        if "Pascal" in dataset_name:
            self.lcfcn_path = base + "dataset:Pascal2007_model:Res50FCN_metric:mRMSE_loss:water_loss_B_config:basic/"

        elif "CityScapes" in dataset_name:
            self.lcfcn_path = base + "dataset:CityScapes_model:Res50FCN_metric:mRMSE_loss:water_loss_B_config:basic/"

        elif "CocoDetection2014" in dataset_name:
            self.lcfcn_path = base + "dataset:CocoDetection2014_model:Res50FCN_metric:mRMSE_loss:water_loss_B_config:sample3000/"

        elif "Kitti" in dataset_name:
            self.lcfcn_path = base + "dataset:Kitti_model:Res50FCN_metric:mRMSE_loss:water_loss_B_config:basic/"
            self.proposals_path = "/mnt/datasets/public/issam/kitti/ProposalsSharp/"

        elif "Plants" in dataset_name:
            self.lcfcn_path = base + "dataset:Plants_model:Res50FCN_metric:mRMSE_loss:water_loss_B_config:basic/"

        else:
            raise

        fname = base + "lcfcn_points/{}.pkl".format(dataset_name)

        if os.path.exists(fname):
            history = ms.load_pkl(self.lcfcn_path + "history.pkl")
            self.pointDict = ms.load_pkl(fname)

            if self.pointDict["best_model"]["epoch"] != history["best_model"][
                    "epoch"]:
                reset = "reset"
        else:
            if dataset_name == "PascalSmall":
                self.pointDict = ms.load_pkl(
                    fname.replace("PascalSmall", "Pascal2012"))
            else:
                import ipdb
                ipdb.set_trace()  # breakpoint 5f76e230 //
Пример #7
0
def predict():

    tfutil.init_tf(config.tf_config)
    with tf.device('/gpu:0'):
        G, D, Gs = misc.load_pkl(resume_network_pkl)
    imsize = Gs.output_shape[-1]

    random_masks = []
    temp = Image.open(request.files['image']).convert('L')
    temp = temp.resize((imsize, imsize))
    temp = (np.float32(temp) - 127.5) / 127.5
    temp = temp.reshape((1, 1, imsize, imsize))
    random_masks.append(temp)
    masks = np.vstack(random_masks)
    #masks = get_random_mask(1)

    ctemp = []
    ctemp.append(float(request.form['R']))
    ctemp.append(float(request.form['G']))
    ctemp.append(float(request.form['B']))
    colors = np.array([ctemp], dtype=object)
    #colors = get_random_color(1)

    texid = -1
    selected_textures = None
    if request.form['texflag'] == "true":
        selected_textures = misc.random_latents(1, Gs)
        texture_list.append(selected_textures[0])
        texid = len(texture_list) - 1
    else:
        selected_textures = np.array(
            [texture_list[int(request.form['texid'])]], dtype=object)
        texid = int(request.form['texid'])
    #selected_textures = misc.random_latents(1, Gs)

    fake_images = Gs.run(selected_textures, colors, masks)
    fake_images = convert_to_image(fake_images)
    matplotlib.image.imsave('localtemp.png', fake_images[0])

    conv_image = Image.open('localtemp.png')
    buffered = io.BytesIO()
    conv_image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue())
    #jsonify({"image": str(img_str), "id": texid})
    return jsonify({
        "image": str(img_str)[2:-1],
        "id": texid
    })
Пример #8
0
    def run(self,
            network_pkl,
            run_dir=None,
            dataset_args=None,
            mirror_augment=None,
            num_gpus=1,
            tf_config=None,
            log_results=True):
        self._network_pkl = network_pkl
        self._dataset_args = dataset_args
        self._mirror_augment = mirror_augment
        self._results = []

        if (dataset_args is None
                or mirror_augment is None) and run_dir is not None:
            run_config = misc.parse_config_for_previous_run(run_dir)
            self._dataset_args = dict(run_config['dataset'])
            self._dataset_args['shuffle_mb'] = 0
            self._mirror_augment = run_config['train'].get(
                'mirror_augment', False)

        time_begin = time.time()
        with tf.Graph().as_default(), tflib.create_session(
                tf_config).as_default():  # pylint: disable=not-context-manager
            _G, _D, Gs = misc.load_pkl(self._network_pkl)
            self._evaluate(Gs, num_gpus=num_gpus)
        self._eval_time = time.time() - time_begin

        if log_results:
            result_str = self.get_result_str()
            if run_dir is not None:
                log = os.path.join(run_dir, 'metric-%s.txt' % self.name)
                with dnnlib.util.Logger(log, 'a'):
                    print(result_str)
            else:
                print(result_str)
Пример #9
0
import os
import sys
import time
import glob
import shutil
import operator
import theano
import lasagne 
import dataset
import network
from theano import tensor as T
import config
import misc
import numpy as np
import scipy.ndimage
_, _, G = misc.load_pkl("network-snapshot-009041.pkl")

class Net: pass

net = Net()
net.G = G

import train

num_example_latents = 10
net.example_latents = train.random_latents(num_example_latents, net.G.input_shape)
net.example_labels = net.example_latents
net.latents_var = T.TensorType('float32', [False] * len(net.example_latents.shape))('latents_var')
net.labels_var  = T.TensorType('float32', [False] * len(net.example_latents.shape)) ('labels_var')

print("HIYA", net.example_latents[:1].shape, net.example_labels[:1].shape)
Пример #10
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)
    #resume_run_id = '/dresden/users/mk1391/evl/pggan_logs/logs_celeba128cc/fsg16_results_0/000-pgan-celeba-preset-v2-2gpus-fp32/network-snapshot-010211.pkl'
    resume_with_new_nets = False
    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            rG, rD, rGs = misc.load_pkl(network_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                Gs = rGs
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    ### pyramid draw fsg (comment out for actual training to happen)
    #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'pggan_fsg_draw.png'))
    #print('>>> done printing fsgs.')
    #return

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    ### shift reals
    print('>>> reals shape: ', grid_reals.shape)
    fc_x = 0.5
    fc_y = 0.5
    im_size = grid_reals.shape[-1]
    kernel_loc = 2.*np.pi*fc_x * np.arange(im_size).reshape((1, 1, im_size)) + \
        2.*np.pi*fc_y * np.arange(im_size).reshape((1, im_size, 1))
    kernel_cos = np.cos(kernel_loc)
    kernel_sin = np.sin(kernel_loc)
    reals_t = (grid_reals / 255.) * 2. - 1
    reals_t *= kernel_cos
    grid_reals_sh = np.rint(
        (reals_t + 1.) * 255. / 2.).clip(0, 255).astype(np.uint8)
    ### end shift reals
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    ### fft drawing
    #sys.path.insert(1, '/home/mahyar/CV_Res/ganist')
    #from fig_draw import apply_fft_win
    #data_size = 1000
    #latents = np.random.randn(data_size, *Gs.input_shapes[0][1:])
    #labels = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:])
    #g_samples = Gs.run(latents, labels, minibatch_size=sched.minibatch//config.num_gpus)
    #g_samples = g_samples.transpose(0, 2, 3, 1)
    #print('>>> g_samples shape: {}'.format(g_samples.shape))
    #apply_fft_win(g_samples, 'fft_pggan_hann.png')
    ### end fft drawing

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    ### drawing shifted real images
    misc.save_image_grid(grid_reals_sh,
                         os.path.join(result_subdir, 'reals_sh.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    ### drawing shifted fake images
    misc.save_image_grid(grid_fakes * kernel_cos,
                         os.path.join(result_subdir, 'fakes%06d_sh.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    #### True cosine fft eval
    #fft_data_size = 1000
    #im_size = training_set.shape[1]
    #freq_centers = [(64/128., 64/128.)]
    #true_samples = sample_true(training_set, fft_data_size, dtype=training_set.dtype, batch_size=32).transpose(0, 2, 3, 1) / 255. * 2. - 1.
    #true_fft, true_fft_hann, true_hist = cosine_eval(true_samples, 'true', freq_centers, log_dir=result_subdir)
    #fractal_eval(true_samples, f'koch_snowflake_true', result_subdir)

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                ### drawing shifted fake images
                misc.save_image_grid(
                    grid_fakes * kernel_cos,
                    os.path.join(result_subdir,
                                 'fakes%06d_sh.png' % (cur_nimg // 1000)),
                    drange=drange_net,
                    grid_size=grid_size)
                ### drawing fsg
                #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'fakes%06d_fsg_draw.png' % (cur_nimg // 1000)))
                ### Gen fft eval
                #gen_samples = sample_gen(Gs, fft_data_size).transpose(0, 2, 3, 1)
                #print(f'>>> fake_samples: max={np.amax(grid_fakes)} min={np.amin(grid_fakes)}')
                #print(f'>>> gen_samples: max={np.amax(gen_samples)} min={np.amin(gen_samples)}')
                #misc.save_image_grid(gen_samples[:25], os.path.join(result_subdir, 'fakes%06d_gsample.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
                #cosine_eval(gen_samples, f'gen_{cur_nimg//1000:06d}', freq_centers, log_dir=result_subdir, true_fft=true_fft, true_fft_hann=true_fft_hann, true_hist=true_hist)
                #fractal_eval(gen_samples, f'koch_snowflake_fakes{cur_nimg//1000:06d}', result_subdir)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Пример #11
0
    def _evaluate(self, Gs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(
            './inception_v3_features.pkl')  # inception_v3_features.pkl
        activations = np.empty([self.num_images, inception.output_shape[1]],
                               dtype=np.float32)

        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mu_real, sigma_real = misc.load_pkl(cache_file)
        else:
            for idx, images in enumerate(
                    self._iterate_reals(minibatch_size=minibatch_size)):
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                activations[begin:end] = inception.run(images[:end - begin],
                                                       num_gpus=num_gpus,
                                                       assume_frozen=True)
                if end == self.num_images:
                    break
            mu_real = np.mean(activations, axis=0)
            sigma_real = np.cov(activations, rowvar=False)
            misc.save_pkl((mu_real, sigma_real), cache_file)

        # Construct TensorFlow graph.
        # different from stylegan
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                # beholdergan-id
                # labels = misc.make_rand_labels(self.minibatch_per_gpu,dims=labelsize)
                # labels = tf.constant(labels)

                # beholdergan
                labels = tf.constant(
                    np.zeros([self.minibatch_per_gpu, labelsize],
                             dtype=np.float32))

                #stylegan
                # images = Gs_clone.get_output_for(latents)

                # CGANs
                # images = Gs_clone.get_output_for(latents, labels)

                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate statistics for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr),
                                                    axis=0)[:end - begin]
        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)

        # Calculate FID.
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)  # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        self._report_result(np.real(dist))
Пример #12
0
    def __getitem__(self, index):
        name_id = self.ids[index]
        name = self.image_names[index]

        image = np.array(
            Image.open(self.path +
                       "/{}/{}".format(self.split +
                                       self.year, name)).convert('RGB'))
        points = np.zeros(image.shape[:2], "uint8")[:, :, None]
        counts = np.zeros(80)
        maskVoid = np.zeros(points.shape[:2])
        annList = ms.load_pkl(
            self.path + "/groundtruth/{}_{}.pkl".format(self.split, name))

        h, w, _ = image.shape
        maskClasses = np.zeros((h, w), int)
        maskObjects = np.zeros((h, w), int)
        for obj_id, ann in enumerate(annList):
            mask = maskUtils.decode(COCO.annToRLE_issam(h, w, ann))
            if ann["iscrowd"]:
                maskVoid += mask
            else:

                dist = distance_transform_edt(mask)
                yx = np.unravel_index(dist.argmax(), dist.shape)

                label = self.category2label[int(ann["category_id"])]
                points[yx] = label
                counts[label - 1] += 1

                assert mask.max() <= 1
                mask_ind = mask == 1

                maskObjects[mask_ind] = obj_id + 1
                maskClasses[mask_ind] = label

        maskVoid = maskVoid.clip(0, 1)
        assert maskVoid.max() <= 1
        counts = torch.LongTensor(counts)

        image, points, maskObjects, maskClasses = self.transform_function(
            [image, points, maskObjects, maskClasses])

        # Sharp Proposals
        image_id = int(name[:-4].split("_")[-1])
        SharpProposals_name = self.proposals_path + "{}".format(image_id)
        lcfcn_pointList = self.get_lcfcn_pointList(str(image_id))
        assert image_id == name_id
        if self.split == "train":
            return {
                "images": image,
                "points": points.squeeze(),
                "SharpProposals_name": str(name_id),
                "counts": counts,
                "index": index,
                "name": str(name_id),
                "image_id": str(image_id),
                "maskObjects": maskObjects * 0,
                "maskClasses": maskClasses * 0,
                "proposals_path": self.proposals_path,
                "dataset": "coco2014",
                "lcfcn_pointList": lcfcn_pointList,
                "split": self.split
            }
        else:
            return {
                "images": image,
                "points": points.squeeze(),
                "SharpProposals_name": str(name_id),
                "counts": counts,
                "index": index,
                "name": str(name_id),
                "image_id": str(image_id),
                "maskObjects": maskObjects,
                "maskClasses": maskClasses,
                "proposals_path": self.proposals_path,
                "dataset": "coco2014",
                "lcfcn_pointList": lcfcn_pointList,
                "split": self.split
            }
Пример #13
0
def train_gan(
        separate_funcs=False,
        D_training_repeats=1,
        G_learning_rate_max=0.0010,
        D_learning_rate_max=0.0010,
        G_smoothing=0.999,
        adam_beta1=0.0,
        adam_beta2=0.99,
        adam_epsilon=1e-8,
        minibatch_default=16,
        minibatch_overrides={},
        rampup_kimg=40 / speed_factor,
        rampdown_kimg=0,
        lod_initial_resolution=4,
        lod_training_kimg=400 / speed_factor,
        lod_transition_kimg=400 / speed_factor,
        #lod_training_kimg       = 40,
        #lod_transition_kimg     = 40,
        total_kimg=10000 / speed_factor,
        dequantize_reals=False,
        gdrop_beta=0.9,
        gdrop_lim=0.5,
        gdrop_coef=0.2,
        gdrop_exp=2.0,
        drange_net=[-1, 1],
        drange_viz=[-1, 1],
        image_grid_size=None,
        #tick_kimg_default       = 1,
        tick_kimg_default=50 / speed_factor,
        tick_kimg_overrides={
            32: 20,
            64: 10,
            128: 10,
            256: 5,
            512: 2,
            1024: 1
        },
        image_snapshot_ticks=4,
        network_snapshot_ticks=40,
        image_grid_type='default',
        #resume_network_pkl      = '006-celeb128-progressive-growing/network-snapshot-002009.pkl',
        resume_network_pkl=None,
        resume_kimg=0,
        resume_time=0.0):

    # Load dataset and build networks.
    training_set, drange_orig = load_dataset()

    print "*************** test the format of dataset ***************"
    print training_set
    print drange_orig
    # training_set是dataset模块解析h5之后的对象,
    # drange_orig 为training_set.get_dynamic_range()

    if resume_network_pkl:
        print 'Resuming', resume_network_pkl
        G, D, _ = misc.load_pkl(
            os.path.join(config.result_dir, resume_network_pkl))
    else:
        G = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.G)
        D = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.D)
    Gs = G.create_temporally_smoothed_version(beta=G_smoothing,
                                              explicit_updates=True)

    # G,D对象可以由misc解析pkl之后生成,也可以由network模块构造

    print G
    print D

    #misc.print_network_topology_info(G.output_layers)
    #misc.print_network_topology_info(D.output_layers)

    # Setup snapshot image grid.
    # 设置中途输出图片的格式
    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[3], G.output_shape[2]
            image_grid_size = np.clip(1920 / w, 3,
                                      16), np.clip(1080 / h, 2, 16)
        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    # Theano input variables and compile generation func.
    print 'Setting up Theano...'
    real_images_var = T.TensorType('float32', [False] *
                                   len(D.input_shape))('real_images_var')
    # <class 'theano.tensor.var.TensorVariable'>
    # print type(real_images_var),real_images_var
    real_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('real_labels_var')
    fake_latents_var = T.TensorType('float32', [False] *
                                    len(G.input_shape))('fake_latents_var')
    fake_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('fake_labels_var')
    # 带有_var的均为输入张量
    G_lrate = theano.shared(np.float32(0.0))
    D_lrate = theano.shared(np.float32(0.0))
    # share语法就是用来设定默认值的,返回复制的对象
    gen_fn = theano.function([fake_latents_var, fake_labels_var],
                             Gs.eval_nd(fake_latents_var,
                                        fake_labels_var,
                                        ignore_unused_inputs=True),
                             on_unused_input='ignore')

    # gen_fn 是一个函数,输入为:[fake_latents_var, fake_labels_var],
    #                  输出位:Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True),

    #生成函数

    # Misc init.
    #读入当前分辨率
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    #lod 精细度
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    # Save example images.
    snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels)
    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    # Training loop.
    # 这里才是主训练入口
    # 注意在训练过程中不会跳出最外层while循环,因此更换分辨率等操作必然在while循环里

    #现有图片数
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0

    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time
    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        #计算当前精细度
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg +
                                                           lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        # 当前分辨率
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        # 更新网络结构
        lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                           rampdown_kimg)
        G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max))
        D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max))

        if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod))
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod))

        # Setup training func for current LOD.
        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))

        #print " cur_lod%f\n  min_lod %f\n new_min_lod %f\n max_lod %f\n new_max_lod %f\n"%(cur_lod,min_lod,new_min_lod,max_lod,new_max_lod)

        if min_lod != new_min_lod or max_lod != new_max_lod:
            print 'Compiling training funcs...'
            min_lod, max_lod = new_min_lod, new_max_lod

            # Pre-process reals.
            real_images_expr = real_images_var
            if dequantize_reals:
                rnd = theano.sandbox.rng_mrg.MRG_RandomStreams(
                    lasagne.random.get_rng().randint(1, 2147462579))
                epsilon_noise = rnd.uniform(size=real_images_expr.shape,
                                            low=-0.5,
                                            high=0.5,
                                            dtype='float32')
                real_images_expr = T.cast(
                    real_images_expr, 'float32'
                ) + epsilon_noise  # match original implementation of Improved Wasserstein
            real_images_expr = misc.adjust_dynamic_range(
                real_images_expr, drange_orig, drange_net)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=2)
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=3)

            # Optimize loss.
            G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss(
                G, D, min_lod, max_lod, real_images_expr, real_labels_var,
                fake_latents_var, fake_labels_var, **config.loss)
            G_updates = adam(G_loss,
                             G.trainable_params(),
                             learning_rate=G_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            D_updates = adam(D_loss,
                             D.trainable_params(),
                             learning_rate=D_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            D_train_fn = theano.function([
                real_images_var, real_labels_var, fake_latents_var,
                fake_labels_var
            ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                         updates=D_updates,
                                         on_unused_input='ignore')
            G_train_fn = theano.function([fake_latents_var, fake_labels_var],
                                         [],
                                         updates=G_updates + Gs.updates,
                                         on_unused_input='ignore')

        for idx in xrange(D_training_repeats):
            mb_reals, mb_labels = training_set.get_random_minibatch(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)

            print "******* test minibatch"
            print "mb_reals"
            print idx, D_training_repeats
            print mb_reals.shape, mb_labels.shape
            #print mb_reals
            print "mb_labels"
            #print mb_labels

            mb_train_out = D_train_fn(
                mb_reals, mb_labels,
                random_latents(minibatch_size, G.input_shape),
                random_labels(minibatch_size, training_set))
            cur_nimg += minibatch_size
            tick_train_out.append(mb_train_out)
        G_train_fn(random_latents(minibatch_size, G.input_shape),
                   random_labels(minibatch_size, training_set))

        # Fade in D noise if we're close to becoming unstable
        fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            D.gdrop_strength.set_value(np.float32(gdrop_strength))

        # Perform maintenance operations once per tick.
        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []

            # Print progress.
            print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
                (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size,
                 misc.format_time(cur_time - train_start_time), tick_time,
                 tick_time / tick_kimg, gdrop_strength) + tick_train_avg)

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = gen_fn(snapshot_fake_latents,
                                              snapshot_fake_labels)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            # Save network snapshot every N ticks.
            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg / 1000)))

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    training_set.close()
    print 'Done.'
    with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'):
        pass
def train_progressive_gan(
    G_smoothing             = 0.999,        # Exponential running average of generator weights.
    D_repeats               = 1,            # How many times the discriminator is trained per G iteration.
    minibatch_repeats       = 4,            # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod   = True,         # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 15000,        # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,        # Enable mirror augment?
    drange_net              = [-1,1],       # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 1,            # How often to export image snapshots?
    network_snapshot_ticks  = 10,           # How often to export network snapshots?
    save_tf_graph           = False,        # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,        # Include weight histograms in the tfevents file?
    resume_run_id           = None,         # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot         = None,         # Snapshot index to resume training from, None = autodetect.
    resume_kimg             = 0.0,          # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0):         # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G)
            D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels   = training_set.get_minibatch_tf()
        reals_split     = tf.split(reals, config.num_gpus)
        labels_split    = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
                G_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                D_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, **config.D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
                cur_nimg += sched.minibatch
            tfutil.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % (
                tfutil.autosummary('Progress/tick', cur_tick),
                tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                tfutil.autosummary('Progress/lod', sched.lod),
                tfutil.autosummary('Progress/minibatch', sched.minibatch),
                misc.format_time(tfutil.autosummary('Timing/total_sec', total_time)),
                tfutil.autosummary('Timing/sec_per_tick', tick_time),
                tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                tfutil.autosummary('Timing/maintenance_sec', maintenance_time)))
            tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Пример #15
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('-e', '--exp')
    parser.add_argument('-b', '--borgy', default=0, type=int)
    parser.add_argument('-br', '--borgy_running', default=0, type=int)
    parser.add_argument('-m', '--mode', default="summary")
    parser.add_argument('-r', '--reset', default="None")
    parser.add_argument('-s', '--status', type=int, default=0)
    parser.add_argument('-k', '--kill', type=int, default=0)
    parser.add_argument('-g', '--gpu', type=int)
    parser.add_argument('-c', '--configList', nargs="+", default=None)
    parser.add_argument('-l', '--lossList', nargs="+", default=None)
    parser.add_argument('-d', '--datasetList', nargs="+", default=None)
    parser.add_argument('-metric', '--metricList', nargs="+", default=None)
    parser.add_argument('-model', '--modelList', nargs="+", default=None)
    parser.add_argument('-p', '--predictList', nargs="+", default=None)

    args = parser.parse_args()

    if args.borgy or args.kill:
        global_prompt = input("Do all? \n(y/n)\n")

    # SEE IF CUDA IS AVAILABLE
    assert torch.cuda.is_available()
    print("CUDA: %s" % torch.version.cuda)
    print("Pytroch: %s" % torch.__version__)

    mode = args.mode
    exp_name = args.exp

    exp_dict = experiments.get_experiment_dict(args, exp_name)

    pp_main = None
    results = {}

    # Get Main Class
    project_name = os.path.realpath(__file__).split("/")[-2]
    MC = ms.MainClass(path_models="models",
                      path_datasets="datasets",
                      path_metrics="metrics/metrics.py",
                      path_losses="losses/losses.py",
                      path_samplers="addons/samplers.py",
                      path_transforms="addons/transforms.py",
                      path_saves="/mnt/projects/counting/Saves/main/",
                      project=project_name)

    key_set = set()
    for model_name, config_name, metric_name, dataset_name, loss_name in product(
            exp_dict["modelList"], exp_dict["configList"],
            exp_dict["metricList"], exp_dict["datasetList"],
            exp_dict["lossList"]):

        # if model_name in ["LC_RESFCN"]:
        #   loss_name = "water_loss"

        config = configs.get_config_dict(config_name)

        key = ("{} - {} - {}".format(model_name, config_name, loss_name),
               "{}_({})".format(dataset_name, metric_name))

        if key in key_set:
            continue

        key_set.add(key)

        main_dict = MC.get_main_dict(mode, dataset_name, model_name,
                                     config_name, config, args.reset,
                                     exp_dict["epochs"], metric_name,
                                     loss_name)
        main_dict["predictList"] = exp_dict["predictList"]

        if mode == "paths":
            print("\n{}_({})".format(dataset_name, model_name))
            print(main_dict["path_best_model"])
            # print( main_dict["exp_name"])

        predictList_str = ' '.join(exp_dict["predictList"])

        if args.status:
            results[key] = borgy.borgy_status(mode, config_name, metric_name,
                                              model_name, dataset_name,
                                              loss_name, args.reset,
                                              predictList_str)

            continue

        if args.kill:
            results[key] = borgy.borgy_kill(mode, config_name, metric_name,
                                            model_name, dataset_name,
                                            loss_name, args.reset,
                                            predictList_str)
            continue

        if args.borgy:
            results[key] = borgy.borgy_submit(project_name, global_prompt,
                                              mode, config_name, metric_name,
                                              model_name, dataset_name,
                                              loss_name, args.reset,
                                              predictList_str)

            continue

        if mode == "debug":
            debug.debug(main_dict)

        if mode == "validate":
            validate.validate(main_dict)
        if mode == "save_gam_points":
            train_set, _ = au.load_trainval(main_dict)
            model = ms.load_best_model(main_dict)
            for i in range(len(train_set)):
                print(i, "/", len(train_set))
                batch = ms.get_batch(train_set, [i])
                fname = train_set.path + "/gam_{}.pkl".format(
                    batch["index"].item())
                points = model.get_points(batch)
                ms.save_pkl(fname, points)
            import ipdb
            ipdb.set_trace()  # breakpoint ee49ab9f //

        if mode == "save_prm_points":
            train_set, _ = au.load_trainval(main_dict)
            model = ms.load_best_model(main_dict)
            for i in range(len(train_set)):
                print(i, "/", len(train_set))
                batch = ms.get_batch(train_set, [i])

                fname = "{}/prm{}.pkl".format(batch["path"][0],
                                              batch["name"][0])
                points = model.get_points(batch)
                ms.save_pkl(fname, points)
            import ipdb
            ipdb.set_trace()  # breakpoint 679ce152 //

            # train_set, _ = au.load_trainval(main_dict)
            # model = ms.load_best_model(main_dict)
            # for i in range(len(train_set)):
            #   print(i, "/", len(train_set))
            #   batch = ms.get_batch(train_set, [i])
            #   fname = train_set.path + "/gam_{}.pkl".format(batch["index"].item())
            #   points = model.get_points(batch)
            #   ms.save_pkl(fname, points)

        # if mode == "pascal_annList":
        #   data_utils.pascal2lcfcn_points(main_dict)
        if mode == "upperboundmasks":
            import ipdb
            ipdb.set_trace()  # breakpoint 02fac8ce //

            results = au.test_upperboundmasks(main_dict, reset=args.reset)
            print(pd.DataFrame(results))

        if mode == "model":

            results = au.test_model(main_dict, reset=args.reset)
            print(pd.DataFrame(results))

        if mode == "upperbound":
            results = au.test_upperbound(main_dict, reset=args.reset)

            print(pd.DataFrame(results))

        if mode == "MUCov":
            gtAnnDict = au.load_gtAnnDict(main_dict, reset=args.reset)

            # model = ms.load_best_model(main_dict)
            fname = main_dict["path_save"] + "/pred_annList.pkl"
            if not os.path.exists(fname):
                _, val_set = au.load_trainval(main_dict)
                model = ms.load_best_model(main_dict)
                pred_annList = au.dataset2annList(model,
                                                  val_set,
                                                  predict_method="BestDice",
                                                  n_val=None)
                ms.save_pkl(fname, pred_annList)

            else:
                pred_annList = ms.load_pkl(fname)
            import ipdb
            ipdb.set_trace()  # breakpoint 527a7f36 //
            pred_annList = au.load_predAnnList(main_dict,
                                               predict_method="BestObjectness")
            # 0.31 best objectness pred_annList =
            # 0.3482122335421256
            # au.get_MUCov(gtAnnDict, pred_annList)
            au.get_SBD(gtAnnDict, pred_annList)

        if mode == "dic_sbd":
            import ipdb
            ipdb.set_trace()  # breakpoint 4af08a17 //

        if mode == "point_mask":
            from datasets import base_dataset

            import ipdb
            ipdb.set_trace()  # breakpoint 7fd55e0c //
            _, val_set = ms.load_trainval(main_dict)
            batch = ms.get_batch(val_set, [1])
            model = ms.load_best_model(main_dict)
            pred_dict = model.LCFCN.predict(batch)
            # ms.pretty_vis(batch["images"], base_dataset.batch2annList(batch))
            ms.images(ms.pretty_vis(
                batch["images"],
                model.LCFCN.predict(batch,
                                    predict_method="original")["annList"]),
                      win="blobs")
            ms.images(ms.pretty_vis(batch["images"],
                                    base_dataset.batch2annList(batch)),
                      win="erww")
            ms.images(batch["images"],
                      batch["points"],
                      denorm=1,
                      enlarge=1,
                      win="e21e")
            import ipdb
            ipdb.set_trace()  # breakpoint ab9240f0 //

        if mode == "lcfcn_output":
            import ipdb
            ipdb.set_trace()  # breakpoint 7fd55e0c //

            gtAnnDict = au.load_gtAnnDict(main_dict, reset=args.reset)

        if mode == "load_gtAnnDict":
            _, val_set = au.load_trainval(main_dict)
            gtAnnDict = au.load_gtAnnDict(val_set)

            # gtAnnClass = COCO(gtAnnDict)
            # au.assert_gtAnnDict(main_dict, reset=None)
            # _,val_set = au.load_trainval(main_dict)
            # annList_path = val_set.annList_path

            # fname_dummy = annList_path.replace(".json","_best.json")
            # predAnnDict = ms.load_json(fname_dummy)
            import ipdb
            ipdb.set_trace()  # breakpoint 100bfe1b //
            pred_annList = ms.load_pkl(main_dict["path_best_annList"])
            # model = ms.load_best_model(main_dict)
            _, val_set = au.load_trainval(main_dict)
            batch = ms.get_batch(val_set, [1])

            import ipdb
            ipdb.set_trace()  # breakpoint 2310bb33 //
            model = ms.load_best_model(main_dict)
            pred_dict = model.predict(batch, "BestDice", "mcg")
            ms.images(batch["images"],
                      au.annList2mask(pred_dict["annList"])["mask"],
                      denorm=1)
            # pointList2UpperBoundMCG
            pred_annList = au.load_predAnnList(main_dict,
                                               predict_method="BestDice",
                                               proposal_type="mcg",
                                               reset="reset")
            # annList = au.pointList2UpperBoundMCG(batch["lcfcn_pointList"], batch)["annList"]
            ms.images(batch["images"],
                      au.annList2mask(annList)["mask"],
                      denorm=1)
            pred_annList = au.load_BestMCG(main_dict, reset="reset")
            # pred_annList = au.dataset2annList(model, val_set,
            #                   predict_method="BestDice",
            #                   n_val=None)
            au.get_perSizeResults(gtAnnDict, pred_annList)

        if mode == "vis":
            _, val_set = au.load_trainval(main_dict)
            batch = ms.get_batch(val_set, [3])

            import ipdb
            ipdb.set_trace()  # breakpoint 05e6ef16 //

            vis.visBaselines(batch)

            model = ms.load_best_model(main_dict)
            vis.visBlobs(model, batch)

        if mode == "qual":
            model = ms.load_best_model(main_dict)
            _, val_set = au.load_trainval(main_dict)
            path = "/mnt/home/issam/Summaries/{}_{}".format(
                dataset_name, model_name)
            try:
                ms.remove_dir(path)
            except:
                pass
            n_images = len(val_set)
            base = "{}_{}".format(dataset_name, model_name)
            for i in range(50):
                print(i, "/10", "- ", base)
                index = np.random.randint(0, n_images)
                batch = ms.get_batch(val_set, [index])
                if len(batch["lcfcn_pointList"]) == 0:
                    continue
                image = vis.visBlobs(model, batch, return_image=True)

                # image_baselines = vis.visBaselines(batch, return_image=True)
                # imgAll = np.concatenate([image, image_baselines], axis=1)

                fname = path + "/{}_{}.png".format(i, base)
                ms.create_dirs(fname)
                ms.imsave(fname, image)

        if mode == "test_baselines":
            import ipdb
            ipdb.set_trace()  # breakpoint b51c5b1f //
            results = au.test_baselines(main_dict, reset=args.reset)
            print(pd.DataFrame(results))

        if mode == "test_best":
            au.test_best(main_dict)

        if mode == "qualitative":
            au.qualitative(main_dict)

        if mode == "figure1":
            from PIL import Image
            from addons import transforms
            model = ms.load_best_model(main_dict)
            _, val_set = au.load_trainval(main_dict)
            # proposals_path = "/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/ProposalsSharp/"
            # vidList = glob("/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/stuttgart_01/*")
            # vidList.sort()

            # pretty_image = ms.visPretty(model, batch = ms.get_batch(val_set, [i]), with_void=1, win="with_void")
            batch = ms.get_batch(val_set, [68])
            bestdice = ms.visPretty(model,
                                    batch=batch,
                                    with_void=0,
                                    win="no_void")
            blobs = ms.visPretty(model,
                                 batch=batch,
                                 predict_method="blobs",
                                 with_void=0,
                                 win="no_void")

            ms.images(bestdice, win="BestDice")
            ms.images(blobs, win="Blobs")
            ms.images(batch["images"], denorm=1, win="Image")
            ms.images(batch["images"],
                      batch["points"],
                      enlarge=1,
                      denorm=1,
                      win="Points")
            import ipdb
            ipdb.set_trace()  # breakpoint cf4bb3d3 //

        if mode == "video2":
            from PIL import Image
            from addons import transforms
            model = ms.load_best_model(main_dict)
            _, val_set = au.load_trainval(main_dict)
            # proposals_path = "/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/ProposalsSharp/"
            # vidList = glob("/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/stuttgart_01/*")
            # vidList.sort()
            index = 0
            for i in range(len(val_set)):

                # pretty_image = ms.visPretty(model, batch = ms.get_batch(val_set, [i]), with_void=1, win="with_void")
                batch = ms.get_batch(val_set, [i])
                pretty_image = ms.visPretty(model,
                                            batch=batch,
                                            with_void=0,
                                            win="no_void")
                # pred_dict = model.predict(batch, predict_method="BestDice")
                path_summary = main_dict["path_summary"]
                ms.create_dirs(path_summary + "/tmp")
                ms.imsave(
                    path_summary + "vid_mask_{}.png".format(index),
                    ms.get_image(batch["images"],
                                 batch["points"],
                                 enlarge=1,
                                 denorm=1))
                index += 1
                ms.imsave(path_summary + "vid_mask_{}.png".format(index),
                          pretty_image)
                index += 1
                # ms.imsave(path_summary+"vid1_full_{}.png".format(i), ms.get_image(img, pred_dict["blobs"], denorm=1))
                print(i, "/", len(val_set))

        if mode == "video":
            from PIL import Image
            from addons import transforms
            model = ms.load_best_model(main_dict)
            # _, val_set = au.load_trainval(main_dict)
            proposals_path = "/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/ProposalsSharp/"
            vidList = glob(
                "/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/stuttgart_01/*"
            )
            vidList.sort()
            for i, img_path in enumerate(vidList):
                image = Image.open(img_path).convert('RGB')
                image = image.resize((1200, 600), Image.BILINEAR)
                img, _ = transforms.Tr_WTP_NoFlip()([image, image])

                pred_dict = model.predict(
                    {
                        "images": img[None],
                        "split": ["test"],
                        "resized": torch.FloatTensor([1]),
                        "name": [ms.extract_fname(img_path)],
                        "proposals_path": [proposals_path]
                    },
                    predict_method="BestDice")
                path_summary = main_dict["path_summary"]
                ms.create_dirs(path_summary + "/tmp")
                ms.imsave(path_summary + "vid1_mask_{}.png".format(i),
                          ms.get_image(pred_dict["blobs"]))
                ms.imsave(path_summary + "vid1_full_{}.png".format(i),
                          ms.get_image(img, pred_dict["blobs"], denorm=1))
                print(i, "/", len(vidList))

        if mode == "5_eval_BestDice":
            gtAnnDict = au.load_gtAnnDict(main_dict)
            gtAnnClass = COCO(gtAnnDict)
            results = au.assert_gtAnnDict(main_dict, reset=None)

        if mode == "cp_annList":
            ms.dataset2cocoformat(dataset_name="CityScapes")

        if mode == "pascal2lcfcn_points":
            data_utils.pascal2lcfcn_points(main_dict)

        if mode == "cp2lcfcn_points":
            data_utils.cp2lcfcn_points(main_dict)

        if mode == "train":

            train.main(main_dict)
            import ipdb
            ipdb.set_trace()  # breakpoint a5d091b9 //

        if mode == "train_only":

            train.main(main_dict, train_only=True)
            import ipdb
            ipdb.set_trace()  # breakpoint a5d091b9 //

        if mode == "sharpmask2psfcn":
            for split in ["train", "val"]:
                root = "/mnt/datasets/public/issam/COCO2014/ProposalsSharp/"
                path = "{}/sharpmask/{}/jsons/".format(root, split)

                jsons = glob(path + "*.json")
                propDict = {}
                for k, json in enumerate(jsons):
                    print("{}/{}".format(k, len(jsons)))
                    props = ms.load_json(json)
                    for p in props:
                        if p["image_id"] not in propDict:
                            propDict[p["image_id"]] = []
                        propDict[p["image_id"]] += [p]

                for k in propDict.keys():
                    fname = "{}/{}.json".format(root, k)
                    ms.save_json(fname, propDict[k])

        if mode == "cp2coco":
            import ipdb
            ipdb.set_trace()  # breakpoint f2eb9e70 //
            dataset2cocoformat.cityscapes2cocoformat(main_dict)
            # train.main(main_dict)
            import ipdb
            ipdb.set_trace()  # breakpoint a5d091b9 //

        if mode == "train_lcfcn":
            train_lcfcn.main(main_dict)
            import ipdb
            ipdb.set_trace()  # breakpoint a5d091b9 //

        if mode == "summary":

            try:
                history = ms.load_history(main_dict)

                # if predictList_str == "MAE":
                #   results[key] = "{}/{}: {:.2f}".format(history["best_model"]["epoch"],
                #                                                           history["epoch"],
                #                                                           history["best_model"][metric_name])

                # else:
                val_dict = history["val"][-1]
                val_dict = history["best_model"]
                iou25 = val_dict["0.25"]
                iou5 = val_dict["0.5"]
                iou75 = val_dict["0.75"]
                results[key] = "{}/{}: {:.1f} - {:.1f} - {:.1f}".format(
                    val_dict["epoch"], history["epoch"], iou25 * 100,
                    iou5 * 100, iou75 * 100)
                # if history["val"][-1]["epoch"] != history["epoch"]:
                #   results[key] += " | Val {}".format(history["epoch"])
                try:
                    results[key] += " | {}/{}".format(
                        len(history["trained_batch_names"]),
                        history["train"][-1]["n_samples"])
                except:
                    pass
            except:
                pass
        if mode == "vals":

            history = ms.load_history(main_dict)

            for i in range(1, len(main_dict["predictList"]) + 1):
                if len(history['val']) == 0:
                    res = "NaN"
                    continue
                else:
                    res = history["val"][-i]

                map50 = res["map50"]
                map75 = res["map75"]

                # if map75 < 1e-3:
                #   continue

                string = "{} - {} - map50: {:.2f} - map75: {:.2f}".format(
                    res["epoch"], res["predict_name"], map50, map75)

                key_tmp = list(key).copy()
                key_tmp[1] += " {} - {}".format(metric_name,
                                                res["predict_name"])
                results[tuple(key_tmp)] = string

            # print("map75", pd.DataFrame(history["val"])["map75"].max())
            # df = pd.DataFrame(history["vals"][:20])["water_loss_B"]
            # print(df)
    try:
        print(ms.dict2frame(results))
    except:
        print("Results not printed...")
Пример #16
0
    def __init__(self, root, split, transform_function):
        super().__init__()
        self.split = split

        self.path = "/mnt/datasets/public/issam/VOCdevkit"

        self.categories = ms.load_json(
            "/mnt/datasets/public/issam/"
            "VOCdevkit/annotations/pascal_val2012.json")["categories"]

        # assert split in ['train', 'val', 'test']
        self.img_names = []
        self.mask_names = []
        self.cls_names = []

        base = "/mnt/projects/counting/Saves/main/"
        fname = base + "lcfcn_points/Pascal2012.pkl"
        self.pointDict = ms.load_pkl(fname)

        berkley_root = os.path.join(self.path, 'benchmark_RELEASE')
        pascal_root = os.path.join(self.path)

        data_dict = d_helpers.get_augmented_filenames(pascal_root,
                                                      berkley_root,
                                                      mode=1)
        # train
        assert len(data_dict["train_imgNames"]) == 10582
        assert len(data_dict["val_imgNames"]) == 1449

        berkley_path = berkley_root + '/dataset/'
        pascal_path = pascal_root + '/VOC2012/'

        corrupted = [
            "2008_005262", "2008_004172", "2008_004562", "2008_005145",
            "2008_008051", "2008_000763", "2009_000573"
        ]

        if split == 'train':
            for name in data_dict["train_imgNames"]:
                name_img = os.path.join(berkley_path, 'img/' + name + '.jpg')
                if os.path.exists(name_img):
                    name_img = name_img
                    name_mask = os.path.join(berkley_path,
                                             'cls/' + name + '.mat')
                else:
                    name_img = os.path.join(pascal_path,
                                            'JPEGImages/' + name + '.jpg')
                    name_mask = os.path.join(
                        pascal_path, 'SegmentationLabels/' + name + '.jpg')

                self.img_names += [name_img]
                self.mask_names += [name_mask]

        if split == 'train_small':
            for name in data_dict["train_imgNames"]:
                # name_img = os.path.join(berkley_path, 'img/' + name + '.jpg')
                # if os.path.exists(name_img):
                #     name_img = name_img
                #     name_mask = os.path.join(berkley_path, 'cls/' + name + '.mat')
                # else:
                #     name_img = os.path.join(pascal_path, 'JPEGImages/' + name + '.jpg')
                #     name_mask =  os.path.join(pascal_path, 'SegmentationLabels/' +  name + '.jpg')

                # self.img_names += [name_img]
                # self.mask_names += [name_mask]

                if name in corrupted:
                    continue

                name_img = os.path.join(pascal_path,
                                        'JPEGImages/' + name + '.jpg')
                name_mask = os.path.join(pascal_path,
                                         'SegmentationObject/' + name + '.png')
                name_cls = os.path.join(pascal_path,
                                        'SegmentationClass/' + name + '.png')

                if not os.path.exists(name_img):
                    name_img = os.path.join(berkley_path,
                                            'img/' + name + '.jpg')
                    name_mask = os.path.join(berkley_path,
                                             'inst/' + name + '.mat')
                    name_cls = os.path.join(berkley_path,
                                            'cls/' + name + '.mat')

                assert os.path.exists(name_img)
                if not os.path.exists(name_mask):
                    continue
                assert os.path.exists(name_cls)

                self.img_names += [name_img]
                self.mask_names += [name_mask]
                self.cls_names += [name_cls]

        elif split in ['val', "test"]:
            data_dict["val_imgNames"].sort()
            for k, name in enumerate(data_dict["val_imgNames"]):

                if name in corrupted:
                    continue
                name_img = os.path.join(pascal_path,
                                        'JPEGImages/' + name + '.jpg')
                name_mask = os.path.join(pascal_path,
                                         'SegmentationObject/' + name + '.png')
                name_cls = os.path.join(pascal_path,
                                        'SegmentationClass/' + name + '.png')

                if not os.path.exists(name_img):
                    name_img = os.path.join(berkley_path,
                                            'img/' + name + '.jpg')
                    name_mask = os.path.join(berkley_path,
                                             'inst/' + name + '.mat')
                    name_cls = os.path.join(berkley_path,
                                            'cls/' + name + '.mat')

                assert os.path.exists(name_img)
                assert os.path.exists(name_mask)
                assert os.path.exists(name_cls)

                self.img_names += [name_img]
                self.mask_names += [name_mask]
                self.cls_names += [name_cls]

        self.proposals_path = "/mnt/datasets/public/issam/VOCdevkit/VOC2012/ProposalsSharp/"
        if len(self.img_names) == 0:
            raise RuntimeError('Found 0 images, please check the data set')

        self.n_classes = 21
        self.transform_function = transform_function()

        self.ignore_index = 255
        self.pointsJSON = ms.jload(
            os.path.join('/mnt/datasets/public/issam/VOCdevkit/VOC2012',
                         'whats_the_point/data',
                         "pascal2012_trainval_main.json"))
        if split == "val":
            annList_path = self.path + "/annotations/{}_gt_annList.json".format(
                split)
            self.annList_path = annList_path
Пример #17
0
def train_classifier(
    smoothing=0.999,  # Exponential running average of encoder weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    lr_mirror_augment=True,  # Enable mirror augment?
    ud_mirror_augment=False,  # Enable up-down mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=10,  # How often to export image snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False
):  # Include weight histograms in the tfevents file?

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.training_set)
    validation_set = dataset.load_dataset(data_dir=config.data_dir,
                                          verbose=True,
                                          **config.validation_set)
    network_snapshot_ticks = total_kimg // 100  # How often to export network snapshots?

    # Construct networks.
    with tf.device('/gpu:0'):
        try:
            network_pkl = misc.locate_network_pkl()
            resume_kimg, resume_time = misc.resume_kimg_time(network_pkl)
            print('Loading networks from "%s"...' % network_pkl)
            EG, D_rec, EGs = misc.load_pkl(network_pkl)
        except:
            print('Constructing networks...')
            resume_kimg = 0.0
            resume_time = 0.0
            EG = tfutil.Network('EG',
                                num_channels=training_set.shape[0],
                                resolution=training_set.shape[1],
                                label_size=training_set.label_size,
                                **config.EG)
            D_rec = tfutil.Network('D_rec',
                                   num_channels=training_set.shape[0],
                                   resolution=training_set.shape[1],
                                   **config.D_rec)
            EGs = EG.clone('EGs')
        EGs_update_op = EGs.setup_as_moving_average_of(EG, beta=smoothing)
    EG.print_layers()
    D_rec.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    EG_opt = tfutil.Optimizer(name='TrainEG',
                              learning_rate=lrate_in,
                              **config.EG_opt)
    D_rec_opt = tfutil.Optimizer(name='TrainD_rec',
                                 learning_rate=lrate_in,
                                 **config.D_rec_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            EG_gpu = EG if gpu == 0 else EG.clone(EG.name + '_shadow_%d' % gpu)
            D_rec_gpu = D_rec if gpu == 0 else D_rec.clone(D_rec.name +
                                                           '_shadow_%d' % gpu)
            reals_fade_gpu, reals_orig_gpu = process_reals(
                reals_split[gpu], lod_in, lr_mirror_augment, ud_mirror_augment,
                training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('EG_loss'):
                EG_loss = tfutil.call_func_by_name(EG=EG_gpu,
                                                   D_rec=D_rec_gpu,
                                                   reals_orig=reals_orig_gpu,
                                                   labels=labels_gpu,
                                                   **config.EG_loss)
            with tf.name_scope('D_rec_loss'):
                D_rec_loss = tfutil.call_func_by_name(
                    EG=EG_gpu,
                    D_rec=D_rec_gpu,
                    D_rec_opt=D_rec_opt,
                    minibatch_size=minibatch_split,
                    reals_orig=reals_orig_gpu,
                    **config.D_rec_loss)
            EG_opt.register_gradients(tf.reduce_mean(EG_loss),
                                      EG_gpu.trainables)
            D_rec_opt.register_gradients(tf.reduce_mean(D_rec_loss),
                                         D_rec_gpu.trainables)
    EG_train_op = EG_opt.apply_updates()
    D_rec_train_op = D_rec_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, train_reals, train_labels = setup_snapshot_image_grid(
        training_set, drange_net, [450, 10], **config.grid)
    grid_size, val_reals, val_labels = setup_snapshot_image_grid(
        validation_set, drange_net, [450, 10], **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)

    train_recs, train_fingerprints, train_logits = EGs.run(
        train_reals, minibatch_size=sched.minibatch // config.num_gpus)
    train_preds = np.argmax(train_logits, axis=1)
    train_gt = np.argmax(train_labels, axis=1)
    train_acc = np.float32(np.sum(train_gt == train_preds)) / np.float32(
        len(train_gt))
    print('Training Accuracy = %f' % train_acc)

    val_recs, val_fingerprints, val_logits = EGs.run(
        val_reals, minibatch_size=sched.minibatch // config.num_gpus)
    val_preds = np.argmax(val_logits, axis=1)
    val_gt = np.argmax(val_labels, axis=1)
    val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32(len(val_gt))
    print('Validation Accuracy = %f' % val_acc)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(train_reals[::30, :, :, :],
                         os.path.join(result_subdir, 'train_reals.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(train_recs[::30, :, :, :],
                         os.path.join(result_subdir, 'train_recs-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(train_fingerprints[::30, :, :, :],
                         os.path.join(result_subdir,
                                      'train_fingerrints-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_reals[::30, :, :, :],
                         os.path.join(result_subdir, 'val_reals.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_recs[::30, :, :, :],
                         os.path.join(result_subdir, 'val_recs-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_fingerprints[::30, :, :, :],
                         os.path.join(result_subdir,
                                      'val_fingerrints-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])

    est_fingerprints = np.transpose(
        EGs.vars['Conv_fingerprints/weight'].eval(), axes=[3, 2, 0, 1])
    misc.save_image_grid(
        est_fingerprints,
        os.path.join(result_subdir, 'est_fingerrints-init.png'),
        drange=[np.amin(est_fingerprints),
                np.amax(est_fingerprints)],
        grid_size=[est_fingerprints.shape[0], 1])

    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        EG.setup_weight_histograms()
        D_rec.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                EG_opt.reset_optimizer_state()
                D_rec_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            tfutil.run(
                [D_rec_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.lrate,
                    minibatch_in: sched.minibatch
                })
            tfutil.run(
                [EG_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.lrate,
                    minibatch_in: sched.minibatch
                })
            tfutil.run([EGs_update_op], {})
            cur_nimg += sched.minibatch

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f resolution %-4d minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/resolution', sched.resolution),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Print accuracy.
            if cur_tick % image_snapshot_ticks == 0 or done:

                train_recs, train_fingerprints, train_logits = EGs.run(
                    train_reals,
                    minibatch_size=sched.minibatch // config.num_gpus)
                train_preds = np.argmax(train_logits, axis=1)
                train_gt = np.argmax(train_labels, axis=1)
                train_acc = np.float32(np.sum(
                    train_gt == train_preds)) / np.float32(len(train_gt))
                print('Training Accuracy = %f' % train_acc)

                val_recs, val_fingerprints, val_logits = EGs.run(
                    val_reals,
                    minibatch_size=sched.minibatch // config.num_gpus)
                val_preds = np.argmax(val_logits, axis=1)
                val_gt = np.argmax(val_labels, axis=1)
                val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32(
                    len(val_gt))
                print('Validation Accuracy = %f' % val_acc)

                misc.save_image_grid(train_recs[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'train_recs-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(train_fingerprints[::30, :, :, :],
                                     os.path.join(
                                         result_subdir,
                                         'train_fingerrints-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(val_recs[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'val_recs-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(val_fingerprints[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'val_fingerrints-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])

                est_fingerprints = np.transpose(
                    EGs.vars['Conv_fingerprints/weight'].eval(),
                    axes=[3, 2, 0, 1])
                misc.save_image_grid(est_fingerprints,
                                     os.path.join(result_subdir,
                                                  'est_fingerrints-final.png'),
                                     drange=[
                                         np.amin(est_fingerprints),
                                         np.amax(est_fingerprints)
                                     ],
                                     grid_size=[est_fingerprints.shape[0], 1])

            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (EG, D_rec, EGs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((EG, D_rec, EGs),
                  os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Пример #18
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    compute_fid_score=False,  # Compute FID during training once sched.lod=0.0 
    minimum_fid_kimg=0,  # Compute FID after 
    fid_snapshot_ticks=1,  # How often to compute FID
    fid_patience=2,  # When to end training based on FID
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    result_subdir="./"):
    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id != "None":
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            resume_pkl_name = os.path.splitext(
                os.path.basename(network_pkl))[0]
            try:
                resume_kimg = int(resume_pkl_name.split('-')[-1])
                print('** Setting resume kimg to', resume_kimg, flush=True)
            except:
                print('** Keeping resume kimg as:', resume_kimg, flush=True)
            print('Loading networks from "%s"...' % network_pkl, flush=True)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...', flush=True)
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...', flush=True)
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)

    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...', flush=True)
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...', flush=True)
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print('Training...', flush=True)
    # FID patience parameters:
    fid_list = []
    fid_steps = 0
    fid_stop = False
    fid_patience_step = 0

    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            if (compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks
                                  == 0) and (sched.lod == 0.0) and (
                                      cur_nimg >= minimum_fid_kimg * 1000):
                fid = compute_fid(Gs=Gs,
                                  minibatch_size=sched.minibatch,
                                  dataset_obj=training_set,
                                  iter_number=cur_nimg / 1000,
                                  lod=0.0,
                                  num_images=10000,
                                  printing=False)
                fid_list.append(fid)

            # Report progress without FID.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)),
                flush=True)
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save image snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

            # Save network snapshots
            if cur_tick % network_snapshot_ticks == 0 or done or (
                    compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks == 0) and (
                        cur_nimg >= minimum_fid_kimg * 1000):
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # End training based on FID patience
            if (compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks
                                  == 0) and (sched.lod == 0.0) and (
                                      cur_nimg >= minimum_fid_kimg * 1000):
                fid_patience_step += 1
                if len(fid_list) == 1:
                    fid_patience_step = 0
                    misc.save_pkl((G, D, Gs),
                                  os.path.join(result_subdir,
                                               'network-final-full-conv.pkl'))
                    print(
                        "Save network-final-full-conv for FID: %.3f at kimg %-8.1f."
                        % (fid_list[-1], cur_nimg // 1000),
                        flush=True)
                else:
                    if fid_list[-1] < np.min(fid_list[:-1]):
                        fid_patience_step = 0
                        misc.save_pkl(
                            (G, D, Gs),
                            os.path.join(result_subdir,
                                         'network-final-full-conv.pkl'))
                        print(
                            "Save network-final-full-conv for FID: %.3f at kimg %-8.1f."
                            % (fid_list[-1], cur_nimg // 1000),
                            flush=True)
                    else:
                        print("No improvement for FID: %.3f at kimg %-8.1f." %
                              (fid_list[-1], cur_nimg // 1000),
                              flush=True)
                if fid_patience_step == fid_patience:
                    fid_stop = True
                    print("Training stopped due to FID early-stopping.",
                          flush=True)
                    cur_nimg = total_kimg * 1000

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    # Save final only if FID-Stopping has not happend:
    if fid_stop == False:
        fid = compute_fid(Gs=Gs,
                          minibatch_size=sched.minibatch,
                          dataset_obj=training_set,
                          iter_number=cur_nimg / 1000,
                          lod=0.0,
                          num_images=10000,
                          printing=False)
        print("Final FID: %.3f at kimg %-8.1f." % (fid, cur_nimg // 1000),
              flush=True)
        ### save final FID to .csv file in result_parent_dir
        csv_file = os.path.join(
            os.path.dirname(os.path.dirname(result_subdir)),
            "results_full_conv.csv")
        list_to_append = [
            result_subdir.split("/")[-2] + "/" + result_subdir.split("/")[-1],
            fid
        ]
        with open(csv_file, 'a') as f_object:
            writer_object = writer(f_object)
            writer_object.writerow(list_to_append)
            f_object.close()
        misc.save_pkl((G, D, Gs),
                      os.path.join(result_subdir,
                                   'network-final-full-conv.pkl'))
        print("Save network-final-full-conv.", flush=True)
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()