Example #1
0
File: main.py Project: romaad/msae
def train(configPath, name):
    useGpu = os.environ.get('GNUMPY_USE_GPU', 'auto')
    if useGpu == "no":
        mode = "cpu"
    else:
        mode = "gpu"

    print '========================================================'
    print 'train %s' % name
    print "the program is on %s" % mode
    print '======================================================='

    config = configparser.ConfigParser(
        interpolation=configparser.ExtendedInterpolation())
    config.read(configPath)
    model_name = config.get(name, 'model')
    if model_name == "ae":
        from ae import AE
        model = AE(config, name)
    elif model_name == "lae":
        from lae import LAE
        model = LAE(config, name)
    elif model_name == "pae":
        from pae import PAE
        model = PAE(config, name)
    elif model_name == "sae":
        from sae import SAE
        model = SAE(config, name)
    elif model_name == "msae":
        from msae import MSAE
        model = MSAE(config, name)

    model.train()
Example #2
0
 def createsae(self, prefix, saeName):
     if self.config.has_option(self.name, saeName):
         saepath = self.readField(self.config, self.name, saeName)
         sae = self.loadModel(self.config, saepath)
         reset = self.readField(self.config, self.name, "reset_hyperparam")
         if reset != "False":
             for ae in sae.ae[1:]:
                 ae.resetHyperParam(self.config, reset)
         return sae
     else:
         return SAE(self.config, self.name, prefix=prefix)
Example #3
0
 def __init__(self):
     super(TrafficPrediction, self).__init__()
     
     self.action_space = spaces.Box(low=-0.05, high=0.05, shape=(257,), dtype=np.float32)
     self.observation_space = spaces.Tuple([
             spaces.Box(low=0., high=1., shape=(6, 257), dtype=np.float32),
             spaces.Box(low=0., high=1., shape=(257,), dtype=np.float32)])
             #(last_state_point,)
     self.delta = 1.
     self.state = None
     
     self.predictor = SAE() #SEKNN #self.load_sae() #KNN()
     self.pointer = 0
     self.load_sae()
     
     self.link = 257
     self.predstep = self.predictor.dp.predstep
     self.maxv = np.asarray(self.predictor.dp.maxv)
     self.valiX, self.valiY, self.valiY_nofilt = self.predictor.dp.get_data(data_type='vali')
     self.testX, self.testY, self.testY_nofilt = self.predictor.dp.get_data(data_type='test')
     
     self.set_predY = []
     self.set_predY_ = []
     self.set_realY = []
Example #4
0
from sae import SAE
import torch

training_set = torch.load('./training_set.pkl')
test_set = torch.load('./test_set.pkl')
sae = SAE(3787, encoder_input=40, decoder_input=40)
sae.add_hiden_layer(20)
sae.add_dropout(0.2)
sae.add_hiden_layer(40)
sae.compile(optimizer='adam')
sae.fit(training_set, 5)
sae.perform(training_set, test_set)
torch.save(sae, 'model.pkl')