Beispiel #1
0
def test_best(main_dict, reset=None):
    _, val_set = load_trainval(main_dict)

    history = ms.load_history(main_dict)

    # if reset == "reset":    
    try:
        pred_annList = ms.load_best_annList(main_dict)
    except:
        model = ms.load_best_model(main_dict)
        pred_annList = dataset2annList(model, val_set, 
                 predict_method="BestDice", 
                 n_val=None)
        ms.save_pkl(main_dict["path_best_annList"], pred_annList)
    # else:
        # pred_annList = ms.load_best_annList(main_dict)

    gt_annDict = load_gtAnnDict(main_dict)
    results = get_perCategoryResults(gt_annDict, pred_annList)
    
    result_dict = results["result_dict"]
    # result_dict[] = 
    # result_dict[] = 
    result_dict["Model"] = main_dict["model_name"]
    result_dict["epoch"] = history["best_model"]["epoch"]
    result_list = test_baselines(main_dict)
    result_list += [result_dict]

    print(pd.DataFrame(result_list))
Beispiel #2
0
def merge(gather_dir=os.getcwd(),
          combine_type='run',
          base_name='',
          out_dir='',
          **params):
    """
    Merges the epoched*.pkl objects in gather_dir
    Arguments:
        gather_dir: path to folder containing epoched files
        combine_type: 'run' to combine averages of each run
                      'trial' to combine every trial
    Outputs:
        merged.pkl, merged.mat containing the merged Epoched object
        A plot is also generated

    """
    files = [
        f for f in os.listdir(gather_dir)
        if ('mat' or 'pkl' in f) and 'epoched' in f
    ]
    merged = None
    for f in files:
        run = read_epoch(os.path.join(gather_dir, f))

        if merged is None:
            merged = Epoched(run.n_categs, run.n_samples, 0)
            merged.names = run.names
            merged.num_trials = [0 for i in range(len(run.num_trials))]
            merged.num_rejected = [0 for i in range(len(run.num_rejected))]

        if combine_type == 'run':
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                avg = np.nanmean(run.matrix, axis=2, keepdims=True)
            merged.matrix = np.concatenate((merged.matrix, avg), axis=2)
        elif combine_type == 'trial':
            merged.matrix = np.concatenate((merged.matrix, run.matrix), axis=2)

        if len(run.num_trials) == len(merged.num_trials):
            merged.num_trials = [
                x + y for x, y in zip(run.num_trials, merged.num_trials)
            ]
        if len(run.num_rejected) == len(merged.num_rejected):
            merged.num_rejected = [
                x + y for x, y in zip(run.num_rejected, merged.num_rejected)
            ]
    spio.savemat(
        make_path('merged', '.mat', out_dir=out_dir, base_name=base_name),
        {'merged': merged})
    save_pkl(make_path('merged', '.pkl', out_dir=out_dir, base_name=base_name),
             merged)
    plot_conds(merged, out_dir=out_dir, base_name=base_name, **params)
Beispiel #3
0
def test_upperbound(main_dict, reset=None):
    # pointDict = load_LCFCNPoints(main_dict)
    fname = main_dict["path_baselines"].replace("baselines","upperbound")

    if reset == "reset":
        pred_annList = load_predAnnList(main_dict, reset=reset)
        gt_annDict = load_gtAnnDict(main_dict)

        results = get_perSizeResults(gt_annDict, pred_annList)

        result_dict = results["result_dict"]

        result_dict["Model"] = "UpperBound"
        result_list = [result_dict]
        ms.save_pkl(fname, result_list)
    else:
        result_list = ms.load_pkl(fname)

    return result_list
Beispiel #4
0
def test_baselines(main_dict, reset=None):
    #### Best Objectness
    if os.path.exists(main_dict["path_baselines"]) and reset!="reset":
        result_list = ms.load_pkl(main_dict["path_baselines"])
        return result_list

    else:
        gt_annDict = load_gtAnnDict(main_dict)
        pred_annList = load_predAnnList(main_dict,
                                        predict_method="BestObjectness",
                                        reset=reset)

        # idList1 = get_image_ids(pred_annList)
        # idList2 = get_image_ids(gt_annDict["annotations"])

        # results = get_perCategoryResults(gt_annDict, pred_annList)
        results = get_perSizeResults(gt_annDict, pred_annList)

        result_dict = results["result_dict"]

        result_dict["Model"] = "BestObjectness"
        result_list = [result_dict]

        #### Upper bound

        pred_annList = load_predAnnList(main_dict, predict_method="UpperBound", 
                                        reset=reset)
        # results = get_perCategoryResults(gt_annDict, pred_annList)
        results = get_perSizeResults(gt_annDict, pred_annList)


        result_dict = results["result_dict"]

        result_dict["Model"] = "UpperBound"
        result_list += [result_dict]
        ms.save_pkl(main_dict["path_baselines"], result_list)

    print(pd.DataFrame(result_list))
    return result_list
Beispiel #5
0
def validation_phase_mAP(history, main_dict, model, val_set, predict_name, epoch):
    val_dict, pred_annList = au.validate(model, val_set, 
                predict_method=predict_name, 
                n_val=len(val_set), return_annList=True)
  
    val_dict["predict_name"] = predict_name
    val_dict["epoch"] = epoch
    val_dict["time"] = datetime.datetime.now().strftime("%b %d, 20%y")

    # Update history
    history["val"] += [val_dict]

    # Higher is better
    if (history["best_model"] == {} or 
        history["best_model"]["0.5"] <= val_dict["0.5"]):

      history["best_model"] = val_dict
      ms.save_best_model(main_dict, model)
      
      ms.save_pkl(main_dict["path_best_annList"], pred_annList)
      ms.copy_code_best(main_dict)

    return history
Beispiel #6
0
    def __init__(self, batch):
        path = "/mnt/datasets/public/issam/VOCdevkit/proposals/MCG_2012/"
        fname = path + "{}.mat".format(batch["name"][0])
        fname_pkl = fname.replace(".mat", ".pkl")

        if not os.path.exists(fname_pkl):
            self.proposals = ms.loadmat(fname)

            self.n_annList = self.proposals["scores"].shape[0]
            self.superpixel = self.proposals["superpixels"]
            self.min_score = abs(np.min(self.proposals["scores"]))
            self.max_score = np.max(self.proposals["scores"] + self.min_score)
            annList = []
            for i in range(len(self)):
                print(i, "/", len(self))
                prp = self.proposals["labels"][i][0].ravel()
                proposal_mask = np.zeros(self.superpixel.shape, int)
                proposal_mask[np.isin(self.superpixel, prp)] = 1

                score = self.proposals["scores"][i][0] + self.min_score
                score /= self.max_score

                ann = au.mask2ann(proposal_mask,
                                  category_id=1,
                                  image_id=batch["name"][0],
                                  height=self.superpixel.shape[0],
                                  width=self.superpixel.shape[1],
                                  maskVoid=None,
                                  score=score)

                annList += [ann]

            ms.save_pkl(fname_pkl, annList)

        self.annList = ms.load_pkl(fname_pkl)
        self.n_annList = len(self.annList)
