def load_config(cfg_dir, nopause=False):
    ''' 
        Raises: 
            FileNotFoundError if 'config.py' doesn't exist in cfg_dir
    '''
    if not os.path.isfile(os.path.join(cfg_dir, 'config.py')):
        raise ImportError('config.py not found in {0}'.format(cfg_dir))
    import sys
    try:
        del sys.modules['config']
    except:
        pass

    sys.path.insert(0, cfg_dir)
    import config as loading_config
    # cleanup
    # print([ v for v in sys.modules if "config" in v])
    # return
    cfg = loading_config.get_cfg(nopause)

    try:
        del sys.modules['config']
    except:
        pass
    sys.path.remove(cfg_dir)

    return cfg
예제 #2
0
def main():
    cfg = get_cfg()

    trainer = DefaultTrainer(cfg)

    trainer.resume_or_load(resume=args.resume)
    return trainer.train()
예제 #3
0
파일: main.py 프로젝트: MaureenZOU/ECS277
def main():
    cfg = get_cfg()

    if cfg.interpolation == 'linear':
        print('Running Linear interpolation ...')
        triLinear(cfg)
    elif cfg.interpolation == 'cubic':
        print('Running Cubic interpolation ...')
        triCubic(cfg)
    else:
        assert False, "Interpolation function not implemented."
예제 #4
0
 def get_config(self):
     # Get time constant from config
     try:
         config = get_cfg()
         self.time_const = config['time constant']
         self.state_size = config['state size']
         self.look_time = config['look time']
         self.threshold = config['threshold']
         self.conf_labels = config['labels']
         self.rnn_size = config['rnn size']
         #self.duration = config['duration']
         #self.batch_size = config['batch size']
     except Exception as e:  # pragma: no cover
         self.logger.error(
             "unable to read 'opts/config.json' properly because: %s",
             str(e))
     return
예제 #5
0
def main():
    """
    Main function to spawn the train and test process.
    """
    cfg = get_cfg()

    if cfg.TRAIN.ENABLE:
        if cfg.NUM_GPUS > 1:
            torch.multiprocessing.spawn(
                mpu.run,
                nprocs=cfg.NUM_GPUS,
                args=(
                    cfg.NUM_GPUS,
                    train,
                    cfg.DIST_INIT_METHOD,
                    cfg.SHARD_ID,
                    cfg.NUM_SHARDS,
                    cfg.DIST_BACKEND,
                    cfg,
                ),
                daemon=False,
            )
        else:
            train(cfg=cfg)

    if cfg.TEST.ENABLE:
        if cfg.NUM_GPUS > 1:
            torch.multiprocessing.spawn(
                mpu.run,
                nprocs=cfg.NUM_GPUS,
                args=(
                    cfg.NUM_GPUS,
                    test,
                    cfg.DIST_INIT_METHOD,
                    cfg.SHARD_ID,
                    cfg.NUM_SHARDS,
                    cfg.DIST_BACKEND,
                    cfg,
                ),
                daemon=False,
            )
        else:
            test(cfg=cfg)
예제 #6
0
def load_config( cfg_dir, nopause=False ):
    ''' 
        Raises: 
            FileNotFoundError if 'config.py' doesn't exist in cfg_dir
    '''
    if not os.path.isfile( os.path.join( cfg_dir, 'config.py' ) ):
        raise ImportError( 'config.py not found in {0}'.format( cfg_dir ) )
    import sys
    sys.path.insert( 0, cfg_dir )
    from config import get_cfg
    cfg = get_cfg( nopause )
    # cleanup
    try:
        del sys.modules[ 'config' ]
    except:
        pass
    sys.path.remove(cfg_dir)

    return cfg
