예제 #1
0
파일: train.py 프로젝트: narutobns/Resnet
def data_batch(recordfile, batch_size):
    _, img_data, label = read_tfrecord(recordfile)
    img_batch, label_batch = tf.train.shuffle_batch([img_data, label],
                                                    batch_size=batch_size,
                                                    num_threads=4,
                                                    capacity=50000,
                                                    min_after_dequeue=10000)
    return img_batch, label_batch
    def read(self, filename_queue):
        data, label = read_tfrecord(
            filename_queue,
            {'nodes': [-1, self._grapher.num_node_channels],
             'neighborhood': [self._num_nodes, self._neighborhood_size]})

        nodes = data['nodes']

        # Convert the neighborhood to a feature map.
        def _map_features(node):
            i = tf.maximum(node, 0)
            positive = tf.strided_slice(nodes, [i], [i+1], [1])
            negative = tf.zeros([1, self._grapher.num_node_channels])

            return tf.where(i < 0, negative, positive)

        data = tf.reshape(data['neighborhood'], [-1])
        data = tf.cast(data, tf.int32)
        data = tf.map_fn(_map_features, data, dtype=tf.float32)
        shape = [self._num_nodes, self._neighborhood_size,
                 self._grapher.num_node_channels]
        data = tf.reshape(data, shape)

        return Record(data, shape, label)
