Exemplo n.º 1
0
Arquivo: train.py Projeto: vzhong/gazp
def main(args):
    args.gpu = torch.cuda.is_available()
    utils.manual_seed(args.seed)
    Model = utils.load_module(args.model)
    cache_file = args.fcache or (os.path.join(
        'cache', 'data_{}_{}.debug.pt'.format(args.model, args.dataset)
        if args.debug else 'data_{}_{}.pt'.format(args.model, args.dataset)))
    splits, ext = torch.load(cache_file, map_location=torch.device('cpu'))
    splits = {k: dataset.Dataset(v) for k, v in splits.items()}
    splits['train'] = Model.prune_train(splits['train'], args)
    splits['dev'] = Model.prune_dev(splits['dev'], args)

    if args.model == 'nl2sql':
        Reranker = utils.load_module(args.beam_rank)
        ext['reranker'] = Reranker(args, ext)
    m = Model(args, ext).place_on_device()

    d = m.get_file('')
    if not os.path.isdir(d):
        os.makedirs(d)

    pprint.pprint(m.get_stats(splits, ext))

    if not args.test_only:
        if not args.skip_upperbound:
            print('upperbound')
            pprint.pprint(m.compute_upperbound(splits['train'][:1000]))
        if args.aug:
            augs = []
            for a in args.aug:
                augs.extend(torch.load(a))
            aug = dataset.Dataset(augs)
            splits['aug'] = Model.prune_train(aug, args)[:args.aug_lim]
            print('aug upperbound')
            pprint.pprint(m.compute_upperbound(aug[:10]))
            # aug_args = copy.deepcopy(args)
            # if 'consistent' not in args.aug:
            #     aug_args.epoch = 10
            # aug_dev = dataset.Dataset(random.sample(splits['train'], 3000))
            # m.run_train(aug, aug_dev, args=aug_args)
        pprint.pprint(m.get_stats(splits, ext))
        m.run_train(dataset.Dataset(splits['train'] + splits.get('aug', [])),
                    splits['dev'],
                    args=args)

    if args.resume:
        m.load_save(fname=args.resume)
    elif args.resumes:
        m.average_saves(args.resumes)
    if args.interactive_eval:
        dev_preds = m.run_interactive_pred(splits['dev'], args, verbose=True)
    else:
        dev_preds = m.run_pred(splits['dev'], args, verbose=True)

    if args.write_test_pred:
        with open(args.write_test_pred, 'wt') as f:
            json.dump(dev_preds, f, indent=2)
        print('saved test preds to {}'.format(args.write_test_pred))

    pprint.pprint(m.compute_metrics(splits['dev'], dev_preds))
Exemplo n.º 2
0
    def load_checks(self, check_ios):
        """
        Load checks from the database.

        Arguments:
            check_ios (Dict(int->List(CheckIO))): Mapping of check IDs to a list of CheckIOs to associate checks with 

        Returns:
            List(Check,int): A list of checks and the ID of their associated systems
        """
        checks = []
        check_rows = db.getall('service_check')
        for check_id, name, system, port, check_string, poller_string in check_rows:
            # Build check
            ios = check_ios[check_id]
            check_function = load_module(check_string)
            poller_class = load_module(poller_string)
            poller = poller_class()
            check = Check(check_id, name, port, check_function,
                          ios, poller)

            # Update link from check IOs to this check
            for check_io in ios:
                check_io.check = check

            checks.append((check, system))
        return checks
Exemplo n.º 3
0
Arquivo: model.py Projeto: vzhong/gazp
 def load_inst(cls, fresume, overwrite=None):
     binary = torch.load(fresume, map_location=torch.device('cpu'))
     args = binary['args']
     ext = binary['ext']
     for k, v in (overwrite or {}).items():
         setattr(args, k, v)
     Model = utils.load_module(args.model)
     if args.model == 'nl2sql':
         Reranker = utils.load_module(args.beam_rank)
         ext['reranker'] = Reranker(args, ext)
     m = Model(args, ext).place_on_device()
     m.load_save(fname=fresume)
     return m
Exemplo n.º 4
0
 def _parse_modules(self, modules, debug):
     L = []
     for mod in modules:
         spam = load_module(mod)
         if debug:
             spam.mtime = self._getmtime(spam.__file__)
         L.append(spam)
     return L
Exemplo n.º 5
0
def load_temp_module(src):
    import tempfile
    import ntpath
    from utils import load_module
    temp = tempfile.NamedTemporaryFile(delete=False)
    temp.file.write(src.encode('utf-8'))
    temp.file.flush()
    module = load_module(ntpath.basename(temp.name), temp.name)
    return module, temp
