Exemplo n.º 1
0
def get_board_conf(chip_id):
    """Loads/generates a configuration in odb.boards_db
    The user then can edit the generated conf to suite the board and on the next run, the specified chached conf will be used for that board ID
    """
    if chip_id not in odb.boards_db:
        print(
            yellow(
                "Creating default board config for {}; consider updating the cached config"
                .format(chip_id)))

        print(green("What type of chip is {} ?".format(chip_id)))

        board_defaults = pick_from_list(odb.dft.boards_db.keys())
        odb.boards_db[chip_id] = EasyDict(odb.dft.boards_db[board_defaults])

        print(green("What type of board is {} ?".format(chip_id)))
        board_type = pick_from_list(odb.dft.boards_types[board_defaults])
        odb.boards_db[chip_id].BOARD = board_type

        if board_type == "d1_mini_lite":
            odb.boards_db[chip_id].UPLOAD_SPEED = 460800

        board_descript = prompt("Short description for this board: ")
        if board_descript != "":
            odb.boards_db[chip_id].DESCRIPTION = board_descript

    bcnf = EasyDict()
    transfer_missing_elements(bcnf, odb.boards_db[chip_id])
    transfer_missing_elements(bcnf, odb.dft.boards_db[bcnf.VARIANT])
    transfer_missing_elements(bcnf, odb.boards_db[chip_id])
    # bcnf.EXTRA_BUILD_FLAGS.update( odb.dft.EXTRA_BUILD_FLAGS)
    transfer_missing_elements(bcnf.EXTRA_BUILD_FLAGS, odb.EXTRA_BUILD_FLAGS)
    transfer_missing_elements(bcnf.BOARD_SPECIFIC_VARS,
                              odb.BOARD_SPECIFIC_VARS)
    return bcnf
    def get_record(self, image_id):
        con = sqlite3.connect(self.dbname)
        cur = con.cursor()
        con.text_factory = str

        sql = f'''
            select id, try_count, decieve_count,
            IFNULL(CAST(decieve_count as float) / try_count, 0) as decieve_rate
            from images
        '''
        cur.execute(sql)
        leader_board = [
            EasyDict(
                img_id=img_id, try_count=try_count,
                decieve_count=decieve_count, decieve_rate=decieve_rate
            )
            for img_id, try_count, decieve_count, decieve_rate in cur.fetchall()
        ]
        leader_board = {
            d.img_id: EasyDict(
                rank=i+1, try_count=d.try_count,
                decieve_count=d.decieve_count, decieve_rate=d.decieve_rate
            )
            for i, d in enumerate(sorted(leader_board, key=lambda x: -x.decieve_rate))
        }

        cur.close()
        con.close()

        return leader_board.get(image_id, None), len(leader_board)
Exemplo n.º 3
0
def save_board_parameters(bcnf):
    basic_params = [
        "FILESYSTEM", "SKETCH", "CHIP_ID", "PORT", "MAC", "CHIP_ID_SHORT",
        "WIFI_STATION_IP", "VALID_BOARD_ADDRESS", "HTTP_URI", "HTTP_PWD",
        "HTTP_USR", "HTTP_ADDR"
    ]

    changed = False

    # print(json.dumps(odb.EXTRA_BUILD_FLAGS, indent=2))

    transfer_missing_elements(odb.EXTRA_BUILD_FLAGS, odb.dft.EXTRA_BUILD_FLAGS)
    transfer_missing_elements(odb.BOARD_SPECIFIC_VARS,
                              odb.dft.BOARD_SPECIFIC_VARS)

    for bpar_ in basic_params:
        if bcnf.has[bpar_]:
            odb.boards_db[bcnf.CHIP_ID][bpar_] = bcnf[bpar_]
            changed = True

    if changed:
        towrite = EasyDict()
        towrite.boards_db = odb.boards_db
        towrite.EXTRA_BUILD_FLAGS = odb.EXTRA_BUILD_FLAGS
        towrite.BOARD_SPECIFIC_VARS = odb.BOARD_SPECIFIC_VARS
        write_to_db(odb.pth.boards_db, towrite)
