示例#1
0
    def train(self):
        marker = os.path.basename(self.output_dir)
        os.makedirs(self.output_dir, exist_ok=True)
        self.save()
        self.summary()
        # save samples
        gt_val = h5py.File(self.gt_val_fname)
        print("bits", gt_val["bits"].shape)
        nb_vis_samples = 20**2


        label_output_sizes = self.get_label_output_sizes()

        if self.use_combined_data:
            gt_train = h5py.File(self.gt_train_fname)
            data_gen = self.combined_generator_factory(gt_train, label_output_sizes)
        else:
            data_gen = self.data_generator_factory()

        truth_gen_only_bits = self.truth_generator_factory(gt_val, missing_label_sizes=[])
        self.check_generator(data_gen, "train")
        self.check_generator(truth_gen_only_bits, "test")
        vis_out = next(data_gen(nb_vis_samples))
        vis_bits = np.array(vis_out[1][:nb_bits]).T
        save_samples(vis_out[0][:, 0], vis_bits,
                     self.outname(marker + "_train_samples.png"))
        gt_data, gt_bits, gt_masks = next(truth_gen_only_bits(nb_vis_samples))
        gt_bits = np.array(gt_bits).T
        print("gt_data", gt_data.shape, gt_data.min(), gt_data.max())
        print("gt_bits", gt_bits.shape, gt_bits.min(), gt_bits.max())
        save_samples(gt_data[:, 0], gt_bits,
                     self.outname(marker + "_val_samples.png"))
        # build model
        bs = self.batch_size
        model = self.get_model(label_output_sizes)
        # setup training
        hist = HistoryPerBatch(self.output_dir, extra_metrics=['bits_loss', 'val_bits_loss'])
        hist_saver = OnEpochEnd(lambda e, l: hist.save(), every_nth_epoch=5)

        def lr_schedule(optimizer):
            lr = K.get_value(optimizer.lr)
            return {
                40: lr / 10.,
            }

        scheduler = LearningRateScheduler(
            model.optimizer, lr_schedule(model.optimizer))
        hdf5_attrs = get_distribution_hdf5_attrs(self.get_label_distributions())
        hdf5_attrs['decoder_uses_hist_equalization'] = self.use_hist_equalization
        checkpointer = SaveModelAndWeightsCheckpoint(
            self.model_fname(), monitor='val_bits_loss',
            verbose=0, save_best_only=True,
            hdf5_attrs=hdf5_attrs)
        plot_history = hist.plot_callback(
            fname=self.outname(marker + '_loss.png'),
            metrics=['bits_loss', 'val_bits_loss'])
        # train
        truth_gen = self.truth_generator_factory(gt_val, label_output_sizes)
        callbacks = [CollectBitsLoss(), scheduler, checkpointer, hist,
                     plot_history, hist_saver]
        if int(self.verbose) == 0:
            callbacks.append(DotProgressBar())

        model.fit_generator(
            data_gen(bs),
            samples_per_epoch=bs*self.nb_batches_per_epoch, nb_epoch=self.nb_epoch,
            callbacks=callbacks,
            verbose=self.verbose,
            validation_data=truth_gen(bs),
            nb_val_samples=gt_val['tags'].shape[0],
            nb_worker=1, max_q_size=4*10, pickle_safe=False
        )
        evaluate_decoder.run(self, cache=False)
