def get_loss_and_output(model,
                        batchsize,
                        scoremap,
                        hand_motion,
                        reuse_variables=None):

    with tf.variable_scope("diff", reuse=reuse_variables):
        network_mv2_hourglass.N_KPOINTS = 1
        _, pred_diffmap_all = get_network(model, scoremap, True)
    losses = []
    for idx, pred_heat in enumerate(pred_diffmap_all):
        # flatten
        s = pred_heat.get_shape().as_list()
        pred_heat = tf.reshape(pred_heat,
                               [-1, s[3] * s[1] * s[2]])  # this is Bx16*16*1

        # pred_heat --> 3 params
        out_chan_list = [32, 16, 8]
        for i, out_chan in enumerate(out_chan_list):
            pred_heat = ops.fully_connected_relu(pred_heat,
                                                 'fc_vp_%d_%d' % (idx, i),
                                                 out_chan=out_chan,
                                                 trainable=True)
            evaluation = tf.placeholder_with_default(True, shape=())
            pred_heat = pred_heat  # ops.dropout(pred_heat, 0.95, evaluation)

        ux = ops.fully_connected(pred_heat,
                                 'fc_vp_ux_%d' % idx,
                                 out_chan=1,
                                 trainable=True)
        uy = ops.fully_connected(pred_heat,
                                 'fc_vp_uy_%d' % idx,
                                 out_chan=1,
                                 trainable=True)
        uz = ops.fully_connected(pred_heat,
                                 'fc_vp_uz_%d' % idx,
                                 out_chan=1,
                                 trainable=True)
        ur = ops.fully_connected(pred_heat,
                                 'fc_vp_ur_%d' % idx,
                                 out_chan=1,
                                 trainable=True)

        loss_l2r = tf.nn.l2_loss(hand_motion[:, 0] - ur[:, 0],
                                 name='lossr_heatmap_stage%d' % idx)
        loss_l2x = tf.nn.l2_loss(hand_motion[:, 1] - ux[:, 0],
                                 name='lossx_heatmap_stage%d' % idx)
        loss_l2y = tf.nn.l2_loss(hand_motion[:, 2] - uy[:, 0],
                                 name='lossy_heatmap_stage%d' % idx)
        loss_l2z = tf.nn.l2_loss(hand_motion[:, 3] - uz[:, 0],
                                 name='lossz_heatmap_stage%d' % idx)
        losses.append(loss_l2x + loss_l2y + loss_l2r * 0.001 +
                      loss_l2z * 0.001)

    ufxuz = tf.concat(values=[ur, ux, uy, uz], axis=1, name='fxuz')

    motion_loss = tf.reduce_sum(losses) / batchsize
    alph = 0.5
    total_loss = motion_loss
    return total_loss, ufxuz
