Example #1
0
def train_callbacks(rendergan,
                    output_dir,
                    nb_visualise,
                    real_hdf5_fname,
                    distribution,
                    lr_schedule=None,
                    overwrite=False):
    save_gan_cb = SaveGAN(rendergan,
                          join(output_dir, "models/{epoch:03d}/{name}.hdf5"),
                          every_epoch=10,
                          hdf5_attrs=get_distribution_hdf5_attrs(distribution))
    nb_score = 1000

    sample_fn = predict_wrapper(
        rendergan.sample_generator_given_z.predict,
        rendergan.sample_generator_given_z_output_names)

    real = next(train_data_generator(real_hdf5_fname, nb_score, 1))['data']

    vis_cb = VisualiseTag3dAndFake(nb_samples=nb_visualise // 2,
                                   output_dir=join(output_dir,
                                                   'visualise_tag3d_fake'),
                                   show=False,
                                   preprocess=lambda x: np.clip(x, -1, 1))
    vis_all = VisualiseAll(nb_samples=nb_visualise //
                           len(rendergan.sample_generator_given_z.outputs),
                           output_dir=join(output_dir, 'visualise_all'),
                           show=False,
                           preprocess=lambda x: np.clip(x, -1, 1))

    vis_fake_sorted = VisualiseFakesSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_fakes_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    vis_real_sorted = VisualiseRealsSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_reals_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    def default_lr_schedule(lr):
        return {
            200: lr / 4,
            250: lr / 4**2,
            300: lr / 4**3,
        }

    def lr_scheduler(opt):
        return LearningRateScheduler(opt,
                                     lr_schedule(float(K.get_value(opt.lr))))

    if lr_schedule is None:
        lr_schedule = default_lr_schedule

    g_optimizer = rendergan.gan.g_optimizer
    d_optimizer = rendergan.gan.d_optimizer
    lr_schedulers = [
        lr_scheduler(g_optimizer),
        lr_scheduler(d_optimizer),
    ]
    hist_dir = join(output_dir, "history")
    os.makedirs(hist_dir, exist_ok=True)
    hist = HistoryPerBatch(hist_dir)

    def history_plot(e, logs={}):
        fig, _ = hist.plot(save_as="{:03d}.png".format(e),
                           metrics=['g_loss', 'd_loss'])
        plt.close(fig)  # allows fig to be garbage collected

    hist_save = OnEpochEnd(history_plot, every_nth_epoch=20)

    sample_outdir = join(output_dir, 'samples')
    os.makedirs(sample_outdir, exist_ok=True)
    store_samples_cb = StoreSamples(sample_outdir, distribution, overwrite)

    dscore_outdir = join(output_dir, 'd_score_hist')
    os.makedirs(dscore_outdir, exist_ok=True)
    dscore = DScoreHistogram(dscore_outdir)

    nb_sample = max(nb_score, nb_visualise)
    sample_cb = SampleGAN(sample_fn,
                          rendergan.discriminator.predict,
                          rendergan.gan.random_z(nb_sample),
                          real,
                          callbacks=[
                              vis_cb, vis_fake_sorted, vis_all,
                              vis_real_sorted, dscore, store_samples_cb
                          ])
    return [sample_cb, save_gan_cb, hist, hist_save] + lr_schedulers
Example #2
0
    def train(self):
        marker = os.path.basename(self.output_dir)
        os.makedirs(self.output_dir, exist_ok=True)
        self.save()
        self.summary()
        # save samples
        gt_val = h5py.File(self.gt_val_fname)
        print("bits", gt_val["bits"].shape)
        nb_vis_samples = 20**2

        label_output_sizes = self.get_label_output_sizes()

        if self.use_combined_data:
            gt_train = h5py.File(self.gt_train_fname)
            data_gen = self.combined_generator_factory(gt_train,
                                                       label_output_sizes)
        else:
            data_gen = self.data_generator_factory()

        truth_gen_only_bits = self.truth_generator_factory(
            gt_val, missing_label_sizes=[])
        self.check_generator(data_gen, "train")
        self.check_generator(truth_gen_only_bits, "test")
        vis_out = next(data_gen(nb_vis_samples))
        vis_bits = np.array(vis_out[1][:nb_bits]).T
        save_samples(vis_out[0][:, 0], vis_bits,
                     self.outname(marker + "_train_samples.png"))
        gt_data, gt_bits, gt_masks = next(truth_gen_only_bits(nb_vis_samples))
        gt_bits = np.array(gt_bits).T
        print("gt_data", gt_data.shape, gt_data.min(), gt_data.max())
        print("gt_bits", gt_bits.shape, gt_bits.min(), gt_bits.max())
        save_samples(gt_data[:, 0], gt_bits,
                     self.outname(marker + "_val_samples.png"))
        # build model
        bs = self.batch_size
        model = self.get_model(label_output_sizes)
        # setup training
        hist = HistoryPerBatch(self.output_dir,
                               extra_metrics=['bits_loss', 'val_bits_loss'])
        hist_saver = OnEpochEnd(lambda e, l: hist.save(), every_nth_epoch=5)

        def lr_schedule(optimizer):
            lr = K.get_value(optimizer.lr)
            return {
                40: lr / 10.,
            }

        scheduler = LearningRateScheduler(model.optimizer,
                                          lr_schedule(model.optimizer))
        hdf5_attrs = get_distribution_hdf5_attrs(
            self.get_label_distributions())
        hdf5_attrs[
            'decoder_uses_hist_equalization'] = self.use_hist_equalization
        checkpointer = SaveModelAndWeightsCheckpoint(self.model_fname(),
                                                     monitor='val_bits_loss',
                                                     verbose=0,
                                                     save_best_only=True,
                                                     hdf5_attrs=hdf5_attrs)
        plot_history = hist.plot_callback(
            fname=self.outname(marker + '_loss.png'),
            metrics=['bits_loss', 'val_bits_loss'])
        # train
        truth_gen = self.truth_generator_factory(gt_val, label_output_sizes)
        callbacks = [
            CollectBitsLoss(), scheduler, checkpointer, hist, plot_history,
            hist_saver
        ]
        if int(self.verbose) == 0:
            callbacks.append(DotProgressBar())

        model.fit_generator(data_gen(bs),
                            samples_per_epoch=bs * self.nb_batches_per_epoch,
                            nb_epoch=self.nb_epoch,
                            callbacks=callbacks,
                            verbose=self.verbose,
                            validation_data=truth_gen(bs),
                            nb_val_samples=gt_val['tags'].shape[0],
                            nb_worker=1,
                            max_q_size=4 * 10,
                            pickle_safe=False)
        evaluate_decoder.run(self, cache=False)
