Пример #1
0
    def test_save_load(self):
        fpath = "models/test_m1"
        self.sess.save(fpath)

        sess = Session.load(fpath)
Пример #2
0
def train(epochs, ds, model=None, batch_size=64, record_epochs=1):
    #加载/构建session
    sess = None
    if model is None:
        sess = Session.load(model_path)
    else:
        sess = Session(model, loss=losses.Mse(), optimizer=optimizers.Fixed())

    train_x = ds['train_x']
    train_y = ds['train_y']
    test_x = ds['test_x']
    test_y = ds['test_y']

    batchs = int(train_x.shape[0] / batch_size)
    print("epochs:%d, batchs=%d" % (epochs, batchs))

    #记录训练历史
    history = {
        'loss': [],
        'val_loss': [],
        'epochs': [],
        'val_x': test_x,
        'val_y': test_y,
        'val_pred': None
    }

    print("start training ")
    t_start = time.time()
    steps = epochs * batchs

    epoch = 1
    #循环训练
    for step in range(steps):
        start = (step % batchs) * batch_size
        end = start + batch_size
        batch_x = train_x[start:end]
        batch_y = train_y[start:end]

        loss = sess.batch_train(batch_x, batch_y)

        cur_epoch = int(step / batchs) + 1

        #每轮打印一次
        if step > 0 and step % batchs == 0:
            print((('epoch:%05d/%d loss=%f' % (cur_epoch, epochs, loss)) +
                   ' ' * 50)[:50],
                  end='\r')

        #记录
        if step % batchs == 0 and (cur_epoch - epoch == record_epochs
                                   or cur_epoch == epochs):
            epoch = cur_epoch

            y_pred = sess.model.predict(test_x)
            val_loss = sess.loss(test_y, y_pred)

            history['loss'].append(loss)
            history['val_loss'].append(val_loss)
            history['epochs'].append(epoch)
            history['val_pred'] = y_pred

            print((('epoch:%05d/%d loss=%f, val_loss=%f' %
                    (cur_epoch, epochs, loss, val_loss)) + ' ' * 50)[:50],
                  end='\r')
            print("")

    sess.save(model_path)
    print("training finished cost:%f" % (time.time() - t_start))

    return history