def get_loss_and_output(model,
                        batchsize,
                        input_image,
                        scoremap,
                        is_loss,
                        reuse_variables=None):
    losses = []

    # 叠加在batch上重用特征提取网络
    input_image = tf.add(input_image, 0, name='input_image')
    with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
        network_mv2_hourglass.N_KPOINTS = 2
        _, pred_heatmaps_all = get_network(model, input_image,
                                           True)  #第一个batch的维度 hand back

    loss_scoremap = 0.0
    loss_is_loss = 0.0
    for loss_i in range(len(pred_heatmaps_all)):
        # 计算 isloss,用softmax计算 0~1}
        is_loss_s = pred_heatmaps_all[loss_i].get_shape().as_list()
        pre_is_loss = tf.reshape(
            pred_heatmaps_all[loss_i],
            [-1, is_loss_s[1] * is_loss_s[2] * is_loss_s[3]
             ])  # this is Bx16*16*1
        out_chan_list = [32, 16, 8, 2]
        for i, out_chan in enumerate(out_chan_list):
            pre_is_loss = ops.fully_connected_relu(pre_is_loss,
                                                   'is_loss_fc_%d_%d' %
                                                   (loss_i, i),
                                                   out_chan=out_chan,
                                                   trainable=True)  #(128,1)

        # 计算热度图
        scale = 2
        pred_heatmaps_tmp = upsample(pred_heatmaps_all[loss_i],
                                     scale,
                                     name="upsample_for_hotmap_loss_%d" %
                                     loss_i)

        #用is loss 修正热度图
        pre_is_loss = tf.nn.softmax(pre_is_loss)
        pred_heatmaps_tmp_01_modi = tf.expand_dims(
            tf.expand_dims(pre_is_loss, axis=1), axis=1) * pred_heatmaps_tmp
        pred_heatmaps_tmp = tf.nn.softmax(pred_heatmaps_tmp)
        pred_heatmaps_tmp_01_modi = tf.nn.softmax(pred_heatmaps_tmp_01_modi)

    total_loss = loss_scoremap + loss_is_loss
    return pred_heatmaps_tmp, pre_is_loss, pred_heatmaps_tmp_01_modi
                hand_motion_back = batch_data_all_back[9]
                scoremap1_back = batch_data_all_back[11]
                scoremap2_back = batch_data_all_back[12]
                is_loss1_back = batch_data_all_back[13]
                is_loss2_back = batch_data_all_back[14]

                input_image1 = tf.concat([input_image1, input_image1_back],
                                         0)  # 第一个batch的维度 hand1 back1
                input_image2 = tf.concat([input_image2, input_image2_back], 0)
                input_image12 = tf.concat([input_image1, input_image2],
                                          0)  # hand1 back1 hand2 back2
                input_image12.set_shape([batchsize * 4, 32, 32, 3])

                with tf.variable_scope(tf.get_variable_scope(), reuse=False):
                    network_mv2_hourglass.N_KPOINTS = 1
                    _, pred_heatmaps_all12 = get_network(
                        'mv2_hourglass', input_image12, True)

                for batch_i in range(len(pred_heatmaps_all12)):
                    # 计算 isloss,用softmax计算 0~1}
                    is_loss_s = pred_heatmaps_all12[batch_i].get_shape(
                    ).as_list()
                    pre_is_loss = tf.reshape(
                        pred_heatmaps_all12[batch_i],
                        [is_loss_s[0], -1])  # this is Bx16*16*1
                    out_chan_list = [32, 16, 8, 1]
                    for i, out_chan in enumerate(out_chan_list):
                        pre_is_loss = ops.fully_connected_relu(
                            pre_is_loss,
                            'is_loss_fc_%d_%d' % (batch_i, i),
                            out_chan=out_chan,
                            trainable=True)  # (128,1)