示例#2
0
def train_callbacks(rendergan,
                    output_dir,
                    nb_visualise,
                    real_hdf5_fname,
                    distribution,
                    lr_schedule=None,
                    overwrite=False):
    save_gan_cb = SaveGAN(rendergan,
                          join(output_dir, "models/{epoch:03d}/{name}.hdf5"),
                          every_epoch=10,
                          hdf5_attrs=get_distribution_hdf5_attrs(distribution))
    nb_score = 1000

    sample_fn = predict_wrapper(
        rendergan.sample_generator_given_z.predict,
        rendergan.sample_generator_given_z_output_names)

    real = next(train_data_generator(real_hdf5_fname, nb_score, 1))['data']

    vis_cb = VisualiseTag3dAndFake(nb_samples=nb_visualise // 2,
                                   output_dir=join(output_dir,
                                                   'visualise_tag3d_fake'),
                                   show=False,
                                   preprocess=lambda x: np.clip(x, -1, 1))
    vis_all = VisualiseAll(nb_samples=nb_visualise //
                           len(rendergan.sample_generator_given_z.outputs),
                           output_dir=join(output_dir, 'visualise_all'),
                           show=False,
                           preprocess=lambda x: np.clip(x, -1, 1))

    vis_fake_sorted = VisualiseFakesSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_fakes_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    vis_real_sorted = VisualiseRealsSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_reals_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    def default_lr_schedule(lr):
        return {
            200: lr / 4,
            250: lr / 4**2,
            300: lr / 4**3,
        }

    def lr_scheduler(opt):
        return LearningRateScheduler(opt,
                                     lr_schedule(float(K.get_value(opt.lr))))

    if lr_schedule is None:
        lr_schedule = default_lr_schedule

    g_optimizer = rendergan.gan.g_optimizer
    d_optimizer = rendergan.gan.d_optimizer
    lr_schedulers = [
        lr_scheduler(g_optimizer),
        lr_scheduler(d_optimizer),
    ]
    hist_dir = join(output_dir, "history")
    os.makedirs(hist_dir, exist_ok=True)
    hist = HistoryPerBatch(hist_dir)

    def history_plot(e, logs={}):
        fig, _ = hist.plot(save_as="{:03d}.png".format(e),
                           metrics=['g_loss', 'd_loss'])
        plt.close(fig)  # allows fig to be garbage collected

    hist_save = OnEpochEnd(history_plot, every_nth_epoch=20)

    sample_outdir = join(output_dir, 'samples')
    os.makedirs(sample_outdir, exist_ok=True)
    store_samples_cb = StoreSamples(sample_outdir, distribution, overwrite)

    dscore_outdir = join(output_dir, 'd_score_hist')
    os.makedirs(dscore_outdir, exist_ok=True)
    dscore = DScoreHistogram(dscore_outdir)

    nb_sample = max(nb_score, nb_visualise)
    sample_cb = SampleGAN(sample_fn,
                          rendergan.discriminator.predict,
                          rendergan.gan.random_z(nb_sample),
                          real,
                          callbacks=[
                              vis_cb, vis_fake_sorted, vis_all,
                              vis_real_sorted, dscore, store_samples_cb
                          ])
    return [sample_cb, save_gan_cb, hist, hist_save] + lr_schedulers
示例#3
0
    def train(self):
        marker = os.path.basename(self.output_dir)
        os.makedirs(self.output_dir, exist_ok=True)
        self.save()
        self.summary()
        # save samples
        gt_val = h5py.File(self.gt_val_fname)
        print("bits", gt_val["bits"].shape)
        nb_vis_samples = 20**2

        label_output_sizes = self.get_label_output_sizes()

        if self.use_combined_data:
            gt_train = h5py.File(self.gt_train_fname)
            data_gen = self.combined_generator_factory(gt_train,
                                                       label_output_sizes)
        else:
            data_gen = self.data_generator_factory()

        truth_gen_only_bits = self.truth_generator_factory(
            gt_val, missing_label_sizes=[])
        self.check_generator(data_gen, "train")
        self.check_generator(truth_gen_only_bits, "test")
        vis_out = next(data_gen(nb_vis_samples))
        vis_bits = np.array(vis_out[1][:nb_bits]).T
        save_samples(vis_out[0][:, 0], vis_bits,
                     self.outname(marker + "_train_samples.png"))
        gt_data, gt_bits, gt_masks = next(truth_gen_only_bits(nb_vis_samples))
        gt_bits = np.array(gt_bits).T
        print("gt_data", gt_data.shape, gt_data.min(), gt_data.max())
        print("gt_bits", gt_bits.shape, gt_bits.min(), gt_bits.max())
        save_samples(gt_data[:, 0], gt_bits,
                     self.outname(marker + "_val_samples.png"))
        # build model
        bs = self.batch_size
        model = self.get_model(label_output_sizes)
        # setup training
        hist = HistoryPerBatch(self.output_dir,
                               extra_metrics=['bits_loss', 'val_bits_loss'])
        hist_saver = OnEpochEnd(lambda e, l: hist.save(), every_nth_epoch=5)

        def lr_schedule(optimizer):
            lr = K.get_value(optimizer.lr)
            return {
                40: lr / 10.,
            }

        scheduler = LearningRateScheduler(model.optimizer,
                                          lr_schedule(model.optimizer))
        hdf5_attrs = get_distribution_hdf5_attrs(
            self.get_label_distributions())
        hdf5_attrs[
            'decoder_uses_hist_equalization'] = self.use_hist_equalization
        checkpointer = SaveModelAndWeightsCheckpoint(self.model_fname(),
                                                     monitor='val_bits_loss',
                                                     verbose=0,
                                                     save_best_only=True,
                                                     hdf5_attrs=hdf5_attrs)
        plot_history = hist.plot_callback(
            fname=self.outname(marker + '_loss.png'),
            metrics=['bits_loss', 'val_bits_loss'])
        # train
        truth_gen = self.truth_generator_factory(gt_val, label_output_sizes)
        callbacks = [
            CollectBitsLoss(), scheduler, checkpointer, hist, plot_history,
            hist_saver
        ]
        if int(self.verbose) == 0:
            callbacks.append(DotProgressBar())

        model.fit_generator(data_gen(bs),
                            samples_per_epoch=bs * self.nb_batches_per_epoch,
                            nb_epoch=self.nb_epoch,
                            callbacks=callbacks,
                            verbose=self.verbose,
                            validation_data=truth_gen(bs),
                            nb_val_samples=gt_val['tags'].shape[0],
                            nb_worker=1,
                            max_q_size=4 * 10,
                            pickle_safe=False)
        evaluate_decoder.run(self, cache=False)
