if args.gpu: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if args.apply: apply(args.load, args.lowres, args.output) else: logger.auto_set_dir() if args.load: session_init = SaverRestore(args.load) else: assert os.path.isfile(args.vgg19) param_dict = dict(np.load(args.vgg19)) param_dict = {'VGG19/' + name: value for name, value in six.iteritems(param_dict)} session_init = DictRestore(param_dict) nr_tower = max(get_num_gpu(), 1) data = QueueInput(get_data(args.data)) model = Model() trainer = SeparateGANTrainer(data, model, d_period=3) trainer.train_with_defaults( callbacks=[ ModelSaver(keep_checkpoint_every_n_hours=2) ], session_init=session_init, steps_per_epoch=len(data) // 4, max_epoch=300 )
if args.gpu: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if args.apply: apply(args.load, args.lowres, args.output) else: logger.auto_set_dir() if args.load: session_init = SaverRestore(args.load) else: assert os.path.isfile(args.vgg19) param_dict = dict(np.load(args.vgg19)) param_dict = { 'VGG19/' + name: value for name, value in six.iteritems(param_dict) } session_init = DictRestore(param_dict) nr_tower = max(get_num_gpu(), 1) data = QueueInput(get_data(args.data)) model = Model() trainer = SeparateGANTrainer(data, model, d_period=3) trainer.train_with_defaults( callbacks=[ModelSaver(keep_checkpoint_every_n_hours=2)], session_init=session_init, steps_per_epoch=data.size() // 4, max_epoch=300)
pass else: # Set up configuration # Set the logger directory logger.auto_set_dir() # SyncMultiGPUTrainer(config).train() nr_tower = max(get_nr_gpu(), 1) if nr_tower == 1: trainer = SeparateGANTrainer(data_set, model, g_period=4, d_period=1) else: trainer = MultiGPUGANTrainer(nr_tower, data_set, model) trainer.train_with_defaults( callbacks=[ # PeriodicTrigger(ModelSaver(), every_k_epochs=20), ClipCallback(), ScheduledHyperParamSetter('learning_rate', [(0, 2e-4), (100, 1e-4), (200, 2e-5), (300, 1e-5), (400, 2e-6), (500, 1e-6)], interp='linear'), PeriodicTrigger(VisualizeRunner(), every_k_epochs=5), ], session_init=SaverRestore(args.load) if args.load else None, steps_per_epoch=data_set.size(), max_epoch=300)
def main(): np.random.seed(2018) tf.set_random_seed(2018) #https://docs.python.org/3/library/argparse.html parser = argparse.ArgumentParser() # parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--load', help='load models for continue train or predict') parser.add_argument('--sample', help='run sampling one instance') parser.add_argument('--imageDir', help='Image directory', required=True) parser.add_argument('--maskDir', help='Masks directory', required=False) parser.add_argument('--labelDir', help='Label directory', required=True) parser.add_argument('-db', '--debug', type=int, default=0) # Debug one particular function in main flow global args args = parser.parse_args() # Create an object of parser if args.gpu: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # os.environ['TENSORPACK_TRAIN_API'] = 'v2' if args.sample: sample(args.imageDir, args.maskDir, args.labelDir, args.load, args.sample) else: logger.auto_set_dir() ds_train, ds_valid = get_data(args.imageDir, args.maskDir, args.labelDir) ds_train = PrefetchDataZMQ(ds_train, nr_proc=4) ds_valid = PrefetchDataZMQ(ds_valid, nr_proc=4) ds_train.reset_state() ds_valid.reset_state() nr_tower = max(get_nr_gpu(), 1) ds_train = QueueInput(ds_train) model = Model() if nr_tower == 1: trainer = SeparateGANTrainer(ds_train, model, g_period=1, d_period=1) else: trainer = MultiGPUGANTrainer(nr_tower, ds_train, model) trainer.train_with_defaults( callbacks=[ PeriodicTrigger(ModelSaver(), every_k_epochs=20), PeriodicTrigger(MaxSaver('validation_PSNR_recon_A'), every_k_epochs=20), PeriodicTrigger(MaxSaver('validation_PSNR_boost_A'), every_k_epochs=20), VisualizeRunner(), InferenceRunner(ds_valid, [ ScalarStats('PSNR_zfill_A'), ScalarStats('PSNR_zfill_B'), ScalarStats('PSNR_recon_A'), ScalarStats('PSNR_recon_B'), ScalarStats('PSNR_boost_A'), ScalarStats('PSNR_boost_B'), ScalarStats('losses/Img/Zfill/zfill_img_MA'), ScalarStats('losses/Img/Zfill/zfill_img_MB'), ScalarStats('losses/Frq/Recon/recon_frq_AA'), ScalarStats('losses/Frq/Recon/recon_frq_BB'), ScalarStats('losses/Img/Recon/recon_img_AA'), ScalarStats('losses/Img/Recon/recon_img_BB'), ScalarStats('losses/Frq/Boost/recon_frq_Aa'), ScalarStats('losses/Frq/Boost/recon_frq_Bb'), ScalarStats('losses/Img/Boost/recon_img_Aa'), ScalarStats('losses/Img/Boost/recon_img_Bb'), ]), ClipCallback(), ScheduledHyperParamSetter('learning_rate', [(0, 2e-4), (100, 1e-4), (200, 2e-5), (300, 1e-5), (400, 2e-6), (500, 1e-6)], interp='linear') ], session_init=SaverRestore(args.load) if args.load else None, steps_per_epoch=ds_train.size(), max_epoch=500 )