Example #1
0
 def infer(self):
     args = self.args
     with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
         saved_args = cPickle.load(f)
     with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
         chars, vocab = cPickle.load(f)
     model = RnnModel(saved_args, training=False)
     with tf.Session() as sess:
         tf.global_variables_initializer().run()
         saver = tf.train.Saver(tf.global_variables())
         ckpt = tf.train.get_checkpoint_state(args.save_dir)
         if ckpt and ckpt.model_checkpoint_path:
             saver.restore(sess, ckpt.model_checkpoint_path)
             print(
                 model.infer(sess, chars, vocab, args.n, args.prime,
                             args.sample).encode('utf-8'))
Example #2
0
def setup():
    '''
    API for setup up the model for prediction.
    CURL format:
        curl -X POST -d '{"model_name": "rlstm"}' -H 'Content-Type: application/json'
        'https://127.0.0.1:5000/model_setup'
    JSON Parameter:
        'model_name':   [COMPULSORY] name of the model to be activated for predict route
    :return:    M:  Missing compulsory parameter in json
                i:  Internal error when loading
                N:  Model not found
    '''
    global loaded_model
    global running
    global req_logger
    global model_logger
    if request.method == 'POST':
        req_logger.info('request for setup is {}'.format(request.data))
        if running == 1:
            return 'R'
        running = 1
        arg_json = request.get_json()
        if 'model_name' in arg_json:
            model_name = arg_json['model_name']
            model_config = ModelInfo(model_name)
            try:
                loaded_model = RnnModel(model_config)
                loaded_model.load()
            except KeyError:
                running = 0
                return 'N'
            except FileNotFoundError:
                running = 0
                return 'N'
            except:
                running = 0
                return 'i'
        else:
            running = 0
            return 'M'
        return '0'
Example #3
0
def main():
    seed()
    meta_featurizer = lambda env: MetaFeatureEnv(env)
    cache = get_cache()
    template_env = MazeEnv(0)
    template_meta_env = meta_featurizer(template_env)
    if FLAGS.run is None:
        raise Exception("'run' flag must be specified")
    elif FLAGS.run == 'train':
        model1 = RnnModel(template_meta_env.featurizer, template_env.n_actions)
        model2 = SimpleModel(template_env.featurizer, template_env.n_actions)
        learning.train_sup(model1, model2,
                           lambda: MazeEnv(np.random.randint(2000)),
                           lambda: MazeEnv(2000 + np.random.randint(500)),
                           meta_featurizer, cache / ('base.maze.txt'))
    else:
        raise Exception('no such task: %s' % FLAGS.task)
Example #4
0
def main(n, read_model):
    state = torch.load(read_model)
    word2int = state['word2int']
    int2word = state['int2word']
    input_size = len(int2word)
    hidden_size = state['hidden_size']
    num_layers = state['num_layers']
    model = RnnModel(input_size, hidden_size, num_layers)
    model.load_state_dict(state['state_dict'])

    initial = [word2int[c] for c in "anna"]
    h = model.init_hidden(1)
    model.eval()
    for c in initial:
        print(int2word[c], end='')
        char, h = predict(model, c, h, 3)

    for i in range(n):
        char, h = predict(model, char, h, 3)
        print(int2word[char], end='')
Example #5
0
def main(epochs, batch_size, seq_length, hidden_size, num_layers, text_file, write_model):
    text = read_text(text_file)
    vocabulary = set(text)

    print("Text file #words: {}".format(len(text)))
    print("Dictionary size: {}".format(len(vocabulary)))

    if torch.cuda.is_available:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    int2word, word2int = create_lookup(vocabulary)
    encoded_text = np.array([word2int[t] for t in text])

    model = RnnModel(len(vocabulary), hidden_size, num_layers=num_layers).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    pbar = tqdm(range(epochs))
    for epoch in pbar:
        h = tuple([state.to(device) for state in model.init_hidden(batch_size)])
        for x, y in generate_batches(encoded_text, batch_size, seq_length):
            optimizer.zero_grad()
            x = one_hot_encoder(x, len(vocabulary))
            x, y = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)

            h = tuple([state.clone().detach() for state in h])

            output, h = model.forward(x, h)

            loss = criterion(output, y.view(batch_size * seq_length))
            loss.backward()
            optimizer.step()

        pbar.set_description("Loss: {:0.6f}".format(loss.item()))

    state = {
        'state_dict': model.state_dict(),
        'word2int': word2int,
        'int2word': int2word,
        'hidden_size': hidden_size,
        'num_layers': num_layers
    }

    torch.save(state, write_model)