Exemplo n.º 4
0
def run_espmake(str_args):
    """
    TODO move some constants to the states_db module
    """
    bcnf = EasyDict()
    bcnf.ESP_ROOT = pjoin(odb.pth.trd_deploy_libs, "esp8266")
    bcnf.DEMO = 1
    bcnf.action = str_args
    return call_makefile(bcnf)
def compute_inst_x_var_by_metric_for_all_anchors(var_metric, model, args):
    result_dict = EasyDict({'var_metric': var_metric})
    result_dict.shown_dim = args.shown_dim
    x_var_ls = []
    for i in range(args.n_anchors):
        anchor_point = torch.randn(1, model.z_dim).to('cuda')
        x_var_tmp = compute_inst_x_var_by_metric_for_anchor(var_metric, i, anchor_point, model, args)
        # print('x_var_tmp.shape:', x_var_tmp.shape)
        x_var_ls.append(x_var_tmp)
        # print('x_var_tmp:', x_var_tmp[0, :10])
    result_dict.x_var = np.concatenate(x_var_ls, axis=0) # (n_anchors, n_samples_per_dim, ...)
    return result_dict
Exemplo n.º 6
0
    def __init__(self, env):
        super().__init__()
        self.save_hyperparameters()

        if env == {}:
            # save.hyperparameters()を行っていなかったため
            from train import env

        if type(env) == dict:
            env = EasyDict(env)

        self.env = env

        assert env.base_model in MODELS

        if env.base_model == "vgg16":
            self.model = models.vgg16(pretrained=True)
            self.model = nn.Sequential(*list(self.model.children())[:-2])
            fc_in_features = 512

        if env.base_model.startswith("resnet"):
            self.model = getattr(models, env.base_model)(pretrained=True)
            fc_in_features = self.model.fc.in_features
            self.model = nn.Sequential(*list(self.model.children())[:-2])

        if env.base_model.startswith("efficientnet"):
            self._model = EfficientNet.from_pretrained(env.base_model,
                                                       include_top=False)
            fc_in_features = self._model._fc.in_features
            self.model = self._model.extract_features

        self.dropout = nn.Dropout(env.dropout_rate)
        self.fc = nn.Linear(fc_in_features, env.num_class)
        self.softmax = nn.Softmax(dim=1)
Exemplo n.º 7
0
def transfer_missing_elements(target_dict, source_dict, transfer_type=None):
    """Transferes missing elements from source to target recusevly
    """

    if transfer_type is None:
        transfer_type = source_dict.get("_transfer_type_", "recursive")

    for key_, val_ in source_dict.items():
        # print(key_,isinstance(val_, dict), val_)
        if isinstance(val_, dict):
            if key_ not in target_dict:
                target_dict[key_] = EasyDict()
            if transfer_type is None:
                transfer_type = val_.get("_transfer_type_", "recursive")
            # print("***********   ",transfer_type)

            if transfer_type == "recursive":
                transfer_missing_elements(target_dict[key_], val_,
                                          transfer_type)
            elif transfer_type == "update":
                target_dict[key_].update(val_)
            elif transfer_type == "overwrite":
                target_dict[key_] = copy.deepcopy(source_dict[key_])
                # target_dict[key_] = val_

        elif key_ not in target_dict:
            target_dict[key_] = copy.deepcopy(source_dict[key_])
def compute_inst_x_var_by_metric_for_all_dims(var_metric, model, args):
    result_dict = EasyDict({'var_metric': var_metric})
    assert max(args.shown_dims) < model.z_dim
    assert min(args.shown_dims) >= 0
    result_dict.shown_dims = args.shown_dims
    x_var_ls = []
    for i in args.shown_dims:
        x_var_tmp = compute_inst_x_var_by_metric_for_dim(var_metric, i, model, args)
        # print('x_var_tmp.shape:', x_var_tmp.shape)
        x_var_ls.append(x_var_tmp)
        # print('x_var_tmp:', x_var_tmp[0, :10])
    result_dict.x_var = np.concatenate(x_var_ls, axis=0) # (shown_dims, n_samples_per_dim, ...)
    # print('x_var.shape:', result_dict.x_var.shape)
    assert result_dict.x_var.shape[0] == len(args.shown_dims)
    assert result_dict.x_var.shape[1] == args.n_samples_per_dim
    return result_dict