예제 #7
0
    def evaluate(self, test_data, iter_idx):
        """返回预测正确的个数"""
        g_max_close = cfg.get_value('g_max_close')
        g_min_close = cfg.get_value('g_min_close')
        denominator = g_max_close - g_min_close

        cnt = 0
        ture_value = []
        pred_value = []
        percentage_error = []
        for X, y_ in test_data:
            y = self.forward(X, flag=0)

            y = y * denominator + g_min_close
            y_ = y_ * denominator + g_min_close

            if abs(y - y_) <= 1e-1:
                cnt += 1
                # print('Predict the outcome and Real results: {0}  {1}'.format(y, y_))
            ture_value.append(y_)
            pred_value.extend(y[0])
            err = "%.2f%%" % (abs(y_ - y[0])*100/y_)
            percentage_error.append(err)

        hidden_layer_num, node_num, epoch = cfg.get_cfg(True, True, True)
        column_num_begin = "l_{0}-N_{1}_b_{2}/{3}".format(hidden_layer_num, node_num, iter_idx, epoch)
        column_num_opt = "l_{0}-N_{1}_opt_{2}/{3}".format(hidden_layer_num, node_num, iter_idx, epoch)
        percentage_error_name_begin = "per_b_{0}/{1}".format(iter_idx, epoch)
        percentage_error_name_opt = "per_opt_{0}/{1}".format(iter_idx, epoch)
        fianl_result = pd.DataFrame({'ture_value': ture_value,
                                     column_num_begin: pred_value,
                                     percentage_error_name_begin: percentage_error,
                                     column_num_opt: pred_value,
                                     percentage_error_name_opt: percentage_error},
                                    index=cfg.get_idx())
        return cnt, fianl_result
예제 #8
0
import torch
import torch.nn as nn
import torch.optim as optim
import config

CFG = config.get_cfg().settings.toy_corewar

torch.set_default_tensor_type('torch.FloatTensor')


class Dueling_DQN(nn.Module):
    def __init__(self, h_size, middle_size, lstm_layers):
        super(Dueling_DQN, self).__init__()
        self.num_lstm_layers = lstm_layers

        input_size = CFG.N_INSTRUCTIONS + CFG.N_VARS * CFG.NUM_REGISTERS + 1
        s_size = CFG.N_TARGETS * 2

        self.lstm_p = nn.LSTM(input_size=input_size,
                              hidden_size=h_size,
                              num_layers=lstm_layers)
        self.fc_s1 = nn.Linear(in_features=s_size, out_features=s_size)
        self.fc_s2 = nn.Linear(in_features=s_size, out_features=s_size)

        self.fc1 = nn.Linear(in_features=(h_size + s_size),
                             out_features=middle_size)

        self.fc_adv = nn.Linear(in_features=middle_size,
                                out_features=CFG.NUM_ACTIONS)
        self.fc_val = nn.Linear(in_features=middle_size, out_features=1)
예제 #9
0
#!/usr/bin/python
# -*- coding: utf8 -*-

import requests
import logging
import time
import os
import subprocess
import sys
from threading import Thread
from Queue import Queue
from NetworkHelper import BatchRequests
from config import get_cfg

chunk_size = int(get_cfg('chunk_size'))
download_threads = int(get_cfg('download_threads'))

if sys.platform == 'darwin':
    AXEL_PATH = os.path.join(os.path.dirname(__file__), "darwin", "axel")
else:
    AXEL_PATH = os.path.join(os.path.dirname(__file__), "axel")


class Downloader:
    def __init__(self,
                 url,
                 download_threads,
                 chunk_size,
                 start_percent=0,
                 outfile=None,
                 file_seq=0,
예제 #10
0
from helpers.bithumb_api import BithumbGlobalRestAPI
from config import get_cfg

cfg = get_cfg('bithumb')
bh = BithumbGlobalRestAPI(cfg['key'], cfg['secret'])


def get_all_pairs_list():
    all_pairs = bh.all_pairs()
    pairs = []
    for pair in all_pairs:
        pairs.append(pair['symbol'])

    return pairs
예제 #11
0
#!/usr/bin/python
# -*- coding: utf8 -*-
from webparser import Video
import sqlite3
import time
import logging
from config import get_cfg
#CREATE TABLE media(id INTEGER PRIMARY KEY, title TEXT NOT NULL, url TEXT NOT NULL, last_play_date INTEGER, last_play_pos INTEGER, duration INTEGER, site TEXT)
dbStorage = get_cfg('dbStorage')


def db_getHistory():
    con = sqlite3.connect(dbStorage)
    con.row_factory = sqlite3.Row
    with con:
        c = con.cursor()
        c.execute("SELECT * \
            FROM media order by last_play_date desc")
        data = c.fetchall()
        videos = []
        for d in data:
            videos.append(
                Video(str(d['title']),
                      str(d['url']),
                      '',
                      d['duration'],
                      d['site'],
                      dbid=d['id'],
                      progress=int(d['last_play_pos'])))
        return videos
예제 #12
0
    for order in orders:
        threads.append(ThreadWithReturnValue(target=get_order, args=(order, )))

    results = []

    for thread in threads:
        thread.start()
        time.sleep(latency)

    for thread in threads:
        results.append(thread.join())

    return results


cfg = get_cfg('bithumb_info')
bh = BithumbGlobalRestAPI(cfg['key'], cfg['secret'])


def check_orders():
    orders = Order.select().where(Order.status == 'active')

    if len(orders) == 0:
        return

    r = multi_request(orders)

    success = []
    cancel = []
    for i, order in enumerate(orders, 0):
        bh_status = r[i]['status']
예제 #13
0
파일: server.py 프로젝트: benzhe/VideoPI
import urllib
import urllib2
import bottle
from urlparse import urlparse
from Constants import websites, actionDesc
from Helper import newFifo, newDir
from database import db_getHistory, db_delete
from player import Controller
from config import get_cfg
from webparser import Video

bottle.debug = True

import logging
import logging.handlers
handler = logging.handlers.RotatingFileHandler(get_cfg('logStorage'), maxBytes=1000000, backupCount=2)
formatter = logging.Formatter(fmt='%(asctime)s %(threadName)s %(module)s:%(lineno)d %(levelname)s: %(message)s')
handler.setFormatter(formatter)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)


