コード例 #1
0
    def select_object(board: np.ndarray, click_location: tuple) -> dotdict:
        """
        Selects object on PyGame canvas using mouse click
        :param board: game state board
        :param click_location: tuple (x,y) that represents canvas click location
        :return: game tile coordinate (x,y)
        """
        n = board.shape[0]
        canvas_scale = int(
            ctypes.windll.user32.GetSystemMetrics(1) * (16 / 30) /
            n)  # for drawing - it takes 2 thirds of screen height

        # select object by clicking on it - you can select only your objects

        for y in range(n):
            for x in range(n):
                actor_location = (int(x * canvas_scale + canvas_scale / 2 +
                                      canvas_scale),
                                  int(y * canvas_scale + canvas_scale / 2) +
                                  canvas_scale)
                actor_x, actor_y = actor_location
                actor_size = int(canvas_scale / 3)

                click_x, click_y = click_location

                dist = sqrt((actor_x - click_x)**2 + (actor_y - click_y)**2)
                if dist <= actor_size:
                    return dotdict({"x": x, "y": y})
        return dotdict({"x": -1, "y": -1})
コード例 #2
0
 def fetch_balance_ws(self):
     balance = dotdict(self.ws.funds())
     balance.BTC = dotdict()
     balance.BTC.free = balance.availableMargin * 0.00000001
     balance.BTC.total = balance.marginBalance * 0.00000001
     balance.BTC.used = balance.BTC.total - balance.BTC.free
     self.logger.info("BALANCE: free {free:.3f} used {used:.3f} total {total:.3f}".format(**balance.BTC))
     return balance
コード例 #3
0
 def fetch_order(self, order_id):
     order = dotdict({'status':'closed', 'id':order_id})
     try:
         order = dotdict(self.exchange.fetch_order(order_id))
         order.info = dotdict(order.info)
     except ccxt.OrderNotFound as e:
         self.logger.warning(type(e).__name__ + ": {0}".format(e))
     return order
コード例 #4
0
ファイル: strategy.py プロジェクト: zihpzhong/mexbot
 def fetch_balance(self):
     """資産情報取得"""
     balance = dotdict(self.exchange.fetch_balance())
     balance.BTC = dotdict(balance.BTC)
     self.logger.info(
         "BALANCE: free {free:.3f} used {used:.3f} total {total:.3f}".
         format(**balance.BTC))
     return balance
コード例 #5
0
ファイル: strategy.py プロジェクト: zihpzhong/mexbot
 def fetch_order_ws(self, order_id):
     orders = self.ws.all_orders()
     for o in orders:
         if o['orderID'] == order_id:
             order = dotdict(self.exchange.parse_order(o))
             order.info = dotdict(order.info)
             return order
     return dotdict({'status': 'closed', 'id': order_id})
コード例 #6
0
def ieee_query():
    """
    Run queries on the IEEE Xplore DB using their REST API. Requires a valid API key. 
    """
    api_key = load_api_key("ieee.key")

    queries = construct_queries()

    fields_mask = ieeelib.SEARCH_FIELD_ABSTRACT | ieeelib.SEARCH_FIELD_DOC_TITLE

    for query in queries:
        query_result = ieeelib.query(query,
                                     api_key,
                                     max_records=max_records,
                                     fields_mask=fields_mask)
        #save_json(query_result, fields_mask)
        ieee_data = json.loads(query_result)

        if not ieee_data:
            print("No data returned for query %s. Aborting." % query,
                  file=sys.stderr)
            exit(NO_DATA_ERROR)

        ieee_data = dotdict(ieee_data)
        bibtex_str = bibtexparser.dumps(ieeeparser.bibtexize(ieee_data))
        #write_bibtex(ieee_data, query, fields_mask)

        total_records = ieee_data.total_records
        next_start_record = 1 + max_records

        while next_start_record < total_records:
            #print("DEBUG: Retrieving records starting at %s..." % next_start_record)
            next_start_record = next_start_record + max_records
            query_result = ieeelib.query(query,
                                         api_key,
                                         max_records=max_records,
                                         fields_mask=fields_mask)
            ieee_data = json.loads(query_result)
            if not ieee_data:
                print(
                    "Something went wrong while retrieving records starting at %s for query. Aborting."
                    % (next_start_record, query),
                    file=sys.stderr)
                exit(NO_DATA_ERROR)

            ieee_data = dotdict(ieee_data)
            bibtex_str = bibtex_str + bibtexparser.dumps(
                ieeeparser.bibtexize(ieee_data))
            #write_bibtex(ieee_data, query, fields_mask)

        write_bibtex_str(bibtex_str, query, fields_mask)
    print('ieee_query() for query "%s" finished.' % query)