Exemplo n.º 9
0
    def create_leader_board(self, limit=10):
        con = sqlite3.connect(self.dbname)
        cur = con.cursor()
        con.text_factory = str

        sql = f'''
            select id, try_count, decieve_count,
            IFNULL(CAST(decieve_count as float) / try_count, 0) as decieve_rate
            from images
        '''
        cur.execute(sql)

        # データベースの内容を取得
        leader_board = [
            EasyDict(
                img_id=img_id, try_count=try_count,
                decieve_count=decieve_count, decieve_rate=decieve_rate
            )
            for img_id, try_count, decieve_count, decieve_rate in cur.fetchall()
        ]
        # ランクの情報を追加
        leader_board = [
            self.create_leader_board_row(rank=i+1, **d)
            for i, d in enumerate(sorted(leader_board, key=lambda x: -x.decieve_rate))
        ]
        # 表示数を制限
        leader_board = leader_board[:limit]

        cur.close()
        con.close()

        return ''.join(leader_board)
Exemplo n.º 10
0
def libs_iterator():
    for lib_ in odb.libs_manager.values():
        if not isinstance(lib_, dict):
            continue
        for grepo_ in lib_.get("git_repo", []):
            libd = EasyDict()
            populate_libd(libd, lib_, grepo_)
            yield libd
Exemplo n.º 11
0
def write_dashboard(get_path, parent_folder, isHome=False):
    dat = EasyDict()
    com = RecursiveFormatter(dat, BACKUP_INFO)

    raw_dbdi = json.loads(run("curl " + get_path))
    the_db = raw_dbdi["dashboard"]
    the_db.pop("version", None)
    the_db.pop("id", None)

    good_from = {}
    good_from["dashboard"] = the_db
    good_from["overwrite"] = True

    dat.grafana_type_export = "dashboard"
    dat.file_month = datetime.datetime.now().month
    dat.file_year = datetime.datetime.now().year
    dat.datasource_name = raw_dbdi["meta"]["slug"]

    # db_fn = (str(raw_dbdi["meta"]["slug"]) + "_dashboard.json")

    if isHome == True:
        # db_fn = "HOME-dashboard.json"
        dat.datasource_name = "_HOME_"
        good_from["meta"] = {"isHome": True}

    datasource_bk_path = com.raw_(
        pjoin(parent_folder, BACKUP_INFO.grafana_export_file))

    with open(datasource_bk_path, "w") as dbo:
        json.dump(good_from, dbo)
Exemplo n.º 12
0
def scan_for_boards():
    boards_list = EasyDict()

    if odb.arg.in_chip_id != "":
        valid_cids = [
            cids_ for cids_ in odb.boards_db.keys()
            if odb.arg.in_chip_id in cids_
        ]
        chip_id = pick_from_list(valid_cids)
        if odb.boards_db.has[chip_id] and chip_id_has_port(chip_id):
            bcnf = get_board_conf(chip_id)
            boards_list[bcnf.CHIP_ID] = bcnf
            return boards_list
        # else:
        #     print( magenta("No saved CHIP ID that containes {} : {}".format(odb.arg.in_chip_id, str(odb.boards_db.keys()))))
        #     print( yellow("Run 'aer node -P' to detect boards "))
        #     sys.exit(42)

    if odb.arg.boards_from_ports:
        for board_port in get_ports_list():
            bcnf = port_to_board(board_port)
            if boards_list.has[bcnf.CHIP_ID]:
                transfer_missing_elements(boards_list[bcnf.CHIP_ID], bcnf)
            else:
                boards_list[bcnf.CHIP_ID] = bcnf

    if odb.arg.boards_from_mdns:
        for bcnf_ in odb.boards_db.values():
            if bcnf_.has.CHIP_ID and odb.arg.in_chip_id in bcnf_.CHIP_ID:
                bcnf = get_board_conf(bcnf_.CHIP_ID)
                if mdns_exists(bcnf, 1):
                    if boards_list.has[bcnf.CHIP_ID]:
                        transfer_missing_elements(boards_list[bcnf.CHIP_ID],
                                                  bcnf)
                    else:
                        boards_list[bcnf.CHIP_ID] = bcnf

    for key_, val_ in boards_list.items():
        if val_.has.CHIP_ID and odb.arg.in_chip_id not in val_.CHIP_ID:
            boards_list.pop(key_, None)

    return boards_list