Exemplo n.º 6
0
 def _add_module(self, mod):
     """Add interceptor(s) and route(s) in module 'mod' and return True
     if there are any, otherwise False."""
     added = False
     m = mod if type(mod) == types.ModuleType else load_module(mod)
     for name in dir(m):
         fn = getattr(m, name)
         if callable(fn):
             if self._add_interceptor(fn) or self._add_route(fn):
                 added = True
     return added
Exemplo n.º 7
0
def create_engine(**kwargs):
    """Create database engine from kwargs.

    create_engine(engine='db.mysql', user='******', password='******')
    """
    global db
    if db and db.driver:
        raise DBError("Database engine has already initialized!")
    kwargs = CaseInsensitiveDict(**kwargs)
    engine = kwargs.pop('ENGINE')
    if isinstance(engine, basestring):
        engine = load_module(engine)
    db = engine(**kwargs)
    return db
Exemplo n.º 8
0
    def load_results(self):
        """
        Update results with any results not yet loaded from the database.
        """
        if self.results is None:
            last_id = 0
            # Setup dict
            self.results = {}
            for team in self.teams:
                self.results[team.id] = {}
                for check in self.checks:
                    self.results[team.id][check.id] = []
        else:
            # If results exist, we can just load the latest ones and keep the old ones
            # Here we find the id of the last result we already have
            last_ids = []
            for team_results in self.results.values():
                for check_results in team_results.values():
                    if len(check_results) != 0:
                        last_ids.append(check_results[-1].id)
            last_id = -1
            if len(last_ids) != 0:
                last_id = max(last_ids)

        rows = db.get('result', ['*'],
                      where='id > %s',
                      orderby='time ASC',
                      args=(last_id))

        # Gather the results
        for result_id, check_id, check_io_id, team_id, check_round, time, poll_input, poll_result, result in rows:
            # Construct the result from the database info
            check = [c for c in self.checks if c.id == check_id][0]
            check_io = [
                cio for cio in self.check_ios if cio.id == check_io_id
            ][0]
            team = [t for t in self.teams if t.id == team_id][0]

            input_class_str, input_args = json.loads(poll_input)
            input_class = utils.load_module(input_class_str)
            poll_input = input_class.deserialize(input_class, input_args,
                                                 self.teams, self.credentials)

            poll_result = json.loads(poll_result)[1]

            res = Result(result_id, check, check_io, team, check_round, time,
                         poll_input, poll_result, result)

            self.results[team_id][check_id].append(res)
Exemplo n.º 9
0
    def post(self):
        """
        New a Dbot
        This API need authorization with signature in headers
        """
        # TODO request data valid check
        profile = request.files['profile']
        specification = request.files['specification']
        form = request.form
        dbot_data = json.load(profile)

        domain = form.get('domain', dbot_data['info'].get('domain'))
        if domain is None:
            abort(400, message="DBot domain is required")
        dbot_data['info'].update({
            'addr': form['address'],
            'owner': form['owner'],
            'floor_price': form['floor_price'],
            'api_host': form['api_host'],
            'protocol': form['protocol'],
            'domain': domain
        })
        dbot_data['specification']['data'] = specification.read().decode('utf-8')
        address = to_checksum_address(dbot_data['info'].get('addr'))
        if db.dbots.get(address) is not None:
            abort(404, message="Bad Request, can not insert an exsit DBot service")
        dbot_data['info']['addr'] = address
        middleware = None
        try:
            mw = dbot_data.get('middleware')
            if mw:
                mw_path = os.path.join(db.path(), 'middleware/{}'.format(address))
                if not os.path.exists(mw_path):
                    os.makedirs(mw_path)
                request.files.get('middleware').save(os.path.join(mw_path, '{}.py'.format(mw['module'])))
                middleware = getattr(load_module(mw['module'], mw_path), mw['class'])
        except Exception as err:
            abort(400, message='unable to load DBot middleware: {}'.format(err))
        try:
            dbot.new_service(dbot_data, middleware)
        except Exception as err:
            abort(400, message=str(err))
        db.dbots.put(address, dbot_data)

        return 'ok', 200
Exemplo n.º 10
0
 def execute(self):
     try:
         import time
         start = time.time()
         from utils import load_module
         method = load_module(self.action)
         argument = self._args()
         result = method(self.owner,argument)
         self.done()
         end = time.time()
         print 'Executing the job with jobid:%s took:%s seconds' % (self.id,(end-start))
         return result
     except Exception,e:
         message= self.action +'\n'
         message+='Exception raised while executing Job with jobid:%s :%s' % (self.id,e.__str__()+'\n')
         print message
         from utils.emailer import mail_admins
         mail_admins(message,locals())