def exceptionLogger(type, value, tb):
    logging.exception("Uncaught exception: %s", value)
    sys.__excepthook__(type, value, tb)


sys.excephook = exceptionLogger

controller = None
예제 #14
0
파일: database.py 프로젝트: benzhe/VideoPI
#!/usr/bin/python
# -*- coding: utf8 -*-
from webparser import Video
import sqlite3
import time
import logging
from config import get_cfg
#CREATE TABLE media(id INTEGER PRIMARY KEY, title TEXT NOT NULL, url TEXT NOT NULL, last_play_date INTEGER, last_play_pos INTEGER, duration INTEGER, site TEXT)
dbStorage = get_cfg('dbStorage')


def db_getHistory():
    con = sqlite3.connect(dbStorage)
    con.row_factory = sqlite3.Row
    with con:
        c = con.cursor()
        c.execute("SELECT * \
            FROM media order by last_play_date desc")
        data = c.fetchall()
        videos = []
        for d in data:
            videos.append(Video(str(d['title']), str(d['url']), '', d['duration'], d['site'], dbid=d['id'], progress=int(d['last_play_pos'])))
        return videos


def db_writeHistory(video):
    con = sqlite3.connect(dbStorage)
    # logging.debug("Write: %s", video)
    with con:
        cur = con.cursor()
        if video.dbid:
예제 #15
0
"""
Main program
May the Force be with you.

This main file is used on slurm server without interactive check of config
"""
from torch.utils.data import DataLoader

from dataset import get_dataset
from logger import get_logger
from core.models import get_model
from core.trainer import Trainer
from config import get_cfg

# preparer configuration
cfg = get_cfg(interactive=False)

# prepare dataset
DatasetClass = get_dataset(cfg.DATASET)
dataloader_dict = dict()
for mode in cfg.MODES:
    phase_dataset = DatasetClass(cfg, mode=mode)
    dataloader_dict[mode] = DataLoader(
        phase_dataset,
        batch_size=cfg.BATCHSIZE,
        shuffle=True if mode in ['train'] else False,
        num_workers=cfg.DATALOADER_WORKERS,
        pin_memory=True,
        drop_last=True)

# prepare models
예제 #16
0
파일: webparser.py 프로젝트: benzhe/VideoPI
import urllib2
import json
import logging
import subprocess
import time
import requests
try:
    import xml.etree.cElementTree as ET
except ImportError:
    import xml.etree.ElementTree as ET
from urlparse import urlparse, parse_qs
from struct import unpack
from NetworkHelper import BatchRequests
from config import get_cfg

playlistStorage = get_cfg('playlistStorage')
default_format = int(get_cfg('default_format'))

format2keyword = {
    1: "",
    2: "high",
    3: "super",
    4: "orig"
}

proxies = {"http": "http://h0.edu.bj.ie.sogou.com"}