コード例 #7
0
    def generate_portfolio(self, **kwargs):
        kwargs = dotdict(kwargs)
        symbols = list(kwargs.cov_matrix.columns)
        self.gene_length = len(symbols)

        # Create initial genes
        initial_genes = self.generate_initial_genes(symbols)

        for i in range(self.iterations):
            # Select
            top_genes = self.select(kwargs.sample_returns, initial_genes)
            # print("Iteration %d Best Sharpe Ratio: %.3f" % (i, top_genes[0][0]))
            top_genes = [item[1] for item in top_genes]

            # Mutate
            mutated_genes = self.mutate(top_genes)
            initial_genes = mutated_genes

        top_genes = self.select(kwargs.sample_returns, initial_genes)
        best_gene = top_genes[0][1]
        # Gene is a distribution of weights for different stocks
        # transposed_gene = np.array(best_gene).transpose()
        # returns = np.dot(return_matrix, transposed_gene)
        # returns_cumsum = np.cumsum(returns)
        n_best = normalize_weights(best_gene)
        weights = {symbols[x]: n_best[x] for x in range(0, len(best_gene))}
        return weights
コード例 #8
0
    def load_checkpoint(self):
        try:
            if self.config.ckpt is not None:
                net_path = self.config.ckpt
            else:
                net_path = routines.get_latest_model(
                    self.config.EXP_DIR_PARAMS, 'checkpoint_G')

            checkpoint = torch.load(net_path)
            checkpoint = dotdict(checkpoint)
            state_dict = checkpoint.state_dict
            self.epoch_on_start = checkpoint.epoch + 1

            self.net.load_state_dict(state_dict)
            cprint('resume network weights from %s successfully' \
                    % net_path, 'red', attrs=['reverse', 'blink'])

            self.optimizer.load_state_dict(checkpoint.optimizer)
            cprint('resume optimizer from %s successfully' % net_path,
                   'red',
                   attrs=['reverse', 'blink'])

            if os.path.exists(self.va_res_file):
                with open(self.va_res_file, "r") as ifp:
                    dump_res = ifp.read()
                dump_res = parse("{best_va_res:e}\n", dump_res)
                self.best_va_res = dump_res["best_va_res"]
        except Exception as e:
            print(e)
            print("resume fail")
コード例 #9
0
def prepare_affix_paths(
    *,
    dir_path=dir_path,
):
    dir_path = Path(dir_path)
    gz_path = dir_path / "affix_complete_set.txt.gz"
    raw_path = dir_path / "affix_complete_set.txt"
    queries_path = dir_path / "queries.txt"

    if not os.path.exists(gz_path):
        logger.info(f"Downloading {gz_path}")
        url = "http://marcobaroni.org/PublicData/affix_complete_set.txt.gz"
        sp.run(f"wget -O {gz_path} {url}".split())

    if not os.path.exists(raw_path):
        logger.info(f"Unzipping {raw_path}")
        with gzip.open(gz_path, 'rb') as fin, open(raw_path, 'wb') as fout:
            shutil.copyfileobj(fin, fout)

    if not os.path.exists(queries_path):
        logger.info(f"Making {queries_path}")
        with open(raw_path) as fin, open(queries_path, 'w') as fout:
            for line in islice(fin, 1, None):  ## skip the title row
                ## row fmt: affix	stem	stemPOS	derived	derivedPOS	type	...
                affix, stem, _, derived, _, split = line.split()[:6]
                print(derived, file=fout)
                if derived.lower() != derived:
                    print(derived.lower(), file=fout)

    return dotdict(
        dir_path=dir_path,
        gz_path=gz_path,
        raw_path=raw_path,
        queries_path=queries_path,
    )
