예제 #1
0
파일: train.py 프로젝트: zgq91/BB8
    def regressor(self):

        regressorNetType = regressorNetType = int(self.network_model)
        batch_size = int(self.batch_size)
        learning_rate = float(self.learning_rate)
        optimizer = self.optimizer

        assert len(self.steps) == len(self.scales)
        n_chan, h_in, w_in, output_dim = self.create_training.get_dim()
        assert n_chan is not None
        assert h_in is not None
        assert w_in is not None
        assert output_dim is not None
        nb_training = int(self.nb_training)

        rng = np.random.RandomState(23455)
        #theano.config.compute_test_value = 'warn'
        theano.config.exception_verbosity = 'high'

        regressorNetParams = PoseRegNetParams(type=regressorNetType,
                                              n_chan=n_chan,
                                              w_in=w_in,
                                              h_in=h_in,
                                              batchSize=batch_size,
                                              output_dim=output_dim)

        self.network = PoseRegNet(rng, cfgParams=regressorNetParams)
        if regressorNetType == 1:
            self.network.load_vgg()
        print(self.network)

        regressorNetTrainingParams = PoseRegNetTrainingParams()
        regressorNetTrainingParams.batch_size = batch_size
        regressorNetTrainingParams.learning_rate = learning_rate
        if self.scales != []:
            scales = [float(s) for s in self.scales.split()]
            regressorNetTrainingParams.learning_rate_scales = scales
        if self.steps != []:
            steps = [int(s) for s in self.steps.split()]
            regressorNetTrainingParams.learning_rate_steps = steps

        regressorNetTrainingParams.optimizer = optimizer

        self.trainer = PoseRegNetTrainer(self.network,
                                         regressorNetTrainingParams, rng)

        self.trainer.setup_para(nb_training,
                                n_chan,
                                h_in,
                                w_in,
                                output_dim,
                                type='regressor')
        self.create_training.pre_create_data()
    pca = PCA(n_components=30)
    pca.fit(train_gt3D.reshape((train_gt3D.shape[0], train_gt3D.shape[1] * 3)))
    train_gt3D_embed = pca.transform(
        train_gt3D.reshape((train_gt3D.shape[0], train_gt3D.shape[1] * 3)))
    test_gt3D_embed = pca.transform(
        test_gt3D.reshape((test_gt3D.shape[0], test_gt3D.shape[1] * 3)))
    val_gt3D_embed = pca.transform(
        val_gt3D.reshape((val_gt3D.shape[0], val_gt3D.shape[1] * 3)))

    ############################################################################
    print("create network")
    batchSize = 128
    poseNetParams = PoseRegNetParams(type=0,
                                     nChan=nChannels,
                                     wIn=imgSizeW,
                                     hIn=imgSizeH,
                                     batchSize=batchSize,
                                     numJoints=1,
                                     nDims=train_gt3D_embed.shape[1])
    poseNet = PoseRegNet(rng, cfgParams=poseNetParams)

    poseNetTrainerParams = PoseRegNetTrainerParams()
    poseNetTrainerParams.batch_size = batchSize
    poseNetTrainerParams.learning_rate = 0.01

    print("setup trainer")
    poseNetTrainer = PoseRegNetTrainer(poseNet, poseNetTrainerParams, rng)
    poseNetTrainer.setData(train_data, train_gt3D_embed, val_data,
                           val_gt3D_embed)
    poseNetTrainer.compileFunctions(compileDebugFcts=False)
    di = NYUImporter('../data/NYU/')
    Seq2 = di.loadSequence('test_1')
    testSeqs = [Seq2]

    testDataSet = NYUDataset(testSeqs)
    test_data, test_gt3D = testDataSet.imgStackDepthOnly('test_1')

    # load trained network
    # poseNetParams = PoseRegNetParams(type=11, nChan=1, wIn=128, hIn=128, batchSize=1, numJoints=16, nDims=3)
    # poseNet = PoseRegNet(numpy.random.RandomState(23455), cfgParams=poseNetParams)
    # poseNet.load("./ICVL_network_prior.pkl")
    poseNetParams = PoseRegNetParams(type=11,
                                     nChan=1,
                                     wIn=128,
                                     hIn=128,
                                     batchSize=1,
                                     numJoints=14,
                                     nDims=3)
    poseNet = PoseRegNet(numpy.random.RandomState(23455),
                         cfgParams=poseNetParams)
    poseNet.load("./NYU_network_prior.pkl")
    # comrefNetParams = ScaleNetParams(type=1, nChan=1, wIn=128, hIn=128, batchSize=1, resizeFactor=2, numJoints=1, nDims=3)
    # comrefNet = ScaleNet(numpy.random.RandomState(23455), cfgParams=comrefNetParams)
    # comrefNet.load("./net_ICVL_COM.pkl")
    comrefNetParams = ScaleNetParams(type=1,
                                     nChan=1,
                                     wIn=128,
                                     hIn=128,
                                     batchSize=1,
                                     resizeFactor=2,