Example #1
0
File: train.py Project: zgq91/BB8
    def regressor(self):

        regressorNetType = regressorNetType = int(self.network_model)
        batch_size = int(self.batch_size)
        learning_rate = float(self.learning_rate)
        optimizer = self.optimizer

        assert len(self.steps) == len(self.scales)
        n_chan, h_in, w_in, output_dim = self.create_training.get_dim()
        assert n_chan is not None
        assert h_in is not None
        assert w_in is not None
        assert output_dim is not None
        nb_training = int(self.nb_training)

        rng = np.random.RandomState(23455)
        #theano.config.compute_test_value = 'warn'
        theano.config.exception_verbosity = 'high'

        regressorNetParams = PoseRegNetParams(type=regressorNetType,
                                              n_chan=n_chan,
                                              w_in=w_in,
                                              h_in=h_in,
                                              batchSize=batch_size,
                                              output_dim=output_dim)

        self.network = PoseRegNet(rng, cfgParams=regressorNetParams)
        if regressorNetType == 1:
            self.network.load_vgg()
        print(self.network)

        regressorNetTrainingParams = PoseRegNetTrainingParams()
        regressorNetTrainingParams.batch_size = batch_size
        regressorNetTrainingParams.learning_rate = learning_rate
        if self.scales != []:
            scales = [float(s) for s in self.scales.split()]
            regressorNetTrainingParams.learning_rate_scales = scales
        if self.steps != []:
            steps = [int(s) for s in self.steps.split()]
            regressorNetTrainingParams.learning_rate_steps = steps

        regressorNetTrainingParams.optimizer = optimizer

        self.trainer = PoseRegNetTrainer(self.network,
                                         regressorNetTrainingParams, rng)

        self.trainer.setup_para(nb_training,
                                n_chan,
                                h_in,
                                w_in,
                                output_dim,
                                type='regressor')
        self.create_training.pre_create_data()
    batchSize = 128
    poseNetParams = PoseRegNetParams(type=0,
                                     nChan=nChannels,
                                     wIn=imgSizeW,
                                     hIn=imgSizeH,
                                     batchSize=batchSize,
                                     numJoints=1,
                                     nDims=train_gt3D_embed.shape[1])
    poseNet = PoseRegNet(rng, cfgParams=poseNetParams)

    poseNetTrainerParams = PoseRegNetTrainerParams()
    poseNetTrainerParams.batch_size = batchSize
    poseNetTrainerParams.learning_rate = 0.01

    print("setup trainer")
    poseNetTrainer = PoseRegNetTrainer(poseNet, poseNetTrainerParams, rng)
    poseNetTrainer.setData(train_data, train_gt3D_embed, val_data,
                           val_gt3D_embed)
    poseNetTrainer.compileFunctions(compileDebugFcts=False)

    ###################################################################
    #
    # TRAIN
    nEpochs = 100
    train_res = poseNetTrainer.train(n_epochs=nEpochs, storeFilters=True)
    train_costs = train_res[0]
    wvals = train_res[1]
    val_errs = train_res[2]

    ###################################################################
    # TEST
Example #3
0
            'di':
            di,
            'aug_modes':
            aug_modes,
            'hd':
            HandDetector(train_data[0, 0].copy(),
                         abs(di.fx),
                         abs(di.fy),
                         importer=di),
            'proj':
            pca
        }
    }

    print("setup trainer")
    poseNetTrainer = PoseRegNetTrainer(poseNet, poseNetTrainerParams, rng,
                                       './eval/' + eval_prefix)
    poseNetTrainer.setData(train_data, train_gt3D_embed, val_data,
                           val_gt3D_embed)
    poseNetTrainer.addStaticData({'val_data_y3D': val_gt3D})
    poseNetTrainer.addStaticData({
        'pca_data': pca.components_,
        'mean_data': pca.mean_
    })
    poseNetTrainer.addManagedData({
        'train_data_cube': train_data_cube,
        'train_data_com': train_data_com,
        'train_gt3Dcrop': train_gt3Dcrop
    })
    poseNetTrainer.compileFunctions(compileDebugFcts=False)

    ###################################################################
