예제 #1
0
def main(config):
    num_step = config.num_step
    data_loader = LSUNCatLoader(get_lsun_patterns(config.data_dir),
                                num_workers=4,
                                actions=lsun_process_actions())

    names = []
    fobjs = []
    try:
        data_loader.start_fetch()
        print("generating images...")
        for _ in xrange(num_step):
            fd, name = tempfile.mkstemp(suffix=".npy")
            fobj = os.fdopen(fd, "wb+")
            names.append(name)
            fobjs.append(fobj)
            image_arr = data_loader.next_batch(config.batch_size)[0]
            np.save(fobj, image_arr, allow_pickle=False)
            fobj.close()

        mean_score, std_score = get_resnet18_score(images_iter(names),
                                                config.model_path,
                                                batch_size=100,
                                                split=10)

        print("mean = %.4f, std = %.4f." % (mean_score, std_score))

        if config.save_path is not None:
            with open(config.save_path, "wb") as f:
                cPickle.dump(dict(batch_size=config.batch_size,
                                  scores=dict(mean=mean_score, std=std_score)), f)
    finally:
        data_loader.stop_fetch()
        for name in names:
            os.unlink(name)
        for fobj in fobjs:
            fobj.close()
예제 #2
0
        eval_losses.append(eval_loss)

    sess.close()
    print("accuracy:", np.mean(eval_losses))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("val_data_dir", metavar="VALDATADIR")
    parser.add_argument("model_path", metavar="MODELPATH")
    parser.add_argument("--batch-size",
                        dest="batch_size",
                        type=int,
                        default=100)
    parser.add_argument("--dim", dest="dim", default=64, type=int)

    config = parser.parse_args()

    print("config: %r" % config)

    eval_data_loader = LSUNCatLoader(get_lsun_patterns(config.val_data_dir),
                                     num_workers=2,
                                     actions=lsun_process_actions())

    try:
        eval_data_loader.start_fetch()
        run_task(config, eval_data_loader, classifier_forward,
                 tf.train.AdamOptimizer())
    finally:
        eval_data_loader.stop_fetch()
예제 #3
0
    else:
        clipper = clipper_ret
        sampler = None
    scheduler = get_scheduler(config.scheduler, config)

    def callback_before_train(_0, _1, _2):
        print(clipper.info())

    supervisor = BasicSupervisorLSUN(
        config,
        clipper,
        scheduler,
        sampler=sampler,
        callback_before_train=callback_before_train)
    if config.adaptive_rate:
        supervisor.put_key("lr", lr)

    try:
        train(config,
              data_loader,
              generator_forward,
              discriminator_forward,
              disc_optimizer=disc_optimizer,
              gen_optimizer=gen_optimizer,
              accountant=accountant,
              supervisor=supervisor)
    finally:
        data_loader.stop_fetch()
        if sampler is not None:
            sampler.data_loader.stop_fetch()