Пример #1
0
    def __init__(self, model):
        self.model_ = model
        self.lstm_stack_enc_ = lstm_spatial.LSTMStack()
        self.lstm_stack_dec_ = lstm.LSTMStack()
        self.lstm_stack_pre_ = lstm.LSTMStack()
        for l in model.lstm:
            self.lstm_stack_enc_.Add(lstm_spatial.LSTM(l))
        if model.dec_seq_length > 0:
            for l in model.lstm_dec:
                self.lstm_stack_dec_.Add(lstm.LSTM(l))
        if model.pre_seq_length > 0:
            for l in model.lstm_pre:
                self.lstm_stack_pre_.Add(lstm.LSTM(l))
        assert model.dec_seq_length > 0
        self.is_conditional_dec_ = model.dec_conditional
        if self.is_conditional_dec_ and model.dec_seq_length > 0:
            assert self.lstm_stack_dec_.HasInputs()
        self.squash_relu_ = False  #model.squash_relu
        self.squash_relu_lambda_ = 0  #model.squash_relu_lambda
        self.relu_data_ = False  #model.relu_data
        self.binary_data_ = True  #model.binary_data or self.squash_relu_
        self.only_occ_predict_ = model.only_occ_predict

        if len(model.timestamp) > 0:
            old_st = model.timestamp[-1]
            ckpt = os.path.join(model.checkpoint_dir,
                                '%s_%s.h5' % (model.name, old_st))
            f = h5py.File(ckpt)
            self.lstm_stack_enc_.Load(f)
            if model.dec_seq_length > 0:
                self.lstm_stack_dec_.Load(f)
            if model.pre_seq_length > 0 and not self.only_occ_predict_:
                self.lstm_stack_pre_.Load(f)
            f.close()
    def __init__(self, model, board, board_ladv, board_sup):
        self.model_ = model
        self.board_ = board
        self.board_ladv_ = board_ladv
        self.board_sup_ = board_sup
        self.lstm_stack_enc_ = lstm_spatial.LSTMStack()
        self.lstm_stack_dec_ = lstm.LSTMStack()
        self.lstm_stack_pre_ = lstm.LSTMStack()
        model_file_sup = './data/bk20151009/part_1/bvlc_googlenet_quick_iter_231760.caffemodel'
        solver_file = './data/googlenet_ladv_solver.prototxt'
        prototxt_file_sup = './data/train_val_quick_grad.prototxt'
        mean_file = './data/bk20151009/part_1/lmdb_casia_full_part1_mean.binaryproto'
        self.cnn_solver = caffe.SGDSolver(solver_file)
        self.cnn_solver.net.copy_from(model_file_sup)
        self.cnn_net_sup = caffe.Net(prototxt_file_sup, model_file_sup,
                                     caffe.TRAIN)
        mean = read_mean(mean_file)
        mean = mean.reshape((1, 128, 128))
        self.cnn_mean_ = mean
        caffe.set_mode_gpu()
        if self.board_ladv_ == self.board_:
            caffe.set_device(self.board_ladv_)
        for l in model.lstm:
            self.lstm_stack_enc_.Add(lstm_spatial.LSTM(l))
        if model.dec_seq_length > 0:
            for l in model.lstm_dec:
                self.lstm_stack_dec_.Add(lstm.LSTM(l))
        if model.pre_seq_length > 0:
            for l in model.lstm_pre:
                self.lstm_stack_pre_.Add(lstm.LSTM(l))
        assert model.dec_seq_length > 0
        self.is_conditional_dec_ = model.dec_conditional
        if self.is_conditional_dec_ and model.dec_seq_length > 0:
            assert self.lstm_stack_dec_.HasInputs()
        self.squash_relu_ = False  #model.squash_relu
        self.squash_relu_lambda_ = 0  #model.squash_relu_lambda
        self.relu_data_ = False  #model.relu_data
        self.binary_data_ = True  #model.binary_data or self.squash_relu_
        self.only_occ_predict_ = model.only_occ_predict

        if len(model.timestamp) > 0:
            old_st = model.timestamp[-1]
            ckpt = os.path.join(model.checkpoint_dir,
                                '%s_%s.h5' % (model.name, old_st))
            f = h5py.File(ckpt)
            self.lstm_stack_enc_.Load(f)
            if model.dec_seq_length > 0:
                self.lstm_stack_dec_.Load(f)
            if model.pre_seq_length > 0 and not self.only_occ_predict_:
                self.lstm_stack_pre_.Load(f)
            f.close()
Пример #3
0
    def __init__(self, model):
        self.model_ = model  # keeps the model configurations alongside global configurations

        # stacks of encoder, decoder and future predictions
        self.lstm_stack_enc_ = lstm.LSTMStack()
        self.lstm_stack_dec_ = lstm.LSTMStack()
        self.lstm_stack_fut_ = lstm.LSTMStack()

        self.decoder_copy_init_state_ = model.decoder_copy_init_state
        self.future_copy_init_state_ = model.future_copy_init_state

        # add LSTM blocks for encoder, decoder and future predictor
        for l in model.lstm:
            # get LSTM encoder model according to specifications
            self.lstm_stack_enc_.Add(lstm.LSTM(l))
        if model.dec_seq_length > 0:
            for l in model.lstm_dec:
                # get LSTM decoder model according to specifications
                self.lstm_stack_dec_.Add(lstm.LSTM(l))
        if model.future_seq_length > 0:
            for l in model.lstm_future:
                # get LSTM future predictor model according to specifications
                self.lstm_stack_fut_.Add(lstm.LSTM(l))

        # do other initialization stuff
        assert model.dec_seq_length > 0 or model.future_seq_length > 0
        # get specification of whether decoder and future predictors are conditional on inputs
        self.is_conditional_dec_ = model.dec_conditional
        self.is_conditional_fut_ = model.future_conditional

        if self.is_conditional_dec_ and model.dec_seq_length > 0:
            assert self.lstm_stack_dec_.HasInputs()
        if self.is_conditional_fut_ and model.future_seq_length > 0:
            assert self.lstm_stack_fut_.HasInputs()

        self.squash_relu_ = model.squash_relu
        self.binary_data_ = model.binary_data or model.squash_relu
        self.squash_relu_lambda_ = model.squash_relu_lambda
        self.relu_data_ = model.relu_data

        # load model if available
        if len(model.timestamp) > 0:
            old_st = model.timestamp[-1]
            ckpt = os.path.join(model.checkpoint_dir,
                                '%s_%s.h5' % (model.name, old_st))
            f = h5py.File(ckpt)
            self.lstm_stack_enc_.Load(f)
            self.lstm_stack_dec_.Load(f)
            self.lstm_stack_fut_.Load(f)
            f.close()
Пример #4
0
 def __init__(self, model):
   self.model_ = model
   self.lstm_stack_ = lstm.LSTMStack()
   for l in model.lstm:
     self.lstm_stack_.Add(lstm.LSTM(l))
   self.squash_relu_ = model.squash_relu
   self.squash_relu_lambda_ = model.squash_relu_lambda
   
   if len(model.timestamp) > 0:
     old_st = model.timestamp[-1]
     ckpt = os.path.join(model.checkpoint_dir, '%s_%s.h5' % (model.name, old_st))
     f = h5py.File(ckpt)
     self.lstm_stack_.Load(f)
     f.close()