示例#1
0
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if args.apply:
        apply(args.load, args.lowres, args.output)
    else:
        logger.auto_set_dir()

        if args.load:
            session_init = SaverRestore(args.load)
        else:
            assert os.path.isfile(args.vgg19)
            param_dict = dict(np.load(args.vgg19))
            param_dict = {'VGG19/' + name: value for name, value in six.iteritems(param_dict)}
            session_init = DictRestore(param_dict)

        nr_tower = max(get_num_gpu(), 1)
        data = QueueInput(get_data(args.data))
        model = Model()

        trainer = SeparateGANTrainer(data, model, d_period=3)

        trainer.train_with_defaults(
            callbacks=[
                ModelSaver(keep_checkpoint_every_n_hours=2)
            ],
            session_init=session_init,
            steps_per_epoch=len(data) // 4,
            max_epoch=300
        )
示例#2
0
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if args.apply:
        apply(args.load, args.lowres, args.output)
    else:
        logger.auto_set_dir()

        if args.load:
            session_init = SaverRestore(args.load)
        else:
            assert os.path.isfile(args.vgg19)
            param_dict = dict(np.load(args.vgg19))
            param_dict = {
                'VGG19/' + name: value
                for name, value in six.iteritems(param_dict)
            }
            session_init = DictRestore(param_dict)

        nr_tower = max(get_num_gpu(), 1)
        data = QueueInput(get_data(args.data))
        model = Model()

        trainer = SeparateGANTrainer(data, model, d_period=3)

        trainer.train_with_defaults(
            callbacks=[ModelSaver(keep_checkpoint_every_n_hours=2)],
            session_init=session_init,
            steps_per_epoch=data.size() // 4,
            max_epoch=300)
示例#3
0
        pass
    else:
        # Set up configuration
        # Set the logger directory
        logger.auto_set_dir()

        # SyncMultiGPUTrainer(config).train()
        nr_tower = max(get_nr_gpu(), 1)
        if nr_tower == 1:
            trainer = SeparateGANTrainer(data_set,
                                         model,
                                         g_period=4,
                                         d_period=1)
        else:
            trainer = MultiGPUGANTrainer(nr_tower, data_set, model)
        trainer.train_with_defaults(
            callbacks=[
                # PeriodicTrigger(ModelSaver(), every_k_epochs=20),
                ClipCallback(),
                ScheduledHyperParamSetter('learning_rate', [(0, 2e-4),
                                                            (100, 1e-4),
                                                            (200, 2e-5),
                                                            (300, 1e-5),
                                                            (400, 2e-6),
                                                            (500, 1e-6)],
                                          interp='linear'),
                PeriodicTrigger(VisualizeRunner(), every_k_epochs=5),
            ],
            session_init=SaverRestore(args.load) if args.load else None,
            steps_per_epoch=data_set.size(),
            max_epoch=300)
示例#4
0
def main():
	np.random.seed(2018)
	tf.set_random_seed(2018)
	#https://docs.python.org/3/library/argparse.html
	parser = argparse.ArgumentParser()
	#
	parser.add_argument('--gpu',        help='comma separated list of GPU(s) to use.')
	parser.add_argument('--load',       help='load models for continue train or predict')
	parser.add_argument('--sample',     help='run sampling one instance')
	parser.add_argument('--imageDir',   help='Image directory', required=True)
	parser.add_argument('--maskDir',    help='Masks directory', required=False)
	parser.add_argument('--labelDir',   help='Label directory', required=True)
	parser.add_argument('-db', '--debug', type=int, default=0) # Debug one particular function in main flow
	global args
	args = parser.parse_args() # Create an object of parser
	if args.gpu:
		os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
		# os.environ['TENSORPACK_TRAIN_API'] = 'v2'
	if args.sample:

		sample(args.imageDir, args.maskDir, args.labelDir, args.load, args.sample)
	else:
		logger.auto_set_dir()
		ds_train, ds_valid = get_data(args.imageDir, args.maskDir, args.labelDir)

		ds_train = PrefetchDataZMQ(ds_train, nr_proc=4)
		ds_valid = PrefetchDataZMQ(ds_valid, nr_proc=4)

		ds_train.reset_state()
		ds_valid.reset_state() 

		nr_tower = max(get_nr_gpu(), 1)
		ds_train = QueueInput(ds_train)
		model = Model()
		if nr_tower == 1:
			trainer = SeparateGANTrainer(ds_train, model, g_period=1, d_period=1)
		else:
			trainer = MultiGPUGANTrainer(nr_tower, ds_train, model)
		trainer.train_with_defaults(
			callbacks=[
				PeriodicTrigger(ModelSaver(), every_k_epochs=20),
				PeriodicTrigger(MaxSaver('validation_PSNR_recon_A'), every_k_epochs=20),
				PeriodicTrigger(MaxSaver('validation_PSNR_boost_A'), every_k_epochs=20),
				VisualizeRunner(),
				InferenceRunner(ds_valid, [
										   ScalarStats('PSNR_zfill_A'), 
										   ScalarStats('PSNR_zfill_B'),
										   ScalarStats('PSNR_recon_A'),
										   ScalarStats('PSNR_recon_B'),
										   ScalarStats('PSNR_boost_A'), 
										   ScalarStats('PSNR_boost_B'),
										
										   ScalarStats('losses/Img/Zfill/zfill_img_MA'),
										   ScalarStats('losses/Img/Zfill/zfill_img_MB'),
											  
										   ScalarStats('losses/Frq/Recon/recon_frq_AA'),
										   ScalarStats('losses/Frq/Recon/recon_frq_BB'),
										   
										   ScalarStats('losses/Img/Recon/recon_img_AA'),
										   ScalarStats('losses/Img/Recon/recon_img_BB'),
										   
										   ScalarStats('losses/Frq/Boost/recon_frq_Aa'),
										   ScalarStats('losses/Frq/Boost/recon_frq_Bb'),
										   
										   ScalarStats('losses/Img/Boost/recon_img_Aa'),
										   ScalarStats('losses/Img/Boost/recon_img_Bb'),
					]),
				ClipCallback(),
				ScheduledHyperParamSetter('learning_rate', 
					[(0, 2e-4), (100, 1e-4), (200, 2e-5), (300, 1e-5), (400, 2e-6), (500, 1e-6)], interp='linear')
				
				],
			session_init=SaverRestore(args.load) if args.load else None, 
			steps_per_epoch=ds_train.size(),
			max_epoch=500
		)