Beispiel #7
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)
    unlabeled_training_set = dataset.load_dataset(
        data_dir=config.unlabeled_data_dir,
        verbose=True,
        **config.unlabeled_dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            print("Training-Set Label Size: ", training_set.label_size)
            print("Unlabeled-Training-Set Label Size: ",
                  unlabeled_training_set.label_size)
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])

        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        unlabeled_reals, _ = unlabeled_training_set.get_minibatch_tf()

        reals_split = tf.split(reals, config.num_gpus)
        unlabeled_reals_split = tf.split(unlabeled_reals, config.num_gpus)

        labels_split = tf.split(labels, config.num_gpus)

    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    G_opt_pggan = tfutil.Optimizer(name='TrainG_pggan',
                                   learning_rate=lrate_in,
                                   **config.G_opt)
    D_opt_pggan = tfutil.Optimizer(name='TrainD_pggan',
                                   learning_rate=lrate_in,
                                   **config.D_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)

    print("CUDA_VISIBLE_DEVICES: ", os.environ['CUDA_VISIBLE_DEVICES'])

    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]

            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            unlabeled_reals_gpu = process_reals(
                unlabeled_reals_split[gpu], lod_in, mirror_augment,
                unlabeled_training_set.dynamic_range, drange_net)

            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    unlabeled_reals=unlabeled_reals_gpu,
                    **config.G_loss)
            with tf.name_scope('G_loss_pggan'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss_pggan = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss_pggan)

            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    unlabeled_reals=unlabeled_reals_gpu,
                    **config.D_loss)
            with tf.name_scope('D_loss_pggan'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss_pggan = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    unlabeled_reals=unlabeled_reals_gpu,
                    **config.D_loss_pggan)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            G_opt_pggan.register_gradients(tf.reduce_mean(G_loss_pggan),
                                           G_gpu.trainables)
            D_opt_pggan.register_gradients(tf.reduce_mean(D_loss_pggan),
                                           D_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
            print('GPU %d loaded!' % gpu)

    G_train_op = G_opt.apply_updates()
    G_train_op_pggan = G_opt_pggan.apply_updates()
    D_train_op_pggan = D_opt_pggan.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * TrainingSpeedInt, training_set,
                             **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.compat.v1.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print("Start Time: ",
          datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    print('Training...')
    cur_nimg = int(resume_kimg * TrainingSpeedInt)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0

    while cur_nimg < total_kimg * TrainingSpeedInt:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        sched2 = TrainingSchedule(cur_nimg, unlabeled_training_set,
                                  **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        unlabeled_training_set.configure(sched2.minibatch, sched2.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
                G_opt_pggan.reset_optimizer_state()
                D_opt_pggan.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                # Run the Pggan loss if lod != 0 else run SSL loss with feature matching
                if sched.lod == 0:
                    tfutil.run(
                        [D_train_op, Gs_update_op], {
                            lod_in: sched.lod,
                            lrate_in: sched.D_lrate,
                            minibatch_in: sched.minibatch
                        })
                else:
                    tfutil.run(
                        [D_train_op_pggan, Gs_update_op], {
                            lod_in: sched.lod,
                            lrate_in: sched.D_lrate,
                            minibatch_in: sched.minibatch
                        })
                cur_nimg += sched.minibatch
                #tmp = min(tick_start_nimg + sched.tick_kimg * TrainingSpeedInt, total_kimg * TrainingSpeedInt)
                #print("Tick progress:  {}/{}".format(cur_nimg, tmp), end="\r", flush=True)
            # Run the Pggan loss if lod != 0 else run SSL loss with feature matching
            if sched.lod == 0:
                tfutil.run(
                    [G_train_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.G_lrate,
                        minibatch_in: sched.minibatch
                    })
            else:
                tfutil.run(
                    [G_train_op_pggan], {
                        lod_in: sched.lod,
                        lrate_in: sched.G_lrate,
                        minibatch_in: sched.minibatch
                    })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * TrainingSpeedInt)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * TrainingSpeedInt or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / TrainingSpeedFloat
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f date %s'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg',
                                      cur_nimg / TrainingSpeedFloat),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary(
                       'Timing/maintenance_sec', maintenance_time),
                   datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))

            #######################
            # VALIDATION ACCURACY #
            #######################

            # example ndim = 512 for an image that is 512x512 pixels
            # All images for SSL-PGGAN must be square
            ndim = 256
            correct = 0
            guesses = 0

            dir_tuple = (config.validation_dog, config.validation_cat)
            # If guessed the wrong class seeing if there is a bias
            FP_RATE = [[0], [0]]
            # For each class
            for indx, directory in enumerate(dir_tuple):
                # Go through every image that needs to be tested
                for filename in os.listdir(directory):
                    guesses += 1
                    #tensor = np.zeros((1, 3, 512, 512))
                    print(filename)
                    img = np.asarray(PIL.Image.open(directory +
                                                    filename)).reshape(
                                                        3, ndim, ndim)
                    img = np.expand_dims(
                        img, axis=0)  # makes the image (1,3,512,512)
                    K_logits_out, fake_logit_out, features_out = test_discriminator(
                        D, img)

                    #print("K Logits Out:",K_logits_out.eval())
                    sample_probs = tf.nn.softmax(K_logits_out)
                    #print("Softmax Output:", sample_probs.eval())
                    label = np.argmax(sample_probs.eval()[0], axis=0)
                    if label == indx:
                        correct += 1
                    else:
                        FP_RATE[indx][0] += 1
                    print("-----------------------------------")
                    print("GUESSED LABEL: ", label)
                    print("CORRECT LABEL: ", indx)
                    validation = (correct / guesses)
                    print("Total Correct: ", correct, "\n", "Total Guesses: ",
                          guesses, "\n", "Percent correct: ", validation)
                    print("False Positives: Dog, Cat", FP_RATE)
                    print()

            tfutil.autosummary('Accuracy/Validation', (correct / guesses))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir, 'fakes%06d.png' %
                                         (cur_nimg // TrainingSpeedInt)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl((G, D, Gs),
                              os.path.join(
                                  result_subdir, 'network-snapshot-%06d.pkl' %
                                  (cur_nimg // TrainingSpeedInt)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Beispiel #8
0
    def _evaluate(self, Gs, num_gpus):
        minibatch_size = num_gpus * self.minibatch_per_gpu
        inception = misc.load_pkl(
            './inception_v3_features.pkl')  # inception_v3_features.pkl
        activations = np.empty([self.num_images, inception.output_shape[1]],
                               dtype=np.float32)

        # Calculate statistics for reals.
        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        if os.path.isfile(cache_file):
            mu_real, sigma_real = misc.load_pkl(cache_file)
        else:
            for idx, images in enumerate(
                    self._iterate_reals(minibatch_size=minibatch_size)):
                begin = idx * minibatch_size
                end = min(begin + minibatch_size, self.num_images)
                activations[begin:end] = inception.run(images[:end - begin],
                                                       num_gpus=num_gpus,
                                                       assume_frozen=True)
                if end == self.num_images:
                    break
            mu_real = np.mean(activations, axis=0)
            sigma_real = np.cov(activations, rowvar=False)
            misc.save_pkl((mu_real, sigma_real), cache_file)

        # Construct TensorFlow graph.
        # different from stylegan
        result_expr = []
        for gpu_idx in range(num_gpus):
            with tf.device('/gpu:%d' % gpu_idx):
                Gs_clone = Gs.clone()
                inception_clone = inception.clone()
                latents = tf.random_normal([self.minibatch_per_gpu] +
                                           Gs_clone.input_shape[1:])
                # beholdergan-id
                # labels = misc.make_rand_labels(self.minibatch_per_gpu,dims=labelsize)
                # labels = tf.constant(labels)

                # beholdergan
                labels = tf.constant(
                    np.zeros([self.minibatch_per_gpu, labelsize],
                             dtype=np.float32))

                #stylegan
                # images = Gs_clone.get_output_for(latents)

                # CGANs
                # images = Gs_clone.get_output_for(latents, labels)

                images = tflib.convert_images_to_uint8(images)
                result_expr.append(inception_clone.get_output_for(images))

        # Calculate statistics for fakes.
        for begin in range(0, self.num_images, minibatch_size):
            end = min(begin + minibatch_size, self.num_images)
            activations[begin:end] = np.concatenate(tflib.run(result_expr),
                                                    axis=0)[:end - begin]
        mu_fake = np.mean(activations, axis=0)
        sigma_fake = np.cov(activations, rowvar=False)

        # Calculate FID.
        m = np.square(mu_fake - mu_real).sum()
        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)  # pylint: disable=no-member
        dist = m + np.trace(sigma_fake + sigma_real - 2 * s)
        self._report_result(np.real(dist))
def train_progressive_gan(
    G_smoothing             = 0.999,        # Exponential running average of generator weights.
    D_repeats               = 1,            # How many times the discriminator is trained per G iteration.
    minibatch_repeats       = 4,            # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod   = True,         # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 15000,        # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,        # Enable mirror augment?
    drange_net              = [-1,1],       # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 1,            # How often to export image snapshots?
    network_snapshot_ticks  = 10,           # How often to export network snapshots?
    save_tf_graph           = False,        # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,        # Include weight histograms in the tfevents file?
    resume_run_id           = None,         # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot         = None,         # Snapshot index to resume training from, None = autodetect.
    resume_kimg             = 0.0,          # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0):         # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G)
            D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels   = training_set.get_minibatch_tf()
        reals_split     = tf.split(reals, config.num_gpus)
        labels_split    = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
                G_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                D_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, **config.D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
                cur_nimg += sched.minibatch
            tfutil.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % (
                tfutil.autosummary('Progress/tick', cur_tick),
                tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                tfutil.autosummary('Progress/lod', sched.lod),
                tfutil.autosummary('Progress/minibatch', sched.minibatch),
                misc.format_time(tfutil.autosummary('Timing/total_sec', total_time)),
                tfutil.autosummary('Timing/sec_per_tick', tick_time),
                tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                tfutil.autosummary('Timing/maintenance_sec', maintenance_time)))
            tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Beispiel #10
0
def train_detector(
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=1,  # Total length of the training, measured in thousands of real images.
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    snapshot_size=16,  # Size of the snapshot image
    snapshot_ticks=2**13,  # Number of images before maintenance
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=1,  # How often to export network snapshots?
    save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()

    # Load the datasets
    training_set = dataset.load_dataset(tfrecord=config.tfrecord_train,
                                        verbose=True,
                                        **config.dataset)
    testing_set = dataset.load_dataset(tfrecord=config.tfrecord_test,
                                       verbose=True,
                                       repeat=False,
                                       shuffle_mb=0,
                                       **config.dataset)
    testing_set_len = len(testing_set)

    # TODO: data augmentation
    # TODO: testing set

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:  # TODO: save methods
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            N = misc.load_pkl(network_pkl)
        else:
            print('Constructing the network...'
                  )  # TODO: better network (like lod-wise network)
            N = tfutil.Network('N',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               **config.N)
    N.print_layers()

    print('Building TensorFlow graph...')
    # Training set up
    with tf.name_scope('Inputs'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        # minibatch_in            = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        reals, labels, bboxes = training_set.get_minibatch_tf(
        )  # TODO: increase the size of the batch by several loss computation and mean
    N_opt = tfutil.Optimizer(name='TrainN',
                             learning_rate=lrate_in,
                             **config.N_opt)

    with tf.device('/gpu:0'):
        reals, labels, gt_outputs, gt_ref = pre_process(
            reals, labels, bboxes, training_set.dynamic_range,
            [0, training_set.shape[-2]], drange_net)
        with tf.name_scope('N_loss'):  # TODO: loss inadapted
            N_loss = tfutil.call_func_by_name(N=N,
                                              reals=reals,
                                              gt_outputs=gt_outputs,
                                              gt_ref=gt_ref,
                                              **config.N_loss)

        N_opt.register_gradients(tf.reduce_mean(N_loss), N.trainables)
    N_train_op = N_opt.apply_updates()

    # Testing set up
    with tf.device('/gpu:0'):
        test_reals_tf, test_labels_tf, test_bboxes_tf = testing_set.get_minibatch_tf(
        )
        test_reals_tf, test_labels_tf, test_gt_outputs_tf, test_gt_ref_tf = pre_process(
            test_reals_tf, test_labels_tf, test_bboxes_tf,
            testing_set.dynamic_range, [0, testing_set.shape[-2]], drange_net)
        with tf.name_scope('N_test_loss'):
            test_loss = tfutil.call_func_by_name(N=N,
                                                 reals=test_reals_tf,
                                                 gt_outputs=test_gt_outputs_tf,
                                                 gt_ref=test_gt_ref_tf,
                                                 is_training=False,
                                                 **config.N_loss)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        N.setup_weight_histograms()

    test_reals, _, test_bboxes = testing_set.get_minibatch_np(snapshot_size)
    misc.save_img_bboxes(test_reals,
                         test_bboxes,
                         os.path.join(result_subdir, 'reals.png'),
                         snapshot_size,
                         adjust_range=False)

    test_reals = misc.adjust_dynamic_range(test_reals,
                                           training_set.dynamic_range,
                                           drange_net)
    test_preds, _ = N.run(test_reals, minibatch_size=snapshot_size)
    misc.save_img_bboxes(test_reals, test_preds,
                         os.path.join(result_subdir, 'fakes.png'),
                         snapshot_size)

    print('Training...')
    if resume_run_id is None:
        tfutil.run(tf.global_variables_initializer())

    cur_nimg = 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time

    # Choose training parameters and configure training ops.
    sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
    training_set.configure(sched.minibatch)

    _train_loss = 0

    while cur_nimg < total_kimg * 1000:

        # Run training ops.
        # for _ in range(minibatch_repeats):
        _, loss = tfutil.run([N_train_op, N_loss], {lrate_in: sched.N_lrate})
        _train_loss += loss
        cur_nimg += sched.minibatch

        # Perform maintenance tasks once per tick.
        if (cur_nimg >= total_kimg * 1000) or (cur_nimg % snapshot_ticks == 0
                                               and cur_nimg > 0):

            cur_tick += 1
            cur_time = time.time()
            # tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            _train_loss = _train_loss / (cur_nimg - tick_start_nimg)
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            testing_set.configure(sched.minibatch)
            _test_loss = 0
            # testing_set_len = 1 # TMP
            for _ in range(0, testing_set_len, sched.minibatch):
                _test_loss += tfutil.run(test_loss)
            _test_loss /= testing_set_len

            # Report progress. # TODO: improved report display
            print(
                'tick %-5d kimg %-6.1f time %-10s sec/tick %-3.1f maintenance %-7.2f train_loss %.4f test_loss %.4f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/maintenance', maintenance_time),
                   tfutil.autosummary('TrainN/train_loss', _train_loss),
                   tfutil.autosummary('TrainN/test_loss', _test_loss)))

            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            if cur_tick % image_snapshot_ticks == 0:
                test_bboxes, test_refs = N.run(test_reals,
                                               minibatch_size=snapshot_size)
                misc.save_img_bboxes_ref(
                    test_reals, test_bboxes, test_refs,
                    os.path.join(result_subdir,
                                 'fakes%06d.png' % (cur_nimg // 1000)),
                    snapshot_size)
            if cur_tick % network_snapshot_ticks == 0:
                misc.save_pkl(
                    N,
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            _train_loss = 0

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    # misc.save_pkl(N, os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Beispiel #11
0
def test_model(main_dict, reset=None):
    # pointDict = load_LCFCNPoints(main_dict)
    _, val_set = load_trainval(main_dict)
    
    model = ms.load_best_model(main_dict)   
    gt_annDict = load_gtAnnDict(main_dict)
    # for i in range(50):
    import ipdb; ipdb.set_trace()  # breakpoint 887ad390 //

    if 1:
        b_list = [23]
        for i in b_list:
            batch = ms.get_batch(val_set, [i])
            annList_ub = pointList2UpperBoundMask(batch["lcfcn_pointList"], batch)["annList"]
            annList_bo = pointList2BestObjectness(batch["lcfcn_pointList"], batch)["annList"]
            annList = model.predict(batch, predict_method="BestDice")["annList"]
            results = get_perSizeResults(gt_annDict, annList)
            print(i,"Counts:", batch["counts"].item(),
                    " - BestObjectness:", len(annList_bo),
                    " - Model:", len(annList), 
                    " - UpperBound", len(annList_ub))
            print(i, 
                     get_perSizeResults(gt_annDict, annList_bo, pred_images_only=1)["result_dict"]["0.25"], 
                    get_perSizeResults(gt_annDict, annList, pred_images_only=1)["result_dict"]["0.25"], 
                    get_perSizeResults(gt_annDict, annList_ub, pred_images_only=1)["result_dict"]["0.25"])
        import ipdb; ipdb.set_trace()  # breakpoint 98d0193a //
        image_points = ms.get_image(batch["images"], batch["points"], enlarge=1,denorm=1)
        ms.images(image_points, annList2mask(annList)["mask"], 
                        win="model prediction")
        ms.images(batch["images"], annList2mask(annList_bo)["mask"],win="2",  denorm=1)
        ms.images(batch["images"], annList2mask(annList_ub)["mask"], win="3", denorm=1)
        ms.images(batch["images"], batch["points"], win="4", enlarge=1,denorm=1)
        ms.images(batch["images"],  model.predict(batch, predict_method="points")["blobs"], 
                        win="5", enlarge=1,denorm=1)
        ms.images(batch["images"], pointList2points(batch["lcfcn_pointList"])["mask"],
 
                              win="predicted_points", enlarge=1,denorm=1)
    fname = main_dict["path_baselines"].replace("baselines", main_dict["model_name"])

    if reset == "reset":
        _, val_set = load_trainval(main_dict)
        history = ms.load_history(main_dict)
        import ipdb; ipdb.set_trace()  # breakpoint a769ce6e //

        model = ms.load_best_model(main_dict)
        pred_annList = dataset2annList(model, val_set, 
                 predict_method="BestDice", 
                 n_val=None)

        pred_annList_up = load_predAnnList(main_dict, predict_method="UpperBoundMask")
        pred_annList_up = load_predAnnList(main_dict, predict_method="UpperBound")
        gt_annDict = load_gtAnnDict(main_dict)

        results = get_perSizeResults(gt_annDict, pred_annList)

        result_dict = results["result_dict"]

        result_dict["Model"] = main_dict["model_name"]
        result_list = [result_dict]
        ms.save_pkl(fname, result_list)
    else:
        result_list = ms.load_pkl(fname)

    return result_list
Beispiel #12
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    compute_fid_score=False,  # Compute FID during training once sched.lod=0.0 
    minimum_fid_kimg=0,  # Compute FID after 
    fid_snapshot_ticks=1,  # How often to compute FID
    fid_patience=2,  # When to end training based on FID
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    result_subdir="./"):
    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id != "None":
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            resume_pkl_name = os.path.splitext(
                os.path.basename(network_pkl))[0]
            try:
                resume_kimg = int(resume_pkl_name.split('-')[-1])
                print('** Setting resume kimg to', resume_kimg, flush=True)
            except:
                print('** Keeping resume kimg as:', resume_kimg, flush=True)
            print('Loading networks from "%s"...' % network_pkl, flush=True)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...', flush=True)
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...', flush=True)
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)

    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...', flush=True)
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...', flush=True)
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print('Training...', flush=True)
    # FID patience parameters:
    fid_list = []
    fid_steps = 0
    fid_stop = False
    fid_patience_step = 0

    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            if (compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks
                                  == 0) and (sched.lod == 0.0) and (
                                      cur_nimg >= minimum_fid_kimg * 1000):
                fid = compute_fid(Gs=Gs,
                                  minibatch_size=sched.minibatch,
                                  dataset_obj=training_set,
                                  iter_number=cur_nimg / 1000,
                                  lod=0.0,
                                  num_images=10000,
                                  printing=False)
                fid_list.append(fid)

            # Report progress without FID.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)),
                flush=True)
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save image snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

            # Save network snapshots
            if cur_tick % network_snapshot_ticks == 0 or done or (
                    compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks == 0) and (
                        cur_nimg >= minimum_fid_kimg * 1000):
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # End training based on FID patience
            if (compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks
                                  == 0) and (sched.lod == 0.0) and (
                                      cur_nimg >= minimum_fid_kimg * 1000):
                fid_patience_step += 1
                if len(fid_list) == 1:
                    fid_patience_step = 0
                    misc.save_pkl((G, D, Gs),
                                  os.path.join(result_subdir,
                                               'network-final-full-conv.pkl'))
                    print(
                        "Save network-final-full-conv for FID: %.3f at kimg %-8.1f."
                        % (fid_list[-1], cur_nimg // 1000),
                        flush=True)
                else:
                    if fid_list[-1] < np.min(fid_list[:-1]):
                        fid_patience_step = 0
                        misc.save_pkl(
                            (G, D, Gs),
                            os.path.join(result_subdir,
                                         'network-final-full-conv.pkl'))
                        print(
                            "Save network-final-full-conv for FID: %.3f at kimg %-8.1f."
                            % (fid_list[-1], cur_nimg // 1000),
                            flush=True)
                    else:
                        print("No improvement for FID: %.3f at kimg %-8.1f." %
                              (fid_list[-1], cur_nimg // 1000),
                              flush=True)
                if fid_patience_step == fid_patience:
                    fid_stop = True
                    print("Training stopped due to FID early-stopping.",
                          flush=True)
                    cur_nimg = total_kimg * 1000

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    # Save final only if FID-Stopping has not happend:
    if fid_stop == False:
        fid = compute_fid(Gs=Gs,
                          minibatch_size=sched.minibatch,
                          dataset_obj=training_set,
                          iter_number=cur_nimg / 1000,
                          lod=0.0,
                          num_images=10000,
                          printing=False)
        print("Final FID: %.3f at kimg %-8.1f." % (fid, cur_nimg // 1000),
              flush=True)
        ### save final FID to .csv file in result_parent_dir
        csv_file = os.path.join(
            os.path.dirname(os.path.dirname(result_subdir)),
            "results_full_conv.csv")
        list_to_append = [
            result_subdir.split("/")[-2] + "/" + result_subdir.split("/")[-1],
            fid
        ]
        with open(csv_file, 'a') as f_object:
            writer_object = writer(f_object)
            writer_object.writerow(list_to_append)
            f_object.close()
        misc.save_pkl((G, D, Gs),
                      os.path.join(result_subdir,
                                   'network-final-full-conv.pkl'))
        print("Save network-final-full-conv.", flush=True)
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Beispiel #13
0
def train_classifier(
    smoothing=0.999,  # Exponential running average of encoder weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    lr_mirror_augment=True,  # Enable mirror augment?
    ud_mirror_augment=False,  # Enable up-down mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=10,  # How often to export image snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False
):  # Include weight histograms in the tfevents file?

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.training_set)
    validation_set = dataset.load_dataset(data_dir=config.data_dir,
                                          verbose=True,
                                          **config.validation_set)
    network_snapshot_ticks = total_kimg // 100  # How often to export network snapshots?

    # Construct networks.
    with tf.device('/gpu:0'):
        try:
            network_pkl = misc.locate_network_pkl()
            resume_kimg, resume_time = misc.resume_kimg_time(network_pkl)
            print('Loading networks from "%s"...' % network_pkl)
            EG, D_rec, EGs = misc.load_pkl(network_pkl)
        except:
            print('Constructing networks...')
            resume_kimg = 0.0
            resume_time = 0.0
            EG = tfutil.Network('EG',
                                num_channels=training_set.shape[0],
                                resolution=training_set.shape[1],
                                label_size=training_set.label_size,
                                **config.EG)
            D_rec = tfutil.Network('D_rec',
                                   num_channels=training_set.shape[0],
                                   resolution=training_set.shape[1],
                                   **config.D_rec)
            EGs = EG.clone('EGs')
        EGs_update_op = EGs.setup_as_moving_average_of(EG, beta=smoothing)
    EG.print_layers()
    D_rec.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    EG_opt = tfutil.Optimizer(name='TrainEG',
                              learning_rate=lrate_in,
                              **config.EG_opt)
    D_rec_opt = tfutil.Optimizer(name='TrainD_rec',
                                 learning_rate=lrate_in,
                                 **config.D_rec_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            EG_gpu = EG if gpu == 0 else EG.clone(EG.name + '_shadow_%d' % gpu)
            D_rec_gpu = D_rec if gpu == 0 else D_rec.clone(D_rec.name +
                                                           '_shadow_%d' % gpu)
            reals_fade_gpu, reals_orig_gpu = process_reals(
                reals_split[gpu], lod_in, lr_mirror_augment, ud_mirror_augment,
                training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('EG_loss'):
                EG_loss = tfutil.call_func_by_name(EG=EG_gpu,
                                                   D_rec=D_rec_gpu,
                                                   reals_orig=reals_orig_gpu,
                                                   labels=labels_gpu,
                                                   **config.EG_loss)
            with tf.name_scope('D_rec_loss'):
                D_rec_loss = tfutil.call_func_by_name(
                    EG=EG_gpu,
                    D_rec=D_rec_gpu,
                    D_rec_opt=D_rec_opt,
                    minibatch_size=minibatch_split,
                    reals_orig=reals_orig_gpu,
                    **config.D_rec_loss)
            EG_opt.register_gradients(tf.reduce_mean(EG_loss),
                                      EG_gpu.trainables)
            D_rec_opt.register_gradients(tf.reduce_mean(D_rec_loss),
                                         D_rec_gpu.trainables)
    EG_train_op = EG_opt.apply_updates()
    D_rec_train_op = D_rec_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, train_reals, train_labels = setup_snapshot_image_grid(
        training_set, drange_net, [450, 10], **config.grid)
    grid_size, val_reals, val_labels = setup_snapshot_image_grid(
        validation_set, drange_net, [450, 10], **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)

    train_recs, train_fingerprints, train_logits = EGs.run(
        train_reals, minibatch_size=sched.minibatch // config.num_gpus)
    train_preds = np.argmax(train_logits, axis=1)
    train_gt = np.argmax(train_labels, axis=1)
    train_acc = np.float32(np.sum(train_gt == train_preds)) / np.float32(
        len(train_gt))
    print('Training Accuracy = %f' % train_acc)

    val_recs, val_fingerprints, val_logits = EGs.run(
        val_reals, minibatch_size=sched.minibatch // config.num_gpus)
    val_preds = np.argmax(val_logits, axis=1)
    val_gt = np.argmax(val_labels, axis=1)
    val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32(len(val_gt))
    print('Validation Accuracy = %f' % val_acc)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(train_reals[::30, :, :, :],
                         os.path.join(result_subdir, 'train_reals.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(train_recs[::30, :, :, :],
                         os.path.join(result_subdir, 'train_recs-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(train_fingerprints[::30, :, :, :],
                         os.path.join(result_subdir,
                                      'train_fingerrints-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_reals[::30, :, :, :],
                         os.path.join(result_subdir, 'val_reals.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_recs[::30, :, :, :],
                         os.path.join(result_subdir, 'val_recs-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_fingerprints[::30, :, :, :],
                         os.path.join(result_subdir,
                                      'val_fingerrints-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])

    est_fingerprints = np.transpose(
        EGs.vars['Conv_fingerprints/weight'].eval(), axes=[3, 2, 0, 1])
    misc.save_image_grid(
        est_fingerprints,
        os.path.join(result_subdir, 'est_fingerrints-init.png'),
        drange=[np.amin(est_fingerprints),
                np.amax(est_fingerprints)],
        grid_size=[est_fingerprints.shape[0], 1])

    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        EG.setup_weight_histograms()
        D_rec.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                EG_opt.reset_optimizer_state()
                D_rec_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            tfutil.run(
                [D_rec_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.lrate,
                    minibatch_in: sched.minibatch
                })
            tfutil.run(
                [EG_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.lrate,
                    minibatch_in: sched.minibatch
                })
            tfutil.run([EGs_update_op], {})
            cur_nimg += sched.minibatch

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f resolution %-4d minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/resolution', sched.resolution),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Print accuracy.
            if cur_tick % image_snapshot_ticks == 0 or done:

                train_recs, train_fingerprints, train_logits = EGs.run(
                    train_reals,
                    minibatch_size=sched.minibatch // config.num_gpus)
                train_preds = np.argmax(train_logits, axis=1)
                train_gt = np.argmax(train_labels, axis=1)
                train_acc = np.float32(np.sum(
                    train_gt == train_preds)) / np.float32(len(train_gt))
                print('Training Accuracy = %f' % train_acc)

                val_recs, val_fingerprints, val_logits = EGs.run(
                    val_reals,
                    minibatch_size=sched.minibatch // config.num_gpus)
                val_preds = np.argmax(val_logits, axis=1)
                val_gt = np.argmax(val_labels, axis=1)
                val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32(
                    len(val_gt))
                print('Validation Accuracy = %f' % val_acc)

                misc.save_image_grid(train_recs[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'train_recs-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(train_fingerprints[::30, :, :, :],
                                     os.path.join(
                                         result_subdir,
                                         'train_fingerrints-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(val_recs[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'val_recs-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(val_fingerprints[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'val_fingerrints-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])

                est_fingerprints = np.transpose(
                    EGs.vars['Conv_fingerprints/weight'].eval(),
                    axes=[3, 2, 0, 1])
                misc.save_image_grid(est_fingerprints,
                                     os.path.join(result_subdir,
                                                  'est_fingerrints-final.png'),
                                     drange=[
                                         np.amin(est_fingerprints),
                                         np.amax(est_fingerprints)
                                     ],
                                     grid_size=[est_fingerprints.shape[0], 1])

            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (EG, D_rec, EGs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((EG, D_rec, EGs),
                  os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Beispiel #14
0
def train_gan(separate_funcs=False,
              D_training_repeats=1,
              G_learning_rate_max=0.0010,
              D_learning_rate_max=0.0010,
              G_smoothing=0.999,
              adam_beta1=0.0,
              adam_beta2=0.99,
              adam_epsilon=1e-8,
              minibatch_default=16,
              minibatch_overrides={},
              rampup_kimg=40,
              rampdown_kimg=0,
              lod_initial_resolution=4,
              lod_training_kimg=400,
              lod_transition_kimg=400,
              total_kimg=10000,
              dequantize_reals=False,
              gdrop_beta=0.9,
              gdrop_lim=0.5,
              gdrop_coef=0.2,
              gdrop_exp=2.0,
              drange_net=[-1, 1],
              drange_viz=[-1, 1],
              image_grid_size=None,
              tick_kimg_default=50,
              tick_kimg_overrides={
                  32: 20,
                  64: 10,
                  128: 10,
                  256: 5,
                  512: 2,
                  1024: 1
              },
              image_snapshot_ticks=4,
              network_snapshot_ticks=40,
              image_grid_type='default',
              resume_network_pkl=None,
              resume_kimg=0.0,
              resume_time=0.0):

    # Load dataset and build networks.
    training_set, drange_orig = load_dataset()
    if resume_network_pkl:
        print 'Resuming', resume_network_pkl
        G, D, _ = misc.load_pkl(
            os.path.join(config.result_dir, resume_network_pkl))
    else:
        G = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.G)
        D = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.D)
    Gs = G.create_temporally_smoothed_version(beta=G_smoothing,
                                              explicit_updates=True)
    misc.print_network_topology_info(G.output_layers)
    misc.print_network_topology_info(D.output_layers)

    # Setup snapshot image grid.
    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[3], G.output_shape[2]
            image_grid_size = np.clip(1920 / w, 3,
                                      16), np.clip(1080 / h, 2, 16)
        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    elif image_grid_type == 'category':
        W = training_set.labels.shape[1]
        H = W if image_grid_size is None else image_grid_size[1]
        image_grid_size = W, H
        snapshot_fake_latents = random_latents(W * H, G.input_shape)
        snapshot_fake_labels = np.zeros((W * H, W),
                                        dtype=training_set.labels.dtype)
        example_real_images = np.zeros((W * H, ) + training_set.shape[1:],
                                       dtype=training_set.dtype)
        for x in xrange(W):
            snapshot_fake_labels[x::W, x] = 1.0
            indices = np.arange(
                training_set.shape[0])[training_set.labels[:, x] != 0]
            for y in xrange(H):
                example_real_images[x + y * W] = training_set.h5_lods[0][
                    np.random.choice(indices)]
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    # Theano input variables and compile generation func.
    print 'Setting up Theano...'
    real_images_var = T.TensorType('float32', [False] *
                                   len(D.input_shape))('real_images_var')
    real_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('real_labels_var')
    fake_latents_var = T.TensorType('float32', [False] *
                                    len(G.input_shape))('fake_latents_var')
    fake_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('fake_labels_var')
    G_lrate = theano.shared(np.float32(0.0))
    D_lrate = theano.shared(np.float32(0.0))
    gen_fn = theano.function([fake_latents_var, fake_labels_var],
                             Gs.eval_nd(fake_latents_var,
                                        fake_labels_var,
                                        ignore_unused_inputs=True),
                             on_unused_input='ignore')

    # Misc init.
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    if config.D.get('mbdisc_kernels', None):
        print 'Initializing minibatch discrimination...'
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(initial_lod))
        D.eval(real_images_var, deterministic=False, init=True)
        init_layers = lasagne.layers.get_all_layers(D.output_layers)
        init_updates = [
            update for layer in init_layers
            for update in getattr(layer, 'init_updates', [])
        ]
        init_fn = theano.function(inputs=[real_images_var],
                                  outputs=None,
                                  updates=init_updates)
        init_reals = training_set.get_random_minibatch(500, lod=initial_lod)
        init_reals = misc.adjust_dynamic_range(init_reals, drange_orig,
                                               drange_net)
        init_fn(init_reals)
        del init_reals

    # Save example images.
    snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels)
    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    # Training loop.
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time
    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / 1000.0) / (lod_training_kimg +
                                          lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                           rampdown_kimg)
        G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max))
        D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max))
        if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod))
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod))

        # Setup training func for current LOD.
        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            print 'Compiling training funcs...'
            min_lod, max_lod = new_min_lod, new_max_lod

            # Pre-process reals.
            real_images_expr = real_images_var
            if dequantize_reals:
                rnd = theano.sandbox.rng_mrg.MRG_RandomStreams(
                    lasagne.random.get_rng().randint(1, 2147462579))
                epsilon_noise = rnd.uniform(size=real_images_expr.shape,
                                            low=-0.5,
                                            high=0.5,
                                            dtype='float32')
                real_images_expr = T.cast(
                    real_images_expr, 'float32'
                ) + epsilon_noise  # match original implementation of Improved Wasserstein
            real_images_expr = misc.adjust_dynamic_range(
                real_images_expr, drange_orig, drange_net)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=2)
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=3)

            # Optimize loss.
            G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss(
                G, D, min_lod, max_lod, real_images_expr, real_labels_var,
                fake_latents_var, fake_labels_var, **config.loss)
            G_updates = adam(G_loss,
                             G.trainable_params(),
                             learning_rate=G_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()
            D_updates = adam(D_loss,
                             D.trainable_params(),
                             learning_rate=D_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            # Compile training funcs.
            if not separate_funcs:
                GD_train_fn = theano.function([
                    real_images_var, real_labels_var, fake_latents_var,
                    fake_labels_var
                ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                              updates=G_updates + D_updates +
                                              Gs.updates,
                                              on_unused_input='ignore')
            else:
                D_train_fn = theano.function([
                    real_images_var, real_labels_var, fake_latents_var,
                    fake_labels_var
                ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                             updates=D_updates,
                                             on_unused_input='ignore')
                G_train_fn = theano.function(
                    [fake_latents_var, fake_labels_var], [],
                    updates=G_updates + Gs.updates,
                    on_unused_input='ignore')

        # Invoke training funcs.
        if not separate_funcs:
            assert D_training_repeats == 1
            mb_reals, mb_labels = training_set.get_random_minibatch(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)
            mb_train_out = GD_train_fn(
                mb_reals, mb_labels,
                random_latents(minibatch_size, G.input_shape),
                random_labels(minibatch_size, training_set))
            cur_nimg += minibatch_size
            tick_train_out.append(mb_train_out)
        else:
            for idx in xrange(D_training_repeats):
                mb_reals, mb_labels = training_set.get_random_minibatch(
                    minibatch_size,
                    lod=cur_lod,
                    shrink_based_on_lod=True,
                    labels=True)
                mb_train_out = D_train_fn(
                    mb_reals, mb_labels,
                    random_latents(minibatch_size, G.input_shape),
                    random_labels(minibatch_size, training_set))
                cur_nimg += minibatch_size
                tick_train_out.append(mb_train_out)
            G_train_fn(random_latents(minibatch_size, G.input_shape),
                       random_labels(minibatch_size, training_set))

        # Fade in D noise if we're close to becoming unstable
        fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            D.gdrop_strength.set_value(np.float32(gdrop_strength))

        # Perform maintenance operations once per tick.
        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []

            # Print progress.
            print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
                (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size,
                 misc.format_time(cur_time - train_start_time), tick_time,
                 tick_time / tick_kimg, gdrop_strength) + tick_train_avg)

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = gen_fn(snapshot_fake_latents,
                                              snapshot_fake_labels)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            # Save network snapshot every N ticks.
            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg / 1000)))

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    training_set.close()
    print 'Done.'
    with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'):
        pass
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('-e', '--exp')
    parser.add_argument('-b', '--borgy', default=0, type=int)
    parser.add_argument('-br', '--borgy_running', default=0, type=int)
    parser.add_argument('-m', '--mode', default="summary")
    parser.add_argument('-r', '--reset', default="None")
    parser.add_argument('-s', '--status', type=int, default=0)
    parser.add_argument('-k', '--kill', type=int, default=0)
    parser.add_argument('-g', '--gpu', type=int)
    parser.add_argument('-c', '--configList', nargs="+", default=None)
    parser.add_argument('-l', '--lossList', nargs="+", default=None)
    parser.add_argument('-d', '--datasetList', nargs="+", default=None)
    parser.add_argument('-metric', '--metricList', nargs="+", default=None)
    parser.add_argument('-model', '--modelList', nargs="+", default=None)
    parser.add_argument('-p', '--predictList', nargs="+", default=None)

    args = parser.parse_args()

    if args.borgy or args.kill:
        global_prompt = input("Do all? \n(y/n)\n")

    # SEE IF CUDA IS AVAILABLE
    assert torch.cuda.is_available()
    print("CUDA: %s" % torch.version.cuda)
    print("Pytroch: %s" % torch.__version__)

    mode = args.mode
    exp_name = args.exp

    exp_dict = experiments.get_experiment_dict(args, exp_name)

    pp_main = None
    results = {}

    # Get Main Class
    project_name = os.path.realpath(__file__).split("/")[-2]
    MC = ms.MainClass(path_models="models",
                      path_datasets="datasets",
                      path_metrics="metrics/metrics.py",
                      path_losses="losses/losses.py",
                      path_samplers="addons/samplers.py",
                      path_transforms="addons/transforms.py",
                      path_saves="/mnt/projects/counting/Saves/main/",
                      project=project_name)

    key_set = set()
    for model_name, config_name, metric_name, dataset_name, loss_name in product(
            exp_dict["modelList"], exp_dict["configList"],
            exp_dict["metricList"], exp_dict["datasetList"],
            exp_dict["lossList"]):

        # if model_name in ["LC_RESFCN"]:
        #   loss_name = "water_loss"

        config = configs.get_config_dict(config_name)

        key = ("{} - {} - {}".format(model_name, config_name, loss_name),
               "{}_({})".format(dataset_name, metric_name))

        if key in key_set:
            continue

        key_set.add(key)

        main_dict = MC.get_main_dict(mode, dataset_name, model_name,
                                     config_name, config, args.reset,
                                     exp_dict["epochs"], metric_name,
                                     loss_name)
        main_dict["predictList"] = exp_dict["predictList"]

        if mode == "paths":
            print("\n{}_({})".format(dataset_name, model_name))
            print(main_dict["path_best_model"])
            # print( main_dict["exp_name"])

        predictList_str = ' '.join(exp_dict["predictList"])

        if args.status:
            results[key] = borgy.borgy_status(mode, config_name, metric_name,
                                              model_name, dataset_name,
                                              loss_name, args.reset,
                                              predictList_str)

            continue

        if args.kill:
            results[key] = borgy.borgy_kill(mode, config_name, metric_name,
                                            model_name, dataset_name,
                                            loss_name, args.reset,
                                            predictList_str)
            continue

        if args.borgy:
            results[key] = borgy.borgy_submit(project_name, global_prompt,
                                              mode, config_name, metric_name,
                                              model_name, dataset_name,
                                              loss_name, args.reset,
                                              predictList_str)

            continue

        if mode == "debug":
            debug.debug(main_dict)

        if mode == "validate":
            validate.validate(main_dict)
        if mode == "save_gam_points":
            train_set, _ = au.load_trainval(main_dict)
            model = ms.load_best_model(main_dict)
            for i in range(len(train_set)):
                print(i, "/", len(train_set))
                batch = ms.get_batch(train_set, [i])
                fname = train_set.path + "/gam_{}.pkl".format(
                    batch["index"].item())
                points = model.get_points(batch)
                ms.save_pkl(fname, points)
            import ipdb
            ipdb.set_trace()  # breakpoint ee49ab9f //

        if mode == "save_prm_points":
            train_set, _ = au.load_trainval(main_dict)
            model = ms.load_best_model(main_dict)
            for i in range(len(train_set)):
                print(i, "/", len(train_set))
                batch = ms.get_batch(train_set, [i])

                fname = "{}/prm{}.pkl".format(batch["path"][0],
                                              batch["name"][0])
                points = model.get_points(batch)
                ms.save_pkl(fname, points)
            import ipdb
            ipdb.set_trace()  # breakpoint 679ce152 //

            # train_set, _ = au.load_trainval(main_dict)
            # model = ms.load_best_model(main_dict)
            # for i in range(len(train_set)):
            #   print(i, "/", len(train_set))
            #   batch = ms.get_batch(train_set, [i])
            #   fname = train_set.path + "/gam_{}.pkl".format(batch["index"].item())
            #   points = model.get_points(batch)
            #   ms.save_pkl(fname, points)

        # if mode == "pascal_annList":
        #   data_utils.pascal2lcfcn_points(main_dict)
        if mode == "upperboundmasks":
            import ipdb
            ipdb.set_trace()  # breakpoint 02fac8ce //

            results = au.test_upperboundmasks(main_dict, reset=args.reset)
            print(pd.DataFrame(results))

        if mode == "model":

            results = au.test_model(main_dict, reset=args.reset)
            print(pd.DataFrame(results))

        if mode == "upperbound":
            results = au.test_upperbound(main_dict, reset=args.reset)

            print(pd.DataFrame(results))

        if mode == "MUCov":
            gtAnnDict = au.load_gtAnnDict(main_dict, reset=args.reset)

            # model = ms.load_best_model(main_dict)
            fname = main_dict["path_save"] + "/pred_annList.pkl"
            if not os.path.exists(fname):
                _, val_set = au.load_trainval(main_dict)
                model = ms.load_best_model(main_dict)
                pred_annList = au.dataset2annList(model,
                                                  val_set,
                                                  predict_method="BestDice",
                                                  n_val=None)
                ms.save_pkl(fname, pred_annList)

            else:
                pred_annList = ms.load_pkl(fname)
            import ipdb
            ipdb.set_trace()  # breakpoint 527a7f36 //
            pred_annList = au.load_predAnnList(main_dict,
                                               predict_method="BestObjectness")
            # 0.31 best objectness pred_annList =
            # 0.3482122335421256
            # au.get_MUCov(gtAnnDict, pred_annList)
            au.get_SBD(gtAnnDict, pred_annList)

        if mode == "dic_sbd":
            import ipdb
            ipdb.set_trace()  # breakpoint 4af08a17 //

        if mode == "point_mask":
            from datasets import base_dataset

            import ipdb
            ipdb.set_trace()  # breakpoint 7fd55e0c //
            _, val_set = ms.load_trainval(main_dict)
            batch = ms.get_batch(val_set, [1])
            model = ms.load_best_model(main_dict)
            pred_dict = model.LCFCN.predict(batch)
            # ms.pretty_vis(batch["images"], base_dataset.batch2annList(batch))
            ms.images(ms.pretty_vis(
                batch["images"],
                model.LCFCN.predict(batch,
                                    predict_method="original")["annList"]),
                      win="blobs")
            ms.images(ms.pretty_vis(batch["images"],
                                    base_dataset.batch2annList(batch)),
                      win="erww")
            ms.images(batch["images"],
                      batch["points"],
                      denorm=1,
                      enlarge=1,
                      win="e21e")
            import ipdb
            ipdb.set_trace()  # breakpoint ab9240f0 //

        if mode == "lcfcn_output":
            import ipdb
            ipdb.set_trace()  # breakpoint 7fd55e0c //

            gtAnnDict = au.load_gtAnnDict(main_dict, reset=args.reset)

        if mode == "load_gtAnnDict":
            _, val_set = au.load_trainval(main_dict)
            gtAnnDict = au.load_gtAnnDict(val_set)

            # gtAnnClass = COCO(gtAnnDict)
            # au.assert_gtAnnDict(main_dict, reset=None)
            # _,val_set = au.load_trainval(main_dict)
            # annList_path = val_set.annList_path

            # fname_dummy = annList_path.replace(".json","_best.json")
            # predAnnDict = ms.load_json(fname_dummy)
            import ipdb
            ipdb.set_trace()  # breakpoint 100bfe1b //
            pred_annList = ms.load_pkl(main_dict["path_best_annList"])
            # model = ms.load_best_model(main_dict)
            _, val_set = au.load_trainval(main_dict)
            batch = ms.get_batch(val_set, [1])

            import ipdb
            ipdb.set_trace()  # breakpoint 2310bb33 //
            model = ms.load_best_model(main_dict)
            pred_dict = model.predict(batch, "BestDice", "mcg")
            ms.images(batch["images"],
                      au.annList2mask(pred_dict["annList"])["mask"],
                      denorm=1)
            # pointList2UpperBoundMCG
            pred_annList = au.load_predAnnList(main_dict,
                                               predict_method="BestDice",
                                               proposal_type="mcg",
                                               reset="reset")
            # annList = au.pointList2UpperBoundMCG(batch["lcfcn_pointList"], batch)["annList"]
            ms.images(batch["images"],
                      au.annList2mask(annList)["mask"],
                      denorm=1)
            pred_annList = au.load_BestMCG(main_dict, reset="reset")
            # pred_annList = au.dataset2annList(model, val_set,
            #                   predict_method="BestDice",
            #                   n_val=None)
            au.get_perSizeResults(gtAnnDict, pred_annList)

        if mode == "vis":
            _, val_set = au.load_trainval(main_dict)
            batch = ms.get_batch(val_set, [3])

            import ipdb
            ipdb.set_trace()  # breakpoint 05e6ef16 //

            vis.visBaselines(batch)

            model = ms.load_best_model(main_dict)
            vis.visBlobs(model, batch)

        if mode == "qual":
            model = ms.load_best_model(main_dict)
            _, val_set = au.load_trainval(main_dict)
            path = "/mnt/home/issam/Summaries/{}_{}".format(
                dataset_name, model_name)
            try:
                ms.remove_dir(path)
            except:
                pass
            n_images = len(val_set)
            base = "{}_{}".format(dataset_name, model_name)
            for i in range(50):
                print(i, "/10", "- ", base)
                index = np.random.randint(0, n_images)
                batch = ms.get_batch(val_set, [index])
                if len(batch["lcfcn_pointList"]) == 0:
                    continue
                image = vis.visBlobs(model, batch, return_image=True)

                # image_baselines = vis.visBaselines(batch, return_image=True)
                # imgAll = np.concatenate([image, image_baselines], axis=1)

                fname = path + "/{}_{}.png".format(i, base)
                ms.create_dirs(fname)
                ms.imsave(fname, image)

        if mode == "test_baselines":
            import ipdb
            ipdb.set_trace()  # breakpoint b51c5b1f //
            results = au.test_baselines(main_dict, reset=args.reset)
            print(pd.DataFrame(results))

        if mode == "test_best":
            au.test_best(main_dict)

        if mode == "qualitative":
            au.qualitative(main_dict)

        if mode == "figure1":
            from PIL import Image
            from addons import transforms
            model = ms.load_best_model(main_dict)
            _, val_set = au.load_trainval(main_dict)
            # proposals_path = "/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/ProposalsSharp/"
            # vidList = glob("/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/stuttgart_01/*")
            # vidList.sort()

            # pretty_image = ms.visPretty(model, batch = ms.get_batch(val_set, [i]), with_void=1, win="with_void")
            batch = ms.get_batch(val_set, [68])
            bestdice = ms.visPretty(model,
                                    batch=batch,
                                    with_void=0,
                                    win="no_void")
            blobs = ms.visPretty(model,
                                 batch=batch,
                                 predict_method="blobs",
                                 with_void=0,
                                 win="no_void")

            ms.images(bestdice, win="BestDice")
            ms.images(blobs, win="Blobs")
            ms.images(batch["images"], denorm=1, win="Image")
            ms.images(batch["images"],
                      batch["points"],
                      enlarge=1,
                      denorm=1,
                      win="Points")
            import ipdb
            ipdb.set_trace()  # breakpoint cf4bb3d3 //

        if mode == "video2":
            from PIL import Image
            from addons import transforms
            model = ms.load_best_model(main_dict)
            _, val_set = au.load_trainval(main_dict)
            # proposals_path = "/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/ProposalsSharp/"
            # vidList = glob("/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/stuttgart_01/*")
            # vidList.sort()
            index = 0
            for i in range(len(val_set)):

                # pretty_image = ms.visPretty(model, batch = ms.get_batch(val_set, [i]), with_void=1, win="with_void")
                batch = ms.get_batch(val_set, [i])
                pretty_image = ms.visPretty(model,
                                            batch=batch,
                                            with_void=0,
                                            win="no_void")
                # pred_dict = model.predict(batch, predict_method="BestDice")
                path_summary = main_dict["path_summary"]
                ms.create_dirs(path_summary + "/tmp")
                ms.imsave(
                    path_summary + "vid_mask_{}.png".format(index),
                    ms.get_image(batch["images"],
                                 batch["points"],
                                 enlarge=1,
                                 denorm=1))
                index += 1
                ms.imsave(path_summary + "vid_mask_{}.png".format(index),
                          pretty_image)
                index += 1
                # ms.imsave(path_summary+"vid1_full_{}.png".format(i), ms.get_image(img, pred_dict["blobs"], denorm=1))
                print(i, "/", len(val_set))

        if mode == "video":
            from PIL import Image
            from addons import transforms
            model = ms.load_best_model(main_dict)
            # _, val_set = au.load_trainval(main_dict)
            proposals_path = "/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/ProposalsSharp/"
            vidList = glob(
                "/mnt/datasets/public/issam/Cityscapes/demoVideo/leftImg8bit/demoVideo/stuttgart_01/*"
            )
            vidList.sort()
            for i, img_path in enumerate(vidList):
                image = Image.open(img_path).convert('RGB')
                image = image.resize((1200, 600), Image.BILINEAR)
                img, _ = transforms.Tr_WTP_NoFlip()([image, image])

                pred_dict = model.predict(
                    {
                        "images": img[None],
                        "split": ["test"],
                        "resized": torch.FloatTensor([1]),
                        "name": [ms.extract_fname(img_path)],
                        "proposals_path": [proposals_path]
                    },
                    predict_method="BestDice")
                path_summary = main_dict["path_summary"]
                ms.create_dirs(path_summary + "/tmp")
                ms.imsave(path_summary + "vid1_mask_{}.png".format(i),
                          ms.get_image(pred_dict["blobs"]))
                ms.imsave(path_summary + "vid1_full_{}.png".format(i),
                          ms.get_image(img, pred_dict["blobs"], denorm=1))
                print(i, "/", len(vidList))

        if mode == "5_eval_BestDice":
            gtAnnDict = au.load_gtAnnDict(main_dict)
            gtAnnClass = COCO(gtAnnDict)
            results = au.assert_gtAnnDict(main_dict, reset=None)

        if mode == "cp_annList":
            ms.dataset2cocoformat(dataset_name="CityScapes")

        if mode == "pascal2lcfcn_points":
            data_utils.pascal2lcfcn_points(main_dict)

        if mode == "cp2lcfcn_points":
            data_utils.cp2lcfcn_points(main_dict)

        if mode == "train":

            train.main(main_dict)
            import ipdb
            ipdb.set_trace()  # breakpoint a5d091b9 //

        if mode == "train_only":

            train.main(main_dict, train_only=True)
            import ipdb
            ipdb.set_trace()  # breakpoint a5d091b9 //

        if mode == "sharpmask2psfcn":
            for split in ["train", "val"]:
                root = "/mnt/datasets/public/issam/COCO2014/ProposalsSharp/"
                path = "{}/sharpmask/{}/jsons/".format(root, split)

                jsons = glob(path + "*.json")
                propDict = {}
                for k, json in enumerate(jsons):
                    print("{}/{}".format(k, len(jsons)))
                    props = ms.load_json(json)
                    for p in props:
                        if p["image_id"] not in propDict:
                            propDict[p["image_id"]] = []
                        propDict[p["image_id"]] += [p]

                for k in propDict.keys():
                    fname = "{}/{}.json".format(root, k)
                    ms.save_json(fname, propDict[k])

        if mode == "cp2coco":
            import ipdb
            ipdb.set_trace()  # breakpoint f2eb9e70 //
            dataset2cocoformat.cityscapes2cocoformat(main_dict)
            # train.main(main_dict)
            import ipdb
            ipdb.set_trace()  # breakpoint a5d091b9 //

        if mode == "train_lcfcn":
            train_lcfcn.main(main_dict)
            import ipdb
            ipdb.set_trace()  # breakpoint a5d091b9 //

        if mode == "summary":

            try:
                history = ms.load_history(main_dict)

                # if predictList_str == "MAE":
                #   results[key] = "{}/{}: {:.2f}".format(history["best_model"]["epoch"],
                #                                                           history["epoch"],
                #                                                           history["best_model"][metric_name])

                # else:
                val_dict = history["val"][-1]
                val_dict = history["best_model"]
                iou25 = val_dict["0.25"]
                iou5 = val_dict["0.5"]
                iou75 = val_dict["0.75"]
                results[key] = "{}/{}: {:.1f} - {:.1f} - {:.1f}".format(
                    val_dict["epoch"], history["epoch"], iou25 * 100,
                    iou5 * 100, iou75 * 100)
                # if history["val"][-1]["epoch"] != history["epoch"]:
                #   results[key] += " | Val {}".format(history["epoch"])
                try:
                    results[key] += " | {}/{}".format(
                        len(history["trained_batch_names"]),
                        history["train"][-1]["n_samples"])
                except:
                    pass
            except:
                pass
        if mode == "vals":

            history = ms.load_history(main_dict)

            for i in range(1, len(main_dict["predictList"]) + 1):
                if len(history['val']) == 0:
                    res = "NaN"
                    continue
                else:
                    res = history["val"][-i]

                map50 = res["map50"]
                map75 = res["map75"]

                # if map75 < 1e-3:
                #   continue

                string = "{} - {} - map50: {:.2f} - map75: {:.2f}".format(
                    res["epoch"], res["predict_name"], map50, map75)

                key_tmp = list(key).copy()
                key_tmp[1] += " {} - {}".format(metric_name,
                                                res["predict_name"])
                results[tuple(key_tmp)] = string

            # print("map75", pd.DataFrame(history["val"])["map75"].max())
            # df = pd.DataFrame(history["vals"][:20])["water_loss_B"]
            # print(df)
    try:
        print(ms.dict2frame(results))
    except:
        print("Results not printed...")
Beispiel #16
0
    def __init__(self,
                 root="",
                 split=None,
                 transform_function=None,
                 ratio=None,
                 year="2017"):
        super().__init__()
        fname = split

        if fname == "test":
            fname = "val"

        dataset_name = "COCO"

        if year == "2014":
            dataset_name = "COCO2014"

        self.n_classes = 81

        self.path = "/mnt/datasets/public/issam/{}/".format(dataset_name)
        self.proposals_path = "{}/ProposalsSharp/".format(self.path)
        self.split = split
        self.year = year
        self.transform_function = transform_function()
        fname_names = self.path + "/{}.json".format(self.split)
        fname_catids = self.path + "/{}_catids.json".format(self.split)
        fname_categories = self.path + "/categories.json"
        fname_ids = self.path + "/{}_ids.json".format(self.split)

        if os.path.exists(fname_names):

            self.image_names = ms.load_json(fname_names)
            self.catids = ms.load_json(fname_catids)
            self.categories = ms.load_json(fname_categories)
            self.ids = ms.load_json(fname_ids)
        else:
            # Save ids

            annFile = "{}/annotations/instances_{}{}.json".format(
                self.path, fname, year)
            self.coco = COCO(annFile)
            self.ids = list(self.coco.imgs.keys())

            self.image_names = []
            # Save Labels
            for index in range(len(self.ids)):
                print(index, "/", len(self.ids))
                img_id = self.ids[index]
                ann_ids = self.coco.getAnnIds(imgIds=img_id)
                annList = self.coco.loadAnns(ann_ids)
                name = self.coco.loadImgs(img_id)[0]['file_name']

                self.image_names += [name]
                ms.save_pkl(
                    self.path +
                    "/groundtruth/{}_{}.pkl".format(self.split, name), annList)

            ms.save_json(fname_names, self.image_names)

            # Catgory
            self.catids = self.coco.getCatIds()
            ms.save_json(fname_catids, self.catids)

            self.categories = []

            categories = self.coco.cats.values()

            for c in categories:
                c["id"] = self.category2label[c["id"]]
                self.categories += [c]
            ms.save_json(fname_categories, self.categories)

            ms.save_json(fname_ids, self.ids)

            if split == "val":
                # gt_annDict = ms.load_json(annFile)

                annDict = {}
                # fname_ann = '/mnt/datasets/public/issam/COCO2014//annotations/val_gt_annList.json'
                annDict["categories"] = self.categories
                annDict["images"] = self.coco.loadImgs(self.ids[:5000])

                annIDList = self.coco.getAnnIds(self.ids[:5000])
                annList = self.coco.loadAnns(annIDList)

                for p in annList:
                    # p["id"] = str(p["id"])
                    p["image_id"] = str(p["image_id"])
                    p["category_id"] = self.category2label[p["category_id"]]

                for p in annDict["images"]:
                    p["id"] = str(p["id"])
                annDict["annotations"] = annList

                ms.save_json(
                    '{}//annotations/val_gt_annList.json'.format(self.path),
                    annDict)

        self.category2label = {}
        self.label2category = {}

        for i, c in enumerate(self.catids):
            self.category2label[c] = i + 1
            self.label2category[i + 1] = c

        if split == "val":
            # gt_annList_path = '/mnt/datasets/public/issam/COCO2014//annotations/val_gt_annList.json'

            annList_path = self.path + "/annotations/{}_gt_annList.json".format(
                split)

            assert os.path.exists(annList_path)
            self.annList_path = annList_path

            # self.image_names.sort()
            self.image_names = self.image_names[:5000]
            self.ids = self.ids[:5000]

        elif split == "test":
            # self.image_names.sort()
            self.image_names = self.ids[-5000:]
Beispiel #17
0
def train_gan(
        separate_funcs=False,
        D_training_repeats=1,
        G_learning_rate_max=0.0010,
        D_learning_rate_max=0.0010,
        G_smoothing=0.999,
        adam_beta1=0.0,
        adam_beta2=0.99,
        adam_epsilon=1e-8,
        minibatch_default=16,
        minibatch_overrides={},
        rampup_kimg=40 / speed_factor,
        rampdown_kimg=0,
        lod_initial_resolution=4,
        lod_training_kimg=400 / speed_factor,
        lod_transition_kimg=400 / speed_factor,
        #lod_training_kimg       = 40,
        #lod_transition_kimg     = 40,
        total_kimg=10000 / speed_factor,
        dequantize_reals=False,
        gdrop_beta=0.9,
        gdrop_lim=0.5,
        gdrop_coef=0.2,
        gdrop_exp=2.0,
        drange_net=[-1, 1],
        drange_viz=[-1, 1],
        image_grid_size=None,
        #tick_kimg_default       = 1,
        tick_kimg_default=50 / speed_factor,
        tick_kimg_overrides={
            32: 20,
            64: 10,
            128: 10,
            256: 5,
            512: 2,
            1024: 1
        },
        image_snapshot_ticks=4,
        network_snapshot_ticks=40,
        image_grid_type='default',
        #resume_network_pkl      = '006-celeb128-progressive-growing/network-snapshot-002009.pkl',
        resume_network_pkl=None,
        resume_kimg=0,
        resume_time=0.0):

    # Load dataset and build networks.
    training_set, drange_orig = load_dataset()

    print "*************** test the format of dataset ***************"
    print training_set
    print drange_orig
    # training_set是dataset模块解析h5之后的对象,
    # drange_orig 为training_set.get_dynamic_range()

    if resume_network_pkl:
        print 'Resuming', resume_network_pkl
        G, D, _ = misc.load_pkl(
            os.path.join(config.result_dir, resume_network_pkl))
    else:
        G = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.G)
        D = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.D)
    Gs = G.create_temporally_smoothed_version(beta=G_smoothing,
                                              explicit_updates=True)

    # G,D对象可以由misc解析pkl之后生成,也可以由network模块构造

    print G
    print D

    #misc.print_network_topology_info(G.output_layers)
    #misc.print_network_topology_info(D.output_layers)

    # Setup snapshot image grid.
    # 设置中途输出图片的格式
    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[3], G.output_shape[2]
            image_grid_size = np.clip(1920 / w, 3,
                                      16), np.clip(1080 / h, 2, 16)
        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    # Theano input variables and compile generation func.
    print 'Setting up Theano...'
    real_images_var = T.TensorType('float32', [False] *
                                   len(D.input_shape))('real_images_var')
    # <class 'theano.tensor.var.TensorVariable'>
    # print type(real_images_var),real_images_var
    real_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('real_labels_var')
    fake_latents_var = T.TensorType('float32', [False] *
                                    len(G.input_shape))('fake_latents_var')
    fake_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('fake_labels_var')
    # 带有_var的均为输入张量
    G_lrate = theano.shared(np.float32(0.0))
    D_lrate = theano.shared(np.float32(0.0))
    # share语法就是用来设定默认值的,返回复制的对象
    gen_fn = theano.function([fake_latents_var, fake_labels_var],
                             Gs.eval_nd(fake_latents_var,
                                        fake_labels_var,
                                        ignore_unused_inputs=True),
                             on_unused_input='ignore')

    # gen_fn 是一个函数,输入为:[fake_latents_var, fake_labels_var],
    #                  输出位:Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True),

    #生成函数

    # Misc init.
    #读入当前分辨率
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    #lod 精细度
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    # Save example images.
    snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels)
    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    # Training loop.
    # 这里才是主训练入口
    # 注意在训练过程中不会跳出最外层while循环,因此更换分辨率等操作必然在while循环里

    #现有图片数
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0

    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time
    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        #计算当前精细度
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg +
                                                           lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        # 当前分辨率
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        # 更新网络结构
        lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                           rampdown_kimg)
        G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max))
        D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max))

        if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod))
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod))

        # Setup training func for current LOD.
        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))

        #print " cur_lod%f\n  min_lod %f\n new_min_lod %f\n max_lod %f\n new_max_lod %f\n"%(cur_lod,min_lod,new_min_lod,max_lod,new_max_lod)

        if min_lod != new_min_lod or max_lod != new_max_lod:
            print 'Compiling training funcs...'
            min_lod, max_lod = new_min_lod, new_max_lod

            # Pre-process reals.
            real_images_expr = real_images_var
            if dequantize_reals:
                rnd = theano.sandbox.rng_mrg.MRG_RandomStreams(
                    lasagne.random.get_rng().randint(1, 2147462579))
                epsilon_noise = rnd.uniform(size=real_images_expr.shape,
                                            low=-0.5,
                                            high=0.5,
                                            dtype='float32')
                real_images_expr = T.cast(
                    real_images_expr, 'float32'
                ) + epsilon_noise  # match original implementation of Improved Wasserstein
            real_images_expr = misc.adjust_dynamic_range(
                real_images_expr, drange_orig, drange_net)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=2)
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=3)

            # Optimize loss.
            G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss(
                G, D, min_lod, max_lod, real_images_expr, real_labels_var,
                fake_latents_var, fake_labels_var, **config.loss)
            G_updates = adam(G_loss,
                             G.trainable_params(),
                             learning_rate=G_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            D_updates = adam(D_loss,
                             D.trainable_params(),
                             learning_rate=D_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            D_train_fn = theano.function([
                real_images_var, real_labels_var, fake_latents_var,
                fake_labels_var
            ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                         updates=D_updates,
                                         on_unused_input='ignore')
            G_train_fn = theano.function([fake_latents_var, fake_labels_var],
                                         [],
                                         updates=G_updates + Gs.updates,
                                         on_unused_input='ignore')

        for idx in xrange(D_training_repeats):
            mb_reals, mb_labels = training_set.get_random_minibatch(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)

            print "******* test minibatch"
            print "mb_reals"
            print idx, D_training_repeats
            print mb_reals.shape, mb_labels.shape
            #print mb_reals
            print "mb_labels"
            #print mb_labels

            mb_train_out = D_train_fn(
                mb_reals, mb_labels,
                random_latents(minibatch_size, G.input_shape),
                random_labels(minibatch_size, training_set))
            cur_nimg += minibatch_size
            tick_train_out.append(mb_train_out)
        G_train_fn(random_latents(minibatch_size, G.input_shape),
                   random_labels(minibatch_size, training_set))

        # Fade in D noise if we're close to becoming unstable
        fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            D.gdrop_strength.set_value(np.float32(gdrop_strength))

        # Perform maintenance operations once per tick.
        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []

            # Print progress.
            print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
                (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size,
                 misc.format_time(cur_time - train_start_time), tick_time,
                 tick_time / tick_kimg, gdrop_strength) + tick_train_avg)

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = gen_fn(snapshot_fake_latents,
                                              snapshot_fake_labels)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            # Save network snapshot every N ticks.
            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg / 1000)))

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    training_set.close()
    print 'Done.'
    with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'):
        pass