Exemple #4
0
def get_loss_and_output(model,
                        batchsize,
                        input_image1,
                        input_image2,
                        hand_motion,
                        scoremap1,
                        scoremap2,
                        is_loss1,
                        is_loss2,
                        reuse_variables=None):
    losses = []

    # 叠加在batch上重用特征提取网络
    input_image12 = tf.concat([input_image1, input_image2],
                              0)  #hand1 back1 hand2 back2
    input_image12.set_shape([batchsize * 4, 32, 32, 3])
    input_image12 = tf.add(input_image12, 0, name='input_image')
    with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
        network_mv2_hourglass.N_KPOINTS = 1
        _, pred_heatmaps_all12 = get_network(
            model, input_image12, True)  #第一个batch的维度 hand1 back1 hand2 back2
    # 计算一个scoremap的loss
    scoremap12 = tf.concat([scoremap1, scoremap2], 0)
    scoremap12 = tf.reduce_sum(scoremap12, axis=-1)
    one_scoremap12 = tf.ones_like(scoremap12)
    scoremap12 = tf.where(scoremap12 > 1, x=one_scoremap12, y=scoremap12)
    scoremap12 = tf.expand_dims(scoremap12, axis=-1)
    scoremap12.set_shape([batchsize * 4, 32, 32, 1])  #hand1 back1 hand2 back2

    is_loss12 = tf.concat([is_loss1, is_loss2], 0)
    is_loss12 = tf.expand_dims(is_loss12, axis=-1)
    is_loss12.set_shape([batchsize * 4, 1])

    loss_scoremap = 0
    loss_is_loss = 0
    for loss_i in range(len(pred_heatmaps_all12)):
        # 计算 isloss,用softmax计算 0~1}
        is_loss_s = pred_heatmaps_all12[loss_i].get_shape().as_list()
        pre_is_loss = tf.reshape(pred_heatmaps_all12[loss_i],
                                 [is_loss_s[0], -1])  # this is Bx16*16*1
        out_chan_list = [32, 16, 8, 1]
        for i, out_chan in enumerate(out_chan_list):
            pre_is_loss = ops.fully_connected_relu(pre_is_loss,
                                                   'is_loss_fc_%d_%d' %
                                                   (loss_i, i),
                                                   out_chan=out_chan,
                                                   trainable=True)  #(128,1)
        #将pre_is_loss约束在01之间
        one_pre_is_loss = tf.ones_like(pre_is_loss)
        zero_pre_is_loss = tf.zeros_like(pre_is_loss)
        pre_is_loss = tf.where(pre_is_loss > 1,
                               x=one_pre_is_loss,
                               y=pre_is_loss)
        pre_is_loss = tf.where(pre_is_loss < 0,
                               x=zero_pre_is_loss,
                               y=pre_is_loss)

        #pre_is_loss = tf.nn.softmax(pre_is_loss)
        loss_is_loss += tf.nn.l2_loss(pre_is_loss - is_loss12)

        # 计算热度图
        scale = 2
        pred_heatmaps_tmp = upsample(pred_heatmaps_all12[loss_i],
                                     scale,
                                     name="upsample_for_hotmap_loss_%d" %
                                     loss_i)
        one_tmp = tf.ones_like(pred_heatmaps_tmp)
        pred_heatmaps_tmp = tf.where(pred_heatmaps_tmp > 1,
                                     x=one_tmp,
                                     y=pred_heatmaps_tmp)

        #用is loss 修正热度图
        pred_heatmaps_tmp_ = tf.expand_dims(
            tf.expand_dims(pre_is_loss, axis=-1), axis=-1) * pred_heatmaps_tmp
        #pred_heatmaps_tmp = tf.nn.softmax(pred_heatmaps_tmp)
        #loss_scoremap += -tf.reduce_mean(scoremap12 * tf.log(pred_heatmaps_tmp))
        #loss_scoremap += tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred_heatmaps_tmp, labels=scoremap12))
        loss_scoremap += tf.nn.l2_loss(pred_heatmaps_tmp - scoremap12)
    loss_is_loss = loss_is_loss / 32.0 / 4.0
    loss_scoremap = loss_scoremap / 32.0 / 4.0 / 32.0 / 32.0
    diffmap = []
    for batch_i in range(len(pred_heatmaps_all12)):  #hand1 back1 hand2 back2
        diffmap.append(pred_heatmaps_all12[batch_i][0:batchsize * 2] -
                       pred_heatmaps_all12[batch_i][batchsize * 2:batchsize *
                                                    4])

    #diffmap_t 将4个阶段的输出,在通道数上整合
    for batch_i in range(len(diffmap)):
        if batch_i == 0:
            diffmap_t = diffmap[batch_i]
        else:
            diffmap_t = tf.concat([diffmap[batch_i], diffmap_t],
                                  axis=3)  #hand12 back12

    with tf.variable_scope("diff", reuse=reuse_variables):
        network_mv2_hourglass.N_KPOINTS = 1
        _, pred_diffmap_all = get_network(model, diffmap_t, True)
    losses = []
    for idx, pred_heat in enumerate(pred_diffmap_all):
        # flatten
        s = pred_heat.get_shape().as_list()
        pred_heat = tf.reshape(pred_heat, [s[0], -1])  # this is Bx16*16*1
        #x = tf.concat([x, hand_side], 1)

        # pred_heat --> 3 params
        out_chan_list = [32, 16, 8]
        for i, out_chan in enumerate(out_chan_list):
            pred_heat = ops.fully_connected_relu(pred_heat,
                                                 'fc_vp_%d_%d' % (idx, i),
                                                 out_chan=out_chan,
                                                 trainable=True)
            evaluation = tf.placeholder_with_default(True, shape=())
            pred_heat = pred_heat  # ops.dropout(pred_heat, 0.95, evaluation)

        ux = ops.fully_connected(pred_heat,
                                 'fc_vp_ux_%d' % idx,
                                 out_chan=1,
                                 trainable=True)
        uy = ops.fully_connected(pred_heat,
                                 'fc_vp_uy_%d' % idx,
                                 out_chan=1,
                                 trainable=True)
        uz = ops.fully_connected(pred_heat,
                                 'fc_vp_uz_%d' % idx,
                                 out_chan=1,
                                 trainable=True)
        ur = ops.fully_connected(pred_heat,
                                 'fc_vp_ur_%d' % idx,
                                 out_chan=1,
                                 trainable=True)

        loss_l2r = tf.nn.l2_loss(hand_motion[:, 0] - ur[:, 0],
                                 name='lossr_heatmap_stage%d' % idx)
        loss_l2x = tf.nn.l2_loss(hand_motion[:, 1] - ux[:, 0],
                                 name='lossx_heatmap_stage%d' % idx)
        loss_l2y = tf.nn.l2_loss(hand_motion[:, 2] - uy[:, 0],
                                 name='lossy_heatmap_stage%d' % idx)
        loss_l2z = tf.nn.l2_loss(hand_motion[:, 3] - uz[:, 0],
                                 name='lossz_heatmap_stage%d' % idx)
        losses.append(loss_l2x + loss_l2y + loss_l2r * 0.001 +
                      loss_l2z * 0.001)
    ufxuz = tf.concat(values=[ur, ux, uy, uz], axis=1, name='fxuz')

    motion_loss = tf.reduce_sum(losses) / batchsize / 2.0
    alph = 0.5
    total_loss = motion_loss * 2 + loss_scoremap * 2 + loss_is_loss * 0.5
    return total_loss, motion_loss*2, loss_scoremap*2, loss_is_loss*0.5,\
           ur, ux, uy, uz, ufxuz, pred_heatmaps_tmp,pred_heatmaps_tmp_, pre_is_loss, is_loss12
