コード例 #1
0
class FLAGS:
    class TRAIN:
        BATCH_SIZE = 32
        DOWN_SAMPLING_RATIO = 4
        SAMPLER_TARGET_SHAPE = [BATCH_SIZE, 32, 32, 1]

    class SUMMARY:
        SUMMARY_DIR = '/home/qinglong/node3share/remote_drssrn/tensorboard_log/4x_down_'


d = DataGen(file, FLAGS.TRAIN.BATCH_SIZE)
ds = DownSampler(d.phantom, FLAGS.TRAIN.DOWN_SAMPLING_RATIO)
aspr = AlignSampler(ds(), d.phantom, FLAGS.TRAIN.SAMPLER_TARGET_SHAPE)
train_low, train_high = aspr()

train_high_shape = shape_as_list(train_high)
train_low_shape = shape_as_list(train_low)

train_high = tf.map_fn(tf.image.per_image_standardization, train_high)
train_low = tf.map_fn(tf.image.per_image_standardization, train_low)

train_high = tf.reshape(tf.py_func(rs, [train_high], tf.float32),
                        train_high_shape)
train_low = tf.reshape(tf.py_func(rs, [train_low], tf.float32),
                       train_low_shape)

# with tf.device('/device:GPU:1'):
#     config = tf.ConfigProto()
#     config.gpu_options.allow_growth = True
#     config.log_device_placement = True
#
コード例 #2
0
 def input_dim(self):
     return len(shape_as_list(self.input_))
コード例 #3
0
 def high_shape(self):
     return shape_as_list(self.high)
コード例 #4
0
 def input_shape(self):
     return shape_as_list(self.input_)
コード例 #5
0
 def low_shape(self):
     return shape_as_list(self.low)