Exemplo n.º 11
0
    def put(self, dbot_address):
        profile = request.files['profile']
        specification = request.files['specification']
        form = request.form
        dbot_data = json.load(profile)
        domain = form.get('domain', dbot_data['info'].get('domain'))
        if domain is None:
            abort(400, message="DBot domain is required")
        dbot_data['info'].update({
            'addr': form['address'],
            'owner': form['owner'],
            'floor_price': form['floor_price'],
            'api_host': form['api_host'],
            'protocol': form['protocol'],
            'domain': domain
        })
        dbot_data['specification']['data'] = specification.read().decode('utf-8')
        if not is_same_address(dbot_data['info']['addr'], dbot_address):
            abort(400, message="Bad Request, wrong address in DBot data")

        dbot_address = to_checksum_address(dbot_address)
        dbot_data['info']['addr'] = dbot_address

        middleware = None
        try:
            # TODO update middleware (delete old one)
            mw = dbot_data.get('middleware')
            if mw:
                mw_path = os.path.join(db.path(), 'middleware/{}'.format(dbot_address))
                if not os.path.exists(mw_path):
                    os.makedirs(mw_path)
                request.files.get('middleware').save(os.path.join(mw_path, '{}.py'.format(mw['module'])))
                middleware = getattr(load_module(mw['module'], mw_path), mw['class'])
        except Exception as err:
            abort(400, message='unable to load DBot middleware: {}'.format(err))

        try:
            dbot.update_service(dbot_data, dbot_address, middleware)
        except Exception as err:
            abort(400, message=str(err))
        db.dbots.put(dbot_address, dbot_data)
        return 'ok', 200
Exemplo n.º 12
0
    def init(self, app, private_key, http_provider=None):
        self.account = Account.privateKeyToAccount(private_key)
        self.web3 = Web3(
            HTTPProvider(app.config['WEB3_PROVIDER_DEFAULT']
                         if http_provider is None else http_provider))
        self.web3.middleware_stack.inject(geth_poa_middleware, layer=0)
        try:
            NETWORK_CFG.set_defaults(int(self.web3.version.network))
        except HTTPError as err:
            logger.error(
                'Can not connect with blockchain node: {}'.format(err))
            raise HTTPError

        self.state_path = os.path.join(app.config['DB_ROOT'], 'channels')
        if not os.path.exists(self.state_path):
            os.makedirs(self.state_path)

        logger.info("load all exist dbot service")
        dbot_address_list = db.dbots.keys()
        for address in dbot_address_list:
            dbot_data = db.dbots.get(address)
            mw = dbot_data.get('middleware')
            if mw:
                mw_path = os.path.join(db.path(),
                                       'middleware/{}'.format(address))
                middleware = getattr(load_module(mw['module'], mw_path),
                                     mw['class'])
                self.new_service(dbot_data, middleware)
            else:
                self.new_service(dbot_data)

        logger.info("start metric collector")
        DBotMetricsCollector().Start(
            os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         "collector.conf"))
        self.metric = DBotApiMetric()
        DBotMetricsCollector().RegisterMetric(self.metric)
        self.metric.EnableDetailRecord(False)

        self.rest_server = None
        self.server_greenlet = None
Exemplo n.º 13
0
def read_module(template,
                metadata_module,
                identifier,
                path,
                replace_value,
                doc_index=0):
    module = utils.load_module(path, identifier, doc_index)
    if module:
        parameters = module.get("Parameters", [])
        if isinstance(metadata_module, dict):
            for parameter in parameters:
                if default := metadata_module.get(
                        parameter[:-len(identifier)],
                        metadata_module.get(
                            parameter[:-len("Param" + identifier)])):
                    parameters.get(parameter)['Default'] = default
        if replace_value:
            k = next(iter(replace_value))
            if (k + "Param") in module.get("Parameters", {}):
                module["Parameters"][k + "Param"]['Default'] = replace_value[k]
        reduce(utils.merge, [template, module])
Exemplo n.º 14
0
def main(cnf, weights_from):
    config = utils.load_module(cnf).config

    if weights_from is None:
        weights_from = config.weights_file
    else:
        weights_from = str(weights_from)

    files = data_orig.get_image_files(config.get('train_dir'))
    names = data_orig.get_names(files)
    labels = data_orig.get_labels(names).astype(np.float32)

    net = create_net(config)

    try:
        net.load_params_from(weights_from)
        print("loaded weights from {}".format(weights_from))
    except IOError:
        print("couldn't load weights starting from scratch")

    print("fitting ...")
    net.fit(files, labels)