class Video:
    formatDict = {
        1: "普通",
예제 #17
0
import re
import subprocess
import os
import signal
import logging

from threading import Thread
from config import get_cfg
additonal_omxplayer_args = get_cfg('additonal_omxplayer_args')
zoom = float(get_cfg('zoom'))


class OMXPlayer(object):

    _STATUS_REXP = re.compile(r"[MV]\s*:\s*(?P<position>[\d]+).*")

    _LAUNCH_CMD = 'nice -n -1 /usr/bin/omxplayer -s "%s" %s < /tmp/cmd \n'
    _PAUSE_CMD = 'p'
    _TOGGLE_SUB_CMD = 's'
    _QUIT_CMD = 'q'
    _VOLUP_CMD = '+'
    _VOLDOWN_CMD = '-'

    _SCRIPT_NAME = '/tmp/play.sh'

    def __init__(self, currentVideo, screenWidth=0, screenHeight=0):
        self.currentVideo = currentVideo
        self.screenWidth = screenWidth
        self.screenHeight = screenHeight
        self._process = None
예제 #18
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths

from config import get_cfg, parse_args
from engine import CenterNet

if __name__ == '__main__':
    cfg = get_cfg()
    # args = parse_args()
    cfg.model_cfg = 'centernet_RD_34_fpn_obj_s2_2x.yaml'
    cfg.merge_from_file()
    # cfg.update_dict(args)
    # 设置cfg参数
    """
    for example
        cfg.down_ratio = 2
    """
    print(cfg)

    centernet = CenterNet()
    centernet.train(cfg)
예제 #19
0
    def __init__(self, args, port):
        cfg = get_cfg()
        cfg.merge_from_file(args.config)
        self.cfg = cfg
        self.port = port
        assert os.path.exists(
            'saved_models'
        ), "Create a path to save the trained models: <default: ./saved_models> "
        self.model_dir = os.path.join('saved_models', cfg.NAME)
        self.writer = SummaryWriter(
            log_dir=os.path.join(self.model_dir, "summary"))
        self.iteration = 0
        print("Arguments used: {}".format(args), flush=True)

        self.trainset, self.testset = get_datasets(cfg)
        self.model = get_model(cfg)
        print("Using model: {}".format(self.model.__class__), flush=True)

        if torch.cuda.is_available() and torch.cuda.device_count() > 1:
            self.model, self.optimiser = self.init_distributed(cfg)
        # TODO: do not use distributed package in this case
        elif torch.cuda.is_available():
            self.model, self.optimiser = self.init_distributed(cfg)
        else:
            raise RuntimeError("CUDA not available.")

        # self.model, self.optimiser, self.start_epoch, start_iter = \
        #   load_weightsV2(self.model, self.optimiser, args.wts, self.model_dir)
        self.lr_schedulers = get_lr_schedulers(self.optimiser, cfg,
                                               self.start_epoch)
        self.batch_size = self.cfg.TRAINING.BATCH_SIZE

        args.world_size = 1
        print(args)
        self.args = args
        self.epoch = 0
        self.best_loss_train = math.inf
        self.losses = AverageMeterDict()
        self.ious = AverageMeterDict()

        num_samples = None if cfg.DATALOADER.NUM_SAMPLES == -1 else cfg.DATALOADER.NUM_SAMPLES
        if torch.cuda.device_count() > 1:
            # shuffle parameter does not seem to shuffle the data for distributed sampler
            self.train_sampler = torch.utils.data.distributed.DistributedSampler(
                torch.utils.data.RandomSampler(self.trainset,
                                               replacement=True,
                                               num_samples=num_samples),
                shuffle=True)
        else:
            self.train_sampler = torch.utils.data.RandomSampler(self.trainset, replacement=True, num_samples=num_samples) \
              if num_samples is not None else None
        shuffle = True if self.train_sampler is None else False
        self.trainloader = DataLoader(self.trainset,
                                      batch_size=self.batch_size,
                                      num_workers=cfg.DATALOADER.NUM_WORKERS,
                                      shuffle=shuffle,
                                      sampler=self.train_sampler)

        print(
            summary(self.model,
                    tuple((3, cfg.INPUT.TW, 256, 256)),
                    batch_size=1))
예제 #20
0
from peewee import Model, CharField, IntegerField, DateTimeField, PostgresqlDatabase, TextField
from playhouse.postgres_ext import JSONField

from datetime import datetime

from config import get_cfg

db_cfg = get_cfg('db')
db_conn = {
    'host': db_cfg['host'],
    'user': db_cfg['user'],
    'password': db_cfg['password'],
    'database': db_cfg['database'],
    'autorollback': db_cfg['autorollback']
}
db = PostgresqlDatabase(**db_conn)

msg_ids = {"b_info": 0, "b_main": 0}


class BaseModel(Model):
    class Meta:
        database = db


class User(BaseModel):
    tg_id = IntegerField()
    username = CharField()
    registered = DateTimeField(default=datetime.now)

예제 #21
0
import torch
from collections import deque, namedtuple
from game.environment import Env
import config

CFG = config.get_cfg()

Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'done'))