Example #4
0
File: train.py Project: zgq91/BB8
class Network:
    def __init__(self):
        self.type = 1  # 1 = Regressor
        self.network_model = '0'  # 0 = Tiny BB8, 1 = BB8 - VGG arch.

        self.batch_size = 128
        self.optimizer = 'MOMENTUM'
        self.learning_rate = 0.001
        self.steps = []
        self.scales = []
        self.nb_epoch = 300

        self.network = None
        self.trainer = None
        self.network_name = None
        self.save_path = './'

        self.validation_size = 5000
        self.train_set_para = None

        self.config = None
        self.nb_process = 10

    def setup_from_config(self):
        if self.config is not None:
            with open(self.config, 'r') as f:
                config = yaml.load(f)
                for key in config.keys():
                    value = config[key]
                    print('set {0} to {1}'.format(key, value))
                    setattr(self, key, value)

    def update(self):
        self.setup_from_config()
        self.print_type()

        sys.path.insert(0,
                        self.train_set_para[:self.train_set_para.rindex('/')])
        self.create_training = __import__(
            self.train_set_para[self.train_set_para.rindex('/') + 1:],
            fromlist=['init', 'pre_create_data', 'create_data', 'get_dim'])
        self.create_training.init()

        self.regressor()

    def regressor(self):

        regressorNetType = regressorNetType = int(self.network_model)
        batch_size = int(self.batch_size)
        learning_rate = float(self.learning_rate)
        optimizer = self.optimizer

        assert len(self.steps) == len(self.scales)
        n_chan, h_in, w_in, output_dim = self.create_training.get_dim()
        assert n_chan is not None
        assert h_in is not None
        assert w_in is not None
        assert output_dim is not None
        nb_training = int(self.nb_training)

        rng = np.random.RandomState(23455)
        #theano.config.compute_test_value = 'warn'
        theano.config.exception_verbosity = 'high'

        regressorNetParams = PoseRegNetParams(type=regressorNetType,
                                              n_chan=n_chan,
                                              w_in=w_in,
                                              h_in=h_in,
                                              batchSize=batch_size,
                                              output_dim=output_dim)

        self.network = PoseRegNet(rng, cfgParams=regressorNetParams)
        if regressorNetType == 1:
            self.network.load_vgg()
        print(self.network)

        regressorNetTrainingParams = PoseRegNetTrainingParams()
        regressorNetTrainingParams.batch_size = batch_size
        regressorNetTrainingParams.learning_rate = learning_rate
        if self.scales != []:
            scales = [float(s) for s in self.scales.split()]
            regressorNetTrainingParams.learning_rate_scales = scales
        if self.steps != []:
            steps = [int(s) for s in self.steps.split()]
            regressorNetTrainingParams.learning_rate_steps = steps

        regressorNetTrainingParams.optimizer = optimizer

        self.trainer = PoseRegNetTrainer(self.network,
                                         regressorNetTrainingParams, rng)

        self.trainer.setup_para(nb_training,
                                n_chan,
                                h_in,
                                w_in,
                                output_dim,
                                type='regressor')
        self.create_training.pre_create_data()

    def train(self):
        self.trainer.train(n_epochs=int(self.nb_epoch), storeFilters=True)

    def train_para(self):
        self.trainer.train_para(int(self.nb_epoch),
                                self.create_training.create_data,
                                int(self.nb_process))

    def save(self):
        if self.network_name is None:
            self.network_name += "_model"
            self.network_name += str(self.network_model)
            self.network_name += "_epoch"
            self.network_name += str(self.nb_epoch)

        self.network.save(join(self.save_path, self.network_name + ".weight"))
        f = file(join(self.save_path, self.network_name + ".cfg"), 'wb')
        cPickle.dump(self.network.cfgParams,
                     f,
                     protocol=cPickle.HIGHEST_PROTOCOL)
        f.close()

    def print_type(self):
        if int(self.type) == 1:
            print('*****************************************************')
            print('*                    REGRESSOR                      *')
            print('*****************************************************')
        else:
            assert False, 'It is not implemented'
Example #5
0
                'di':
                di,
                'aug_modes':
                aug_modes,
                'hd':
                HandDetector(train_data[0, 0].copy(),
                             abs(di.fx),
                             abs(di.fy),
                             importer=di),
                'proj':
                pca
            }
        }

        print("setup trainer")
        poseNetTrainer = PoseRegNetTrainer(poseNet, poseNetTrainerParams, rng,
                                           './eval/' + eval_prefix)
        poseNetTrainer.setData(train_data, train_gt3D_embed, val_data,
                               val_gt3D_embed)
        poseNetTrainer.addManagedData({
            'train_data_cube': train_data_cube,
            'train_data_com': train_data_com,
            'train_data_M': train_data_M,
            'train_gt3Dcrop': train_gt3Dcrop
        })
        poseNetTrainer.compileFunctions(compileDebugFcts=False)

        ###################################################################
        # TRAIN
        train_res = poseNetTrainer.train(n_epochs=100)
        train_costs = train_res[0]
        val_errs = train_res[2]