Exemplo n.º 15
0
    def load_check_ios(self, credentials):
        """ 
        Load CheckIOs from the database.

        Arguments:
            credentials (List(Credential)): List of credentials to associate CheckIOs with

        Returns:
            Dict(int->List(CheckIO)): Mapping of check IDs to a list of CheckIOs
        """
        check_ios = super().load_check_ios(credentials)
        for check_id, cios in check_ios.items():
            for cio in cios:
                # Rebuild the PollInput
                poll_input = cio.poll_input
                input_class_str, input_args = json.loads(poll_input)
                input_class = load_module(input_class_str)
                poll_input = input_class.deserialize(input_class, input_args,
                                                     self.teams,
                                                     self.credentials)

                cio.poll_input = poll_input
        return check_ios
Exemplo n.º 16
0
Arquivo: ok.py Projeto: pombredanne/ok
    def load(self, module_name):
        # If it has been loaded already, just return it!
        if module_name in self._loaded_modules:
            return self._loaded_modules[module_name]

        # If it's not in the cache, download it
        if module_name not in self._cache:
            module_path = 'modules/{}.py'.format(module_name)
            self._download(module_path, module_path)
            # Just downloaded, still have to verify
            self._cache[module_name] = False

        module = utils.load_module(module_name)
        # Shim some stuff into the module
        # @TODO: There has got to be a better (but still explicit) way...
        setattr(module, 'ok', okapi)
        setattr(module, 'utils', utils)

        # If it's not verified, try to verify it
        if not self._cache[module_name]:
            okapi.log(
                'Checking whether {} is installed...'.format(module_name),
                important=False)
            if not module.check():
                okapi.log(
                    '{} not installed, installing now...'.format(module_name))
                try:
                    module.install()
                except Exception as e:
                    raise utils.OkException(
                        'Unable to install module {}: {}'.format(
                            module_name, e.message))
            self._cache[module_name] = True

        # At this point we know it exists and is loaded, so...
        self._loaded_modules[module_name] = module
        return module
Exemplo n.º 17
0
Arquivo: run.py Projeto: zgsxwsdxg/ELF
    def setup(self, all_args):
        self.game = load_module(os.environ["game"]).Loader()
        self.game.args.set(all_args, actor_only=True, game_multi=2)

        self.gpu = all_args.eval_gpu
        self.tqdm = all_args.tqdm

        self.runner = SingleProcessRun()
        self.runner.args.set(all_args)

        self.GC = self.game.initialize()
        self.GC.setup_gpu(self.gpu)

        self.sampler = Sampler()
        self.sampler.args.set(all_args, greedy=True)

        self.trainer = Trainer()
        self.trainer.args.set(all_args)

        if self.args.stats == "rewards":
            self.collector = RewardCount()
        elif self.args.stats == "winrate":
            self.collector = WinRate()

        def actor(sel, sel_gpu, reply):
            self.trainer.actor(sel, sel_gpu, reply)
            v = sel[0]

            for batch_idx, (id, last_terminal) in enumerate(zip(v["id"], v["last_terminal"])):
                self.collector.feed(id, v["last_r"][batch_idx])
                if last_terminal:
                    self.collector.terminal(id)

        self.GC.reg_callback("actor", actor)

        self.GC.Start()
"""
Training Code Adapted from https://github.com/TwentyBN/smth-smth-v2-baseline-with-models
"""
import os
import time
import signal
import torch
import utils
import importlib
from visualisation import PlotLearning

# load configurations
args = utils.load_args()
config = utils.load_module(args.config).config

# set column model
file_name = config['conv_model']
cnn_def = importlib.import_module("{}".format(file_name))

# setup device - CPU or GPU
device, device_ids = utils.setup_cuda_devices(args)
print(" > Using device: {}".format(device.type))
print(" > Active GPU ids: {}".format(device_ids))

best_loss = float('Inf')

if config["input_mode"] == "jpg":
    from data_loader_jpg import ImLoader
else:
    raise ValueError("Please provide a valid input mode")