Exemplo n.º 13
0
def gen_conts_list():
    if "all_containers" not in odb.cache:
        odb.cache.all_containers = EasyDict()
    for croot_ in odb.pth.docker_containers_root:
        if fexists(croot_):
            for pth_ in os.listdir(croot_):
                cont_path = pjoin(croot_, pth_)
                if not os.path.isdir(cont_path):
                    continue
                if not os.path.isfile(pjoin(cont_path, "container_main.py")):
                    continue
                odb.cache.all_containers[pth_] = cont_path
Exemplo n.º 14
0
def grafana_retrive_datasources(parent_local_folder):
    grafana_url = "localhost:3000"
    dat = EasyDict()
    com = RecursiveFormatter(dat, BACKUP_INFO)
    dat.grafana_type_export = "datasource"

    dat.file_month = datetime.datetime.now().month
    dat.file_year = datetime.datetime.now().year
    with quiet():
        # TODO in case values are not found, prompt user for them !!!!, Maybe
        # generalize the procedure
        dash_json = run("curl http://{}:{}@{}/api/datasources/ ".format(
            odb.var.GRAFANA_ADMIN_NAME, odb.var.GRAFANA_ADMIN_PASS,
            grafana_url)).strip()
        print(dash_json)
        dasource = json.loads(dash_json)

    local("mkdir -p " + parent_local_folder)
    for datas_dict in dasource:
        dat.datasource_name = datas_dict["name"]

        datasource_bk_path = com.raw_(
            pjoin(parent_local_folder, BACKUP_INFO.grafana_export_file))
        datas_dict.pop("id", None)

        with open(datasource_bk_path, "w") as das:
            json.dump(datas_dict, das)
Exemplo n.º 15
0
def influx_section_database_time(database, strategy="full"):
    start_date = start_date_influxdb(database)
    if start_date is None:
        print(
            red("Did not find a valid starting date for database " + database))
        return []
    time_seq = []

    start_date = start_date.replace(day=1,
                                    hour=0,
                                    minute=0,
                                    second=0,
                                    microsecond=0)
    while start_date < datetime.datetime.now():
        end_date = (start_date + datetime.timedelta(35)).replace(day=1,
                                                                 hour=0,
                                                                 minute=0,
                                                                 second=0,
                                                                 microsecond=0)
        tslice = EasyDict()
        tslice.start = start_date
        tslice.end = end_date
        tslice.database = database
        tslice.status = "full"
        time_seq.append(tslice)
        start_date = end_date

    time_seq[0].status = "start"
    time_seq[-1].status = "part"
    if strategy == "last_two":
        return time_seq[-2:]
    return time_seq
Exemplo n.º 16
0
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer, seed_everything

import model
import data
from utils import EasyDict, plot_dataset

env = EasyDict(
    base_model="efficientnet-b0",
    dataset="custom",  # ["custom"]
    dataset_root="./dataset/",  # "custom" のデータセットの読み込み元
    dataset_download="./external_dataset/",  # image_net 等データセットのダウンロード先
    num_class=3,
    # 学習関連
    batch_size=32,
    learning_rate=1e-3,
    dropout_rate=0.2)
