def anomaly_detection_encoder(run_id,
                              test_data_folder,
                              log,
                              test_batch_size=10,
                              start_at_batch=0,
                              end_at_batch=10):

    result_subdir = misc.locate_result_subdir(run_id)
    snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=True)
    dataset_obj, mirror_augment = misc.load_dataset_for_previous_run(
        result_subdir, verbose=True, shuffle_mb=0)
    print('# snapshot_pkls: ' + str(len(snapshot_pkls)))

    with tf.Graph().as_default(), tfutil.create_session(
            config.tf_config).as_default():
        #Load network from specific run
        G, D, Gs, E = misc.load_pkl(snapshot_pkls[-1])

        # Take off the requirement of the generator having labels
        Ga = tfutil.Network('G_anomaly',
                            num_channels=Gs.output_shapes[0][1],
                            resolution=Gs.output_shapes[0][2],
                            label_size=dataset_obj.label_size,
                            **config.G_anomaly)
        Ga.copy_vars_from(Gs)

        print("Initializing Anomaly detector")
        anoGAN = tfutil.AnomalyDetectorEncoder(config,
                                               Ga,
                                               E,
                                               test_data_folder,
                                               test_batch_size=test_batch_size)
        print('# AnoGAN test data names: ' + str(len(anoGAN.test_data_names)))

        for batch in range(anoGAN.filename_batches.__len__()):
            if batch < start_at_batch:
                continue
            test_data = anoGAN.preprocess_img(anoGAN.filename_batches[batch])
            test_input = test_data
            test_name = anoGAN.filename_batches[batch]
            anoGAN.find_closest_match(test_input, test_name)
            print(f'Batch {batch} complete..')