Exemplo n.º 19
0
    def __init__(self, fconf, handler):
        """
        Initialize a MasterServer instance
        @param fconf the path to the configuration file
        @param handler the handler object in charge of managing HTTP requests
        """
        Logger.__init__(self, "Manager")

        conf = json.load(open(fconf))

        # Jinja2 initialization.
        tmpl_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'templates')
        self.env = Environment(loader=FileSystemLoader(tmpl_path))
        self.status = ApplicationStatus()

        # This is a dictionary structure in the form
        # reduce_dict["group-name"] = [
        #   [ file list by unique integers, size in byte
        #   ] => Reduce-0
        #   [
        #   ] => Reduce-1
        # ]
        self.reduce_mark = set()
        self.reduce_dict = defaultdict(list)
        self.dead_reduce_dict = defaultdict(list)

        # This is a dictionary nick => Handler instance
        self.masters = {}
        self.last_id = -1
        self.pending_works = defaultdict(list) # nick => [work, ...]

        self.ping_max = int(conf["ping-max"])
        self.ping_interval = int(conf["ping-interval"])
        self.num_reducer = int(conf["num-reducer"])

        # This will just keep track of the name of the files
        self.reduce_files = []
        self.results_printed = False

        for _ in range(self.num_reducer):
            self.reduce_files.append("N/A")

        # Load the input module and assing the generator to the work_queue
        module = load_module(conf["input-module"])
        cls = getattr(module, "Input", None)

        # Some code for the DFS
        generator = cls(fconf).input()
        self.use_dfs = use_dfs = conf['dfs-enabled']

        if use_dfs:
            dfsconf = conf['dfs-conf']
            dfsconf['host'] = dfsconf['master']

            self.path = conf['output-prefix']
        else:
            dfsconf = None

            self.path = os.path.join(
                os.path.join(conf['datadir'], conf['output-prefix'])
            )

        self.work_queue = WorkQueue(self.logger, generator, use_dfs, dfsconf)

        # Lock to synchronize access to the timestamps dictionary
        self.lock = Lock()
        self.timestamps = {} # nick => (send_ts:enum, ts:float)

        # Ping thread
        self.hb_thread = Thread(target=self.hearthbeat)

        # Event to mark the end of the server
        self.finished = Event()

        self.addrinfo = (conf['master-host'], conf['master-port'])
        Server.__init__(self, self.addrinfo[0], self.addrinfo[1], handler)
Exemplo n.º 20
0
import os
import sys
import json
import utils
import tempfile
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim

config = utils.load_config()
sys.path.append(os.path.expanduser(config['slim_path']))
resnet = utils.load_module('nets.resnet_v1')
logger = utils.get_default_logger()


class WildCat:
    def __init__(self,
                 images,
                 labels=None,
                 n_classes=None,
                 training=False,
                 transfer_conv_size=(3, 3),
                 n_maps_per_class=5,
                 alpha=1.0,
                 k=1,
                 reg=None):
        self.images = images
        self.labels = labels
        self.n_classes = n_classes
        self.training = training
        self.transfer_conv_size = transfer_conv_size
Exemplo n.º 21
0
def get_model(model_path):
    mod_model = load_module(model_path)
    return mod_model.get()
Exemplo n.º 22
0
Arquivo: potion.py Projeto: jmg/potion
def get_entities(name, path=None):

    models = load_module(name, path=path)
    return [model for name, model in models.__dict__.iteritems() if is_subclass(model, Entity)]
Exemplo n.º 23
0
import sys

import utils


def run():
    pass


if __name__ == "__main__":
    if len(sys.argv) < 2:
        raise Exception("Command missing.")

    command = utils.load_module(sys.argv[1])
    command.run()
Exemplo n.º 24
0
def main(orig_args):
    # load pretrained model
    fresume = os.path.abspath(orig_args.resume)
    # print('resuming from {}'.format(fresume))
    assert os.path.isfile(fresume), '{} does not exist'.format(fresume)

    orig_args.input = os.path.abspath(orig_args.input)
    orig_args.tables = os.path.abspath(orig_args.tables)
    orig_args.db = os.path.abspath(orig_args.db)
    orig_args.dcache = os.path.abspath(orig_args.dcache)

    binary = torch.load(fresume, map_location=torch.device('cpu'))
    args = binary['args']
    ext = binary['ext']
    args.gpu = torch.cuda.is_available()
    args.tables = orig_args.tables
    args.db = orig_args.db
    args.dcache = orig_args.dcache
    args.batch = orig_args.batch
    Model = utils.load_module(args.model)
    if args.model == 'nl2sql':
        Reranker = utils.load_module(args.beam_rank)
        ext['reranker'] = Reranker(args, ext)
    m = Model(args, ext).place_on_device()

    if orig_args.resumes:
        m.average_saves(orig_args.resumes)
    else:
        m.load_save(fname=fresume)

    # preprocess data
    data = dataset.Dataset()

    if orig_args.dataset == 'spider':
        import preprocess_nl2sql as preprocess
    elif orig_args.dataset == 'sparc':
        import preprocess_nl2sql_sparc as preprocess
    elif orig_args.dataset == 'cosql':
        import preprocess_nl2sql_cosql as preprocess

    proc_errors = set()
    with open(orig_args.input) as f:
        C = preprocess.SQLDataset
        raw = json.load(f)
        # make contexts and populate vocab
        for i, ex in enumerate(raw):
            for k in ['query', 'query_toks', 'query_toks_no_value', 'sql']:
                if k in ex:
                    del ex[k]
            ex['id'] = '{}/{}'.format(ex['db_id'], i)
            new = C.make_example(ex,
                                 m.bert_tokenizer,
                                 m.sql_vocab,
                                 m.conv.kmaps,
                                 m.conv,
                                 train=False,
                                 evaluation=True)
            new['question'] = ex['question']
            if new is not None:
                new['cands_query'], new['cands_value'] = C.make_cands(
                    new, m.sql_vocab)
                data.append(new)
            else:
                print('proc error')
                proc_errors.add(ex['id'])

    # run preds
    if orig_args.dataset in {'cosql', 'sparc'}:
        preds = m.run_interactive_pred(data, args, verbose=True)
        raise NotImplementedError()
    else:
        preds = m.run_pred(data, args, verbose=True)
        assert len(preds) + len(proc_errors) == len(
            data), 'got {} predictions for {} examples'.format(
                len(preds), len(data))
        #  print('writing to {}'.format(orig_args.output))
        with open(orig_args.output, 'wt') as f:
            for ex in data:
                if ex['id'] in proc_errors:
                    s = 'ERROR'
                else:
                    p = preds[ex['id']]
                    s = p['query']
                f.write(s + '\n')
            f.flush()