Example #6
0
def train():
    '''
    API for training the model, should be called after configuration
    CURL format:
        curl -X POST -F file=@/path/to/encoded/data 'https://127.0.0.1:5000/train'
    :return: a jsonified result
        {
            'final_model_name': 'xxx',
            're_startup': 'xxx',
            're_total': 'xxx',
            'converged': '0/1',
            'feature_length': '???'
        }
        Errors:     'M': Missing compulsory parameter in json
        Errors:     'R': session is running
    '''
    global running
    global model_config
    global req_logger
    global model_logger
    if request.method == 'POST':
        if running == 1:
            return 'R'
        running = 1
        # save file
        if 'file' in request.files:
            f = request.files['file']
        else:
            return 'M'
        base_path = os.path.dirname(__file__)
        dtObj = datetime.now()
        fname = str(dtObj.year) + '-' + str(dtObj.month) + '-' + str(dtObj.day) + '_' \
                + str(dtObj.hour) + '-' + str(dtObj.minute) + '-' + str(dtObj.second) + '-' \
                + secure_filename(f.filename)
        file_path = os.path.join(
            base_path, settings.PATH_UPLOAD, fname)
        f.save(file_path)
        # trigger training
        try:
            model = RnnModel(model_config)
            val_re = model.fit(file_path)
        except:
            running = 0
            raise
        re_startup, re_total, re_row, re_mem = -1, -1, -1, -1
        converged = 1
        for v in val_re:
            if v > 2:
                converged = 0
                break
        for i in range(int(model_config.label_length)):
            if model_config.model_targets[i] == 'S':
                re_startup = val_re[i]
            elif model_config.model_targets[i] == 'T':
                re_total = val_re[i]
            elif model_config.model_targets[i] == 'R':
                re_row = val_re[i]
            elif model_config.model_targets[i] == 'M':
                re_mem = val_re[i]
        res = {
            're_startup': re_startup,
            're_total': re_total,
            're_row': re_row,
            're_mem': re_mem,
            'max_startup': float(model_config.max_startup),
            'max_total': float(model_config.max_total),
            'max_row': float(model_config.max_row),
            'max_mem': float(model_config.max_mem),
            'converged': converged,
            'feature_length': int(model_config.feature_length)
        }

        running = 0
        model_logger.info(jsonify(res))
        return jsonify(res)
Example #7
0
 def train(self):
     args = self.args
     data_loader = DataLoader(args.data_dir, args.batch_size, args.seq_length)
     args.vocab_size = data_loader.vocab_size
 
     # check compatibility if training is continued from pre-trained model
     if args.init_from is not None:
         # check if all necessary files exist
         assert os.path.isdir(args.init_from)," %s must be a a path!" % args.init_from
         assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path! %s"%args.init_from
         assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path! %s" % args.init_from
         ckpt = tf.train.get_checkpoint_state(args.init_from)
         assert ckpt, "No checkpoint found!"
         assert ckpt.model_checkpoint_path, "No model path found in checkpoint!"
 
         # open old config and check if models are compatible
         with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
             saved_model_args = cPickle.load(f)
         need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
         for checkme in need_be_same:
             assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s'!"%checkme
 
         # open saved vocab/dict and check compatibility
         with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
             saved_chars, saved_vocab = cPickle.load(f)
         assert saved_chars==data_loader.chars, "Data and loaded model disagree on character set!"
         assert saved_vocab==data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"
 
     if not os.path.isdir(args.save_dir):
         os.makedirs(args.save_dir)
     with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
         cPickle.dump(args, f)
     with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
         cPickle.dump((data_loader.chars, data_loader.vocab), f)
 
     model = RnnModel(args)
 
     with tf.Session() as sess:
         # instrument for tensorboard
         summaries = tf.summary.merge_all()
         writer = tf.summary.FileWriter(
                 os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S")))
         writer.add_graph(sess.graph)
 
         sess.run(tf.global_variables_initializer())
         saver = tf.train.Saver(tf.global_variables())
         # restore model
         if args.init_from is not None:
             saver.restore(sess, ckpt.model_checkpoint_path)
             print(">> restoring model from: ", ckpt.model_checkpoint_path)
         for e in range(args.num_epochs):
             sess.run(tf.assign(model.lr,
                                args.learning_rate * (args.decay_rate ** e)))
             data_loader.reset_batch_pointer()
             state = sess.run(model.initial_state)
             for b in range(data_loader.num_batches):
                 start = time.time()
                 x, y = data_loader.next_batch()
                 feed = {model.input_data: x, model.targets: y}
                 for i, (c, h) in enumerate(model.initial_state):
                     feed[c] = state[i].c
                     feed[h] = state[i].h
                 train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
 
                 # instrument for tensorboard
                 summ, train_loss, state, _ = sess.run([summaries, model.cost, model.final_state, model.train_op], feed)
                 writer.add_summary(summ, e * data_loader.num_batches + b)
 
                 end = time.time()
                 print(">> {}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                       .format(e * data_loader.num_batches + b,
                               args.num_epochs * data_loader.num_batches,
                               e, train_loss, end - start))
                 if (e * data_loader.num_batches + b) % args.save_every == 0\
                         or (e == args.num_epochs-1 and
                             b == data_loader.num_batches-1):
                     # save for the last result
                     checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                     saver.save(sess, checkpoint_path,
                                global_step=e * data_loader.num_batches + b)
                     print(">> model saved to {}".format(checkpoint_path))