示例#1
0
文件: trainer.py 项目: har07/sgld-1
def runall(cuda_device, model_desc, db_string):
    labnotebook.initialize(db_string)

    torch.cuda.set_device(cuda_device)
    torch.manual_seed(model_desc['seed'])
    model = MnistModel()
    train_loader, test_loader = sgld.make_datasets()
    model = model.cuda()
    if model_desc['optimizer'][:4] == 'sgld':
        optimizer = eval(model_desc['optimizer'])(
            model.parameters(),
            lr=model_desc['lr'],
            addnoise=model_desc['addnoise'])
    else:
        optimizer = eval(model_desc['optimizer'])(model.parameters(),
                                                  lr=model_desc['lr'])
    xp = train(model, train_loader, test_loader, optimizer, model_desc,
               cuda_device)

    return xp
示例#2
0
def start_backend():
    parser = argparse.ArgumentParser()
    parser.add_argument("database_url")
    parser.add_argument("--host", default="127.0.0.1")
    args = parser.parse_args()

    app = Flask(__name__)
    # this is for the ability to run everything from different servers.
    cors = CORS(app)
    api = Api(app)

    xp, ts, mp = labnotebook.initialize(args.database_url)

    def flip_dict(dict_list):
        """
        helper function to flip list of dicts to dict of lists
        """
        result = defaultdict(list)

        for dic in dict_list:
            for k, v in dic.items():
                result[k].append(v)

        return result

    class Steps(Resource):
        def get(self, run_id):
            query = labnotebook.session.query(
                ts.step_id, ts.timestep, ts.run_id, ts.trainacc, ts.valacc,
                ts.trainloss).filter(ts.run_id == run_id).all()
            result = labnotebook.runs_schema.dump(query)[0]

            return jsonify(flip_dict(result))

    class Experiments(Resource):
        def get(self):
            query = labnotebook.session.query(
                xp.run_id, xp.dt, xp.gpu, xp.completed, xp.final_trainacc,
                xp.final_trainloss, xp.final_valacc,
                xp.model_desc).order_by(xp.run_id.desc()).all()
            result = labnotebook.xps_schema.dump(query)[0]

            return jsonify(result)

    class CustomFieldNames(Resource):
        def get(self, run_id):
            query = labnotebook.session.query(
                func.jsonb_object_keys(ts.custom_fields)).filter(
                    ts.run_id == run_id).filter(ts.timestep == 1).all()

            query = [q[0] for q in query]

            return jsonify(query)

    class CustomFields(Resource):
        def get(self, run_id):
            parser = reqparse.RequestParser()
            parser.add_argument('fieldname', type=str)

            fieldname = parser.parse_args()['fieldname']

            query = labnotebook.session.query(
                ts.timestep, ts.custom_fields[fieldname].label('cf')).filter(
                    ts.run_id == run_id).all()

            result = labnotebook.cfs_schema.dump(query)[0]

            return jsonify(flip_dict(result))

    # our api:
    api.add_resource(Steps, '/steps/<string:run_id>')
    api.add_resource(CustomFields, '/customfields/<string:run_id>')
    api.add_resource(CustomFieldNames, '/customfieldnames/<string:run_id>')
    api.add_resource(Experiments, '/experiments')

    app.run(debug=True, host=args.host, port=3000)
import labnotebook
from mas_tools.api import Binance

# logging
path = 'E:/Projects/market-analysis-system/mas_arbitrage/'
logging.basicConfig(level=logging.INFO,
                    handlers=[
                        logging.FileHandler("{p}/{fn}.log".format(
                            p=path, fn='bot_0.0')),
                        logging.StreamHandler()
                    ])
log = logging.getLogger()

# labnotebook
db_url = 'postgres://*****:*****@localhost/postgres'
experiments, steps, model_params = labnotebook.initialize(db_url)

# cripto exchange api
MY_API_KEY = '---'
MY_API_SECRET = '---'
api = Binance(MY_API_KEY, MY_API_SECRET)

# parameters
symb1 = 'BTCUSDT'
symb2 = 'ETHUSDT'
period = '1m'
arbitrage_sum = True
coef = 15
h_level = 1345
l_level = 1180