Exemplo n.º 25
0
    count += 1

    return True


def max_size(proto_message_type):

    cb_args = {}
    ext_refs = {}
    immutable_objs = {}
    ext_constraints = None
    num_objects = 1

    mbt_obj_store.set_walk_enum(True)
    mbt_obj_store.set_skip_none_enum(True)

    walk_proto(proto_message_type, max_size_cb, cb_args, num_objects, ext_refs,
               ext_constraints, immutable_objs)


module_name = 'fwlog_pb2'
proto_message_name = 'FWEvent'

module = utils.load_module(module_name)
proto_message_type = getattr(module, proto_message_name)

max_size(proto_message_type)

print("Max size: " + str(max_size_val))
Exemplo n.º 26
0
from __future__ import division
import time

import click
import numpy as np

import nn
import data_orig
import tta
import utils

cnf='configs/c_512_5x5_32.py'
config = utils.load_module(cnf).config
config.cnf['batch_size_train'] = 128

runs = {}
runs['train'] = config.get('train_dir')

net = nn.create_net(config)

weights_from = 'weights/c_512_5x5_32/weights_final.pkl'
net.load_params_from(weights_from)

tf, color_vecs = tta.build_quasirandom_transforms(1, skip=0, color_sigma=0.0, **data_orig.no_augmentation_params)
for i, (tf, color_vec) in enumerate(zip(tfs, color_vecs), start=1):
    pass



Exemplo n.º 27
0
    Gen_optimizer = optim.Adam(Gen_Model.parameters(), lr=lr)  #lr = 1e-4
    Dis_optimizer = optim.Adam(Dis_Model.parameters(), lr=lr)

    Vggloss = Perceptual_Loss.vggloss()  # vgg(5,4) loss

    content_criterion = nn.MSELoss()
    adversal_criterion = nn.BCEWithLogitsLoss()
    PSNR_eval = np.zeros(TOTAL_EPOCH)
    PSNR_train = np.zeros(TOTAL_EPOCH)
    Train_Gen_loss = np.zeros(TOTAL_EPOCH)
    Train_Dis_loss = np.zeros(TOTAL_EPOCH)
    train_len = len(train_dataloader)

    start_epoch = 0
    if PRETRAINED_PATH is not None:
        _, gen_modelpath = utils.load_module(
            os.path.join(PRETRAINED_PATH, "Generator"))
        start_epoch, dis_modelpath = utils.load_module(
            os.path.join(PRETRAINED_PATH, "Discriminator"))
        print(dis_modelpath)
        print("load module : saved on {} epoch".format(start_epoch))
        Gen_Model.load_state_dict(torch.load(gen_modelpath))
        Dis_Model.load_state_dict(torch.load(dis_modelpath))

    if PRE_RESULT_DIR is not None:
        PSNR_eval = np.load("result_data/PSNR_eval.npy")
        PSNR_Train = np.load("result_data/PSNR_train.npy")
        Train_Dis_loss = np.load("result_data/Train_Dis_loss.npy")
        Train_Gen_loss = np.load("result_data/Train_Gen_loss.npy")

    Gen_Model = Gen_Model.to(device)
    Dis_Model = Dis_Model.to(device)
