Beispiel #1
0
 def test_default(self):
     registry = Registry()
     registry.set('a', 'aaa')
     self.assertEqual(registry.default('a', 'bbb'), 'aaa')
     self.assertEqual(registry.get('a'), 'aaa')
     self.assertEqual(registry.default('c', 'ccc'), 'ccc')
     self.assertEqual(registry.get('c', 'ccc'), 'ccc')
Beispiel #2
0
 def test_default(self):
     """测试设置默认值"""
     registry = Registry({'a': 'aaa'})
     self.assertEqual(registry.default('a', 'bbb'), 'aaa')
     self.assertEqual(registry.get('a'), 'aaa')
     self.assertEqual(registry.default('c', 'ccc'), 'ccc')
     self.assertEqual(registry.get('c', 'ccc'), 'ccc')
Beispiel #3
0
 def add_options(parser, train):
     if train:
         group = parser.add_argument_group('loss computation')
         group.add_argument('--loss',
                            choices=Registry.keys('loss'),
                            default=Registry.default('loss'),
                            help='GAN loss')
         group.add_argument(
             '--gradient-penalty-factor',
             type=float,
             default=10,
             help='gradient penalty factor (lambda in WGAN-GP)')
         group.add_argument('--soft-labels',
                            action=YesNoAction,
                            help='use soft labels in GAN loss')
         group.add_argument('--noisy-labels',
                            action=YesNoAction,
                            help='use noisy labels in GAN loss')
         group.add_argument(
             '--noisy-labels-frequency',
             type=float,
             default=0.1,
             help='how often to use noisy labels in GAN loss')
         group.add_argument(
             '--l1-weight',
             type=float,
             default=1,
             help='weight of the L1 distance contribution to the GAN loss')
Beispiel #4
0
 def add_options(parser, train):
     if train:
         group = parser.add_argument_group('training options')
         group.add_argument('--game',
                            choices=Registry.keys('game'),
                            default=Registry.default('game'),
                            help='type of game')
         group.add_argument('--generator-iterations',
                            type=int,
                            default=1,
                            help='number of iterations for the generator')
         group.add_argument(
             '--discriminator-iterations',
             type=int,
             default=1,
             help='number of iterations for the discriminator')
         group.add_argument('--generator-lr',
                            type=float,
                            default=1e-4,
                            help='learning rate for the generator')
         group.add_argument('--discriminator-lr',
                            type=float,
                            default=1e-4,
                            help='learning rate for the discriminator')
         group.add_argument('--beta1',
                            type=float,
                            default=0,
                            help='first beta')
         group.add_argument('--beta2',
                            type=float,
                            default=0.9,
                            help='second beta')
         group.add_argument('--max-batches-per-epoch',
                            type=int,
                            help='maximum number of minibatches per epoch')
Beispiel #5
0
 def add_options(parser, train):
     group = parser.add_argument_group('result snapshotting')
     group.add_argument('--save-images-as',
                        choices=Registry.keys('snapshot'),
                        default=Registry.default('snapshot'),
                        help='how to save the output images')
     group.add_argument(
         '--output-dir',
         default='output',
         help='directory where to store the generated images')
     group.add_argument(
         '--snapshot-size',
         type=int,
         default=16,
         help=
         'how many images to generate for each sample (must be <= batch-size)'
     )
     group.add_argument('--sample-every',
                        type=int,
                        default=10,
                        help='how often to sample images (in epochs)')
     group.add_argument(
         '--sample-from-fixed-noise',
         action=YesNoAction,
         help='always use the same input noise when sampling')
     group.add_argument(
         '--snapshot-translate',
         action=YesNoAction,
         help='generate snapshots for an image translation task')
Beispiel #6
0
 def add_options(parser, train):
     if train:
         group = parser.add_argument_group('model evaluation')
         group.add_argument(
             '--evaluation-criterion',
             choices=Registry.keys('evaluation'),
             default=Registry.default('evaluation'),
             help='the criterion to evaluate model improvement')
Beispiel #7
0
 def add_options(parser, train):
     if train:
         group = parser.add_argument_group('logging')
         group.add_argument('--log',
                            choices=Registry.keys('log'),
                            default=Registry.default('log'),
                            help='logging format')
         group.add_argument('--log-file', help='file to log statistics')