コード例 #10
0
def make_as_dotdict(obj):
    if type(obj) is dict:
        obj = dotdict(obj)
        for key in obj:
            if type(obj[key]) is dict:
                obj[key] = make_as_dotdict(obj[key])
    return obj
コード例 #11
0
ファイル: parameters.py プロジェクト: sidak/otfusion
def get_deprecated_params_mnist_act():
    parameters = {
        'n_epochs': 1,
        'enable_dropout': False,
        'batch_size_train': 64,
        'batch_size_test': 1000,
        'learning_rate': 0.01,
        'momentum': 0.5,
        'log_interval': 100,
        'to_download':
        True,  # set to True if MNIST/dataset hasn't been downloaded,
        'disable_bias': True,  # no bias at all in fc or conv layers,
        'dataset': 'mnist',
        'num_models': 2,
        'model_name': 'simplenet',
        # model_name: net,
        # model_name: mlpnet,
        'num_hidden_nodes': 400,
        'num_hidden_nodes1': 400,
        'num_hidden_nodes2': 200,
        'num_hidden_nodes3': 100,
        'gpu_id': 5,
        'skip_last_layer': False,
        'reg': 1e-2,
        'debug': False,
        'activation_histograms': True,
        'act_num_samples': 100,
        'softmax_temperature': 1,
    }
    return dotdict(parameters)
コード例 #12
0
ファイル: strategy.py プロジェクト: zihpzhong/mexbot
 def fetch_position_ws(self):
     pos = dotdict(self.ws.position())
     pos.unrealisedPnlPcnt100 = pos.unrealisedPnlPcnt * 100
     self.logger.info(
         "POSITION: qty {currentQty} cost {avgCostPrice} pnl {unrealisedPnl}({unrealisedPnlPcnt100:.2f}%) {realisedPnl}"
         .format(**pos))
     return pos
コード例 #13
0
def exp(ref_vec_name):
    result_path = Path("results") / "ws" / f"{ref_vec_name}_sasaki"
    ref_vec_path = prepare_target_vector_paths(ref_vec_name).w2v_emb_path
    codecs_path = prepare_codecs_path(ref_vec_path, result_path)

    log_file = open(result_path / "log.txt", "w+")
    logging.basicConfig(level=logging.DEBUG, stream=log_file)

    logger.info("Training...")
    model_info = train(
        ref_vec_path,
        result_path,
        codecs_path=codecs_path,
        H=40_000,
        F=500_000,
        epoch=300,
    )

    logger.info("Inferencing...")
    combined_query_path = prepare_ws_combined_query_path()
    result_emb_path = inference(model_info, combined_query_path)

    logger.info("Evaluating...")
    evaluate(
        dotdict(eval_result_path=result_path / "result.txt",
                pred_path=result_emb_path))
コード例 #14
0
def exp(ref_vec_name):
    result_path = Path("results") / "ws_multi" / f"{ref_vec_name}_sasaki"
    ref_vec_path = prepare_target_vector_paths(f"wiki2vec-{ref_vec_name}").w2v_emb_path
    codecs_path = prepare_codecs_path(ref_vec_path, result_path)

    log_file = open(result_path / "log.txt", "w+")
    logging.basicConfig(level=logging.DEBUG, stream=log_file)

    logger.info("Training...")
    train(
        ref_vec_path,
        result_path,
        codecs_path=codecs_path,
        H=40_000,
        F=500_000,
        epoch=300,
    )

    model_info = get_info_from_result_path(result_path / "sep_kvq")

    logger.info("Inferencing...")
    combined_query_path = prepare_ws_combined_query_path(ref_vec_name)
    result_emb_path = inference(model_info, combined_query_path)

    logger.info("Evaluating...")
    evaluate(dotdict(
        model_type="sasaki",
        eval_result_path=result_path / "result.txt",
        pred_path=result_emb_path,
        target_vector_name=ref_vec_name,
        results_dir=result_path,
    ))