예제 #3
0
    def build(self, encode):

        # Parameters
        DATA_DIR = self.args.data_dir

        LAYER_NUM = self.args.layer_num
        HIDDEN_UNIT = self.args.hidden_unit

        LAMBDA_CTX = self.args.lambda_ctx
        CHANNEL_NUM = self.args.channel_num
        CHANNEL_TYPE = self.args.channel_type

        LR = self.args.lr

        BATCH_SIZE = self.args.batch_size
        CROP_SIZE = self.args.crop_size

        CHANNEL_EPOCH = self.args.channel_epoch
        JOINT_EPOCH = self.args.joint_epoch

        # TFRecord
        tfrecord_name = 'train.tfrecord'

        # if train tfrecord does not exist, create dataset
        if not data_exist(DATA_DIR, tfrecord_name):
            img_list = read_dir(DATA_DIR + 'train/')
            write_tfrecord(DATA_DIR, img_list, tfrecord_name)

        self.input_crop, _, _ = read_tfrecord(DATA_DIR,
                                              tfrecord_name,
                                              num_epochs=3 * CHANNEL_EPOCH +
                                              JOINT_EPOCH,
                                              batch_size=BATCH_SIZE,
                                              min_after_dequeue=10,
                                              crop_size=CROP_SIZE)

        if encode:
            self.input = tf.placeholder(tf.int16, (None, None, None, 3))

            input_yuv = self.rgb2yuv(self.input)

            if CHANNEL_NUM == 1:
                if (CHANNEL_TYPE == 0):
                    input_img = tf.expand_dims(input_yuv[:, :, :, 0], axis=3)
                elif (CHANNEL_TYPE == 1):
                    input_img = tf.expand_dims(input_yuv[:, :, :, 1], axis=3)
                elif (CHANNEL_TYPE == 2):
                    input_img = tf.expand_dims(input_yuv[:, :, :, 2], axis=3)
            elif CHANNEL_NUM == 3:
                input_img = input_yuv
            else:
                print("Invalid Channel Num")
                sys.exit(1)

            input_depth = tf.nn.space_to_depth(input_img, 2)

            original_img = tf.nn.space_to_depth(input_yuv, 2)

            self.input_0, _, _, _ = tf.split(original_img, 4, axis=3)

            self.input_1, self.input_4, self.input_3, self.input_2 = tf.split(
                input_depth, 4, axis=3)

            # Prediction of 2
            pred_2, ctx_2 = model_conv(self.input_1, LAYER_NUM, HIDDEN_UNIT,
                                       'pred_2')

            error_pred_2 = abs(tf.subtract(pred_2, self.input_2))

            # Prediction of 3
            concat_1_2 = tf.concat([self.input_1, self.input_2], axis=3)
            pred_3, ctx_3 = model_conv(concat_1_2, LAYER_NUM, HIDDEN_UNIT,
                                       'pred_3')

            error_pred_3 = abs(tf.subtract(pred_3, self.input_3))

            # Prediction of 4
            concat_1_2_3 = tf.concat(
                [self.input_1, self.input_2, self.input_3], axis=3)
            pred_4, ctx_4 = model_conv(concat_1_2_3, LAYER_NUM, HIDDEN_UNIT,
                                       'pred_4')

            # Prediction error

            error_pred_4 = abs(tf.subtract(pred_4, self.input_4))

            # Losses
            loss_pred_2 = tf.reduce_mean(error_pred_2)
            loss_pred_3 = tf.reduce_mean(error_pred_3)
            loss_pred_4 = tf.reduce_mean(error_pred_4)

            loss_ctx_2 = LAMBDA_CTX * tf.reduce_mean(
                abs(tf.subtract(ctx_2, error_pred_2)))
            loss_ctx_3 = LAMBDA_CTX * tf.reduce_mean(
                abs(tf.subtract(ctx_3, error_pred_3)))
            loss_ctx_4 = LAMBDA_CTX * tf.reduce_mean(
                abs(tf.subtract(ctx_4, error_pred_4)))

            loss_2 = loss_pred_2 + loss_ctx_2
            loss_3 = loss_pred_3 + loss_ctx_3
            loss_4 = loss_pred_4 + loss_ctx_4

            total_loss = loss_2 + loss_3 + loss_4

            # Optimizer
            all_vars = tf.trainable_variables()
            vars_2 = [var for var in all_vars if 'pred_2' in var.name]
            vars_3 = [var for var in all_vars if 'pred_3' in var.name]
            vars_4 = [var for var in all_vars if 'pred_4' in var.name]

            self.optimizer_2 = tf.train.AdamOptimizer(LR).minimize(
                loss_2, var_list=vars_2)
            self.optimizer_3 = tf.train.AdamOptimizer(LR).minimize(
                loss_3, var_list=vars_3)
            self.optimizer_4 = tf.train.AdamOptimizer(LR).minimize(
                loss_4, var_list=vars_4)
            self.optimizer_all = tf.train.AdamOptimizer(LR).minimize(
                total_loss, var_list=all_vars)

            # Variables
            self.loss_2 = loss_2
            self.loss_3 = loss_3
            self.loss_4 = loss_4
            self.loss_all = loss_4 + loss_2 + loss_3

            self.loss_pred_2 = loss_pred_2
            self.loss_pred_3 = loss_pred_3
            self.loss_pred_4 = loss_pred_4
            self.loss_pred_all = loss_pred_2 + loss_pred_3 + loss_pred_4

            self.loss_ctx_2 = loss_ctx_2
            self.loss_ctx_3 = loss_ctx_3
            self.loss_ctx_4 = loss_ctx_4
            self.loss_ctx_all = loss_ctx_2 + loss_ctx_3 + loss_ctx_4

            self.pred_2 = pred_2
            self.pred_3 = pred_3
            self.pred_4 = pred_4

            self.ctx_2 = ctx_2
            self.ctx_3 = ctx_3
            self.ctx_4 = ctx_4

        else:
            '''
            self.input = tf.placeholder(tf.uint8, (None, None, None, 3))
            input_img = tf.expand_dims(self.rgb2yuv(self.input)[:,:,:,0],axis=3)
            self.input_1, _, _, _ = tf.split(tf.nn.space_to_depth(input_img, 2), 4, axis=3)
            '''
            self.input_1 = tf.placeholder(tf.int16,
                                          (None, None, None, CHANNEL_NUM))
            self.input_1 = tf.to_float(self.input_1)

            # Prediction of 2
            pred_2, ctx_2 = model_conv(self.input_1, LAYER_NUM, HIDDEN_UNIT,
                                       'pred_2')

            self.pred_2 = pred_2
            self.ctx_2 = ctx_2

            self.input_2 = tf.placeholder(tf.int16,
                                          (None, None, None, CHANNEL_NUM))
            self.input_2 = tf.to_float(self.input_2)

            # Prediction of 3
            concat_1_2 = tf.concat([self.input_1, self.input_2], axis=3)
            pred_3, ctx_3 = model_conv(concat_1_2, LAYER_NUM, HIDDEN_UNIT,
                                       'pred_3')

            self.pred_3 = pred_3
            self.ctx_3 = ctx_3

            self.input_3 = tf.placeholder(tf.int16,
                                          (None, None, None, CHANNEL_NUM))
            self.input_3 = tf.to_float(self.input_3)

            # Prediction of 4
            concat_1_2_3 = tf.concat(
                [self.input_1, self.input_2, self.input_3], axis=3)
            pred_4, ctx_4 = model_conv(concat_1_2_3, LAYER_NUM, HIDDEN_UNIT,
                                       'pred_4')

            self.pred_4 = pred_4
            self.ctx_4 = ctx_4

        # Original images
        '''self.input_1 = input_1
예제 #4
0
    def build(self, encode):

        # Parameters
        DATA_DIR = self.args.data_dir

        LAYER_NUM = self.args.layer_num
        HIDDEN_UNIT = self.args.hidden_unit

        LAMBDA_CTX = self.args.lambda_ctx
        CHANNEL_NUM = self.args.channel_num
        CHANNEL_TYPE = self.args.channel_type

        LR = self.args.lr

        BATCH_SIZE = self.args.batch_size
        CROP_SIZE = self.args.crop_size

        CHANNEL_EPOCH = self.args.channel_epoch
        JOINT_EPOCH = self.args.joint_epoch
        NUM_PATCHES = 11

        # TFRecord
        tfrecord_name = 'train.tfrecord'

        # if train tfrecord does not exist, create dataset
        if not data_exist(DATA_DIR, tfrecord_name):
            img_list = read_dir(DATA_DIR + 'train/')
            write_tfrecord(DATA_DIR, img_list, tfrecord_name)

        self.input_crop, _, _ = read_tfrecord(DATA_DIR,
                                              tfrecord_name,
                                              num_epochs=3 * CHANNEL_EPOCH +
                                              JOINT_EPOCH,
                                              batch_size=BATCH_SIZE,
                                              min_after_dequeue=10,
                                              crop_size=CROP_SIZE)

        self.type = self.args.learning_order
        if self.type == "1234":
            self.channel = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
            self.origin = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]
        elif self.type == "yuv":
            self.channel = [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
            self.origin = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4]

        if encode:
            self.input = tf.placeholder(tf.uint8, (None, None, None, 3))

            input_yuv = self.rgb2yuv(self.input)

            if CHANNEL_NUM == 1:
                if (CHANNEL_TYPE == 0):
                    input_img = tf.expand_dims(input_yuv[:, :, :, 0], axis=3)
                elif (CHANNEL_TYPE == 1):
                    input_img = tf.expand_dims(input_yuv[:, :, :, 1], axis=3)
                elif (CHANNEL_TYPE == 2):
                    input_img = tf.expand_dims(input_yuv[:, :, :, 2], axis=3)
            elif CHANNEL_NUM == 3:
                input_img = input_yuv
            else:
                print("Invalid Channel Num")
                sys.exit(1)

            input_depth = tf.nn.space_to_depth(input_img, 2)

            #original_img = tf.nn.space_to_depth(input_yuv, 2)

            #self.input_0, _, _, _ = tf.split(original_img, 4, axis=3)

            self.input_1, self.input_4, self.input_3, self.input_2 = tf.split(
                input_depth, 4, axis=3)

            if self.type == "1234":
                order = tf.concat([
                    tf.expand_dims(self.input_1[:, :, :, 0], axis=3),
                    tf.expand_dims(self.input_2[:, :, :, 0], axis=3),
                    tf.expand_dims(self.input_3[:, :, :, 0], axis=3),
                    tf.expand_dims(self.input_4[:, :, :, 0], axis=3),
                    tf.expand_dims(self.input_1[:, :, :, 1], axis=3),
                    tf.expand_dims(self.input_2[:, :, :, 1], axis=3),
                    tf.expand_dims(self.input_3[:, :, :, 1], axis=3),
                    tf.expand_dims(self.input_4[:, :, :, 1], axis=3),
                    tf.expand_dims(self.input_1[:, :, :, 2], axis=3),
                    tf.expand_dims(self.input_2[:, :, :, 2], axis=3),
                    tf.expand_dims(self.input_3[:, :, :, 2], axis=3),
                    tf.expand_dims(self.input_4[:, :, :, 2], axis=3)
                ],
                                  axis=3)
                #self.channel = [0,0,0,0,1,1,1,1,2,2,2,2]
                #self.origin = [1,2,3,4,1,2,3,4,1,2,3,4]
            elif self.type == "yuv":
                order = tf.concat([
                    tf.expand_dims(self.input_1[:, :, :, 0], axis=3),
                    tf.expand_dims(self.input_1[:, :, :, 1], axis=3),
                    tf.expand_dims(self.input_1[:, :, :, 2], axis=3),
                    tf.expand_dims(self.input_2[:, :, :, 0], axis=3),
                    tf.expand_dims(self.input_2[:, :, :, 1], axis=3),
                    tf.expand_dims(self.input_2[:, :, :, 2], axis=3),
                    tf.expand_dims(self.input_3[:, :, :, 0], axis=3),
                    tf.expand_dims(self.input_3[:, :, :, 1], axis=3),
                    tf.expand_dims(self.input_3[:, :, :, 2], axis=3),
                    tf.expand_dims(self.input_4[:, :, :, 0], axis=3),
                    tf.expand_dims(self.input_4[:, :, :, 1], axis=3),
                    tf.expand_dims(self.input_4[:, :, :, 2], axis=3)
                ],
                                  axis=3)
                #self.channel = [0,1,2,0,1,2,0,1,2,0,1,2]
                #self.origin = [1,1,1,2,2,2,3,3,3,4,4,4]
            print("building order completed.\n")

            pred_li = []
            ctx_li = []
            error_pred_li = []
            loss_pred_li = []
            loss_ctx_li = []
            loss_li = []
            total_loss = 0

            for i in range(NUM_PATCHES):
                if self.origin[i + 1] != 1:
                    pred, ctx = model_conv(order[:, :, :, :(i + 1)], LAYER_NUM,
                                           HIDDEN_UNIT, 'model_' + str(i + 1))
                    error_pred = abs(
                        tf.subtract(
                            pred,
                            tf.expand_dims(order[:, :, :, (i + 1)], axis=3)))
                    loss_pred = tf.reduce_mean(error_pred)
                    loss_ctx = LAMBDA_CTX * tf.reduce_mean(
                        abs(tf.subtract(ctx, error_pred)))
                    pred_li.append(pred)
                    ctx_li.append(ctx)
                    error_pred_li.append(error_pred)
                    loss_pred_li.append(loss_pred)
                    loss_ctx_li.append(loss_ctx)
                    loss_li.append(loss_pred + loss_ctx)
                    total_loss += loss_pred + loss_ctx

            all_vars = tf.trainable_variables()

            optimizer_li = []
            k = 0
            for j in range(NUM_PATCHES):
                if self.origin[j + 1] != 1:
                    vars = [
                        var for var in all_vars
                        if 'model_' + str(j + 1) in var.name
                    ]
                    optimizer = tf.train.AdamOptimizer(LR).minimize(
                        loss_li[k], var_list=vars)
                    optimizer_li.append(optimizer)
                    k += 1

            self.pred_li = pred_li
            self.ctx_li = ctx_li
            self.error_pred_li = error_pred_li
            self.loss_pred_li = loss_pred_li
            self.loss_ctx_li = loss_ctx_li
            self.loss_li = loss_li
            self.optimizer_li = optimizer_li
            self.optimizer_all = tf.train.AdamOptimizer(LR).minimize(
                total_loss, var_list=all_vars)
            self.order = order
            self.true_patches = k

        else:
            '''
            self.input = tf.placeholder(tf.uint8, (None, None, None, 3))
            input_img = tf.expand_dims(self.rgb2yuv(self.input)[:,:,:,0],axis=3)
            self.input_1, _, _, _ = tf.split(tf.nn.space_to_depth(input_img, 2), 4, axis=3)
            '''
            #self.input_list = [tf.to_float(tf.placeholder(tf.int16, (None, None, None, CHANNEL_NUM))) for _ in range(NUM_PATCHES)]
            self.input_all = tf.placeholder(
                tf.uint8, (None, None, None, NUM_PATCHES + 1))
            self.input_all = tf.to_float(self.input_all)

            pred_li = []
            ctx_li = []

            for i in range(NUM_PATCHES):
                if self.origin[i + 1] != 1:
                    pred, ctx = model_conv(self.input_all[:, :, :, :(i + 1)],
                                           LAYER_NUM, HIDDEN_UNIT,
                                           'model_' + str(i + 1))
                    pred_li.append(pred)
                    ctx_li.append(ctx)

            self.pred_li = pred_li
            self.ctx_li = ctx_li
예제 #5
0
                              [adjacency_unweighted])
num_channels = grapher.num_node_channels
patchy = PatchySan(pascal,
                   grapher,
                   data_dir='/tmp/patchy_san_slic_pascal_voc_data',
                   num_nodes=400,
                   node_stride=1,
                   neighborhood_size=1)

filename_queue = tf.train.string_input_producer(patchy.train_filenames,
                                                num_epochs=1,
                                                shuffle=False)

# Load the node features. We are not interested in the labels.
data, _ = read_tfrecord(filename_queue, {
    'nodes': [-1, num_channels],
    'neighborhood': [400, 1]
})
data = data['nodes']

# The data queue.
data_batch = tf.train.batch([data],
                            batch_size=128,
                            num_threads=16,
                            capacity=300,
                            dynamic_pad=True,
                            allow_smaller_final_batch=True)
data_batch = tf.reshape(data_batch, [-1, num_channels])

sess = tf.Session()
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
예제 #6
0
    def build(self):

        # Parameters
        DATA_DIR = self.args.data_dir

        LAYER_NUM = self.args.layer_num
        HIDDEN_UNIT = self.args.hidden_unit

        LAMBDA_CTX = self.args.lambda_ctx
        CHANNEL_NUM = self.args.channel_num

        LR = self.args.lr

        BATCH_SIZE = self.args.batch_size
        CROP_SIZE = self.args.crop_size

        CHANNEL_EPOCH = self.args.channel_epoch
        JOINT_EPOCH = self.args.joint_epoch
        NUM_PATCHES = 11

        # TFRecord
        tfrecord_name = 'train.tfrecord'

        # if train tfrecord does not exist, create dataset
        if not data_exist(DATA_DIR, tfrecord_name):
            img_list = read_dir(DATA_DIR + 'train/')
            write_tfrecord(DATA_DIR, img_list, tfrecord_name)

        self.input_crop, _, _ = read_tfrecord(DATA_DIR,
                                              tfrecord_name,
                                              num_epochs=3 * CHANNEL_EPOCH +
                                              JOINT_EPOCH,
                                              batch_size=BATCH_SIZE,
                                              min_after_dequeue=10,
                                              crop_size=CROP_SIZE)

        self.input = tf.placeholder(tf.uint8, (None, None, None, 3))

        input_yuv = self.rgb2yuv(self.input)

        if CHANNEL_NUM == 1:
            input_img = tf.expand_dims(input_yuv[:, :, :, 0], axis=3)
        elif CHANNEL_NUM == 3:
            input_img = input_yuv
        else:
            print("Invalid Channel Num")
            sys.exit(1)

        input_depth = tf.nn.space_to_depth(input_img, 2)
        input_1, input_4, input_3, input_2 = tf.split(input_depth, 4, axis=3)
        """
        if CHANNEL_NUM == 3:
            input_1y, input_1u, input_1v = tf.split(input_1, 3, axis=3)
            input_4y, input_4u, input_4v = tf.split(input_4, 3, axis=3)
            input_3y, input_3u, input_3v = tf.split(input_3, 3, axis=3)
            input_2y, input_2u, input_2v = tf.split(input_2, 3, axis=3)
        """
        type = self.args.learning_order
        if type == "1234":
            order = tf.concat([
                tf.expand_dims(input_1[:, :, :, 0], axis=3),
                tf.expand_dims(input_2[:, :, :, 0], axis=3),
                tf.expand_dims(input_3[:, :, :, 0], axis=3),
                tf.expand_dims(input_4[:, :, :, 0], axis=3),
                tf.expand_dims(input_1[:, :, :, 1], axis=3),
                tf.expand_dims(input_2[:, :, :, 1], axis=3),
                tf.expand_dims(input_3[:, :, :, 1], axis=3),
                tf.expand_dims(input_4[:, :, :, 1], axis=3),
                tf.expand_dims(input_1[:, :, :, 2], axis=3),
                tf.expand_dims(input_2[:, :, :, 2], axis=3),
                tf.expand_dims(input_3[:, :, :, 2], axis=3),
                tf.expand_dims(input_4[:, :, :, 2], axis=3)
            ],
                              axis=3)
        elif type == "yuv":
            order = tf.concat([
                tf.expand_dims(input_1[:, :, :, 0], axis=3),
                tf.expand_dims(input_1[:, :, :, 1], axis=3),
                tf.expand_dims(input_1[:, :, :, 2], axis=3),
                tf.expand_dims(input_2[:, :, :, 0], axis=3),
                tf.expand_dims(input_2[:, :, :, 1], axis=3),
                tf.expand_dims(input_2[:, :, :, 2], axis=3),
                tf.expand_dims(input_3[:, :, :, 0], axis=3),
                tf.expand_dims(input_3[:, :, :, 1], axis=3),
                tf.expand_dims(input_3[:, :, :, 2], axis=3),
                tf.expand_dims(input_4[:, :, :, 0], axis=3),
                tf.expand_dims(input_4[:, :, :, 1], axis=3),
                tf.expand_dims(input_4[:, :, :, 2], axis=3)
            ],
                              axis=3)
        print("building order completed.\n")
        #order = tf.concat([tf.expand_dims(input_1[:,:,:,0],axis=3),tf.expand_dims(input_2[:,:,:,0],axis=3),tf.expand_dims(input_3[:,:,:,0],axis=3),tf.expand_dims(input_4[:,:,:,0],axis=3)], axis=3)

        pred_li = []
        ctx_li = []
        error_pred_li = []
        loss_pred_li = []
        loss_ctx_li = []
        loss_li = []
        total_loss = 0

        for i in range(NUM_PATCHES):
            #order1 = tf.concat([tf.expand_dims(input_1[:,:,:,0],axis=3),tf.expand_dims(input_2[:,:,:,0],axis=3)], axis=3)
            if i == 0:
                pred, ctx = model_conv(
                    tf.expand_dims(input_1[:, :, :, 0], axis=3), LAYER_NUM,
                    HIDDEN_UNIT, 'model_' + str(i + 1))
            else:
                pred, ctx = model_conv(order[:, :, :, :i], LAYER_NUM,
                                       HIDDEN_UNIT, 'model_' + str(i + 1))
            error_pred = abs(
                tf.subtract(pred,
                            tf.expand_dims(order[:, :, :, (i + 1)], axis=3)))
            loss_pred = tf.reduce_mean(error_pred)
            loss_ctx = LAMBDA_CTX * tf.reduce_mean(
                abs(tf.subtract(ctx, error_pred)))
            pred_li.append(pred)
            ctx_li.append(ctx)
            error_pred_li.append(error_pred)
            loss_pred_li.append(loss_pred)
            loss_ctx_li.append(loss_ctx)
            loss_li.append(loss_pred + loss_ctx)
            total_loss += loss_pred + loss_ctx

        all_vars = tf.trainable_variables()

        optimizer_li = []
        for j in range(NUM_PATCHES):
            vars = [
                var for var in all_vars if 'model_' + str(j + 1) in var.name
            ]
            optimizer = tf.train.AdamOptimizer(LR).minimize(loss_li[j],
                                                            var_list=vars)
            optimizer_li.append(optimizer)

        self.pred_li = pred_li
        self.ctx_li = ctx_li
        self.error_pred_li = error_pred_li
        self.loss_pred_li = loss_pred_li
        self.loss_ctx_li = loss_ctx_li
        self.loss_li = loss_li
        self.optimizer_li = optimizer_li
        self.optimizer_all = tf.train.AdamOptimizer(LR).minimize(
            total_loss, var_list=all_vars)
        '''
        # Prediction of 2
        pred_2, ctx_2 = model_conv(input_1, LAYER_NUM, HIDDEN_UNIT, 'pred_2')

        # Prediction of 3
        concat_1_2 = tf.concat([input_1, input_2], axis=3)
        pred_3, ctx_3 = model_conv(concat_1_2, LAYER_NUM, HIDDEN_UNIT, 'pred_3')

        # Prediction of 4
        concat_1_2_3 = tf.concat([input_1, input_2, input_3], axis=3)
        pred_4, ctx_4 = model_conv(concat_1_2_3, LAYER_NUM, HIDDEN_UNIT, 'pred_4')

        # Prediction error
        error_pred_2 = abs(tf.subtract(pred_2, input_2))
        error_pred_3 = abs(tf.subtract(pred_3, input_3))
        error_pred_4 = abs(tf.subtract(pred_4, input_4))

        # Losses
        loss_pred_2 = tf.reduce_mean(error_pred_2)
        loss_pred_3 = tf.reduce_mean(error_pred_3)
        loss_pred_4 = tf.reduce_mean(error_pred_4)

        loss_ctx_2 = LAMBDA_CTX * tf.reduce_mean(abs(tf.subtract(ctx_2, error_pred_2)))
        loss_ctx_3 = LAMBDA_CTX * tf.reduce_mean(abs(tf.subtract(ctx_3, error_pred_3)))
        loss_ctx_4 = LAMBDA_CTX * tf.reduce_mean(abs(tf.subtract(ctx_4, error_pred_4)))

        loss_2 = loss_pred_2 + loss_ctx_2
        loss_3 = loss_pred_3 + loss_ctx_3
        loss_4 = loss_pred_4 + loss_ctx_4

        total_loss = loss_2 + loss_3 + loss_4

        # Optimizer
        all_vars = tf.trainable_variables()
        vars_2 = [var for var in all_vars if 'pred_2' in var.name]
        vars_3 = [var for var in all_vars if 'pred_3' in var.name]
        vars_4 = [var for var in all_vars if 'pred_4' in var.name]

        self.optimizer_2 = tf.train.AdamOptimizer(LR).minimize(loss_2, var_list=vars_2)
        self.optimizer_3 = tf.train.AdamOptimizer(LR).minimize(loss_3, var_list=vars_3)
        self.optimizer_4 = tf.train.AdamOptimizer(LR).minimize(loss_4, var_list=vars_4)
        self.optimizer_all = tf.train.AdamOptimizer(LR).minimize(total_loss, var_list=all_vars)

        # Variables
        self.loss_2 = loss_2
        self.loss_3 = loss_3
        self.loss_4 = loss_4
        self.loss_all = loss_4 + loss_2 + loss_3

        self.loss_pred_2 = loss_pred_2
        self.loss_pred_3 = loss_pred_3
        self.loss_pred_4 = loss_pred_4
        self.loss_pred_all = loss_pred_2 + loss_pred_3 + loss_pred_4

        self.loss_ctx_2 = loss_ctx_2
        self.loss_ctx_3 = loss_ctx_3
        self.loss_ctx_4 = loss_ctx_4
        self.loss_ctx_all = loss_ctx_2 + loss_ctx_3 + loss_ctx_4

        self.pred_2 = pred_2
        self.pred_3 = pred_3
        self.pred_4 = pred_4

        self.ctx_2 = ctx_3
        self.ctx_3 = ctx_3
        self.ctx_4 = ctx_4
        '''
        # Original images
        self.input_1 = input_1
        self.input_2 = input_2
        self.input_3 = input_3
        self.input_4 = input_4
    def build(self):

        # Parameters
        DATA_DIR = self.args.data_dir

        LAYER_NUM = self.args.layer_num
        HIDDEN_UNIT = self.args.hidden_unit
        LAMBDA_CTX = self.args.lambda_ctx
        LAMBDA_Y = self.args.lambda_y
        LAMBDA_U = self.args.lambda_u
        LAMBDA_V = self.args.lambda_v
        LR = self.args.lr

        BATCH_SIZE = self.args.batch_size
        CROP_SIZE = self.args.crop_size

        CHANNEL_EPOCH = self.args.channel_epoch
        JOINT_EPOCH = self.args.joint_epoch

        tfrecord_name = 'train.tfrecord'

        if not data_exist(DATA_DIR, tfrecord_name):
            img_list = read_dir(DATA_DIR + 'train/')
            write_tfrecord(DATA_DIR, img_list, tfrecord_name)

        input_crop, _, _ = read_tfrecord(DATA_DIR, tfrecord_name, num_epochs=3*CHANNEL_EPOCH+JOINT_EPOCH,
                                        batch_size=4, min_after_dequeue=10, crop_size=CROP_SIZE)

        input_data, label = self.crop_to_data(input_crop)

        y_gt = tf.slice(label, [0, 0], [-1, 1])
        u_gt = tf.slice(label, [0, 1], [-1, 1])
        v_gt = tf.slice(label, [0, 2], [-1, 1])

        out_y, hidden_y = model(input_data, LAYER_NUM, HIDDEN_UNIT, 'pred_y')

        input_f2 = tf.concat([hidden_y, input_data, y_gt, tf.expand_dims(out_y[:,0], axis=1)], axis=1)

        out_u, hidden_u = model(input_f2, LAYER_NUM, HIDDEN_UNIT, 'pred_u')

        input_f3 = tf.concat([hidden_u, input_data, y_gt, tf.expand_dims(out_y[:, 0], axis=1), u_gt, tf.expand_dims(out_u[:, 0], axis=1)], axis=1)

        out_v, _, = model(input_f3, LAYER_NUM, HIDDEN_UNIT, 'pred_v')

        pred_y = out_y[:, 0]
        pred_u = out_u[:, 0]
        pred_v = out_v[:, 0]
        ctx_y  = tf.nn.relu(out_y[:, 1])
        ctx_u  = tf.nn.relu(out_u[:, 1])
        ctx_v  = tf.nn.relu(out_v[:, 1])

        predError_y = abs(tf.subtract(pred_y, tf.squeeze(y_gt, axis=1)))
        predError_u = abs(tf.subtract(pred_u, tf.squeeze(u_gt, axis=1)))
        predError_v = abs(tf.subtract(pred_v, tf.squeeze(v_gt, axis=1)))

        loss_pred_y = LAMBDA_Y * tf.reduce_mean(predError_y)
        loss_pred_u = LAMBDA_U * tf.reduce_mean(predError_u)
        loss_pred_v = LAMBDA_V * tf.reduce_mean(predError_v)

        loss_ctx_y = LAMBDA_Y * LAMBDA_CTX * tf.reduce_mean(abs(tf.subtract(ctx_y, predError_y)))
        loss_ctx_u = LAMBDA_U * LAMBDA_CTX * tf.reduce_mean(abs(tf.subtract(ctx_u, predError_u)))
        loss_ctx_v = LAMBDA_V * LAMBDA_CTX * tf.reduce_mean(abs(tf.subtract(ctx_v, predError_v)))

        loss_y = loss_pred_y + loss_ctx_y
        loss_u = loss_pred_u + loss_ctx_u
        loss_v = loss_pred_v + loss_ctx_v

        loss_yuv = loss_y + loss_u + loss_v

        t_vars = tf.trainable_variables()
        y_vars = [var for var in t_vars if 'pred_y' in var.name]
        u_vars = [var for var in t_vars if 'pred_u' in var.name]
        v_vars = [var for var in t_vars if 'pred_v' in var.name]

        self.optimizer_y = tf.train.AdamOptimizer(LR).minimize(loss_y, var_list=y_vars)
        self.optimizer_u = tf.train.AdamOptimizer(LR).minimize(loss_u, var_list=u_vars)
        self.optimizer_v = tf.train.AdamOptimizer(LR).minimize(loss_v, var_list=v_vars)
        self.optimizer_yuv = tf.train.AdamOptimizer(LR).minimize(loss_yuv, var_list=t_vars)

        # Variables
        self.loss_y = loss_y
        self.loss_u = loss_u
        self.loss_v = loss_v
        self.loss_yuv = loss_yuv
        self.loss_pred_y = loss_pred_y
        self.loss_pred_u = loss_pred_u
        self.loss_pred_v = loss_pred_v
        self.loss_pred_yuv = loss_pred_v + loss_pred_u + loss_pred_v
        self.loss_ctx_y = loss_ctx_y
        self.loss_ctx_u = loss_ctx_u
        self.loss_ctx_v = loss_ctx_v
        self.loss_ctx_yuv = loss_ctx_y + loss_ctx_u + loss_ctx_v
        self.ctx_y = ctx_y
        self.ctx_u = ctx_u
        self.ctx_v = ctx_v
예제 #8
0
import os
import tqdm
import pandas as pd
import tensorflow as tf
from data import read_tfrecord
from train import accuracy, num_blocks, classes, batch_size, one_hot_label
from resnet import resnet_backbone

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

if __name__ == '__main__':
    err = 0
    num_class = len(classes)
    img_name, img_data, label = read_tfrecord('./test.record',
                                              num_epochs=1,
                                              shuffle=False)
    img_name, img_batch, label_batch = tf.train.batch(
        [img_name, img_data, label], 1)
    logits, prob = resnet_backbone(img_batch,
                                   num_blocks,
                                   num_class,
                                   training=False)
    pred_class = tf.argmax(prob, axis=1)
    gt_class, _ = one_hot_label(label_batch, classes, 1)
    acc = accuracy(pred_class, gt_class)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)