Exemple #5
0
def get_loss_and_output(model,
                        batchsize,
                        input_image1,
                        input_image2,
                        reuse_variables=None):
    losses = []

    # 叠加在batch上重用特征提取网络
    input_image12 = tf.concat([input_image1, input_image2],
                              0)  #hand1 back1 hand2 back2
    #input_image12.set_shape([batchsize * 4, 32, 32, 3])
    #input_image12 = tf.add(input_image12, 0, name='input_image')
    with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
        network_mv2_hourglass.N_KPOINTS = 1
        _, pred_heatmaps_all12 = get_network(
            model, input_image12, True)  #第一个batch的维度 hand1 back1 hand2 back2

    for loss_i in range(len(pred_heatmaps_all12)):
        # 计算 isloss,用softmax计算 0~1}
        # is_loss_s = pred_heatmaps_all12[loss_i].get_shape().as_list()
        pre_is_loss = tf.reshape(pred_heatmaps_all12[loss_i],
                                 [batchsize * 4, -1])  # this is Bx16*16*1
        out_chan_list = [32, 16, 8, 1]
        for i, out_chan in enumerate(out_chan_list):
            pre_is_loss = ops.fully_connected_relu(pre_is_loss,
                                                   'is_loss_fc_%d_%d' %
                                                   (loss_i, i),
                                                   out_chan=out_chan,
                                                   trainable=True)  #(128,1)
        #将pre_is_loss约束在01之间
        one_pre_is_loss = tf.ones_like(pre_is_loss)
        zero_pre_is_loss = tf.zeros_like(pre_is_loss)
        pre_is_loss = tf.where(pre_is_loss > 1,
                               x=one_pre_is_loss,
                               y=pre_is_loss)
        pre_is_loss = tf.where(pre_is_loss < 0,
                               x=zero_pre_is_loss,
                               y=pre_is_loss)

        #pre_is_loss = tf.nn.softmax(pre_is_loss)

        # 计算热度图
        scale = 2
        pred_heatmaps_tmp = upsample(pred_heatmaps_all12[loss_i],
                                     scale,
                                     name="upsample_for_hotmap_loss_%d" %
                                     loss_i)
        one_tmp = tf.ones_like(pred_heatmaps_tmp)
        pred_heatmaps_tmp = tf.where(pred_heatmaps_tmp > 1,
                                     x=one_tmp,
                                     y=pred_heatmaps_tmp)

        #用is loss 修正热度图
        pred_heatmaps_tmp_ = tf.expand_dims(
            tf.expand_dims(pre_is_loss, axis=-1), axis=-1) * pred_heatmaps_tmp

    diffmap = []
    for batch_i in range(len(pred_heatmaps_all12)):  #hand1 back1 hand2 back2
        diffmap.append(pred_heatmaps_all12[batch_i][0:batchsize * 2] -
                       pred_heatmaps_all12[batch_i][batchsize * 2:batchsize *
                                                    4])

    #diffmap_t 将4个阶段的输出,在通道数上整合
    for batch_i in range(len(diffmap)):
        if batch_i == 0:
            diffmap_t = diffmap[batch_i]
        else:
            diffmap_t = tf.concat([diffmap[batch_i], diffmap_t],
                                  axis=3)  #hand12 back12

    with tf.variable_scope("diff", reuse=reuse_variables):
        network_mv2_hourglass.N_KPOINTS = 1
        _, pred_diffmap_all = get_network(model, diffmap_t, True)
    for idx, pred_heat in enumerate(pred_diffmap_all):
        # flatten
        s = pred_heat.get_shape().as_list()
        pred_heat = tf.reshape(pred_heat, [s[0], -1])  # this is Bx16*16*1
        #x = tf.concat([x, hand_side], 1)

        # pred_heat --> 3 params
        out_chan_list = [32, 16, 8]
        for i, out_chan in enumerate(out_chan_list):
            pred_heat = ops.fully_connected_relu(pred_heat,
                                                 'fc_vp_%d_%d' % (idx, i),
                                                 out_chan=out_chan,
                                                 trainable=True)
            evaluation = tf.placeholder_with_default(True, shape=())
            pred_heat = pred_heat  # ops.dropout(pred_heat, 0.95, evaluation)

        ux = ops.fully_connected(pred_heat,
                                 'fc_vp_ux_%d' % idx,
                                 out_chan=1,
                                 trainable=True)
        uy = ops.fully_connected(pred_heat,
                                 'fc_vp_uy_%d' % idx,
                                 out_chan=1,
                                 trainable=True)
        uz = ops.fully_connected(pred_heat,
                                 'fc_vp_uz_%d' % idx,
                                 out_chan=1,
                                 trainable=True)
        ur = ops.fully_connected(pred_heat,
                                 'fc_vp_ur_%d' % idx,
                                 out_chan=1,
                                 trainable=True)

    ufxuz = tf.concat(values=[ur, ux, uy, uz], axis=1, name='fxuz')

    return ur, ux, uy, uz, ufxuz, pred_heatmaps_tmp, pred_heatmaps_tmp_, pre_is_loss