transform = transforms.Compose([
    transforms.Resize(400),
    transforms.RandomCrop(384),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

if __name__ == "__main__":
    seed_everything(42)

    model = model.FineTuningModel(env)

    train_dataset, val_dataset = data.get_dataset(env, transform)
Exemplo n.º 17
0
def parse_args(cmd_args):
    args = EasyDict()

    # args copied from cmd_args
    args.out_dir = cmd_args.out_dir
    args.height = cmd_args.height
    args.width = cmd_args.width
    args.n_triangle = cmd_args.n_triangle
    args.alpha_scale = cmd_args.alpha_scale
    args.coordinate_scale = cmd_args.coordinate_scale
    args.fps = cmd_args.fps
    args.n_population = cmd_args.n_population
    args.n_iterations = cmd_args.n_iterations
    args.report_interval = cmd_args.report_interval
    args.step_report_interval = cmd_args.step_report_interval
    args.save_as_gif_interval = cmd_args.save_as_gif_interval
    args.gpus = cmd_args.gpus
    args.thread_per_clip = cmd_args.thread_per_clip
    args.prompt = cmd_args.prompt

    return args
Exemplo n.º 18
0
def influxdb_export_points(database, backup_root, compress=True):
    dat = EasyDict()
    com = RecursiveFormatter(dat, BACKUP_INFO)
    dat.database = database
    dat.cont_name = "influxdb"
    dat.extension = "gz" if compress == True else "txt"
    dat.compress = "-compress" if compress == True else ""
    dat.root_data = "/data"
    with quiet():
        if run(com.raw_(
                "{dexec} {cont_name} test -d {root_data}/data")).failed:
            dat.root_data = "/var/lib/influxdb"
        if run(com.raw_(
                "{dexec} {cont_name} test -d {root_data}/data")).failed:
            print(
                red("Did not find root folder for influxdb ./data and ./wal"))
            return  # TODO Maybe something to improve
    dat.export_influx = "{dexec} {cont_name} influx_inspect export " + \
                        " -database {database} -datadir {root_data}/data -waldir {root_data}/wal {compress} " + \
                        " -start {start_date_iso} -end {end_date_iso} -out  {influx_bk_container}/{influx_export_file} "

    strategy = "last_two" if odb.arg.last_two else "full_range"
    sections_info = influx_section_database_time(database, strategy)

    # for seq_ in sections_info:
    #     print()
    #     print( seq_.start)
    #     print( seq_.end)

    # return

    run(com.raw_("mkdir -p {influx_bk_target}"))
    run(com.raw_("{dexec} {cont_name} mkdir -p {influx_bk_container}  "))

    # show("stdout")
    for seq_ in sections_info:
        dat.start_date_iso = seq_.start.isoformat() + "Z"
        dat.end_date_iso = seq_.end.isoformat() + "Z"
        dat.backup_status = seq_.status if "status" in seq_ else "full"
        dat.file_month = seq_.start.month
        dat.file_year = seq_.start.year

        influxdb_handle_exported(com, backup_root)
Exemplo n.º 19
0
def influxdb_restore_points(restore_root, compressed=True):
    dat = EasyDict()
    com = RecursiveFormatter(dat, BACKUP_INFO)

    dat.restore_root = restore_root
    dat.cont_name = "influxdb"
    dat.extension = "gz" if compressed else "txt"
    dat.compressed = "-compressed" if compressed else ""
    dat.import_influxdb = "{dexec} {cont_name} influx -import -pps {influx_PPS} " + \
        " -path={container_file_path} {compressed} "

    files_list = filter_backup_files(restore_root, print_info=True)

    if not confirm("Do you want to continue ?"):
        return

    run(com.raw_("mkdir -p {influx_bk_target} "))
    run(com.raw_("{dexec} {cont_name}  mkdir -p {influx_bk_container} "))

    for file_ in files_list:
        local_file_path = pjoin(restore_root, file_)
        dat.file_name = file_
        dat.target_file_path = com.raw_("{influx_bk_target}/{file_name}")
        dat.container_file_path = "{influx_bk_container}/{file_name}"
        # if not texists(dat.target_file_path):
        l_to_t = put(local_file_path, dat.target_file_path)
        # if l_to_t.failed:
        #     continue

        t_to_c = run(
            com.raw_(
                "docker cp {target_file_path} {cont_name}:{influx_bk_container}/."
            ))
        # if t_to_c.failed:
        #     continue
        # TODO if there are failed points , retry the precedure
        try_again = run(com.import_influxdb).failed
        if try_again:
            dat.influx_PPS = "3000"
            run(com.import_influxdb)
            dat.influx_PPS = "7000"

        print()
        print(com.raw_(dat.import_influxdb))
        print()
Exemplo n.º 20
0
def generate_user_db():
    user_db_path = odb.pth.user_db
    user_dict = EasyDict()
    # user_dict.pth = odb.pth.copy()
    user_dict.var = odb.var.copy()
    write_to_db(user_db_path, user_dict)
Exemplo n.º 21
0
"""Main entry point to train Pong agents."""

from training_loop import training_loop
from utils import EasyDict

from networks import *

if __name__ == "__main__":
    # Configure environment and agent.
    num_episodes = 50000000
    start_training_at_frame = 5000
    target_epsilon = 0.05
    reach_target_at_frame = 1e6
    update_target_freq = 10000
    save_every_n_ep = 1
    player_id = 1
    log_freq = 50
    agent_config = EasyDict(input_shape=(1, 84, 84),
                            network_fn=CNN,
                            num_actions=3,
                            stack_size=4,
                            replay_memory_size=int(1e6),
                            minibatch_size=128)

    # Call training loop.
    training_loop(num_episodes, target_epsilon, reach_target_at_frame,
                  player_id, start_training_at_frame, update_target_freq,
                  save_every_n_ep, log_freq, agent_config)
Exemplo n.º 22
0
def load_default_config():
    config = EasyDict()
    config.dataset_name = 'name'
    config.dataset_type = 'Raw_Image'
    config.save_datapath = None
    config.preload = './1.tfrecord'

    raw_image = EasyDict()
    raw_image.dataset_root_dir = '/Users/ecohnoch/Desktop/face_gan/StyleGAN-Tensorflow-master'
    raw_image.raw_data_format = ['.png', '.jpg']
    raw_image.ignore_dir_names = ['.DS_Store']
    raw_image.shape = 32  # None, 32, 64, 128, 224...
    raw_image.labeling_func = None
    raw_image.preprocessing_func = None

    raw_image.tfrecord = True
    raw_image.batch_size = 32
    raw_image.gpu_device = None
    config.Raw_Image = raw_image
    return config
Exemplo n.º 23
0
# pylint: disable=I0011,E1129,E0611
from __future__ import absolute_import, division, print_function

import os
from os.path import exists as fexists
from os.path import join as pjoin

from fabric.colors import blue, cyan, green, magenta, red, white, yellow

from utils import EasyDict, RecursiveFormatter, local_hostname

odb = EasyDict()

odb.boards_db = EasyDict()
odb.cache = EasyDict()
odb.run = EasyDict()

odb.env = EasyDict()
odb.env.skip_bad_hosts = True
odb.env.colorize_errors = True
odb.env.combine_stderr = True
odb.env.skip_unknown_tasks = True
odb.env.warn_only = True
odb.env.timeout = 5

odb.pth = EasyDict()
odb.pth.root = os.path.realpath(
    os.path.join(os.path.dirname(__file__), "../.."))
odb.pth.user_db = os.path.join(odb.pth.root, "user_conf_db.json")
odb.pth.boards_db = os.path.join(odb.pth.root, "dev_boards_db.json")
odb.pth.cache_db = os.path.join(odb.pth.root, "cache_gen_db.json")
Exemplo n.º 24
0
from fabric.colors import blue, cyan, green, magenta, red, white, yellow
from fabric.contrib import files
from fabric.contrib.console import confirm
from fabric.contrib.files import exists as texists

from aer import utils
# from aer.api import *
from states_db import odb
from utils import EasyDict, RecursiveFormatter

# pylint: disable=I0011,E1129

# reload(sys)
# sys.setdefaultencoding('utf8')

BACKUP_INFO = EasyDict()
BACKUP_INFO.influx_PPS = "7000"
BACKUP_INFO.influx_bk_container = "/backup/influxdb"
BACKUP_INFO.influx_bk_target = "~/backup/influxdb"
BACKUP_INFO.influx_export_file = "{database}_{host_id}_{file_year}-{file_month}_{backup_status}.{extension}"
BACKUP_INFO.grafana_export_file = "{grafana_type_export}_{datasource_name}_{file_year}-{file_month}.json"
BACKUP_INFO.dexec = "docker exec "


def entrypoint():
    # config.handle()
    BACKUP_INFO.host_id = run("hostname").strip()

    if odb.arg.enumerate_files:
        list_all_files()
Exemplo n.º 25
0
def parse_args(cmd_args):
    args = EasyDict()

    args.out_dir = cmd_args.out_dir
    args.height = cmd_args.height
    args.width = cmd_args.width
    args.target_fn = cmd_args.target_fn
    args.n_triangle = cmd_args.n_triangle
    args.loss_type = cmd_args.loss_type
    args.alpha_scale = cmd_args.alpha_scale
    args.coordinate_scale = cmd_args.coordinate_scale
    args.fps = cmd_args.fps
    args.n_population = cmd_args.n_population
    args.n_iterations = cmd_args.n_iterations
    args.mp_batch_size = cmd_args.mp_batch_size
    args.solver = cmd_args.solver
    args.report_interval = cmd_args.report_interval
    args.step_report_interval = cmd_args.step_report_interval
    args.save_as_gif_interval = cmd_args.save_as_gif_interval
    args.profile = cmd_args.profile

    return args
Exemplo n.º 26
0
def run_tests():
    print("TESTING ___________ ")

    print("TESTING transfer_missing_elements ")
    d1 = EasyDict()
    d1.a = 1
    d1.b = 1
    d1.c = 1
    d1.d = EasyDict()
    d1.d.a = 1
    d1.d.b = 1

    d1._transfer_type_ = "update"
    d2 = EasyDict()
    d2.a = 20
    d2.c = 20
    d2.d = EasyDict()
    d2.d.a = 20
    d2.d.x = 20
    transfer_missing_elements(d2, d1)
    expected = EasyDict({
        "a": 20,
        "d": {
            "a": 1,
            "x": 20,
            "b": 1
        },
        "c": 20,
        "_transfer_type_": "update",
        "b": 1
    })
    # print(json.dumps(d2,indent=2))
    # print(json.dumps(d2))
    if d2 != expected:
        print(expected, "!=", json.dumps(d2))
        raise Exception("transfer_missing_elements is different")

    d1._transfer_type_ = "recursive"
    d2 = EasyDict()
    d2.a = 20
    d2.c = 20
    d2.d = EasyDict()
    d2.d.a = 20
    d2.d.x = 20
    transfer_missing_elements(d2, d1)
    expected = EasyDict({
        "a": 20,
        "d": {
            "a": 20,
            "x": 20,
            "b": 1
        },
        "b": 1,
        "c": 20,
        "_transfer_type_": "recursive"
    })
    # print(json.dumps(d2,indent=2))
    # print(json.dumps(d2))
    if d2 != expected:
        print(expected, "!=", json.dumps(d2))
        raise Exception("transfer_missing_elements is different")

    d1._transfer_type_ = "overwrite"
    d2 = EasyDict()
    d2.a = 20
    d2.c = 20
    d2.d = EasyDict()
    d2.d.a = 20
    d2.d.x = 20
    transfer_missing_elements(d2, d1)
    expected = EasyDict({
        "a": 20,
        "d": {
            "a": 1,
            "b": 1
        },
        "c": 20,
        "b": 1,
        "_transfer_type_": "overwrite"
    })
    # print(json.dumps(d2,indent=2))
    # print(json.dumps(d2))
    if d2 != expected:
        print(expected, "!=", json.dumps(d2))
        raise Exception("transfer_missing_elements is different")