Beispiel #18
0
def main(main_dict, train_only=False):

    ms.print_welcome(main_dict)

    # EXTRACT VARIABLES
    reset = main_dict["reset"]
    epochs = main_dict["epochs"] = 100
    batch_size = main_dict["batch_size"]
    sampler_name = main_dict["sampler_name"]
    verbose = main_dict["verbose"]
    loss_name = main_dict["loss_name"]
    metric_name = main_dict["metric_name"]
    metric_class = main_dict["metric_dict"][metric_name]
    loss_function = main_dict["loss_dict"][loss_name]
    predictList = main_dict["predictList"]

    # Assert everything is available
    ## Sharp proposals
    ## LCFCN points
    ## gt_annDict

    # Dataset
    train_set, val_set = ms.load_trainval(main_dict)
    train_set[0]

    # Model

    if reset == "reset" or not ms.model_exists(main_dict):
        model, opt, history = ms.init_model_and_opt(main_dict, train_set)
        print("TRAINING FROM SCRATCH EPOCH: %d/%d" %
              (history["epoch"], epochs))
    else:
        model, opt, history = ms.load_latest_model_and_opt(
            main_dict, train_set)
        print("RESUMING EPOCH %d/%d" % (history["epoch"], epochs))

    # Get Dataloader
    trainloader = ms.get_dataloader(
        dataset=train_set,
        batch_size=batch_size,
        sampler_class=main_dict["sampler_dict"][sampler_name])

    # SAVE HISTORY
    history["epoch_size"] = len(trainloader)

    if "trained_batch_names" in history:
        model.trained_batch_names = set(history["trained_batch_names"])

    ms.save_pkl(main_dict["path_history"], history)

    # START TRAINING
    start_epoch = history["epoch"]
    predict_name = predictList[0]

    for epoch in range(start_epoch + 1, epochs):
        # %%%%%%%%%%% 1. Training PHASE %%%%%%%%%%%%"
        history = training_phase(history, main_dict, model, trainloader, opt,
                                 loss_function, verbose, epoch)

        # %%%%%%%%%%% 2. VALIDATION PHASE %%%%%%%%%%%%"
        if (epoch % 5) == 0:
            history = validation_phase_mAP(history, main_dict, model, val_set,
                                           predict_name, epoch)

        ms.save_pkl(main_dict["path_history"], history)