Exemple #2
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs, E = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)

            E = tfutil.Network('E',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.E)

            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()
    E.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)
    E_opt = tfutil.Optimizer(name='TrainE',
                             learning_rate=lrate_in,
                             **config.E_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in),
                tf.assign(E_gpu.find_var('lod'), lod_in)
            ]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)
            with tf.name_scope('E_loss'), tf.control_dependencies(
                    lod_assign_ops):
                E_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=E_opt,
                    training_set=training_set,
                    reals=reals_gpu,
                    minibatch_size=minibatch_split,
                    **config.E_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
            E_opt.register_gradients(tf.reduce_mean(E_loss), E_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    E_train_op = E_opt.apply_updates()

    #sys.exit(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
                E_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op, E_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

                misc.save_all_res(training_set.shape[1],
                                  Gs,
                                  result_subdir,
                                  50,
                                  minibatch_size=sched.minibatch //
                                  config.num_gpus)

            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (G, D, Gs, E),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs, E),
                  os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Exemple #3
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)
    unlabeled_training_set = dataset.load_dataset(
        data_dir=config.unlabeled_data_dir,
        verbose=True,
        **config.unlabeled_dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            print("Training-Set Label Size: ", training_set.label_size)
            print("Unlabeled-Training-Set Label Size: ",
                  unlabeled_training_set.label_size)
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])

        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        unlabeled_reals, _ = unlabeled_training_set.get_minibatch_tf()

        reals_split = tf.split(reals, config.num_gpus)
        unlabeled_reals_split = tf.split(unlabeled_reals, config.num_gpus)

        labels_split = tf.split(labels, config.num_gpus)

    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    G_opt_pggan = tfutil.Optimizer(name='TrainG_pggan',
                                   learning_rate=lrate_in,
                                   **config.G_opt)
    D_opt_pggan = tfutil.Optimizer(name='TrainD_pggan',
                                   learning_rate=lrate_in,
                                   **config.D_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)

    print("CUDA_VISIBLE_DEVICES: ", os.environ['CUDA_VISIBLE_DEVICES'])

    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]

            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            unlabeled_reals_gpu = process_reals(
                unlabeled_reals_split[gpu], lod_in, mirror_augment,
                unlabeled_training_set.dynamic_range, drange_net)

            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    unlabeled_reals=unlabeled_reals_gpu,
                    **config.G_loss)
            with tf.name_scope('G_loss_pggan'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss_pggan = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss_pggan)

            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    unlabeled_reals=unlabeled_reals_gpu,
                    **config.D_loss)
            with tf.name_scope('D_loss_pggan'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss_pggan = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    unlabeled_reals=unlabeled_reals_gpu,
                    **config.D_loss_pggan)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            G_opt_pggan.register_gradients(tf.reduce_mean(G_loss_pggan),
                                           G_gpu.trainables)
            D_opt_pggan.register_gradients(tf.reduce_mean(D_loss_pggan),
                                           D_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
            print('GPU %d loaded!' % gpu)

    G_train_op = G_opt.apply_updates()
    G_train_op_pggan = G_opt_pggan.apply_updates()
    D_train_op_pggan = D_opt_pggan.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * TrainingSpeedInt, training_set,
                             **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.compat.v1.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print("Start Time: ",
          datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    print('Training...')
    cur_nimg = int(resume_kimg * TrainingSpeedInt)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0

    while cur_nimg < total_kimg * TrainingSpeedInt:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        sched2 = TrainingSchedule(cur_nimg, unlabeled_training_set,
                                  **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        unlabeled_training_set.configure(sched2.minibatch, sched2.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
                G_opt_pggan.reset_optimizer_state()
                D_opt_pggan.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                # Run the Pggan loss if lod != 0 else run SSL loss with feature matching
                if sched.lod == 0:
                    tfutil.run(
                        [D_train_op, Gs_update_op], {
                            lod_in: sched.lod,
                            lrate_in: sched.D_lrate,
                            minibatch_in: sched.minibatch
                        })
                else:
                    tfutil.run(
                        [D_train_op_pggan, Gs_update_op], {
                            lod_in: sched.lod,
                            lrate_in: sched.D_lrate,
                            minibatch_in: sched.minibatch
                        })
                cur_nimg += sched.minibatch
                #tmp = min(tick_start_nimg + sched.tick_kimg * TrainingSpeedInt, total_kimg * TrainingSpeedInt)
                #print("Tick progress:  {}/{}".format(cur_nimg, tmp), end="\r", flush=True)
            # Run the Pggan loss if lod != 0 else run SSL loss with feature matching
            if sched.lod == 0:
                tfutil.run(
                    [G_train_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.G_lrate,
                        minibatch_in: sched.minibatch
                    })
            else:
                tfutil.run(
                    [G_train_op_pggan], {
                        lod_in: sched.lod,
                        lrate_in: sched.G_lrate,
                        minibatch_in: sched.minibatch
                    })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * TrainingSpeedInt)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * TrainingSpeedInt or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / TrainingSpeedFloat
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f date %s'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg',
                                      cur_nimg / TrainingSpeedFloat),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary(
                       'Timing/maintenance_sec', maintenance_time),
                   datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))

            #######################
            # VALIDATION ACCURACY #
            #######################

            # example ndim = 512 for an image that is 512x512 pixels
            # All images for SSL-PGGAN must be square
            ndim = 256
            correct = 0
            guesses = 0

            dir_tuple = (config.validation_dog, config.validation_cat)
            # If guessed the wrong class seeing if there is a bias
            FP_RATE = [[0], [0]]
            # For each class
            for indx, directory in enumerate(dir_tuple):
                # Go through every image that needs to be tested
                for filename in os.listdir(directory):
                    guesses += 1
                    #tensor = np.zeros((1, 3, 512, 512))
                    print(filename)
                    img = np.asarray(PIL.Image.open(directory +
                                                    filename)).reshape(
                                                        3, ndim, ndim)
                    img = np.expand_dims(
                        img, axis=0)  # makes the image (1,3,512,512)
                    K_logits_out, fake_logit_out, features_out = test_discriminator(
                        D, img)

                    #print("K Logits Out:",K_logits_out.eval())
                    sample_probs = tf.nn.softmax(K_logits_out)
                    #print("Softmax Output:", sample_probs.eval())
                    label = np.argmax(sample_probs.eval()[0], axis=0)
                    if label == indx:
                        correct += 1
                    else:
                        FP_RATE[indx][0] += 1
                    print("-----------------------------------")
                    print("GUESSED LABEL: ", label)
                    print("CORRECT LABEL: ", indx)
                    validation = (correct / guesses)
                    print("Total Correct: ", correct, "\n", "Total Guesses: ",
                          guesses, "\n", "Percent correct: ", validation)
                    print("False Positives: Dog, Cat", FP_RATE)
                    print()

            tfutil.autosummary('Accuracy/Validation', (correct / guesses))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir, 'fakes%06d.png' %
                                         (cur_nimg // TrainingSpeedInt)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl((G, D, Gs),
                              os.path.join(
                                  result_subdir, 'network-snapshot-%06d.pkl' %
                                  (cur_nimg // TrainingSpeedInt)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Exemple #4
0
def train_detector(
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=1,  # Total length of the training, measured in thousands of real images.
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    snapshot_size=16,  # Size of the snapshot image
    snapshot_ticks=2**13,  # Number of images before maintenance
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=1,  # How often to export network snapshots?
    save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()

    # Load the datasets
    training_set = dataset.load_dataset(tfrecord=config.tfrecord_train,
                                        verbose=True,
                                        **config.dataset)
    testing_set = dataset.load_dataset(tfrecord=config.tfrecord_test,
                                       verbose=True,
                                       repeat=False,
                                       shuffle_mb=0,
                                       **config.dataset)
    testing_set_len = len(testing_set)

    # TODO: data augmentation
    # TODO: testing set

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:  # TODO: save methods
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            N = misc.load_pkl(network_pkl)
        else:
            print('Constructing the network...'
                  )  # TODO: better network (like lod-wise network)
            N = tfutil.Network('N',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               **config.N)
    N.print_layers()

    print('Building TensorFlow graph...')
    # Training set up
    with tf.name_scope('Inputs'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        # minibatch_in            = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        reals, labels, bboxes = training_set.get_minibatch_tf(
        )  # TODO: increase the size of the batch by several loss computation and mean
    N_opt = tfutil.Optimizer(name='TrainN',
                             learning_rate=lrate_in,
                             **config.N_opt)

    with tf.device('/gpu:0'):
        reals, labels, gt_outputs, gt_ref = pre_process(
            reals, labels, bboxes, training_set.dynamic_range,
            [0, training_set.shape[-2]], drange_net)
        with tf.name_scope('N_loss'):  # TODO: loss inadapted
            N_loss = tfutil.call_func_by_name(N=N,
                                              reals=reals,
                                              gt_outputs=gt_outputs,
                                              gt_ref=gt_ref,
                                              **config.N_loss)

        N_opt.register_gradients(tf.reduce_mean(N_loss), N.trainables)
    N_train_op = N_opt.apply_updates()

    # Testing set up
    with tf.device('/gpu:0'):
        test_reals_tf, test_labels_tf, test_bboxes_tf = testing_set.get_minibatch_tf(
        )
        test_reals_tf, test_labels_tf, test_gt_outputs_tf, test_gt_ref_tf = pre_process(
            test_reals_tf, test_labels_tf, test_bboxes_tf,
            testing_set.dynamic_range, [0, testing_set.shape[-2]], drange_net)
        with tf.name_scope('N_test_loss'):
            test_loss = tfutil.call_func_by_name(N=N,
                                                 reals=test_reals_tf,
                                                 gt_outputs=test_gt_outputs_tf,
                                                 gt_ref=test_gt_ref_tf,
                                                 is_training=False,
                                                 **config.N_loss)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        N.setup_weight_histograms()

    test_reals, _, test_bboxes = testing_set.get_minibatch_np(snapshot_size)
    misc.save_img_bboxes(test_reals,
                         test_bboxes,
                         os.path.join(result_subdir, 'reals.png'),
                         snapshot_size,
                         adjust_range=False)

    test_reals = misc.adjust_dynamic_range(test_reals,
                                           training_set.dynamic_range,
                                           drange_net)
    test_preds, _ = N.run(test_reals, minibatch_size=snapshot_size)
    misc.save_img_bboxes(test_reals, test_preds,
                         os.path.join(result_subdir, 'fakes.png'),
                         snapshot_size)

    print('Training...')
    if resume_run_id is None:
        tfutil.run(tf.global_variables_initializer())

    cur_nimg = 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time

    # Choose training parameters and configure training ops.
    sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
    training_set.configure(sched.minibatch)

    _train_loss = 0

    while cur_nimg < total_kimg * 1000:

        # Run training ops.
        # for _ in range(minibatch_repeats):
        _, loss = tfutil.run([N_train_op, N_loss], {lrate_in: sched.N_lrate})
        _train_loss += loss
        cur_nimg += sched.minibatch

        # Perform maintenance tasks once per tick.
        if (cur_nimg >= total_kimg * 1000) or (cur_nimg % snapshot_ticks == 0
                                               and cur_nimg > 0):

            cur_tick += 1
            cur_time = time.time()
            # tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            _train_loss = _train_loss / (cur_nimg - tick_start_nimg)
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            testing_set.configure(sched.minibatch)
            _test_loss = 0
            # testing_set_len = 1 # TMP
            for _ in range(0, testing_set_len, sched.minibatch):
                _test_loss += tfutil.run(test_loss)
            _test_loss /= testing_set_len

            # Report progress. # TODO: improved report display
            print(
                'tick %-5d kimg %-6.1f time %-10s sec/tick %-3.1f maintenance %-7.2f train_loss %.4f test_loss %.4f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/maintenance', maintenance_time),
                   tfutil.autosummary('TrainN/train_loss', _train_loss),
                   tfutil.autosummary('TrainN/test_loss', _test_loss)))

            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            if cur_tick % image_snapshot_ticks == 0:
                test_bboxes, test_refs = N.run(test_reals,
                                               minibatch_size=snapshot_size)
                misc.save_img_bboxes_ref(
                    test_reals, test_bboxes, test_refs,
                    os.path.join(result_subdir,
                                 'fakes%06d.png' % (cur_nimg // 1000)),
                    snapshot_size)
            if cur_tick % network_snapshot_ticks == 0:
                misc.save_pkl(
                    N,
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            _train_loss = 0

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    # misc.save_pkl(N, os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Exemple #5
0
def train_classifier(
    smoothing=0.999,  # Exponential running average of encoder weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    lr_mirror_augment=True,  # Enable mirror augment?
    ud_mirror_augment=False,  # Enable up-down mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=10,  # How often to export image snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False
):  # Include weight histograms in the tfevents file?

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.training_set)
    validation_set = dataset.load_dataset(data_dir=config.data_dir,
                                          verbose=True,
                                          **config.validation_set)
    network_snapshot_ticks = total_kimg // 100  # How often to export network snapshots?

    # Construct networks.
    with tf.device('/gpu:0'):
        try:
            network_pkl = misc.locate_network_pkl()
            resume_kimg, resume_time = misc.resume_kimg_time(network_pkl)
            print('Loading networks from "%s"...' % network_pkl)
            EG, D_rec, EGs = misc.load_pkl(network_pkl)
        except:
            print('Constructing networks...')
            resume_kimg = 0.0
            resume_time = 0.0
            EG = tfutil.Network('EG',
                                num_channels=training_set.shape[0],
                                resolution=training_set.shape[1],
                                label_size=training_set.label_size,
                                **config.EG)
            D_rec = tfutil.Network('D_rec',
                                   num_channels=training_set.shape[0],
                                   resolution=training_set.shape[1],
                                   **config.D_rec)
            EGs = EG.clone('EGs')
        EGs_update_op = EGs.setup_as_moving_average_of(EG, beta=smoothing)
    EG.print_layers()
    D_rec.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    EG_opt = tfutil.Optimizer(name='TrainEG',
                              learning_rate=lrate_in,
                              **config.EG_opt)
    D_rec_opt = tfutil.Optimizer(name='TrainD_rec',
                                 learning_rate=lrate_in,
                                 **config.D_rec_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            EG_gpu = EG if gpu == 0 else EG.clone(EG.name + '_shadow_%d' % gpu)
            D_rec_gpu = D_rec if gpu == 0 else D_rec.clone(D_rec.name +
                                                           '_shadow_%d' % gpu)
            reals_fade_gpu, reals_orig_gpu = process_reals(
                reals_split[gpu], lod_in, lr_mirror_augment, ud_mirror_augment,
                training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('EG_loss'):
                EG_loss = tfutil.call_func_by_name(EG=EG_gpu,
                                                   D_rec=D_rec_gpu,
                                                   reals_orig=reals_orig_gpu,
                                                   labels=labels_gpu,
                                                   **config.EG_loss)
            with tf.name_scope('D_rec_loss'):
                D_rec_loss = tfutil.call_func_by_name(
                    EG=EG_gpu,
                    D_rec=D_rec_gpu,
                    D_rec_opt=D_rec_opt,
                    minibatch_size=minibatch_split,
                    reals_orig=reals_orig_gpu,
                    **config.D_rec_loss)
            EG_opt.register_gradients(tf.reduce_mean(EG_loss),
                                      EG_gpu.trainables)
            D_rec_opt.register_gradients(tf.reduce_mean(D_rec_loss),
                                         D_rec_gpu.trainables)
    EG_train_op = EG_opt.apply_updates()
    D_rec_train_op = D_rec_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, train_reals, train_labels = setup_snapshot_image_grid(
        training_set, drange_net, [450, 10], **config.grid)
    grid_size, val_reals, val_labels = setup_snapshot_image_grid(
        validation_set, drange_net, [450, 10], **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)

    train_recs, train_fingerprints, train_logits = EGs.run(
        train_reals, minibatch_size=sched.minibatch // config.num_gpus)
    train_preds = np.argmax(train_logits, axis=1)
    train_gt = np.argmax(train_labels, axis=1)
    train_acc = np.float32(np.sum(train_gt == train_preds)) / np.float32(
        len(train_gt))
    print('Training Accuracy = %f' % train_acc)

    val_recs, val_fingerprints, val_logits = EGs.run(
        val_reals, minibatch_size=sched.minibatch // config.num_gpus)
    val_preds = np.argmax(val_logits, axis=1)
    val_gt = np.argmax(val_labels, axis=1)
    val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32(len(val_gt))
    print('Validation Accuracy = %f' % val_acc)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(train_reals[::30, :, :, :],
                         os.path.join(result_subdir, 'train_reals.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(train_recs[::30, :, :, :],
                         os.path.join(result_subdir, 'train_recs-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(train_fingerprints[::30, :, :, :],
                         os.path.join(result_subdir,
                                      'train_fingerrints-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_reals[::30, :, :, :],
                         os.path.join(result_subdir, 'val_reals.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_recs[::30, :, :, :],
                         os.path.join(result_subdir, 'val_recs-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_fingerprints[::30, :, :, :],
                         os.path.join(result_subdir,
                                      'val_fingerrints-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])

    est_fingerprints = np.transpose(
        EGs.vars['Conv_fingerprints/weight'].eval(), axes=[3, 2, 0, 1])
    misc.save_image_grid(
        est_fingerprints,
        os.path.join(result_subdir, 'est_fingerrints-init.png'),
        drange=[np.amin(est_fingerprints),
                np.amax(est_fingerprints)],
        grid_size=[est_fingerprints.shape[0], 1])

    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        EG.setup_weight_histograms()
        D_rec.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                EG_opt.reset_optimizer_state()
                D_rec_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            tfutil.run(
                [D_rec_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.lrate,
                    minibatch_in: sched.minibatch
                })
            tfutil.run(
                [EG_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.lrate,
                    minibatch_in: sched.minibatch
                })
            tfutil.run([EGs_update_op], {})
            cur_nimg += sched.minibatch

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f resolution %-4d minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/resolution', sched.resolution),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Print accuracy.
            if cur_tick % image_snapshot_ticks == 0 or done:

                train_recs, train_fingerprints, train_logits = EGs.run(
                    train_reals,
                    minibatch_size=sched.minibatch // config.num_gpus)
                train_preds = np.argmax(train_logits, axis=1)
                train_gt = np.argmax(train_labels, axis=1)
                train_acc = np.float32(np.sum(
                    train_gt == train_preds)) / np.float32(len(train_gt))
                print('Training Accuracy = %f' % train_acc)

                val_recs, val_fingerprints, val_logits = EGs.run(
                    val_reals,
                    minibatch_size=sched.minibatch // config.num_gpus)
                val_preds = np.argmax(val_logits, axis=1)
                val_gt = np.argmax(val_labels, axis=1)
                val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32(
                    len(val_gt))
                print('Validation Accuracy = %f' % val_acc)

                misc.save_image_grid(train_recs[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'train_recs-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(train_fingerprints[::30, :, :, :],
                                     os.path.join(
                                         result_subdir,
                                         'train_fingerrints-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(val_recs[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'val_recs-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(val_fingerprints[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'val_fingerrints-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])

                est_fingerprints = np.transpose(
                    EGs.vars['Conv_fingerprints/weight'].eval(),
                    axes=[3, 2, 0, 1])
                misc.save_image_grid(est_fingerprints,
                                     os.path.join(result_subdir,
                                                  'est_fingerrints-final.png'),
                                     drange=[
                                         np.amin(est_fingerprints),
                                         np.amax(est_fingerprints)
                                     ],
                                     grid_size=[est_fingerprints.shape[0], 1])

            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (EG, D_rec, EGs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((EG, D_rec, EGs),
                  os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Exemple #6
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    compute_fid_score=False,  # Compute FID during training once sched.lod=0.0 
    minimum_fid_kimg=0,  # Compute FID after 
    fid_snapshot_ticks=1,  # How often to compute FID
    fid_patience=2,  # When to end training based on FID
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    result_subdir="./"):
    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id != "None":
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            resume_pkl_name = os.path.splitext(
                os.path.basename(network_pkl))[0]
            try:
                resume_kimg = int(resume_pkl_name.split('-')[-1])
                print('** Setting resume kimg to', resume_kimg, flush=True)
            except:
                print('** Keeping resume kimg as:', resume_kimg, flush=True)
            print('Loading networks from "%s"...' % network_pkl, flush=True)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...', flush=True)
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...', flush=True)
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)

    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...', flush=True)
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...', flush=True)
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print('Training...', flush=True)
    # FID patience parameters:
    fid_list = []
    fid_steps = 0
    fid_stop = False
    fid_patience_step = 0

    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            if (compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks
                                  == 0) and (sched.lod == 0.0) and (
                                      cur_nimg >= minimum_fid_kimg * 1000):
                fid = compute_fid(Gs=Gs,
                                  minibatch_size=sched.minibatch,
                                  dataset_obj=training_set,
                                  iter_number=cur_nimg / 1000,
                                  lod=0.0,
                                  num_images=10000,
                                  printing=False)
                fid_list.append(fid)

            # Report progress without FID.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)),
                flush=True)
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save image snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

            # Save network snapshots
            if cur_tick % network_snapshot_ticks == 0 or done or (
                    compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks == 0) and (
                        cur_nimg >= minimum_fid_kimg * 1000):
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # End training based on FID patience
            if (compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks
                                  == 0) and (sched.lod == 0.0) and (
                                      cur_nimg >= minimum_fid_kimg * 1000):
                fid_patience_step += 1
                if len(fid_list) == 1:
                    fid_patience_step = 0
                    misc.save_pkl((G, D, Gs),
                                  os.path.join(result_subdir,
                                               'network-final-full-conv.pkl'))
                    print(
                        "Save network-final-full-conv for FID: %.3f at kimg %-8.1f."
                        % (fid_list[-1], cur_nimg // 1000),
                        flush=True)
                else:
                    if fid_list[-1] < np.min(fid_list[:-1]):
                        fid_patience_step = 0
                        misc.save_pkl(
                            (G, D, Gs),
                            os.path.join(result_subdir,
                                         'network-final-full-conv.pkl'))
                        print(
                            "Save network-final-full-conv for FID: %.3f at kimg %-8.1f."
                            % (fid_list[-1], cur_nimg // 1000),
                            flush=True)
                    else:
                        print("No improvement for FID: %.3f at kimg %-8.1f." %
                              (fid_list[-1], cur_nimg // 1000),
                              flush=True)
                if fid_patience_step == fid_patience:
                    fid_stop = True
                    print("Training stopped due to FID early-stopping.",
                          flush=True)
                    cur_nimg = total_kimg * 1000

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    # Save final only if FID-Stopping has not happend:
    if fid_stop == False:
        fid = compute_fid(Gs=Gs,
                          minibatch_size=sched.minibatch,
                          dataset_obj=training_set,
                          iter_number=cur_nimg / 1000,
                          lod=0.0,
                          num_images=10000,
                          printing=False)
        print("Final FID: %.3f at kimg %-8.1f." % (fid, cur_nimg // 1000),
              flush=True)
        ### save final FID to .csv file in result_parent_dir
        csv_file = os.path.join(
            os.path.dirname(os.path.dirname(result_subdir)),
            "results_full_conv.csv")
        list_to_append = [
            result_subdir.split("/")[-2] + "/" + result_subdir.split("/")[-1],
            fid
        ]
        with open(csv_file, 'a') as f_object:
            writer_object = writer(f_object)
            writer_object.writerow(list_to_append)
            f_object.close()
        misc.save_pkl((G, D, Gs),
                      os.path.join(result_subdir,
                                   'network-final-full-conv.pkl'))
        print("Save network-final-full-conv.", flush=True)
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Exemple #7
0
def recovery(name,
             pkl_path1,
             pkl_path2,
             out_dir,
             target_latents_dir,
             num_init=20,
             num_total_sample=100,
             image_shrink=1,
             random_seed=2020,
             minibatch_size=1,
             noise_sigma=0):
    #     misc.init_output_logging()
    #     np.random.seed(random_seed)
    #     print('Initializing TensorFlow...')
    #     os.environ.update(config.env)
    #     tfutil.init_tf(config.tf_config)

    print('num_init:' + str(num_init))

    # load sorce model
    print('Loading network1...' + pkl_path1)
    _, _, G_sorce = misc.load_network_pkl(pkl_path1)

    # load target model
    print('Loading  network2...' + pkl_path2)
    _, _, G_target = misc.load_network_pkl(pkl_path2)

    # load Gt
    Gt = tfutil.Network('Gt',
                        num_samples=num_init,
                        num_channels=3,
                        resolution=128,
                        func='networks.G_recovery')
    latents = misc.random_latents(num_init, Gt, random_state=None)
    labels = np.zeros([latents.shape[0], 0], np.float32)
    Gt.copy_vars_from_with_input(G_target, latents)

    # load Gs
    Gs = tfutil.Network('Gs',
                        num_samples=num_init,
                        num_channels=3,
                        resolution=128,
                        func='networks.G_recovery')
    Gs.copy_vars_from_with_input(G_sorce, latents)

    out_dir = os.path.join(out_dir, name)
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    def G_loss(G, target_images):
        tmp_latents = tfutil.run(G.trainables['Input/weight'])
        G_out = G.get_output_for(tmp_latents, labels, is_training=True)
        G_out = rescale_output(G_out)
        return tf.losses.mean_squared_error(target_images, G_out)

    z_init = []
    z_recovered = []

    #load target z
    if target_latents_dir is not None:
        print('using latents:' + target_latents_dir)
        pre_latents = np.load(target_latents_dir)

    for k in range(num_total_sample):
        result_dir = os.path.join(out_dir, str(k) + '.png')

        #============sample target image
        if target_latents_dir is not None:
            latent = pre_latents[k]
        else:
            latents = misc.random_latents(1, Gs, random_state=None)
            latent = latents[0]
        z_init.append(latent)

        latents = np.zeros((num_init, 512))
        for i in range(num_init):
            latents[i] = latent
        Gt.change_input(inputs=latents)

        #================add_noise
        target_images = Gt.get_output_for(latents, labels, is_training=False)
        target_images_tf = rescale_output(target_images)
        target_images = tfutil.run(target_images_tf)

        target_images_noise = addGaussianNoise(target_images,
                                               sigma=noise_sigma)
        target_images_noise = tf.cast(target_images_noise, dtype='float32')
        target_images = target_images_noise

        #=============select random start point
        latents_2 = misc.random_latents(num_init, Gs, random_state=None)
        Gs.change_input(inputs=latents_2)

        #==============define loss&optimizer
        regularizer = tf.abs(tf.norm(latents_2) - np.sqrt(512))
        loss = G_loss(G=Gs, target_images=target_images)  # + regularizer
        # init_var = OrderedDict([('Input/weight',Gs.trainables['Input/weight'])])
        # decayed_lr = tf.train.exponential_decay(0.1,500, 50, 0.5, staircase=True)
        G_opt = tfutil.Optimizer(name='latent_recovery', learning_rate=0.01)
        G_opt.register_gradients(loss, Gs.trainables)
        G_train_op = G_opt.apply_updates()

        #===========recovery==========
        EPOCH = 500
        losses = []
        losses.append(tfutil.run(loss))
        for i in range(EPOCH):
            G_opt.reset_optimizer_state()
            tfutil.run([G_train_op])

        ########
        learned_latent = tfutil.run(Gs.trainables['Input/weight'])
        result_images = Gs.run(learned_latent,
                               labels,
                               minibatch_size=config.num_gpus * 256,
                               num_gpus=config.num_gpus,
                               out_mul=127.5,
                               out_add=127.5,
                               out_shrink=image_shrink,
                               out_dtype=np.float32)

        sample_losses = []
        tmp_latents = tfutil.run(Gs.trainables['Input/weight'])
        G_out = Gs.get_output_for(tmp_latents, labels, is_training=True)
        G_out = rescale_output(G_out)
        for i in range(num_init):
            loss = tf.losses.mean_squared_error(target_images[i], G_out[i])
            sample_losses.append(tfutil.run(loss))

        #========save best optimized image
        plt.subplot(1, 2, 1)
        plt.imshow(tfutil.run(target_images)[0].transpose(1, 2, 0) / 255.0)
        plt.subplot(1, 2, 2)
        plt.imshow(result_images[np.argmin(sample_losses)].transpose(1, 2, 0) /
                   255.0)
        plt.savefig(result_dir)

        #========store optimized z
        z_recovered.append(tmp_latents)

        #=========save losses
        #         loss=min(sample_losses)

        with open(out_dir + "/losses.txt", "a") as f:
            for loss in sample_losses:
                f.write(str(loss) + ' ')
            f.write('\n')
        np.save(out_dir + '/z_init', np.array(z_init))
        np.save(out_dir + '/z_re', np.array(z_recovered))
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)
    #resume_run_id = '/dresden/users/mk1391/evl/pggan_logs/logs_celeba128cc/fsg16_results_0/000-pgan-celeba-preset-v2-2gpus-fp32/network-snapshot-010211.pkl'
    resume_with_new_nets = False
    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            rG, rD, rGs = misc.load_pkl(network_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                Gs = rGs
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    ### pyramid draw fsg (comment out for actual training to happen)
    #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'pggan_fsg_draw.png'))
    #print('>>> done printing fsgs.')
    #return

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    ### shift reals
    print('>>> reals shape: ', grid_reals.shape)
    fc_x = 0.5
    fc_y = 0.5
    im_size = grid_reals.shape[-1]
    kernel_loc = 2.*np.pi*fc_x * np.arange(im_size).reshape((1, 1, im_size)) + \
        2.*np.pi*fc_y * np.arange(im_size).reshape((1, im_size, 1))
    kernel_cos = np.cos(kernel_loc)
    kernel_sin = np.sin(kernel_loc)
    reals_t = (grid_reals / 255.) * 2. - 1
    reals_t *= kernel_cos
    grid_reals_sh = np.rint(
        (reals_t + 1.) * 255. / 2.).clip(0, 255).astype(np.uint8)
    ### end shift reals
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    ### fft drawing
    #sys.path.insert(1, '/home/mahyar/CV_Res/ganist')
    #from fig_draw import apply_fft_win
    #data_size = 1000
    #latents = np.random.randn(data_size, *Gs.input_shapes[0][1:])
    #labels = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:])
    #g_samples = Gs.run(latents, labels, minibatch_size=sched.minibatch//config.num_gpus)
    #g_samples = g_samples.transpose(0, 2, 3, 1)
    #print('>>> g_samples shape: {}'.format(g_samples.shape))
    #apply_fft_win(g_samples, 'fft_pggan_hann.png')
    ### end fft drawing

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    ### drawing shifted real images
    misc.save_image_grid(grid_reals_sh,
                         os.path.join(result_subdir, 'reals_sh.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    ### drawing shifted fake images
    misc.save_image_grid(grid_fakes * kernel_cos,
                         os.path.join(result_subdir, 'fakes%06d_sh.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    #### True cosine fft eval
    #fft_data_size = 1000
    #im_size = training_set.shape[1]
    #freq_centers = [(64/128., 64/128.)]
    #true_samples = sample_true(training_set, fft_data_size, dtype=training_set.dtype, batch_size=32).transpose(0, 2, 3, 1) / 255. * 2. - 1.
    #true_fft, true_fft_hann, true_hist = cosine_eval(true_samples, 'true', freq_centers, log_dir=result_subdir)
    #fractal_eval(true_samples, f'koch_snowflake_true', result_subdir)

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                ### drawing shifted fake images
                misc.save_image_grid(
                    grid_fakes * kernel_cos,
                    os.path.join(result_subdir,
                                 'fakes%06d_sh.png' % (cur_nimg // 1000)),
                    drange=drange_net,
                    grid_size=grid_size)
                ### drawing fsg
                #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'fakes%06d_fsg_draw.png' % (cur_nimg // 1000)))
                ### Gen fft eval
                #gen_samples = sample_gen(Gs, fft_data_size).transpose(0, 2, 3, 1)
                #print(f'>>> fake_samples: max={np.amax(grid_fakes)} min={np.amin(grid_fakes)}')
                #print(f'>>> gen_samples: max={np.amax(gen_samples)} min={np.amin(gen_samples)}')
                #misc.save_image_grid(gen_samples[:25], os.path.join(result_subdir, 'fakes%06d_gsample.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
                #cosine_eval(gen_samples, f'gen_{cur_nimg//1000:06d}', freq_centers, log_dir=result_subdir, true_fft=true_fft, true_fft_hann=true_fft_hann, true_hist=true_hist)
                #fractal_eval(gen_samples, f'koch_snowflake_fakes{cur_nimg//1000:06d}', result_subdir)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
def anomaly_detection_encoder(run_id,
                              log,
                              test_data_folder,
                              test_batch_size=64,
                              n_samples=1000):

    result_subdir = misc.locate_result_subdir(run_id)
    snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False)
    print('# snapshot_pkls: ' + str(len(snapshot_pkls)))

    for idx in range(0, n_samples, test_batch_size):

        with tf.Graph().as_default(), tfutil.create_session(
                config.tf_config).as_default():
            #Load network from specific run
            G, D, Gs, E = misc.load_pkl(snapshot_pkls[-1])
            print(snapshot_pkls[-1])

            dataset_obj, mirror_augment = misc.load_dataset_for_previous_run(
                result_subdir, verbose=True, shuffle_mb=0)

            Ga = tfutil.Network('G_anomaly',
                                num_channels=G.output_shapes[0][1],
                                resolution=G.output_shapes[0][2],
                                label_size=dataset_obj.label_size,
                                **config.G_anomaly)
            Ga.copy_vars_from(Gs)

            Da_Gout = tfutil.Network('D_anomaly_Gout',
                                     num_channels=G.output_shapes[0][1],
                                     resolution=G.output_shapes[0][2],
                                     label_size=dataset_obj.label_size,
                                     images_in=Ga.output_templates[0],
                                     **config.D_anomaly_Gout)
            image_dims = [
                G.output_shapes[0][1], G.output_shapes[0][2],
                G.output_shapes[0][3]
            ]
            Da_test = tfutil.Network('D_anomaly_test',
                                     num_channels=G.output_shapes[0][1],
                                     resolution=G.output_shapes[0][2],
                                     label_size=dataset_obj.label_size,
                                     **config.D_anomaly_test)

            Da_Gout.copy_vars_from(D)
            Da_test.copy_vars_from(D)

            Da_Gout.print_layers()
            Da_test.print_layers()
            E.print_layers()

            print("Initializing Anomaly detector")
            anoGAN = tfutil.AnomalyDetectorEncoder(config, Ga, Da_Gout,
                                                   Da_test, E,
                                                   test_data_folder)
            print('# AnoGAN test data names: ' +
                  str(len(anoGAN.test_data_names)))
            assert len(anoGAN.test_data_names) > 0

            test_input = anoGAN.test_data[idx:idx + test_batch_size]
            test_name = anoGAN.test_data_names[idx:idx + test_batch_size]
            anoGAN.find_closest_match(test_input, test_name)
    tf.reset_default_graph()