Exemplo n.º 28
0
def main(orig_args):
    # load pretrained model
    fresume = os.path.abspath(orig_args.resume)
    # print('resuming from {}'.format(fresume))
    assert os.path.isfile(fresume), '{} does not exist'.format(fresume)

    orig_args.input = os.path.abspath(orig_args.input)
    orig_args.tables = os.path.abspath(orig_args.tables)
    orig_args.db = os.path.abspath(orig_args.db)
    orig_args.dcache = os.path.abspath(orig_args.dcache)

    binary = torch.load(fresume, map_location=torch.device('cpu'))
    args = binary['args']
    ext = binary['ext']
    args.gpu = torch.cuda.is_available()
    args.tables = orig_args.tables
    args.db = orig_args.db
    args.dcache = orig_args.dcache
    args.batch = orig_args.batch
    Model = utils.load_module(args.model)
    if args.model == 'nl2sql':
        Reranker = utils.load_module(args.beam_rank)
        ext['reranker'] = Reranker(args, ext)
    m = Model(args, ext).place_on_device()
    m.load_save(fname=fresume)

    # preprocess data
    data = dataset.Dataset()

    if args.dataset == 'sparc':
        import preprocess_nl2sql_sparc as preprocess
    elif args.dataset == 'cosql':
        import preprocess_nl2sql_cosql as preprocess
    else:
        raise NotImplementedError()

    proc_errors = set()
    with open(orig_args.input) as f:
        C = preprocess.SQLDataset
        raw = json.load(f)

        # make contexts and populate vocab
        for i, ex in enumerate(raw):
            for turn_i, turn in enumerate(ex['interaction']):
                turn['id'] = '{}/{}:{}'.format(ex['database_id'], i, turn_i)
                turn['db_id'] = ex['database_id']
                for k in ['query', 'query_toks', 'query_toks_no_value', 'sql']:
                    if k in turn:
                        del turn[k]
                turn['question'] = turn['utterance']
                turn['g_question_toks'] = C.tokenize_question(
                    turn['utterance'].split(), m.bert_tokenizer)
                turn['value_context'] = [
                    m.bert_tokenizer.cls_token
                ] + turn['g_question_toks'] + [m.bert_tokenizer.sep_token]
                turn['turn_i'] = turn_i
                data.append(turn)

    # run preds
    preds = m.run_interactive_pred(data, args, verbose=True)
    assert len(preds) == len(
        data), 'got {} predictions for {} examples'.format(
            len(preds), len(data))
    #  print('writing to {}'.format(orig_args.output))
    with open(orig_args.output, 'wt') as f:
        for i, ex in enumerate(data):
            if i != 0 and ex['turn_i'] == 0:
                f.write('\n')
            if ex['id'] in proc_errors:
                s = 'ERROR'
            else:
                p = preds[ex['id']]
                s = p['query']
            f.write(s + '\n')
        f.flush()
Exemplo n.º 29
0
Arquivo: run.py Projeto: zgsxwsdxg/ELF
                self.GC.Run()

        summary = c.summary()
        for k, v in summary.items():
            print("%s: %s" % (str(k), str(v)))

    def __del__(self):
        self.GC.Stop()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    use_multi_process = int(os.environ.get("multi_process", 0))

    sampler = Sampler()
    trainer = Trainer()
    game = load_module(os.environ["game"]).Loader()
    runner = MultiProcessRun() if use_multi_process else SingleProcessRun()
    model_file = load_module(os.environ["model_file"])
    model_class, method_class = model_file.Models[os.environ["model"]]

    model_loader = ModelLoader(model_class)
    method = method_class()

    args_providers = [sampler, trainer, game, runner, model_loader, method]

    eval_only = os.environ.get("eval_only", False)
    has_eval_process = os.environ.get("eval_process", False)
    if has_eval_process or eval_only:
        eval_process = EvaluationProcess()
        evaluator = Eval()
Exemplo n.º 30
0
conf.label_training = utils.abspath(conf.label_training)
conf.input_validation = [utils.abspath(path) for path in conf.input_validation]
conf.label_validation = utils.abspath(conf.label_validation)
conf.one_hot_palette_input = utils.abspath(conf.one_hot_palette_input)
conf.one_hot_palette_label = utils.abspath(conf.one_hot_palette_label)
conf.model = utils.abspath(conf.model)
conf.unetxst_homographies = utils.abspath(
    conf.unetxst_homographies
) if conf.unetxst_homographies is not None else conf.unetxst_homographies
conf.model_weights = utils.abspath(
    conf.model_weights
) if conf.model_weights is not None else conf.model_weights
conf.output_dir = utils.abspath(conf.output_dir)

# load network architecture module
architecture = utils.load_module(conf.model)

# get max_samples_training random training samples
n_inputs = len(conf.input_training)
files_train_input = [
    utils.get_files_in_folder(folder) for folder in conf.input_training
]
files_train_label = utils.get_files_in_folder(conf.label_training)
_, idcs = utils.sample_list(files_train_label,
                            n_samples=conf.max_samples_training)
files_train_input = [np.take(f, idcs) for f in files_train_input]
files_train_label = np.take(files_train_label, idcs)
image_shape_original_input = utils.load_image(
    files_train_input[0][0]).shape[0:2]