Beispiel #19
0
def epoch(pupil_data,
          events_path,
          sample_rate=250,
          epoch_time=200,
          back_time=60,
          out_dir='',
          base_name='',
          baseline_type='no',
          **params):
    '''
    Epochs the pupil_data according to behavioural events specified
    by the events file in events_path.

    Arguments:
        pupil_data: A pupil_data object like the one loaded by read_pupil

        events_path <str>: A path to the events file (see tutorial)

        sample_rate <int>: The sample rate of your pupil recording

        epoch_time <int>: Time in ms to epoch

        back_time <int>: Time before the event to plot

    Returns:
        epoched <Epoched>: an Epoched object defined above
    '''

    print('\nEpoching...\n')

    # Find all the events
    pupil_events = pupil_data[pupil_data['Content'] != '-']
    num_events = len(pupil_events)
    print('There are {} events in the pupil data, does '
          'this look correct? If not, the first two events '
          'of pupil may not have been recorded...'.format(num_events))

    # Reads behavioral events
    # TODO: from python
    categories = read_events(events_path, pupil_events.Time.iat[0])
    samples_per_epoch = get_nsamples(sample_rate, epoch_time, back_time)
    # Initialize output variables
    epoched = Epoched(len(categories), samples_per_epoch, num_events)
    # Extra: Report median miss time.
    # Get the baseline dictionary mapping from time to baseline values.
    bl_events = None
    if type(baseline_type) == tuple:
        bl_events = get_baseline_events(pupil_events,
                                        categories[baseline_type[1]])

    for c, category in enumerate(categories):
        epoched.names.append(category.name)
        # +1 to exclude the content sample, which is out of sync.
        onsets = list(
            map(lambda x: get_nearest_ind(pupil_events, x) + 1,
                category.start))
        rejected_inds = []
        rejected = 0
        for t, onset in enumerate(onsets):
            # Note the -1 to avoid the content sample
            pre = pupil_data.Pupil.iloc[onset - samples_per_epoch[0] -
                                        1:onset - 1].values
            post = pupil_data.Pupil.iloc[onset:onset +
                                         samples_per_epoch[1]].values
            baseline = get_baseline(pupil_data, onset, baseline_type,
                                    sample_rate, bl_events)
            trial = np.concatenate((pre, post)) - baseline

            if np.isnan(np.sum(trial)):  # Checks if there is a nan
                rejected_inds.append(t)
                rejected += 1
            elif len(trial) == sum(samples_per_epoch):
                epoched.matrix[c, :, t - rejected] = trial

        epoched.num_trials.append(len(onsets))
        epoched.num_rejected.append(rejected)
        epoched.rejected[category.name] = rejected_inds

    # Save .mat and .pkl of epoched data
    spio.savemat(
        make_path('epoched', '.mat', out_dir=out_dir, base_name=base_name),
        {'epoched': epoched})
    save_pkl(
        make_path('epoched', '.pkl', out_dir=out_dir, base_name=base_name),
        epoched)

    return epoched
