Exemplo n.º 1
0
    def make_network(self, is_train):
        if is_train:
            image = tf.placeholder(tf.float32,
                                   shape=[cfg.batch_size, *cfg.data_shape, 3])
            label15 = tf.placeholder(
                tf.float32,
                shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label11 = tf.placeholder(
                tf.float32,
                shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label9 = tf.placeholder(
                tf.float32,
                shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label7 = tf.placeholder(
                tf.float32,
                shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            valids = tf.placeholder(tf.float32,
                                    shape=[cfg.batch_size, cfg.nr_skeleton])
            labels = [label15, label11, label9, label7]
            # labels.reverse() # The original labels are reversed. For reproduction of our pre-trained model, I'll keep it same.
            self.set_inputs(image, label15, label11, label9, label7, valids)
        else:
            image = tf.placeholder(tf.float32,
                                   shape=[None, *cfg.data_shape, 3])
            # labels.reverse() # The original labels are reversed. For reproduction of our pre-trained model, I'll keep it same.
            self.set_inputs(image)

        resnet_fms = resnet50(image, is_train, bn_trainable=True)
        out = create_deconv_net(resnet_fms[3], is_train)

        def ohkm(loss, top_k):
            ohkm_loss = 0.
            for i in range(cfg.batch_size):
                sub_loss = loss[i]
                topk_val, topk_idx = tf.nn.top_k(sub_loss,
                                                 k=top_k,
                                                 sorted=False,
                                                 name='ohkm{}'.format(i))
                tmp_loss = tf.gather(
                    sub_loss, topk_idx,
                    name='ohkm_loss{}'.format(i))  # can be ignore ???
                ohkm_loss += tf.reduce_sum(tmp_loss) / top_k
            ohkm_loss /= cfg.batch_size
            return ohkm_loss

        # make loss
        if is_train:
            #             print(out.shape,label7.shape)(24, 96, 72, 17)
            total_loss = tf.reduce_mean(tf.square(out - label7),
                                        (1, 2)) * tf.to_float(
                                            (tf.greater(valids, 0.1)))
            #             print(total_loss.shape)
            total_loss_value = tf.reduce_sum(total_loss) / cfg.batch_size
            self.add_tower_summary('loss', total_loss_value)
            self.set_loss(total_loss_value)
        else:
            self.set_outputs(out)
Exemplo n.º 2
0
    def make_network(self, is_train):
        if is_train:
            image = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.data_shape, 3])
            label15 = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label11 = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label9 = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label7 = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            valids = tf.placeholder(tf.float32, shape=[cfg.batch_size, cfg.nr_skeleton])
            labels = [label15, label11, label9, label7]
            self.set_inputs(image, label15, label11, label9, label7, valids)
        else:
            image = tf.placeholder(tf.float32, shape=[None, *cfg.data_shape, 3])
            self.set_inputs(image)

        resnet_fms = resnet50(image, is_train, bn_trainable=True)
        global_fms, global_outs = create_global_net(resnet_fms, is_train)
        refine_out = create_refine_net(global_fms, is_train)

        # make loss
        if is_train:
            def ohkm(loss, top_k):
                ohkm_loss = 0.
                for i in range(cfg.batch_size):
                    sub_loss = loss[i]
                    topk_val, topk_idx = tf.nn.top_k(sub_loss, k=top_k, sorted=False, name='ohkm{}'.format(i))
                    tmp_loss = tf.gather(sub_loss, topk_idx, name='ohkm_loss{}'.format(i)) # can be ignore ???
                    ohkm_loss += tf.reduce_sum(tmp_loss) / top_k
                ohkm_loss /= cfg.batch_size
                return ohkm_loss

            global_loss = 0.
            for i, (global_out, label) in enumerate(zip(global_outs, labels)):
                global_label = label * tf.to_float(tf.greater(tf.reshape(valids, (-1, 1, 1, cfg.nr_skeleton)), 1.1))
                global_loss += tf.reduce_mean(tf.square(global_out - global_label)) / len(labels)
            global_loss /= 2.
            self.add_tower_summary('global_loss', global_loss)
            refine_loss = tf.reduce_mean(tf.square(refine_out - label7), (1,2)) * tf.to_float((tf.greater(valids, 0.1)))
            refine_loss = ohkm(refine_loss, 8)
            self.add_tower_summary('refine_loss', refine_loss)

            total_loss = refine_loss + global_loss
            self.add_tower_summary('loss', total_loss)
            self.set_loss(total_loss)
        else:
            self.set_outputs(refine_out)
Exemplo n.º 3
0
    def make_network(self, is_train):
        if is_train:
            image = tf.placeholder(tf.float32,
                                   shape=[cfg.batch_size, *cfg.data_shape, 3])
            label15 = tf.placeholder(
                tf.float32,
                shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label11 = tf.placeholder(
                tf.float32,
                shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label9 = tf.placeholder(
                tf.float32,
                shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label7 = tf.placeholder(
                tf.float32,
                shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            valids = tf.placeholder(tf.float32,
                                    shape=[cfg.batch_size, cfg.nr_skeleton])
            labels = [label15, label11, label9, label7]
            # labels.reverse() # The original labels are reversed. For reproduction of our pre-trained model, I'll keep it same.
            self.set_inputs(image, label15, label11, label9, label7, valids)
        else:
            image = tf.placeholder(tf.float32,
                                   shape=[None, *cfg.data_shape, 3])
            self.set_inputs(image)

        #with tf.name_scope('image'):
        #    image_shaped_input = tf.reshape(image, [-1, 256, 192, 3])
        #   tf.summary.image('input',image_shaped_input,4)

        resnet_fms = resnet50(
            image, is_train, bn_trainable=True
        )  #image = [32,256,192,3]   resnet_fms = [(32,64,48,256),(32,32,24,512),(32,16,12,1024),(32,8,6,2048)]
        global_fms, global_outs = create_global_net(
            resnet_fms, is_train
        )  #global_fms = [(32,64,48,256),(32,32,24,256),(32,16,12,256),(32,8,6,256)],global_outs = [(32,64,48,17),(32,64,48,17),(32,64,48,17),(32,64,48,17)]

        #
        coarse_fms, coarse_outs = create_coarse_net(global_fms, is_train)
        #
        refine_out = create_refine_net(coarse_fms,
                                       is_train)  #refine_out = (32,64,48,17)
        #with tf.name_scope('refine_out'):
        #    refine_out_reshape = tf.reshape(refine_out,[-1,64,48,1])
        #   tf.summary.image('heatmap',refine_out_reshape,17)
        # make loss
        if is_train:

            def ohkm(loss, top_k):
                ohkm_loss = 0.
                for i in range(cfg.batch_size):
                    sub_loss = loss[i]
                    topk_val, topk_idx = tf.nn.top_k(sub_loss,
                                                     k=top_k,
                                                     sorted=False,
                                                     name='ohkm{}'.format(i))
                    tmp_loss = tf.gather(
                        sub_loss, topk_idx,
                        name='ohkm_loss{}'.format(i))  # can be ignore ???
                    ohkm_loss += tf.reduce_sum(tmp_loss) / top_k
                ohkm_loss /= cfg.batch_size
                return ohkm_loss

            global_loss = 0.
            for i, (global_out, label) in enumerate(zip(global_outs, labels)):
                global_label = label * tf.to_float(
                    tf.greater(tf.reshape(valids,
                                          (-1, 1, 1, cfg.nr_skeleton)), 1.1))
                global_loss += tf.reduce_mean(
                    tf.square(global_out - global_label)) / len(labels)
            global_loss /= 2.
            self.add_tower_summary('global_loss', global_loss)
            tf.summary.scalar('global_loss', global_loss)
            '''
            #cdx 2018.10.18
            coarse_loss = 0.
            for i,(coarse_out,label) in enumerate(zip(coarse_outs,labels)):
                coarse_label = label*tf.to_float(tf.greater(tf.reshape(valids,(-1,1,1,cfg.nr_skeleton)),1.1))
                coarse_loss += tf.reduce_mean(tf.square(coarse_out - coarse_label)) / len(labels)
            self.add_tower_summary('coarse_loss',coarse_loss)
            tf.summary.scalar('coarse_loss',coarse_loss)
            #coarse_loss = ohkm(coarse_loss,12)
            #
            '''
            coarse_loss = tf.reduce_mean(tf.square(coarse_outs - label7),
                                         (1, 2)) * tf.to_float(
                                             (tf.greater(valids, 0.1)))
            coarse_loss = ohkm(coarse_loss, 12)
            self.add_tower_summary('coarse_loss', coarse_loss)
            tf.summary.scalar('coarse_loss', coarse_loss)

            refine_loss = tf.reduce_mean(tf.square(refine_out - label7),
                                         (1, 2)) * tf.to_float(
                                             (tf.greater(valids, 0.1)))
            refine_loss = ohkm(refine_loss, 8)
            self.add_tower_summary('refine_loss', refine_loss)
            tf.summary.scalar('refine_loss', refine_loss)
            total_loss = refine_loss + global_loss + coarse_loss
            self.add_tower_summary('loss', total_loss)
            self.set_loss(total_loss)
        else:
            self.set_outputs(refine_out)