コード例 #15
0
ファイル: strategy.py プロジェクト: zihpzhong/mexbot
 def fetch_ticker_ws(self):
     trade = self.ws.recent_trades()[-1]
     ticker = dotdict(self.ws.get_ticker())
     ticker.datetime = pd.to_datetime(trade['timestamp'])
     self.logger.info(
         "TICK: bid {bid} ask {ask} last {last}".format(**ticker))
     return ticker
コード例 #16
0
def get_args_cloud_ds1():
    args = dict()
    dir_path = "./deepspeech-models/"
    args['model'] = dir_path+"output_graph.pb"
    args['lm'] = dir_path+"lm.binary"
    args['trie'] = dir_path+"trie"
    args['alphabet'] = dir_path+"alphabet.txt"
    return dotdict(args)
コード例 #17
0
ファイル: strategy_bitflyer.py プロジェクト: zihpzhong/mexbot
 def fetch_ticker(self, symbol=None, timeframe=None):
     symbol = symbol or self.settings.symbol
     timeframe = timeframe or self.settings.timeframe
     ticker = dotdict(self.exchange.fetch_ticker(symbol))
     ticker.datetime = pd.to_datetime(ticker.datetime)
     self.logger.info(
         "TICK: last {last} bid {bid} ask {ask}".format(**ticker))
     return ticker
コード例 #18
0
ファイル: strategy_bitflyer.py プロジェクト: zihpzhong/mexbot
 def fetch_position(self, symbol=None):
     """現在のポジションを取得"""
     symbol = symbol or self.settings.symbol
     market = self.exchange.market(symbol)
     req = {'product_code': market['id']}
     res = self.exchange.privateGetGetpositions(req)
     if len(res):
         pos = dotdict(res[0])
         pos.currentQty = pos.size if pos.side == 'BUY' else pos.size * -1
         pos.unrealisedPnl = pos.pnl
     else:
         pos = dotdict()
         pos.currentQty = 0
         pos.unrealisedPnl = 0
     self.logger.info(
         "POSITION: qty {currentQty} pnl {unrealisedPnl}".format(**pos))
     return pos
コード例 #19
0
 def fetch_position(self, symbol=None):
     """現在のポジションを取得"""
     symbol = symbol or self.settings.symbol
     res = self.exchange.privateGetPosition()
     pos = [x for x in res if x['symbol'] == self.exchange.market(symbol)['id']]
     if len(pos):
         pos = dotdict(pos[0])
         pos.timestamp = pd.to_datetime(pos.timestamp)
     else:
         pos = dotdict()
         pos.currentQty = 0
         pos.avgCostPrice = 0
         pos.unrealisedPnl = 0
         pos.unrealisedPnlPcnt = 0
         pos.realisedPnl = 0
     pos.unrealisedPnlPcnt100 = pos.unrealisedPnlPcnt * 100
     self.logger.info("POSITION: qty {currentQty} cost {avgCostPrice} pnl {unrealisedPnl}({unrealisedPnlPcnt100:.2f}%) {realisedPnl}".format(**pos))
     return pos
コード例 #20
0
	def loss_batch(self, nb_neg_samples):
		batch = []
		for el in self.all_relations:
			id1, id2 = el
			negative_samples = self.negative_samples(id1, nb_neg_samples)
			sample_dict = {"u_id": id1,
			               "v_id": id2,
			               "neigh_u_ids": negative_samples}
			batch.append(dotdict(sample_dict))
		return batch
