class FLAGS: class TRAIN: BATCH_SIZE = 32 DOWN_SAMPLING_RATIO = 4 SAMPLER_TARGET_SHAPE = [BATCH_SIZE, 32, 32, 1] class SUMMARY: SUMMARY_DIR = '/home/qinglong/node3share/remote_drssrn/tensorboard_log/4x_down_' d = DataGen(file, FLAGS.TRAIN.BATCH_SIZE) ds = DownSampler(d.phantom, FLAGS.TRAIN.DOWN_SAMPLING_RATIO) aspr = AlignSampler(ds(), d.phantom, FLAGS.TRAIN.SAMPLER_TARGET_SHAPE) train_low, train_high = aspr() train_high_shape = shape_as_list(train_high) train_low_shape = shape_as_list(train_low) train_high = tf.map_fn(tf.image.per_image_standardization, train_high) train_low = tf.map_fn(tf.image.per_image_standardization, train_low) train_high = tf.reshape(tf.py_func(rs, [train_high], tf.float32), train_high_shape) train_low = tf.reshape(tf.py_func(rs, [train_low], tf.float32), train_low_shape) # with tf.device('/device:GPU:1'): # config = tf.ConfigProto() # config.gpu_options.allow_growth = True # config.log_device_placement = True #
def input_dim(self): return len(shape_as_list(self.input_))
def high_shape(self): return shape_as_list(self.high)
def input_shape(self): return shape_as_list(self.input_)
def low_shape(self): return shape_as_list(self.low)