image_shape_original_label = utils.load_image(files_train_label[0]).shape[0:2]
print(f"Found {len(files_train_label)} training samples")
Exemplo n.º 31
0
 def template_engine(self, engine):
     if not callable(engine):
         engine = load_module(engine)
     self._template_engine = engine
#print("Arguments",args)

# Read master config
with open(args.master_config) as f:
    config = yaml.load(f)
#print(json.dumps(config, indent=2))

# Get modular functions
# Note that event lists have parsers and handlers; state lists only parsers
event_parser = dict()
event_handler = dict()
state_parser = dict()
for listname in config['event_lists'].keys():
    print("Loading parser for event list {} from {}".format(
        listname, config['event_lists'][listname]['parser']))
    module = load_module('parsers.events.' + listname,
                         config['event_lists'][listname]['parser'])
    event_parser[listname] = module.parser
    print("Loading handler for event {} from {}".format(
        listname, config['event_lists'][listname]['event_handler']))
    module = load_module('handlers.events.' + listname,
                         config['event_lists'][listname]['event_handler'])
    event_handler[listname] = module.handler
print("Event list parsers:", event_parser.keys())
# Note that state lists do not have event handlers
for listname in config['state_lists'].keys():
    print("Loading parser for state list {} from {}".format(
        listname, config['state_lists'][listname]['parser']))
    module = load_module('parsers.state.' + listname,
                         config['state_lists'][listname]['parser'])
    state_parser[listname] = module.parser
print("State list parsers:", state_parser.keys())
Exemplo n.º 33
0
    def __init__(self, config_spec, cfg_spec_obj, hal_channel):

        self._config_spec  = config_spec
        self._service_name = config_spec.Service

        # create client stub object from module
        module           = utils.load_module(config_spec.ProtoObject)
        stub_method_name = self._service_name + "Stub"

        # connect to client stub to GRPC server
        stub = getattr(module, stub_method_name)(hal_channel)

        self._cfg_spec_obj_name = cfg_spec_obj.name

        self._max_objects = mbt_obj_store.default_max_objects()

        if 'max_objects' in dir(cfg_spec_obj):
            self._max_objects = int(cfg_spec_obj.max_objects)

        self._key_or_handle_str = ""
        if 'key_handle' in dir(cfg_spec_obj):
            self._key_or_handle_str = cfg_spec_obj.key_handle

        # config methods which are ignored
        ignore = {}
        ignore[ConfigMethodType.CREATE] = False
        ignore[ConfigMethodType.GET]    = False
        ignore[ConfigMethodType.UPDATE] = False
        ignore[ConfigMethodType.DELETE] = False

        if 'ignore_v2' in dir(cfg_spec_obj):
            ignore_ops = cfg_spec_obj.ignore_v2
        else:
            ignore_ops = cfg_spec_obj.ignore

        for index in ignore_ops or []:
            if index.op == 'Create':
                ignore[ConfigMethodType.CREATE] = True

            if index.op == 'Get':
                ignore[ConfigMethodType.GET] = True

            if index.op == 'Update':
                ignore[ConfigMethodType.UPDATE] = True

            if index.op == 'Delete':
                ignore[ConfigMethodType.DELETE] = True

        # configuration methods for this object
        self._config_methods = {}

        cfg_method_type = ConfigMethodType.CREATE
        self._config_methods[cfg_method_type] = ConfigMethod(stub, module, cfg_spec_obj.create, ignore[cfg_method_type])

        cfg_method_type = ConfigMethodType.GET
        self._config_methods[cfg_method_type] = ConfigMethod(stub, module, cfg_spec_obj.get, ignore[cfg_method_type])

        cfg_method_type = ConfigMethodType.UPDATE
        self._config_methods[cfg_method_type] = ConfigMethod(stub, module, cfg_spec_obj.update, ignore[cfg_method_type])

        cfg_method_type = ConfigMethodType.DELETE
        self._config_methods[cfg_method_type] = ConfigMethod(stub, module, cfg_spec_obj.delete, ignore[cfg_method_type])

        # constraints specified in the config spec object
        self._constraints = []
        if 'constraints' in dir(cfg_spec_obj):
            cfg_spec_obj_constraints = getattr(cfg_spec_obj, 'constraints')
            for cfg_spec_obj_constraint in cfg_spec_obj_constraints:
                constraint = GrpcReqRspMsg.extract_constraints(cfg_spec_obj_constraint.constraint)
                self._constraints.append(constraint)

        # store external reference object for the specific key_handle
        # __DREPRECATED__
        self._ext_ref_obj_list = []

        # cache of mbt_handles of this type
        self._mbt_handle_list = []

        # stats for this object
        self._num_create_ops = 0
        self._num_get_ops    = 0
        self._num_update_ops    = 0
        self._num_delete_ops = 0