コード例 #21
0
def get_player(time):
    g = TaflGame(7, True)
    white_nnet = nn(g)
    black_nnet = nn(g)
    white_nnet.load_checkpoint('./tafl_model_1/', 'white.pth.tar')
    black_nnet.load_checkpoint('./tafl_model_1/', 'white.pth.tar')
    args = dotdict({'numMCTSSims': 10000, 'cpuct': 1.1})
    mcts = MCTS(g, white_nnet, black_nnet, args)
    return lambda board, turn_player: np.argmax(
        mcts.getActionProb(board, turn_player, temp=0, time=time))
コード例 #22
0
def get_best_action(board, player, use_mcts):
    if use_mcts:
        args1 = utils.dotdict({'numMCTSSims': 128, 'cpuct': 1.0})
        mcts1 = MCTS(game, nnet, args1)
        board = game.getCanonicalForm(board, player)
        best_action = int(np.argmax(mcts1.getActionProb(board, temp=0)))
    else:
        p, v = nnet.predict(board)
        best_action = int(np.argmax(p))
    return best_action
コード例 #23
0
def train(args):
    subwords.build_subword_vocab_cli(
        dotdict(
            ChainMap(
                dict(word_freq=args.subword_vocab_word_freq,
                     output=args.subword_vocab),
                args,
            )))

    if args.subword_prob:
        subwords.build_subword_prob_cli(
            dotdict(
                ChainMap(
                    dict(word_freq=args.subword_prob_word_freq,
                         output=args.subword_prob),
                    args,
                )))

    pbos_train.main(args)
コード例 #24
0
def dict_from_group(group):
    assert isinstance(group, h5py.Group)
    d = utils.dotdict()
    for key in group:
        if isinstance(group[key], h5py.Group):
            value = dict_from_group(group[key])
        else:
            value = read_clean(group[key][...])
        d[key] = value
    return d
コード例 #25
0
def run():
    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config_path",
                        default='config/config.yaml',
                        help="The default config file.")
    # obligatory arguments
    parser.add_argument("--dataset_path",
                        help="Input data folder",
                        required=True)
    parser.add_argument("--dataset_cache",
                        help="Cache for input data folder",
                        required=True)
    parser.add_argument("-mq",
                        "--model_path",
                        type=str,
                        required=True,
                        help='Pretrained model path to local checkpoint')
    parser.add_argument("-e",
                        "--exp_name",
                        type=str,
                        default='qgen',
                        help='The name of experiment')
    args = parser.parse_args()

    # Read config from yaml file.
    config_file = args.config_path
    with open(config_file) as reader:
        config = yaml.safe_load(reader)
        config = dotdict(config)

    # overload with command line arguments
    for k, v in vars(args).items():
        config[k] = v

    config.checkpoint = os.path.join(config.model_path, "sampling",
                                     config.exp_name)
    os.makedirs(config.checkpoint, exist_ok=True)
    copyfile(config.config_path, os.path.join(config.checkpoint,
                                              "config.yaml"))

    config.device = "cuda" if torch.cuda.is_available() else "cpu"
    config.n_gpu = torch.cuda.device_count()
    config.n_gpu = 1

    # logging is set to INFO
    logging.basicConfig(level=logging.INFO)
    logger.info("Arguments: %s", pformat(config))
    logger.info("device: {}, n_gpu {}".format(config.device, config.n_gpu))

    random.seed(config.seed)
    torch.random.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    torch.manual_seed(config.seed)
    main(config)
コード例 #26
0
def main():
    rospy.init_node('amr_nn_controller_service')
    if rospy.has_param('/nn_controller_path'):
        model_path = rospy.get_param('/nn_controller_path')

    if rospy.has_param('/params'):
        # convert dict to work with dot notation.
        params = dotdict(rospy.get_param('/params'))

    controller = NNController(model_path)
    rospy.spin()