Exemple #6
0
def get_loss_and_output(model,
                        batchsize,
                        input_image,
                        scoremap,
                        finger_mask_sum,
                        reuse_variables=None):
    # 叠加在batch上重用特征提取网络
    input_image = tf.add(input_image, 0, name='input_image')
    with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
        network_mv2_hourglass.N_KPOINTS = 2
        network_mv2_hourglass.STAGE_NUM = 2
        _, pred_heatmaps_all = get_network(model, input_image,
                                           True)  #第一个batch的维度 hand back

        # # 将估计的最后一层用全连接连到 5,再用5*5*5*5来个几层估计最后的结果
        # z_rate_pre = ops.max_pool(tf.expand_dims(pred_heatmaps_all[-1][:,:,:,0], axis=-1))
        # z_rate_pre = tf.reshape(z_rate_pre, [batchsize, 16*80])
        # z_rate_pre = ops.fully_connected_relu(z_rate_pre, "z_rate_Pro", 5)
        # for i_z in range(5):
        #     z_rate_pre = ops.fully_connected_relu(z_rate_pre, "z_rate_"+str(i_z), 5)
        #
    loss_scoremap = 0.0
    # loss_zrate = 0.0
    loss_scoremap_m = 0.0
    loss_is_loss = 0.0
    for loss_i in range(len(pred_heatmaps_all)):
        # # 计算 isloss,用softmax计算 0~1}
        # is_loss_s = pred_heatmaps_all[loss_i].get_shape().as_list()
        # pre_is_loss = tf.reshape(pred_heatmaps_all[loss_i], [-1, is_loss_s[1]*is_loss_s[2]*is_loss_s[3]])  # this is Bx16*16*1
        # out_chan_list = [32, 16, 8, 2]
        # for i, out_chan in enumerate(out_chan_list):
        #     pre_is_loss = ops.fully_connected_relu(pre_is_loss, 'is_loss_fc_%d_%d' % (loss_i, i), out_chan=out_chan, trainable=True)#(128,1)
        # #将pre_is_loss [?,2]
        # loss_is_loss += tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pre_is_loss, labels=is_loss))

        # 计算热度图
        # scale = 2
        # pred_heatmaps_tmp = upsample(pred_heatmaps_all[loss_i], scale, name="upsample_for_hotmap_loss_%d" % loss_i)
        pred_heatmaps_tmp = pred_heatmaps_all[loss_i]

        #在计算loss时将其约束在01之间可以增加估计热度图的对比度
        s = scoremap.get_shape().as_list()
        gt = tf.reshape(scoremap, [batchsize * s[1] * s[2], -1])
        pred = tf.reshape(pred_heatmaps_tmp, [batchsize * s[1] * s[2], -1])
        loss_scoremap += tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=gt))
        pred_heatmaps_tmp = tf.nn.softmax(pred_heatmaps_tmp)
        # loss_zrate += tf.reduce_mean(tf.abs(z_rate_pre - finger_mask_sum))
        # one_tmp = tf.ones_like(pred_heatmaps_tmp)
        # zero_tmp = tf.zeros_like(pred_heatmaps_tmp)
        # pred_heatmaps_tmp_01 = tf.where(pred_heatmaps_tmp > 1, x=one_tmp, y=pred_heatmaps_tmp)
        # pred_heatmaps_tmp_01 = tf.where(pred_heatmaps_tmp_01 < 0, x=zero_tmp, y=pred_heatmaps_tmp_01)
        # loss_scoremap += tf.nn.l2_loss(pred_heatmaps_tmp_01 - scoremap)

        #用is loss 修正热度图
        # pre_is_loss = tf.nn.softmax(pre_is_loss)
        #
        # pred_heatmaps_tmp_modi = tf.expand_dims(tf.expand_dims(pre_is_loss, axis=1), axis=1)*pred_heatmaps_tmp
        #
        # loss_scoremap_m += tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred_heatmaps_tmp_modi, labels=gt))
        #
        # pred_heatmaps_tmp = tf.nn.softmax(pred_heatmaps_tmp)
        # pred_heatmaps_tmp_modi = tf.nn.softmax(pred_heatmaps_tmp_modi)
    #

    # loss_is_loss = loss_is_loss
    # loss_scoremap = loss_scoremap
    # loss_scoremap_m = loss_scoremap_m

    total_loss = loss_scoremap  #*3 + loss_zrate*0.001# + loss_is_loss + loss_scoremap_m
    #return total_loss, loss_is_loss, loss_scoremap, loss_scoremap_m, pred_heatmaps_tmp, pre_is_loss, pred_heatmaps_tmp_modi
    return total_loss, pred_heatmaps_tmp  #, [total_loss, loss_scoremap, loss_zrate, z_rate_pre, finger_mask_sum]
        int(inputs.get_shape()[1]) * factor,
        int(inputs.get_shape()[2]) * factor
    ],
                                    name=name)