Beispiel #8
0
class Aria2Local(object):
    process: subprocess.Popen or None

    def __init__(self, program='aria2c', **kwargs):
        self.program = program
        self.process = None
        self.args = Registry(kwargs.get('args'))
        self.args.default('rpc-listen-port', '6800')
        self.args.default('enable-rpc', 'true')
        self.args.default('rpc-allow-origin-all', 'true')
        self.args.default('rpc-listen-all', 'true')

    def set_arg(self, key, value):
        self.args.set(key, value)

    def get_args_string(self):
        return ' '.join([
            '--%s=%s' % (key, value) for key, value in self.args.get().items()
        ])

    def start(self):
        """启动服务"""
        cmd = '%s %s' % (self.program, self.get_args_string())
        self.process = subprocess.Popen(cmd,
                                        shell=False,
                                        stdout=subprocess.DEVNULL,
                                        stderr=subprocess.DEVNULL)
        return self.process.pid

    def stop(self):
        """停止服务"""
        if self.process is not None:
            parent = psutil.Process(self.process.pid)
            for child in parent.children(recursive=True):
                child.terminate()
            parent.terminate()
            return self.process.wait()
        else:
            return None

    def is_install(self):
        """是否安装"""
        cmd = '%s --version' % self.program
        process = subprocess.Popen(cmd,
                                   shell=True,
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE)
        stderr_line = process.stderr.readline()
        if stderr_line == b'':
            return True
        else:
            return False

    def is_running(self):
        """是否运行"""
        return self.process is not None and self.process.poll() is None
Beispiel #9
0
 def add_options(parser, train):
     group = parser.add_argument_group('noise generation')
     group.add_argument('--noise',
                        choices=Registry.keys('noise'),
                        default=Registry.default('noise'),
                        help='type of noise')
     group.add_argument('--state-size',
                        type=int,
                        default=128,
                        help='state size')
Beispiel #10
0
 def add_options(parser, train):
     group = parser.add_argument_group('discriminator')
     group.add_argument('--discriminator',
                        choices=Registry.keys('discriminator'),
                        default=Registry.default('discriminator'),
                        help='type of discriminator')
     group.add_argument('--discriminator-dropout',
                        type=float,
                        help='dropout coefficient in discriminator layers')
     group.add_argument('--discriminator-layers',
                        type=int,
                        default=4,
                        help='number of discriminator layers')
     group.add_argument('--discriminator-channels',
                        type=int,
                        default=4,
                        help='number of channels for the discriminator')
class MySQL(object):

    def __init__(self, *args, **kwargs):
        self.options = Registry(kwargs.get('options', {}))
        self.pool = None
        self.server = None
        self.state = kwargs.get('state', True)

        self.options.default('host', '127.0.0.1')
        self.options.default('user', 'root')
        self.options.default('password', '')
        self.options.default('charset', 'utf8')

    def set_option(self, key, value):
        self.options.set(key, value)

    def get_pool(self) -> PooledDB:
        if self.pool is None:
            self.pool = PooledDB(creator=pymysql, cursorclass=DictCursor, **self.options.get())
        return self.pool

    def reconnect(self):
        self.pool = None
        self.server = self.get_pool().connection()

    def get_tmp_server(self) -> Connection:
        return self.get_pool().connection()

    def get_server(self) -> Connection:
        if self.server is None:
            self.server = self.get_pool().connection()
        return self.server

    def check_state(self) -> bool:
        try:
            self.server.ping()
            self.state = True
        except:
            self.state = False
        return self.state

    def check_wait(self, interval_time=60):
        while self.check_state() is False:
            time.sleep(interval_time)

    def __new__(cls, *args, **kwargs):
        instance = kwargs.get('instance', 0)
        if not hasattr(cls, '_instances'):
            cls._instances = {}
        if instance not in cls._instances:
            cls._instances[instance] = object.__new__(cls)
        return cls._instances[instance]