コード例 #27
0
ファイル: strategy_bitflyer.py プロジェクト: zihpzhong/mexbot
    def order(self,
              myid,
              side,
              qty,
              limit=None,
              stop=None,
              trailing_offset=None,
              symbol=None):
        """注文"""

        qty_total = qty
        qty_limit = self.risk.max_position_size

        # 買いポジあり
        if self.position.currentQty > 0:
            # 買い増し
            if side == 'buy':
                # 現在のポジ数を加算
                qty_total = qty_total + self.position.currentQty
            else:
                # 反対売買の場合、ドテンできるように上限を引き上げる
                qty_limit = qty_limit + self.position.currentQty

        # 売りポジあり
        if self.position.currentQty < 0:
            # 売りまし
            if side == 'sell':
                # 現在のポジ数を加算
                qty_total = qty_total + -self.position.currentQty
            else:
                # 反対売買の場合、ドテンできるように上限を引き上げる
                qty_limit = qty_limit + -self.position.currentQty

        # 購入数をポジション最大サイズに抑える
        if qty_total > qty_limit:
            qty = qty - (qty_total - qty_limit)

        if qty > 0:
            symbol = symbol or self.settings.symbol
            if myid in self.orders:
                order_id = self.orders[myid].id
                order = dotdict(self.exchange.fetch_order(order_id, symbol))
                if order.status == 'open':
                    if ((order.price != limit) or (order.amount != qty)):
                        order = self.exchange.cancel_order(order_id, symbol)
                        order = self.create_order(side, qty, limit, stop,
                                                  trailing_offset, symbol)
                else:
                    order = self.create_order(side, qty, limit, stop,
                                              trailing_offset, symbol)
            else:
                order = self.create_order(side, qty, limit, stop,
                                          trailing_offset, symbol)
            self.orders[myid] = order
コード例 #28
0
 def fetch_ticker(self, symbol=None, timeframe=None):
     symbol = symbol or self.settings.symbol
     timeframe = timeframe or self.settings.timeframe
     book = self.exchange.fetch_order_book(symbol, limit=1)
     trade = self.exchange.fetch_trades(symbol, limit=1, params={'reverse':True})
     ticker = dotdict()
     ticker.bid = book['bids'][0][0]
     ticker.ask = book['asks'][0][0]
     ticker.last = trade[0]['price']
     ticker.datetime = pd.to_datetime(trade[0]['datetime'])
     self.logger.info("TICK: bid {bid} ask {ask} last {last}".format(**ticker))
     return ticker
コード例 #29
0
ファイル: run_maml.py プロジェクト: iamsimha/pytorch-maml
def build_network(args):
    hparams = dotdict({
        "dim_output": args.num_classes,
        "inner_update_lr": args.inner_update_lr,
        "meta_lr": args.meta_lr,
        "meta_test_num_inner_updates": args.meta_test_num_inner_updates,
        "dim_hidden": args.dim_hidden,
        "img_size": 28,
        "channels": 1,
    })
    model = CNNModel(hparams)
    return model
コード例 #30
0
def get_args_edge():
    args = dict()
    args['model'] = 'DBiRNN'
    args['num_layer'] = 2
    args['activation'] = 'tanh'
    args['batch_size'] = 1
    args['num_hidden'] = 256
    args['num_feature'] = 39
    args['num_class'] = 29
    args['num_epochs'] = 1
    args['savedir'] = './models/04262030'
    return dotdict(args)
コード例 #31
0
from utils import dotdict
import os, yaml

target = os.path.dirname(os.path.realpath(__file__)) + "/config.yaml"
with open(target, 'r') as config_file:
    config_yaml = config_file.read()

c = dotdict(yaml.load(config_yaml))

with open('config.h', 'w') as o:

    o.write("""#ifndef CONFIG_H
#define CONFIG_H

//===================================
// CONFIGURATION CONSTANTS
//===================================

// GENERAL
""")

    for key, value in c.firmware.iteritems():
        o.write('#define ')
        o.write(key.upper() + " ")
        o.write(str(value) + '\n')

    o.write("\n// DEVICES\n")
    for key in c.devices:
        for key2, value in c.devices[key].iteritems():
            o.write('#define ')
            o.write(key.upper() + '_')