コード例 #1
0
    def build_g_cpu(self):
        self.batch = tf.Variable(0, name='batch', trainable=False)
        self.point_pl, self.label_pl, self.smpws_pl = SEG_MODEL.placeholder_inputs(self.batch_sz, self.point_sz)
        self.is_train_pl = tf.placeholder(dtype=tf.bool, shape=())
        self.ave_tp_pl = tf.placeholder(dtype=tf.float32, shape=())
        self.optimizer = tf.train.AdamOptimizer(self.get_learning_rate())
        self.bn_decay = self.get_bn_decay()

        SEG_MODEL.get_model(self.point_pl, self.is_train_pl, num_class=NUM_CLASS, bn_decay=self.bn_decay)
コード例 #2
0
    def build_g_gpu(self, gpu_idx):
        print("build graph in gpu %d" % gpu_idx)
        with tf.device('/gpu:%d' % gpu_idx), tf.name_scope('gpu_%d' %
                                                           gpu_idx) as scope:
            point_cloud_slice = tf.slice(self.point_pl,
                                         [gpu_idx * BATCH_PER_GPU, 0, 0],
                                         [BATCH_PER_GPU, -1, -1])
            label_slice = tf.slice(self.label_pl, [gpu_idx * BATCH_PER_GPU, 0],
                                   [BATCH_PER_GPU, -1])
            smpws_slice = tf.slice(self.smpws_pl, [gpu_idx * BATCH_PER_GPU, 0],
                                   [BATCH_PER_GPU, -1])
            net, end_point = SEG_MODEL.get_model(point_cloud_slice,
                                                 self.is_train_pl,
                                                 num_class=NUM_CLASS,
                                                 bn_decay=self.bn_decay)
            SEG_MODEL.get_loss(net, label_slice, smpw=smpws_slice)
            loss = tf.compat.v1.get_collection('losses', scope=scope)
            total_loss = tf.add_n(loss, name='total_loss')
            for _i in loss + [total_loss]:
                tf.summary.scalar(_i.op.name, _i)

            gvs = self.optimizer.compute_gradients(total_loss)
            self.tower_grads.append(gvs)
            self.net_gpu.append(net)
            self.total_loss_gpu_list.append(total_loss)