with tf.Graph().as_default(), tf.device("/cpu:0"):
    with tf.device("/gpu:%d" % i):
        with tf.name_scope("GPU_%d" % i):
            input_node = tf.placeholder(tf.float32,
                                        shape=[1, args.size, args.size * 5, 3],
                                        name="input_image")
            with tf.variable_scope(tf.get_variable_scope(), reuse=False):
                network_mv2_hourglass.N_KPOINTS = 2
                network_mv2_hourglass.STAGE_NUM = 2
                _, pred_heatmaps_all = get_network('mv2_hourglass', input_node,
                                                   True)
            for loss_i in range(len(pred_heatmaps_all)):
                pred_heatmaps_tmp = pred_heatmaps_all[loss_i]
                pred_heatmaps_tmp = tf.nn.softmax(pred_heatmaps_tmp)

            output_node_ufxuz = tf.add(pred_heatmaps_tmp,
                                       0,
                                       name='final_pred_heatmaps_tmp')  #(1,4)
    saver = tf.train.Saver(max_to_keep=10)
    init = tf.global_variables_initializer()
    config = tf.ConfigProto()
    # occupy gpu gracefully
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        init.run()
Exemple #8
0
                    type=str,
                    default='',
                    help='checkpoint path')
parser.add_argument('--output_node_names',
                    type=str,
                    default='upsample2_for_loss_3')
parser.add_argument('--output_graph',
                    type=str,
                    default='./model.pb',
                    help='output_freeze_path')

args = parser.parse_args()

input_node = tf.placeholder(tf.float32,
                            shape=[1, args.size, args.size, 3],
                            name="image")

with tf.Session() as sess:
    net = get_network(args.model, input_node, trainable=False)
    saver = tf.train.Saver()
    saver.restore(sess, args.checkpoint)

    input_graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,  # The session
        input_graph_def,  # input_graph_def is useful for retrieving the nodes
        args.output_node_names.split(","))

with tf.gfile.GFile(args.output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())