Example #3
0
def run(output_dir, force, tags_3d_hdf5_fname, nb_units, depth, nb_epoch,
        filter_size, project_factor, nb_dense):
    batch_size = 64
    basename = "network_tags3d_n{}_d{}_e{}".format(nb_units, depth, nb_epoch)
    output_basename = os.path.join(output_dir, basename)

    tag_dataset = DistributionHDF5Dataset(tags_3d_hdf5_fname)
    tag_dataset._dataset_created = True
    print("Got {} images from the 3d model".format(tag_dataset.nb_samples))
    weights_fname = output_basename + ".hdf5"
    if os.path.exists(weights_fname) and not force:
        raise OSError("File {} already exists. Use --force to override it")
    elif os.path.exists(weights_fname) and force:
        os.remove(weights_fname)
    os.makedirs(output_dir, exist_ok=True)

    def generator(batch_size):
        for batch in tag_dataset.iter(batch_size):
            labels = []
            for name in batch['labels'].dtype.names:
                labels.append(batch['labels'][name])

            assert not np.isnan(batch['tag3d']).any()
            assert not np.isnan(batch['depth_map']).any()
            labels = np.concatenate(labels, axis=-1)
            yield labels, [batch['tag3d'], batch['depth_map']]

    labels = next(generator(batch_size))[0]
    print("labels.shape ", labels.shape)
    print("labels.dtype ", labels.dtype)
    nb_input = next(generator(batch_size))[0].shape[1]

    x = Input(shape=(nb_input, ))
    tag3d, depth_map = tag3d_network_dense(x,
                                           nb_units=nb_units,
                                           depth=depth,
                                           nb_dense_units=nb_dense)
    g = Model(x, [tag3d, depth_map])
    # optimizer = SGD(momentum=0.8, nesterov=True)
    optimizer = Nadam()

    g.compile(optimizer, loss=['mse', 'mse'], loss_weights=[1, 1 / 3.])

    scheduler = AutomaticLearningRateScheduler(optimizer,
                                               'loss',
                                               epoch_patience=5,
                                               min_improvement=0.0002)
    history = HistoryPerBatch()
    save = SaveModels({basename + '_snapshot_{epoch:^03}.hdf5': g},
                      output_dir=output_dir,
                      hdf5_attrs=tag_dataset.get_distribution_hdf5_attrs())
    history_plot = history.plot_callback(fname=output_basename + "_loss.png",
                                         every_nth_epoch=10)
    g.fit_generator(generator(batch_size),
                    samples_per_epoch=800 * batch_size,
                    nb_epoch=nb_epoch,
                    verbose=1,
                    callbacks=[scheduler, save, history, history_plot])

    nb_visualize = 18**2
    vis_labels, (tags_3d, depth_map) = next(generator(nb_visualize))
    predict_tags_3d, predict_depth_map = g.predict(vis_labels)

    def zip_and_save(fname, *args):
        clipped = list(map(lambda x: np.clip(x, 0, 1)[:, 0], args))
        print(clipped[0].shape)
        tiled = zip_tile(*clipped)
        print(tiled.shape)
        scipy.misc.imsave(fname, tiled)

    zip_and_save(output_basename + "_predict_tags.png", tags_3d,
                 predict_tags_3d)
    zip_and_save(output_basename + "_predict_depth_map.png", depth_map,
                 predict_depth_map)

    save_model(g,
               weights_fname,
               attrs=tag_dataset.get_distribution_hdf5_attrs())
    with open(output_basename + '.json', 'w+') as f:
        f.write(g.to_json())

    with open(output_basename + '_loss_history.json', 'w+') as f:
        json.dump(history.history, f)

    fig, _ = history.plot()
    fig.savefig(output_basename + "_loss.png")
    print("Saved weights to: {}".format(weights_fname))