class LinearSchedule(object):
    def __init__(self, schedule_episodes, total_episodes, final_p, initial_p=1.0):
        """Linear interpolation between initial_p and final_p over
        schedule_episodes. After this many episodes pass final_p is
        returned.
        Parameters
        ----------
        schedule_episodes: int
            Number of episodes for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        if isinstance(schedule_episodes, float):
            self.schedule_episodes = int(schedule_episodes * total_episodes)
        else:
            self.schedule_episodes = schedule_episodes
        self.final_p            = final_p
        self.initial_p          = initial_p
예제 #22
0
from cryptography.fernet import Fernet
from config import get_cfg

key = str.encode(get_cfg('dec')['key'])


def encr(data):
    return Fernet(key).encrypt(data.encode('UTF-8')).decode('UTF-8')


def decr(data):
    return Fernet(key).decrypt(data.encode('UTF-8')).decode('UTF-8')
예제 #23
0
import re
import subprocess
import os
import signal
import logging

from threading import Thread
from config import get_cfg
additonal_omxplayer_args = get_cfg('additonal_omxplayer_args')
zoom = float(get_cfg('zoom'))


class OMXPlayer(object):

    _STATUS_REXP = re.compile(r"[MV]\s*:\s*(?P<position>[\d]+).*")

    _LAUNCH_CMD = 'nice -n -1 /usr/bin/omxplayer -s "%s" %s < /tmp/cmd \n'
    _PAUSE_CMD = 'p'
    _TOGGLE_SUB_CMD = 's'
    _QUIT_CMD = 'q'
    _VOLUP_CMD = '+'
    _VOLDOWN_CMD = '-'

    _SCRIPT_NAME = '/tmp/play.sh'

    def __init__(self, currentVideo, screenWidth=0, screenHeight=0):
        self.currentVideo = currentVideo
        self.screenWidth = screenWidth
        self.screenHeight = screenHeight
        self._process = None
예제 #24
0
import urllib
import urllib2
import bottle
from urlparse import urlparse
from Constants import websites, actionDesc
from Helper import newFifo, newDir
from database import db_getHistory, db_delete
from player import Controller
from config import get_cfg
from webparser import Video

bottle.debug = True

import logging
import logging.handlers
handler = logging.handlers.RotatingFileHandler(get_cfg('logStorage'),
                                               maxBytes=1000000,
                                               backupCount=2)
formatter = logging.Formatter(
    fmt=
    '%(asctime)s %(threadName)s %(module)s:%(lineno)d %(levelname)s: %(message)s'
)
handler.setFormatter(formatter)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)


def exceptionLogger(type, value, tb):
    logging.exception("Uncaught exception: %s", value)
    sys.__excepthook__(type, value, tb)
예제 #25
0
파일: main.py 프로젝트: MaureenZOU/ECS277
def main():
    cfg = get_cfg()
    # volume = load_raw(cfg.data_root, cfg.vol_size)
    volume = get_test_volume(cfg.vol_size)

    scattered_tree, F, sample_points, mask, f_volume = sample_scatter(
        volume, sample_num=cfg.sample_num)
    F_locs = torch.tensor(scattered_tree.get_arrays()[0])
    F = torch.tensor(F)
    mask = torch.tensor(mask).float()
    f_volume = torch.tensor(f_volume).float()
    sample_points = torch.tensor(sample_points)
    # sample points (w,h,d,n)

    if cfg.method == 'global_shaper_s2':
        out = global_shaper_s2(sample_points, F, F_locs, cfg.vol_size, mask,
                               f_volume)
    elif cfg.method == 'local_shaper_s2':
        out = local_shaper_s2(sample_points,
                              scattered_tree,
                              cfg.vol_size,
                              F_locs,
                              F,
                              mask,
                              f_volume,
                              k=cfg.knn)
    elif cfg.method == 'global_hardy_negative':
        out = global_hardy_negative(sample_points, F, F_locs, cfg.vol_size,
                                    mask, f_volume)
    elif cfg.method == 'global_hardy_positive':
        out = global_hardy_positive(sample_points, F, F_locs, cfg.vol_size,
                                    mask, f_volume)
    elif cfg.method == 'local_hardy_negative':
        out = local_hardy_negative(sample_points,
                                   F,
                                   F_locs,
                                   cfg.vol_size,
                                   mask,
                                   f_volume,
                                   scattered_tree,
                                   k=cfg.knn)
    elif cfg.method == 'local_hardy_positive':
        out = local_hardy_positive(sample_points,
                                   F,
                                   F_locs,
                                   cfg.vol_size,
                                   mask,
                                   f_volume,
                                   scattered_tree,
                                   k=cfg.knn)
    elif cfg.method == 'global_shaper_s3':
        out = global_shaper_s3(sample_points,
                               scattered_tree,
                               cfg.vol_size,
                               F_locs,
                               F,
                               mask,
                               f_volume,
                               k=cfg.knn)
    elif cfg.method == 'local_shaper_s3':
        out = local_shaper_s3(sample_points,
                              scattered_tree,
                              cfg.vol_size,
                              F_locs,
                              F,
                              mask,
                              f_volume,
                              k1=8,
                              k2=cfg.knn)

    # plt.colorbar()
    # plt.savefig('out.png')

    plt = render_image(out, plane='xy', axis=32)
    plt.savefig(cfg.file_name + '_xy.png',
                bbox_inches='tight',
                pad_inches=0,
                transparent=True)
    plt.clf()

    # plt = render_image(out, plane='yz', axis=32)
    # plt.savefig(name + '_yz.png', bbox_inches='tight', pad_inches=0, transparent=True)
    # plt.clf()

    # plt = render_image(out, plane='xz', axis=32)
    # plt.savefig(name + '_xz.png', bbox_inches='tight', pad_inches=0, transparent=True)
    # plt.clf()

    plt.close()
예제 #26
0
from Constants import websites
from webparser import Video
from database import db_writeHistory, db_getById
from downloader import MultiDownloader
from show_image import ImgService, FINISHED, LOADING
from config import get_cfg
import logging
import sys
import os
import time

downloadLock = RLock()
playLock = RLock()

imgService = ImgService()
playlistStorage = get_cfg('playlistStorage')


class Player:
    def __init__(self, video):
        self.video = video
        screenWidth, screenHeight = getScreenSize()
        if sys.platform == 'darwin':
            self.player = MPlayerX(self.video, screenWidth, screenHeight)
        else:
            self.player = OMXPlayer(self.video, screenWidth, screenHeight)
        self.from_position = video.progress

    def getUrls(self):
        if self.video.realUrl == playlistStorage:
            with open(playlistStorage, 'r') as f:
예제 #27
0
import os, shutil
import config
try:
    path = os.path.dirname(__file__)
    config.load(os.path.join(path, "config.json"))
    cfg = config.get_cfg()
except:
    print("cannot load config.json")
    exit()

import torch
import reward
import multiprocessing
from multiprocessing import Pool
import argparse
import json
from DQN.DQN_agent import DQN_Agent
from Actor_Critic.AC_agent import AC_Agent

def unpack(args):
    run_training(**args)

def run_training(id, algo, episodes, reward_func, reward_settings, targets, reg_inits, root_dir):
    log_dir = os.path.join(root_dir, str(id))
    os.makedirs(log_dir)

    preset = getattr(cfg.presets, algo)
    agent = globals()[preset.agent](**preset.parameters.todict(), verbose=True, log_dir=log_dir)
    Reward_func = getattr(reward, reward_func)
    agent.train(Reward_func, reward_settings, episodes, targets, reg_inits)
    agent.save("best", best=True)
예제 #28
0
from cryptography.fernet import Fernet
from config import get_cfg

key = str.encode(get_cfg('security')['key'])


def generate_key():
    """
    Генерирует секретный ключ
    """
    return Fernet.generate_key()


def enc(data):
    """
    Шифрует строку при помощи ключа
    """
    return Fernet(key).encrypt(data.encode('UTF-8')).decode('UTF-8')


def dec(data):
    """
    Расшифровывает строку
    """
    return Fernet(key).decrypt(data.encode('UTF-8')).decode('UTF-8')
예제 #29
0
#!/usr/bin/python
# -*- coding: utf8 -*-

import requests
import logging
import time
import os
import subprocess
import sys
from threading import Thread
from Queue import Queue
from NetworkHelper import BatchRequests
from config import get_cfg

chunk_size = int(get_cfg('chunk_size'))
download_threads = int(get_cfg('download_threads'))

if sys.platform == 'darwin':
    AXEL_PATH = os.path.join(os.path.dirname(__file__),"darwin", "axel")
else:
    AXEL_PATH = os.path.join(os.path.dirname(__file__), "axel")


class Downloader:

    def __init__(self, url, download_threads, chunk_size, start_percent=0,
                 outfile=None, file_seq=0, alternativeUrls=[], total_length=None):
        self.url = url
        self.total_length = total_length
        logging.info("Construct downloader for url %s", self.url)
        self.download_threads = download_threads