Beispiel #1
0
    parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
    parser.add_argument('-b', '--batch', type=int, default=1)
    args = parser.parse_args()
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    BATCH = args.batch

    if args.sample:
        assert args.load
        sample(args.data, args.load)
    else:
        logger.auto_set_dir()

        data = QueueInput(get_data())

        nr_tower = max(get_num_gpu(), 1)
        if nr_tower == 1:
            trainer = GANTrainer(data, Model())
        else:
            trainer = MultiGPUGANTrainer(nr_tower, data, Model())

        trainer.train_with_defaults(
            callbacks=[
                PeriodicTrigger(ModelSaver(), every_k_epochs=3),
                ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
            ],
            steps_per_epoch=data.size(),
            max_epoch=300,
            session_init=SaverRestore(args.load) if args.load else None)
Beispiel #2
0
        # sample
        pass
    else:
        # Set up configuration
        # Set the logger directory
        logger.auto_set_dir()

        # SyncMultiGPUTrainer(config).train()
        nr_tower = max(get_nr_gpu(), 1)
        if nr_tower == 1:
            trainer = SeparateGANTrainer(data_set,
                                         model,
                                         g_period=4,
                                         d_period=1)
        else:
            trainer = MultiGPUGANTrainer(nr_tower, data_set, model)
        trainer.train_with_defaults(
            callbacks=[
                # PeriodicTrigger(ModelSaver(), every_k_epochs=20),
                ClipCallback(),
                ScheduledHyperParamSetter('learning_rate', [(0, 2e-4),
                                                            (100, 1e-4),
                                                            (200, 2e-5),
                                                            (300, 1e-5),
                                                            (400, 2e-6),
                                                            (500, 1e-6)],
                                          interp='linear'),
                PeriodicTrigger(VisualizeRunner(), every_k_epochs=5),
            ],
            session_init=SaverRestore(args.load) if args.load else None,
            steps_per_epoch=data_set.size(),
Beispiel #3
0
def get_config():
    return TrainConfig(
        model=Model(),
        dataflow=DCGAN.get_data(G.data),
        callbacks=[
            ModelSaver(),
            StatMonitorParamSetter('learning_rate', 'measure',
                                   lambda x: x * 0.5, 0, 10)
        ],
        steps_per_epoch=500,
        max_epoch=400,
    )


if __name__ == '__main__':
    args = DCGAN.get_args()
    if args.sample:
        DCGAN.sample(args.load, 'gen/conv4.3/output')
    else:
        assert args.data
        logger.auto_set_dir()
        config = get_config()
        if args.load:
            config.session_init = SaverRestore(args.load)
        nr_gpu = get_nr_gpu()
        config.nr_tower = max(get_nr_gpu(), 1)
        if config.nr_tower == 1:
            GANTrainer(config).train()
        else:
            MultiGPUGANTrainer(config).train()
Beispiel #4
0
                             initializer=1e-4,
                             trainable=False)
        opt = tf.train.AdamOptimizer(lr, beta1=0.5, beta2=0.9)
        return opt


if __name__ == '__main__':
    args = DCGAN.get_args(default_batch=32, default_z_dim=64)
    if args.sample:
        DCGAN.sample(Model(), args.load, 'gen/conv4.3/output')
    else:
        logger.auto_set_dir()

        input = QueueInput(DCGAN.get_data())
        model = Model()
        nr_tower = max(get_nr_gpu(), 1)
        if nr_tower == 1:
            trainer = GANTrainer(input, model)
        else:
            trainer = MultiGPUGANTrainer(nr_tower, input, model)

        trainer.train_with_defaults(
            callbacks=[
                ModelSaver(),
                StatMonitorParamSetter('learning_rate', 'measure',
                                       lambda x: x * 0.5, 0, 10)
            ],
            session_init=SaverRestore(args.load) if args.load else None,
            steps_per_epoch=500,
            max_epoch=400)