def make_network(self, is_train):
        if is_train:
            image = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.data_shape, 3])
            label15 = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label11 = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label9 = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            label7 = tf.placeholder(tf.float32, shape=[cfg.batch_size, *cfg.output_shape, cfg.nr_skeleton])
            valids = tf.placeholder(tf.float32, shape=[cfg.batch_size, cfg.nr_skeleton])
            labels = [label15, label11, label9, label7]
            # labels.reverse() # The original labels are reversed. For reproduction of our pre-trained model, I'll keep it same.
            self.set_inputs(image, label15, label11, label9, label7, valids)
        else:
            image = tf.placeholder(tf.float32, shape=[None, *cfg.data_shape, 3])
            self.set_inputs(image)

        resnet_fms = resnet101(image, is_train, bn_trainable=True)
        global_fms, global_outs = create_global_net(resnet_fms, is_train)
        refine_out = create_refine_net(global_fms, is_train)

        # make loss
        if is_train:
            def ohkm(loss, top_k):
                ohkm_loss = 0.
                for i in range(cfg.batch_size):
                    sub_loss = loss[i]
                    topk_val, topk_idx = tf.nn.top_k(sub_loss, k=top_k, sorted=False, name='ohkm{}'.format(i))
                    tmp_loss = tf.gather(sub_loss, topk_idx, name='ohkm_loss{}'.format(i))  # can be ignore ???
                    ohkm_loss += tf.reduce_sum(tmp_loss) / top_k
                ohkm_loss /= cfg.batch_size
                return ohkm_loss

            global_loss = 0.
            for i, (global_out, label) in enumerate(zip(global_outs, labels)):
                global_label = label * tf.to_float(tf.greater(tf.reshape(valids, (-1, 1, 1, cfg.nr_skeleton)), 1.1))
                global_loss += tf.reduce_mean(tf.square(global_out - global_label)) / len(labels)
            dataset = "PoseTrack"
            if dataset == "COCO":
                global_loss /= 2.
            elif dataset == "PoseTrack":
                # global_loss /= (15.0/7.0)
                global_loss /= 1.
            self.add_tower_summary('global_loss', global_loss)
            refine_loss = tf.reduce_mean(tf.square(refine_out - label7), (1, 2)) * tf.to_float((tf.greater(valids, 0.1)))
            if dataset == "COCO":
                refine_loss = ohkm(refine_loss, 8)
            elif dataset == "PoseTrack":
                refine_loss = ohkm(refine_loss, 7)
                print("ohkm = 7")
            self.add_tower_summary('refine_loss', refine_loss)

            total_loss = refine_loss + global_loss
            self.add_tower_summary('loss', total_loss)
            self.set_loss(total_loss)
        else:
            self.set_outputs(refine_out)
Beispiel #2
0
    def make_network(self, is_train):
        if is_train:
            image = tf.placeholder(tf.float32, shape=[cfg.batch_size, cfg.data_shape[0], cfg.data_shape[1], 3])
            label15 = tf.placeholder(tf.float32, shape=[cfg.batch_size, cfg.output_shape[0], cfg.output_shape[1], cfg.nr_skeleton])
            label11 = tf.placeholder(tf.float32, shape=[cfg.batch_size,cfg.output_shape[0], cfg.output_shape[1], cfg.nr_skeleton])
            label9 = tf.placeholder(tf.float32, shape=[cfg.batch_size, cfg.output_shape[0], cfg.output_shape[1], cfg.nr_skeleton])
            label7 = tf.placeholder(tf.float32, shape=[cfg.batch_size, cfg.output_shape[0], cfg.output_shape[1], cfg.nr_skeleton])
            valids = tf.placeholder(tf.float32, shape=[cfg.batch_size, cfg.nr_skeleton])
            labels = [label15, label11, label9, label7]
            self.set_inputs(image, label15, label11, label9, label7, valids)
        else:
            image = tf.placeholder(tf.float32, shape=[None, cfg.data_shape[0], cfg.data_shape[1], 3])
            self.set_inputs(image)

        resnet_fms = resnet101(image, is_train, bn_trainable=True)
        global_fms, global_outs = create_global_net(resnet_fms, is_train)
        refine_out = create_refine_net(global_fms, is_train)

        # make loss
        if is_train:
            def ohkm(loss, top_k):
                ohkm_loss = 0.
                for i in range(cfg.batch_size):
                    sub_loss = loss[i]
                    topk_val, topk_idx = tf.nn.top_k(sub_loss, k=top_k, sorted=False, name='ohkm{}'.format(i))
                    tmp_loss = tf.gather(sub_loss, topk_idx, name='ohkm_loss{}'.format(i)) # can be ignore ???
                    ohkm_loss += tf.reduce_sum(tmp_loss) / top_k
                ohkm_loss /= cfg.batch_size
                return ohkm_loss

            global_loss = 0.
            for i, (global_out, label) in enumerate(zip(global_outs, labels)):
                global_label = label * tf.to_float(tf.greater(tf.reshape(valids, (-1, 1, 1, cfg.nr_skeleton)), 1.1))
                global_loss += tf.reduce_mean(tf.square(global_out - global_label)) / len(labels)
            global_loss /= 2.
            self.add_tower_summary('global_loss', global_loss)
            refine_loss = tf.reduce_mean(tf.square(refine_out - label7), (1,2)) * tf.to_float((tf.greater(valids, 0.1)))
            refine_loss = ohkm(refine_loss, 8)
            self.add_tower_summary('refine_loss', refine_loss)

            total_loss = refine_loss + global_loss
            self.add_tower_summary('loss', total_loss)
            self.set_loss(total_loss)
        else:
            self.set_outputs(refine_out)