Beispiel #12
0
 def add_options(parser, train):
     group = parser.add_argument_group('data loading')
     group.add_argument('--data-format',
                        choices=Registry.keys('data'),
                        default=Registry.default('data'),
                        help='type of dataset')
     group.add_argument('--data-dir',
                        default='data',
                        help='directory with the images')
     group.add_argument('--dataset',
                        choices=[
                            'folder', 'mnist', 'emnist', 'fashion-mnist',
                            'lsun', 'cifar10', 'cifar100'
                        ],
                        default='folder',
                        help='source of the dataset')
     group.add_argument('--image-class',
                        help='class to train on, only for some datasets')
     group.add_argument('--image-size',
                        type=int,
                        default=64,
                        help='image dimension')
     group.add_argument('--image-colors',
                        type=int,
                        default=3,
                        help='image colors')
     group.add_argument('--split',
                        choices=['horizontal', 'vertical'],
                        help='how to split an image pair')
     group.add_argument('--batch-size',
                        type=int,
                        default=64,
                        help='batch size')
     group.add_argument('--loader-workers',
                        type=int,
                        default=4,
                        help='number of threads loading data')
     group.add_argument('--pin-memory',
                        action=YesNoAction,
                        help='pin memory to CPU cores for loading data')
class Redis(object):

    def __init__(self, **kwargs):
        self.options = Registry(kwargs)
        self.pool = None
        self.server = None
        self.state = kwargs.get('state', False)

        self.options.default('host', '127.0.0.1')
        self.options.default('port', 6379)
        self.options.default('db', 0)
        self.options.default('decode_responses', True)

    def set_option(self, key, value):
        self.options.set(key, value)

    def reconnect(self):
        self.pool = redis.ConnectionPool(**self.options.get())
        self.server = redis.Redis(connection_pool=self.pool)

    def get_server(self) -> redis.Redis:
        if self.server is None:
            self.reconnect()
        return self.server

    def check_state(self):
        try:
            self.server.ping()
            self.state = True
        except:
            self.state = False
        return self.state

    def check_wait(self, interval_time=60):
        while self.check_state() is False:
            time.sleep(interval_time)
Beispiel #14
0
class App(object):
    def __init__(self, args):
        self.args = args
        self.options = Registry()
        self.init_options()
        self.init_env()
        self.client = self.mysql_server()

    def init_options(self):
        self.options.load(dict(args=self.args))
        self.options.default('args.<hostname>', 'localhost')
        self.options.default('args.<password>', '')

    def dispose(self):
        args = self.options.get('args')
        if args.get('adduser'):
            self.adduser()
        elif args.get('passwd'):
            self.passwd()
        elif args.get('create'):
            if args.get('database'):
                self.create_database()
            else:
                pass
        else:
            pass

    def mysql_server(self):
        options = self.options.get('server')
        [
            logging.debug('MySQL Server: %s=%s' % (key, value))
            for key, value in options.items()
        ]
        try:
            return pymysql.connect(cursorclass=pymysql.cursors.DictCursor,
                                   **options)
        except pymysql.err.MySQLError as ex:
            logging.error(ex)
            exit(ex.args[0])

    def init_env(self):
        """加载环境变量"""
        self.options.set(
            'server',
            dict(
                host=os.getenv('MYSQL_HOST', self.options.get('args.--host')),
                port=int(
                    os.getenv('MYSQL_PORT', self.options.get('args.--port'))),
                user=os.getenv('MYSQL_USER', self.options.get('args.--user')),
                password=os.getenv('MYSQL_PASSWORD',
                                   self.options.get('args.--password')),
                charset=os.getenv('MYSQL_CHARSET',
                                  self.options.get('args.--charset')),
            ))

    def execute(self, sql):
        sql_list = [i for i in sql.split('\n') if i != '']
        with self.client.cursor() as cursor:
            for sql in sql_list:
                logging.debug('Execute Sql: %s' % sql)
                cursor.execute(sql)
        self.client.commit()

    def adduser(self):
        self.execute(__SQL_ADDUSER__ % dict(
            hostname=self.options.get('args.<hostname>'),
            username=self.options.get('args.<username>'),
            password=self.options.get('args.<password>'),
        ))

    def passwd(self):
        self.execute(__SQL_PASSWD__ % dict(
            hostname=self.options.get('args.<hostname>'),
            username=self.options.get('args.<username>'),
            password=self.options.get('args.<password>'),
        ))

    def create_database(self):
        self.execute(__SQL_CREATE_DATABASE__ %
                     dict(database=self.options.get('args.<database>'), ))