Beispiel #20
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs, E = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)

            E = tfutil.Network('E',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.E)

            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()
    E.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)
    E_opt = tfutil.Optimizer(name='TrainE',
                             learning_rate=lrate_in,
                             **config.E_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in),
                tf.assign(E_gpu.find_var('lod'), lod_in)
            ]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)
            with tf.name_scope('E_loss'), tf.control_dependencies(
                    lod_assign_ops):
                E_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=E_opt,
                    training_set=training_set,
                    reals=reals_gpu,
                    minibatch_size=minibatch_split,
                    **config.E_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
            E_opt.register_gradients(tf.reduce_mean(E_loss), E_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    E_train_op = E_opt.apply_updates()

    #sys.exit(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
                E_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op, E_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

                misc.save_all_res(training_set.shape[1],
                                  Gs,
                                  result_subdir,
                                  50,
                                  minibatch_size=sched.minibatch //
                                  config.num_gpus)

            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (G, D, Gs, E),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs, E),
                  os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)
    #resume_run_id = '/dresden/users/mk1391/evl/pggan_logs/logs_celeba128cc/fsg16_results_0/000-pgan-celeba-preset-v2-2gpus-fp32/network-snapshot-010211.pkl'
    resume_with_new_nets = False
    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            rG, rD, rGs = misc.load_pkl(network_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                Gs = rGs
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    ### pyramid draw fsg (comment out for actual training to happen)
    #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'pggan_fsg_draw.png'))
    #print('>>> done printing fsgs.')
    #return

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    ### shift reals
    print('>>> reals shape: ', grid_reals.shape)
    fc_x = 0.5
    fc_y = 0.5
    im_size = grid_reals.shape[-1]
    kernel_loc = 2.*np.pi*fc_x * np.arange(im_size).reshape((1, 1, im_size)) + \
        2.*np.pi*fc_y * np.arange(im_size).reshape((1, im_size, 1))
    kernel_cos = np.cos(kernel_loc)
    kernel_sin = np.sin(kernel_loc)
    reals_t = (grid_reals / 255.) * 2. - 1
    reals_t *= kernel_cos
    grid_reals_sh = np.rint(
        (reals_t + 1.) * 255. / 2.).clip(0, 255).astype(np.uint8)
    ### end shift reals
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    ### fft drawing
    #sys.path.insert(1, '/home/mahyar/CV_Res/ganist')
    #from fig_draw import apply_fft_win
    #data_size = 1000
    #latents = np.random.randn(data_size, *Gs.input_shapes[0][1:])
    #labels = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:])
    #g_samples = Gs.run(latents, labels, minibatch_size=sched.minibatch//config.num_gpus)
    #g_samples = g_samples.transpose(0, 2, 3, 1)
    #print('>>> g_samples shape: {}'.format(g_samples.shape))
    #apply_fft_win(g_samples, 'fft_pggan_hann.png')
    ### end fft drawing

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    ### drawing shifted real images
    misc.save_image_grid(grid_reals_sh,
                         os.path.join(result_subdir, 'reals_sh.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    ### drawing shifted fake images
    misc.save_image_grid(grid_fakes * kernel_cos,
                         os.path.join(result_subdir, 'fakes%06d_sh.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    #### True cosine fft eval
    #fft_data_size = 1000
    #im_size = training_set.shape[1]
    #freq_centers = [(64/128., 64/128.)]
    #true_samples = sample_true(training_set, fft_data_size, dtype=training_set.dtype, batch_size=32).transpose(0, 2, 3, 1) / 255. * 2. - 1.
    #true_fft, true_fft_hann, true_hist = cosine_eval(true_samples, 'true', freq_centers, log_dir=result_subdir)
    #fractal_eval(true_samples, f'koch_snowflake_true', result_subdir)

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                ### drawing shifted fake images
                misc.save_image_grid(
                    grid_fakes * kernel_cos,
                    os.path.join(result_subdir,
                                 'fakes%06d_sh.png' % (cur_nimg // 1000)),
                    drange=drange_net,
                    grid_size=grid_size)
                ### drawing fsg
                #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'fakes%06d_fsg_draw.png' % (cur_nimg // 1000)))
                ### Gen fft eval
                #gen_samples = sample_gen(Gs, fft_data_size).transpose(0, 2, 3, 1)
                #print(f'>>> fake_samples: max={np.amax(grid_fakes)} min={np.amin(grid_fakes)}')
                #print(f'>>> gen_samples: max={np.amax(gen_samples)} min={np.amin(gen_samples)}')
                #misc.save_image_grid(gen_samples[:25], os.path.join(result_subdir, 'fakes%06d_gsample.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
                #cosine_eval(gen_samples, f'gen_{cur_nimg//1000:06d}', freq_centers, log_dir=result_subdir, true_fft=true_fft, true_fft_hann=true_fft_hann, true_hist=true_hist)
                #fractal_eval(gen_samples, f'koch_snowflake_fakes{cur_nimg//1000:06d}', result_subdir)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()