示例#4
0
def train_callbacks(rendergan, output_dir, nb_visualise, real_hdf5_fname,
                    distribution, lr_schedule=None, overwrite=False):
    save_gan_cb = SaveGAN(rendergan, join(output_dir, "models/{epoch:03d}/{name}.hdf5"),
                          every_epoch=10, hdf5_attrs=get_distribution_hdf5_attrs(distribution))
    nb_score = 1000

    sample_fn = predict_wrapper(rendergan.sample_generator_given_z.predict,
                                rendergan.sample_generator_given_z_output_names)

    real = next(train_data_generator(real_hdf5_fname, nb_score, 1))['data']

    vis_cb = VisualiseTag3dAndFake(
        nb_samples=nb_visualise // 2,
        output_dir=join(output_dir, 'visualise_tag3d_fake'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1)
    )
    vis_all = VisualiseAll(
        nb_samples=nb_visualise // len(rendergan.sample_generator_given_z.outputs),
        output_dir=join(output_dir, 'visualise_all'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    vis_fake_sorted = VisualiseFakesSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_fakes_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    vis_real_sorted = VisualiseRealsSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_reals_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    def default_lr_schedule(lr):
        return {
            200: lr / 4,
            250: lr / 4**2,
            300: lr / 4**3,
        }

    def lr_scheduler(opt):
        return LearningRateScheduler(opt, lr_schedule(float(K.get_value(opt.lr))))

    if lr_schedule is None:
        lr_schedule = default_lr_schedule

    g_optimizer = rendergan.gan.g_optimizer
    d_optimizer = rendergan.gan.d_optimizer
    lr_schedulers = [
        lr_scheduler(g_optimizer),
        lr_scheduler(d_optimizer),
    ]
    hist_dir = join(output_dir, "history")
    os.makedirs(hist_dir, exist_ok=True)
    hist = HistoryPerBatch(hist_dir)

    def history_plot(e, logs={}):
        fig, _ = hist.plot(save_as="{:03d}.png".format(e), metrics=['g_loss', 'd_loss'])
        plt.close(fig)  # allows fig to be garbage collected
    hist_save = OnEpochEnd(history_plot, every_nth_epoch=20)

    sample_outdir = join(output_dir, 'samples')
    os.makedirs(sample_outdir, exist_ok=True)
    store_samples_cb = StoreSamples(sample_outdir, distribution, overwrite)

    dscore_outdir = join(output_dir, 'd_score_hist')
    os.makedirs(dscore_outdir, exist_ok=True)
    dscore = DScoreHistogram(dscore_outdir)

    nb_sample = max(nb_score, nb_visualise)
    sample_cb = SampleGAN(sample_fn, rendergan.discriminator.predict,
                          rendergan.gan.random_z(nb_sample), real,
                          callbacks=[vis_cb, vis_fake_sorted, vis_all, vis_real_sorted,
                                     dscore, store_samples_cb])
    return [sample_cb, save_gan_cb, hist, hist_save] + lr_schedulers