# model load & save configs flags.DEFINE_string('summaries_dir', 'volume/TF_Logs/TensorflowModelZoo/resnet50_deform/', 'where to store summary log') flags.DEFINE_string( 'pretrained_ckpts', '/home/chenyifeng/TF_Models/ptrain/ILSVRC/ResNet_50_V1_GN/model.ckpt', 'where to load pretrained model') flags.DEFINE_string('last_ckpt', None, 'where to load last saved model') FLAGS = flags.FLAGS # config devices store_device = parse_device_name(FLAGS.store_device) run_device = parse_device_name(FLAGS.run_device) config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) if not 'CPU' in FLAGS.run_device: GPU_NUMS = len(FLAGS.run_device.split(',')) print( '====================================Deploying Model on {} GPUs====================================' .format(GPU_NUMS)) os.environ['CUDA_VISIBLE_DEVICES'] = ''.join(FLAGS.run_device) config.gpu_options.allow_growth = FLAGS.allow_growth else: GPU_NUMS = 1 print( '====================================Deploying Model on CPU====================================' )
'/home/chenyifeng/TF_Models/atrain/SEGS/fcn/mgpu', 'where to load last saved model') tf.app.flags.DEFINE_string('next_ckpt', '/home/chenyifeng/TF_Models/atrain/SEGS/fcn/mgpu', 'where to store current model') tf.app.flags.DEFINE_integer('save_per_step', 1000, 'save model per xxx steps') FLAGS = tf.app.flags.FLAGS if (FLAGS.reshape_height is None or FLAGS.reshape_weight is None) and FLAGS.batch_size != 1: assert 0, 'Can' 't Stack Images Of Different Shapes, Please Speicify Reshape Size!' store_device = parse_device_name(FLAGS.store_device) run_device = parse_device_name(FLAGS.run_device) config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) if FLAGS.run_device in '01234567': print('Deploying Model on {} GPU Card'.format(''.join(FLAGS.run_device))) # os.environ['CUDA_VISIBLE_DEVICES'] = ''.join(FLAGS.run_device) config.gpu_options.allow_growth = FLAGS.allow_growth config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_fraction else: print('Deploying Model on CPU') weight_reg = regularizer(mode=FLAGS.weight_reg_func, scale=FLAGS.weight_reg_scale) bias_reg = regularizer(mode=FLAGS.bias_reg_func, scale=FLAGS.bias_reg_scale) net = get_net(FLAGS.net_name)