예제 #1
0
def prune_benchmark():
    logger = get_logger("prune_pong_agent_benchmark")
    prune_model = DQNPacman(input_size=prune_config.input_size,
                            output_size=prune_config.output_size,
                            model_path=prune_config.model_path,
                            scope=prune_config.scope,
                            epsilon_stop=prune_config.final_epsilon,
                            epsilon=prune_config.initial_epsilon,
                            pruning_end=prune_config.pruning_end,
                            pruning_freq=prune_config.pruning_freq,
                            sparsity_end=prune_config.sparsity_end,
                            target_sparsity=prune_config.target_sparsity,
                            prune_till_death=True)
    target_model = PacmanTargetNet(input_size=dense_config.input_size,
                                   output_size=dense_config.output_size)
    logger.info("loading models")
    print("loading models")
    target_model.load_model(dense_config.ready_path)
    prune_model.load_model(dense_config.ready_path)
    prune_model.change_loss_to_benchmark_loss()
    prune_model.reset_global_step()
    logger.info("Commencing iterative pruning")
    sparsity_vs_accuracy = iterative_pruning(logger,
                                             prune_model,
                                             target_model,
                                             prune_config.n_epoch,
                                             benchmarking=True)
    print("benchmark finished")
    plot_graph(sparsity_vs_accuracy,
               "sparsity_vs_accuracy_benchmark",
               figure_num=1)
예제 #2
0
 def __init__(self, data_type=NONE, value=None, name=None):
     Blueprint.__init__(self,
                        type=Blueprint.TYPES.get("ATTRIBUTE"),
                        name=name)
     self.__logger = logger_utils.get_logger(__name__)
     self.__data_type = data_type
     self.__value = value
예제 #3
0
 def __init__(self):
     logger = logger_utils.get_logger(__name__)
     logger.info("Board initialized")
     self.app_status = Status.STARTED
     pg.init()
     self.__scene_builder = None
     self.__setup_board()
 def __init__(self, display, coords=None, size=None):
     Form.__init__(self, display, coords, size)
     self.__logger = logger_utils.get_logger(__name__)
     self.__bp = None
     self.ta_populated = False
     self.__tas = list()
     self.boarder_rect = None
예제 #5
0
 def __init__(self,
              input_size,
              output_size,
              model_path: str,
              momentum=0.9,
              reg_str=0.0005,
              scope='ConvNet',
              pruning_start=int(10e4),
              pruning_end=int(10e5),
              pruning_freq=int(10),
              sparsity_start=0,
              sparsity_end=int(10e5),
              target_sparsity=0.0,
              dropout=0.5,
              initial_sparsity=0,
              wd=0.0):
     super(ConvNet, self).__init__(input_size=input_size,
                                   output_size=output_size,
                                   model_path=model_path)
     self.scope = scope
     self.momentum = momentum
     self.reg_str = reg_str
     self.dropout = dropout
     self.logger = get_logger(scope)
     self.wd = wd
     self.logger.info("creating graph...")
     with self.graph.as_default():
         self.global_step = tf.Variable(0, trainable=False)
         self._build_placeholders()
         self.logits = self._build_model()
         self.weights_matrices = pruning.get_masked_weights()
         self.sparsity = pruning.get_weight_sparsity()
         self.loss = self._loss()
         self.train_op = self._optimizer()
         self._create_metrics()
         self.saver = tf.train.Saver(var_list=tf.global_variables())
         self.hparams = pruning.get_pruning_hparams()\
             .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},'
                    ' sparsity_function_begin_step={},sparsity_function_end_step={},'
                    'pruning_frequency={},initial_sparsity={},'
                    ' sparsity_function_exponent={}'.format(scope,
                                                            pruning_start,
                                                            pruning_end,
                                                            target_sparsity,
                                                            sparsity_start,
                                                            sparsity_end,
                                                            pruning_freq,
                                                            initial_sparsity,
                                                            3))
         # note that the global step plays an important part in the pruning mechanism,
         # the higher the global step the closer the sparsity is to sparsity end
         self.pruning_obj = pruning.Pruning(self.hparams,
                                            global_step=self.global_step)
         self.mask_update_op = self.pruning_obj.conditional_mask_update_op()
         # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned
         # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP
         self.init_variables(
             tf.global_variables())  # initialize variables in graph
예제 #6
0
 def __init__(self, panel):
     Blueprint.__init__(self, panel, AB())
     self.__logger = logger_utils.get_logger(__name__)
     self.set_custom_size(AttributeBlueprint.SIZE)
     self.data_type_pressed = [False, None]  # IS PRESSED; TEXT BOX
     self.data_type_selection = list()
     self.change_font(
         pg.font.Font(Themes.DEFAULT_THEME.get("text_font_style"),
                      int(self.get_rect().height * .23)))
예제 #7
0
def main():
    logger = get_logger('train_lunarlander')
    actor = ActorLunarlander(input_size=dense_config.input_size,
                             output_size=dense_config.output_size,
                             model_path=FLAGS.actor_path)
    critic = CriticLunarLander(input_size=dense_config.input_size,
                               output_size=dense_config.critic_output,
                               model_path=FLAGS.critic_path)

    train(logger, actor, critic, epochs=FLAGS.n_epoch)
 def __init__(self, display, coords=None, size=None):
     Form.__init__(self, display, coords, size)
     self.lang_select = None
     self.btn_drop_down = None
     self.__logger = logger_utils.get_logger(__name__)
     self.__lang = StringUtils.get_string(StringUtils.DEFAULT_LANGUAGE)
     self.__lang_content = list()
     self.__lang_counter = 0
     self.__is_drop_down_pressed = False
     self.__selected = False
예제 #9
0
def main():
    agent = CartPoleDQN(input_size=dense_config.input_size,
                        output_size=dense_config.output_size,
                        model_path=dense_config.model_path_overtrained)
    target_agent = CartPoleDQNTarget(input_size=dense_config.input_size,
                                     output_size=dense_config.output_size)
    agent.print_num_of_params()
    target_agent.print_num_of_params()
    logger = get_logger("train_Cartpole_agent")
    fit(logger, agent, target_agent, dense_config.n_epoch)
예제 #10
0
 def __init__(self, text, pos):
     # TODO set button size according to the text object size
     font = pg.font.Font(Themes.DEFAULT_THEME.get("button_font_style"),
                         int(DisplaySettings.get_size_by_key()[1] * .045))
     self.logger = logger_utils.get_logger(__name__)
     self.__text_str = text
     self.__text = font.render(text, True, Themes.DEFAULT_THEME.get("font"))
     self.__height = int(font.size(self.__text_str)[1] * 1.1)
     self.__width = int(font.size(self.__text_str)[0] * 1.1)
     self.color = Themes.DEFAULT_THEME.get(
         "button")  # Default color but overrides in scene drawings
     self.__x, self.__y = self.set_coordinates(pos)
예제 #11
0
class TemplateManager(Manager):
    ROOT_PATH = logger_utils.ROOT_PATH + "GeneratorAPI\\templates\\"

    TEMPLATE_PATHS = {
        "Car Simulator": "{}{}\\".format(ROOT_PATH, "car_simulator")
    }

    TEMPLATE_EXTENSION = ".py.temp"

    __LOGGER = logger_utils.get_logger(__name__)

    @classmethod
    def get_templates(cls, api):
        path = TemplateManager.TEMPLATE_PATHS.get(api)
        r = dict()
        for root, dirs, files in os.walk(path):
            for file in files:
                if file.endswith(TemplateManager.TEMPLATE_EXTENSION):
                    fn = file.split(".")
                    r[fn[0]] = "{}{}".format(path, file)
                    TemplateManager.__LOGGER.debug(file)
        return r

    @classmethod
    def read_template(cls, path):
        con = ""
        with open(path, "r") as file:
            for line in file:
                con += TemplateManager.spaces_to_tabs(line)
        return con

    @classmethod
    def spaces_to_tabs(cls, line):
        s = ""
        i, t = 0, 0
        while i < len(line):
            if i < 0:
                i = 0
            if i < len(line) - 4:
                if line[i] == " " and line[i + 1] == " " and line[
                        i + 2] == " " and line[i + 3] == " ":
                    s += "\t"
                    line = line[i + 4:]
                    t += 1
                else:
                    s += line[i]
                    t = 0
            else:
                s += line[i:]
                break
            i += 1 - t
        return s
예제 #12
0
def train():
    logger = get_logger("train_pong_student")
    agent = DQNPong(input_size=dense_config.input_size,
                    output_size=dense_config.output_size,
                    model_path=FLAGS.model_path,
                    scope=dense_config.scope,
                    epsilon_stop=dense_config.final_epsilon,
                    epsilon=dense_config.initial_epsilon)
    target_agent = PongTargetNet(input_size=dense_config.input_size,
                                 output_size=dense_config.output_size)
    agent.print_num_of_params()
    target_agent.print_num_of_params()
    fit(logger, agent, target_agent, FLAGS.n_epoch)
 def __init__(self, display):
     SceneBuilder.__init__(self, display)
     self.__logger = logger_utils.get_logger(__name__)
     self.btn_theme = ThemeButton()
     self.btn_theme.color = Themes.DEFAULT_THEME.get("front_screen")
     self.btn_language = LanguageButton()
     self.btn_language.color = Themes.DEFAULT_THEME.get("front_screen")
     self.btn_back = BackButton()
     self.btn_back.color = Themes.DEFAULT_THEME.get("front_screen")
     self.btn_display = DisplayButton()
     self.btn_display.color = Themes.DEFAULT_THEME.get("front_screen")
     self.frm_theme = ThemeSelectionForm(self.display)
     self.frm_lang = LanguageSelectionForm(self.display)
     self.frm_display = DisplaySelectionForm(self.display)
 def __init__(self, display):
     SceneBuilder.__init__(self, display)
     self.__logger = logger_utils.get_logger(__name__)
     self.input = None
     self.__project_name = ""
     self.api_select = None
     self.__api = GameApi.DEFAULT_API
     self.btn_drop_down = None
     self.btn_create = CreateButton(1)
     self.btn_cancel = CancelButton(0)
     self.__is_drop_down_pressed = False
     self.__menu_content = []
     self.__menu_counter = 0
     self.__popup = None
예제 #15
0
 def __init__(self,
              control_panel,
              display,
              project_info,
              generated,
              coords=None,
              size=None):
     Form.__init__(self, display, coords, size)
     self.__project_info = project_info
     self.__cont_panel = control_panel
     self.__logger = logger_utils.get_logger(__name__)
     self.__bps = list()
     self.__bps_connections = list()
     self.generated = generated
     self.popup = None
예제 #16
0
 def __init__(self, display):
     SceneBuilder.__init__(self, display)
     self.__logger = logger_utils.get_logger(__name__)
     self.btn_select = SelectButton(0)
     self.btn_select.color = Themes.DEFAULT_THEME.get("front_screen")
     self.btn_delete = DeleteButton(0)
     self.btn_delete.color = Themes.DEFAULT_THEME.get("front_screen")
     self.btn_back = BackButton(0)
     self.btn_back.color = Themes.DEFAULT_THEME.get("front_screen")
     self.file_container = pg.Rect(
         (int(DisplaySettings.get_size_by_key()[0] * .01),
          int(DisplaySettings.get_size_by_key()[1] * 0.22)),
         (int(DisplaySettings.get_size_by_key()[0] * .98),
          int(DisplaySettings.get_size_by_key()[1] * .66)))
     self.files = ProjectManager.get_projects()
     self.__logger.debug(self.files)
예제 #17
0
    def __init__(self, display, project):
        SceneBuilder.__init__(self, display)
        self.__project = (project.get("PROJECT_NAME"),
                          project.get("PROJECT_API"))
        pg.display.set_caption("{} - {}   {}".format(self.__project[0],
                                                     self.__project[1],
                                                     app_utils.CAPTION))

        self.__logger = logger_utils.get_logger(__name__)
        self.__logger.debug("{} --- {}".format(self.__project[0],
                                               self.__project[1]))

        self.btn_file = FileButton()
        self.btn_file.color = Themes.DEFAULT_THEME.get("background")
        self.btn_run = RunButton()
        self.btn_run.color = Themes.DEFAULT_THEME.get("background")
        self.btn_settings = SettingsButton()
        self.btn_settings.color = Themes.DEFAULT_THEME.get("background")
        self.btn_edit = EditButton()
        self.btn_edit.color = Themes.DEFAULT_THEME.get("background")
        self.__init_btn_size()
        self.__file_menu_content = self.__init_file_menu()
        self.__edit_menu_content = self.__init_edit_menu()
        self.__settings_menu_content = self.__init_settings_menu()
        self.__run_menu_content = self.__init_run_menu()
        self.__btn_file_pressed, self.__btn_edit_pressed, self.__btn_run_pressed, \
            self.__btn_settings_pressed = False, False, False, False
        self.__popup = None

        self.__cont_panel = ControlPanelForm(
            self.display, (int(DisplaySettings.get_size_by_key()[0] * .005),
                           int(self.btn_file.get_rect().bottom * 1.005)),
            (int(DisplaySettings.get_size_by_key()[0] * .265),
             int(DisplaySettings.get_size_by_key()[1] * .945)))
        self.__bp_panel = BlueprintControlForm(
            self.__cont_panel, self.display, self.__project,
            project.get("GENERATED"),
            (int(self.__cont_panel.get_rect().right +
                 DisplaySettings.get_size_by_key()[0] * .005),
             int(self.btn_file.get_rect().bottom * 1.05)),
            (int(DisplaySettings.get_size_by_key()[0] * .723),
             int(DisplaySettings.get_size_by_key()[1] * .945)))
        if project.get("CONNECTIONS") is not None and project.get(
                "BLUEPRINTS") is not None:
            self.__bp_panel.load_project(project.get("CONNECTIONS"),
                                         project.get("BLUEPRINTS"))
예제 #18
0
def get_all_i_need():
    parser = ArgumentParser()
    parser.add_argument('-do', type=float, help='model: dropout')
    parser.add_argument('-ed', type=int, help='model: embed dim')
    parser.add_argument('-bs', type=int, help='iter: pos sample batch size')
    parser.add_argument('-ws', type=int, help='model: window size')
    parser.add_argument('-id', type=int, default=0, help='identification')
    parser.add_argument('-cn', help='cluster num')
    parser.add_argument('-dn', help='data class name')
    parser.add_argument('-gp', type=float, help='gpu fraction')
    args = parser.parse_args()
    print('Using class', args.dn)
    d_class = name2object[args.dn]
    args.cn = d_class.clu_num
    # params_as_whole = ','.join(['{}={}'.format(k, v) for k, v in args.__dict__.items() if v is not None])
    params_as_whole = entries2name(args.__dict__)
    logger = lu.get_logger('./logging/{}.txt'.format(params_as_whole))
    return args, logger, d_class
예제 #19
0
class SecurityManager(Manager):

    CIPHER = b"CB*(GH&V09IKdsf4"
    IV = b"0000000000000000"
    __LOGGER = logger_utils.get_logger(__name__)

    @classmethod
    def encode_data(cls, data):
        """Description: method encodes incoming string into AES-128 cipher

        :param data: Data string to save
        :return: Encrypted cypher
        """
        backend = default_backend()
        cipher = Cipher(algorithms.AES(SecurityManager.CIPHER),
                        modes.CBC(SecurityManager.IV),
                        backend=backend)
        encryptor = cipher.encryptor()
        data = data.encode("utf-8")
        # Data padding - add zero bytes to the end of the string
        # to make it length divisible by 16 (AES block)
        while len(data) % 16 != 0:
            data += b"0"
        SecurityManager.__LOGGER.debug("Size: {}".format(len(data)))
        return encryptor.update(data) + encryptor.finalize()

    @classmethod
    def decode_data(cls, e_data):
        backend = default_backend()
        cipher = Cipher(algorithms.AES(SecurityManager.CIPHER),
                        modes.CBC(SecurityManager.IV),
                        backend=backend)
        decryptor = cipher.decryptor()
        data = decryptor.update(e_data) + decryptor.finalize()
        data = data.decode("utf-8")
        data = data[::-1]
        for i in range(0, len(data), 1):
            if data[i] != "0":
                SecurityManager.__LOGGER.debug(data[i])
                data = data[i:]
                break
        return data[::-1]
예제 #20
0
class ExecutionManager(Manager):
    __LOGGER = logger_utils.get_logger(__name__)

    @classmethod
    def execute_program(cls, project, main_file):
        main_path = "{}{}\\out\\{}\\{}.py".format(ProjectManager.PATH, project,
                                                  project, main_file)
        if os.path.exists(main_path):
            app_daemon = threading.Thread(
                target=ExecutionManager.call_subprocess(main_path, "python"),
                name="{} daemon".format(project))
            app_daemon.daemon = True
            app_daemon.start()
        else:
            ExecutionManager.__LOGGER.error(
                "Cannot find generated source code")
            raise FileNotFoundError("Cannot find generated source code")

    @classmethod
    def call_subprocess(cls, path, language):
        subprocess.Popen([language, path])
예제 #21
0
class ConfigManager(Manager):
    LOGGER = logger_utils.get_logger(__name__)
    DEFAULT_CONFIG = {
        "SIZE": {
            "WIDTH": 800,
            "HEIGHT": 600
        },
        "LANGUAGE": "ID_ENGLISH",
        "THEME": "ST_1"
    }

    CONFIG_FILE_NAME = "app.config"
    ROOT_PATH = logger_utils.ROOT_PATH + "BlueprintsApp\\"
    CONFIG_PATH = ROOT_PATH + CONFIG_FILE_NAME

    @classmethod
    def set_configurations(cls):
        cfgs = dict()
        if not os.path.exists(ConfigManager.CONFIG_PATH):
            ConfigManager.LOGGER.error("Configuration file not found...")
            ConfigManager.LOGGER.info(
                "Generating default configuration file...")
            ConfigManager.generate_default_configuration()
        with open(ConfigManager.CONFIG_PATH, 'r') as json_cfg:
            cfg = json.load(json_cfg)
            # LOADING CONFIGURATIONS
            # TRY to find custom settings
            try:
                ConfigManager.LOGGER.info(
                    "Loading custom configuration settings...")
                cfgs["screen"] = [
                    cfg["CUSTOM"]["SIZE"]["WIDTH"],
                    cfg["CUSTOM"]["SIZE"]["HEIGHT"]
                ]
                cfgs["lang"] = cfg["CUSTOM"]["LANGUAGE"]
                cfgs["theme"] = cfg["CUSTOM"]["THEME"]
            except KeyError as ex:
                ConfigManager.LOGGER.error("Custom configurations not found")
                ConfigManager.LOGGER.info("Loading default configurations...")
                cfgs["screen"] = [
                    cfg["DEFAULT"]["SIZE"]["WIDTH"],
                    cfg["DEFAULT"]["SIZE"]["HEIGHT"]
                ]
                cfgs["lang"] = cfg["DEFAULT"]["LANGUAGE"]
                cfgs["theme"] = cfg["DEFAULT"]["THEME"]
        DisplaySettings.set_size_by_key(
            DisplaySettings.get_size_name(cfgs.get("screen")))
        StringUtils.set_language(cfgs.get("lang"))  # LANGUAGE
        Themes.set_theme(cfgs.get("theme"))  # THEME

    @classmethod
    def generate_default_configuration(cls):
        cfg_dict = dict()
        cfg_dict["DEFAULT"] = ConfigManager.DEFAULT_CONFIG
        with open(ConfigManager.CONFIG_PATH, 'w+') as cfg_file:
            json.dump(cfg_dict, cfg_file)

    @classmethod
    def save_configurations(cls):
        # TODO implement data crypto for security reasons
        cfg_dict = dict()
        cfg_dict["CUSTOM"] = {
            "SIZE": {
                "WIDTH": DisplaySettings.DEFAULT_SCREEN_SIZE[0],
                "HEIGHT": DisplaySettings.DEFAULT_SCREEN_SIZE[1]
            },
            "LANGUAGE": StringUtils.DEFAULT_LANGUAGE,
            "THEME": Themes.get_value(Themes.DEFAULT_THEME)
        }
        cfg_dict["DEFAULT"] = ConfigManager.DEFAULT_CONFIG
        with open(ConfigManager.CONFIG_PATH, 'w+') as cfg_file:
            json.dump(cfg_dict, cfg_file)
예제 #22
0
                     help='use assaf measure')
 FLAGS, unparsed = parser.parse_known_args()
 Net_OOP = TilesUnetMirrored()
 config.Augment = FLAGS.aug
 config.use_assaf = FLAGS.assaf
 print("Augment is : {}".format(int(config.Augment)))
 optimizer = tf.compat.v2.optimizers.Adam(beta_1=0.99)
 loader = Loader(batch_size=FLAGS.batch_size)
 Net_OOP.compile(optimizer=optimizer,
                 loss=loss_fn,
                 metrics=['acc', 'loss', 'val_acc', 'val_loss'])
 if not os.path.exists(FLAGS.log_dir):
     os.makedirs(FLAGS.log_dir)
 if not os.path.exists(FLAGS.model_path):
     os.makedirs(FLAGS.model_path)
 logger = get_logger(os.path.join(FLAGS.log_dir, "train_log"))
 Tensorcallback = callbacks.TensorBoard(FLAGS.log_dir,
                                        write_graph=False,
                                        write_images=False)
 Checkpoint = callbacks.ModelCheckpoint(filepath=FLAGS.model_path +
                                        "/checkpoint.hdf5",
                                        monitor='val_acc',
                                        mode='max',
                                        save_best_only=True)
 Checkpoint.set_model(Net_OOP)
 Tensorcallback.set_model(Net_OOP)
 callbacks = {'tensorboard': Tensorcallback, 'checkpoint': Checkpoint}
 Net_OOP.fit(logger=logger,
             callbacks=callbacks,
             epochs=FLAGS.epochs,
             steps_per_epoch=config.steps_per_epoch,

nmap -sP 192.168.0.*
"""

import re
from subprocess import Popen, PIPE
from enumeration import Commands
from utils.logger_utils import get_logger
from utils.conf_reader import get_config_option
from utils.custom_exceptions import IPNotFoundError, BroadcastFailureError, \
    PingFailureError, ArpFailureError



logging = get_logger()


def get_ip_mac_address():
    """
    command - ifconfig
    parsing commands output to get valid ip and mac address from eth or wlan.

    :return: ip and mac address.

    """

    try:
        process = Popen(args=[Commands.ifconfig], stdout=PIPE, stderr=PIPE)
        std_out, std_err = process.communicate()
예제 #24
0
 def __init__(self, name=Status.NONE, b_type=Status.NONE, data_type=Status.NONE, value=Status.NONE):
     Blueprint.__init__(self, name, b_type)
     self.__logger = logger_utils.get_logger(__name__)
     self.data_type = data_type
     self.value = value
예제 #25
0
# For those usages not covered by the Apache version 2.0 License please
# contact with [email protected]
__author__ = 'jesus.movilla'

import sys
import argparse
import time
import requests
import datetime
import xmltodict

from sdcclient.client import SDCClient
from paasmanagerclient.client import PaaSManagerClient
from utils.logger_utils import get_logger

logger = get_logger(__name__)

# HEADERS
X_AUTH_TOKEN = "X-Auth-Token"
TENANT_ID = "Tenant-Id"
ACCEPT = "Accept"
APPLICATION_JSON = "application/json"

#HTTP STATUS CODE
HTTPS_PROTOCOL ="https"
HTTP_STATUSCODE_NO_CONTENT = 204
HTTP_STATUSCODE_OK = 200

# GLANCE SERVICE
GLANCE_SERVICE_TYPE = "glance"
GLANCE_ENDPOINT_TYPE = "publicURL"
예제 #26
0
class PythonGenerator(Generator):
    __LOGGER = logger_utils.get_logger(__name__)

    @classmethod
    def generate(cls, project):
        s = Status.SUCCESS
        PythonGenerator.initialize_directory(project.name)
        temps = TemplateManager.get_templates(project.api)

        # CHARACTER DATA
        file = "custom_character"
        content = TemplateManager.read_template(temps.get(file))
        PythonGenerator.generate_character(project, content)

        # SPRITE DATA

        # SYSTEM SPECIFIC DATA (BOARD)
        file = "board"
        content = TemplateManager.read_template(temps.get(file))
        content = PythonGenerator.generate_board(project, content)
        PythonGenerator.save_generated_content(project.name, file, content)

        # SAVE OTHER FILES THAT DO NOT HAVE GENERATOR TAGS
        for k, v in temps.items():
            content = TemplateManager.read_template(v)
            if Generator.FINDER not in content:
                PythonGenerator.save_generated_content(project.name, k,
                                                       content)

        return s

    @classmethod
    def initialize_directory(cls, project):
        path = "{}{}\\".format(TemplateManager.ROOT_PATH, "out")
        if not os.path.exists(path):
            os.mkdir(path)
        path = "{}{}".format(path, project)
        if os.path.exists(path):
            shutil.rmtree(path=path, ignore_errors=True)
            os.mkdir(path)
            os.mkdir("{}\\src".format(path))
            os.mkdir("{}\\src\\{}\\".format(path, project))
        else:
            os.mkdir(path)
            os.mkdir("{}src\\{}\\".format(path, project))

    @classmethod
    def generate_board(cls, project, content):
        # DEFINE IMPORTS
        data = ""
        for ch in project.characters:
            data += "from {}_character import {}\n".format(
                ch.name.lower(), str(ch))
        for sp in project.sprites:
            data += "from {}_sprite import {}\n".format(
                sp.name.lower(), str(sp))
        content = PythonGenerator.insert_data(content, data,
                                              Generator.SYSTEM_IMPORT)
        # DEFINE CHARACTER INITIALIZATION
        data = ""
        for ch in project.characters:
            data += "result.append({}(pos={}, size={}, alive={}))\n\t\t".format(
                str(ch), ch.pos, ch.size, ch.alive)
        content = PythonGenerator.insert_data(content, data,
                                              Generator.SYSTEM_INIT_CHARACTER)
        # DEFINE SPRITE INITIALIZATION
        data = ""
        for sp in project.sprites:
            data += "result.append({}())\n\t\t".format(str(sp))
        content = PythonGenerator.insert_data(content, data,
                                              Generator.SYSTEM_INIT_SPRITE)
        return content

    @classmethod
    def generate_character(cls, project, content):
        for ch in project.characters:
            f_content = content
            file = "{}_character".format(ch.name.lower())
            # CHANGE CHARACTER CLASS
            f_content = PythonGenerator.insert_data(f_content, str(ch),
                                                    Generator.CHARACTER_CLASS)
            # ATTRIBUTE GENERATION
            for att in ch.attributes:
                f_content = PythonGenerator.insert_data(
                    f_content, "self.{}\n\t".format(str(att)),
                    Generator.CHARACTER_ATTR)
            PythonGenerator.save_generated_content(project.name, file,
                                                   f_content)

    @classmethod
    def remove_tag(cls, content, tag):
        try:
            i = content.index(tag)
            content = content[:i] + content[i + len(tag):]
        except ValueError as ex:
            PythonGenerator.__LOGGER.debug(
                "No tag found to remove [{}]".format(tag))
        return content

    @classmethod
    def insert_data(cls, content, data, tag):
        try:
            i = content.index(tag)
            content = PythonGenerator.remove_tag(
                (content[:i] + data + content[i:]), tag)
        except ValueError as ex:
            PythonGenerator.__LOGGER(
                "Tag [{}] not found to insert data".format(tag))
        return content

    @classmethod
    def save_generated_content(cls, project_name, file, content):
        path = "{}{}\\{}\\{}\\{}\\".format(TemplateManager.ROOT_PATH, "out",
                                           project_name, "src", project_name)
        f_path = "{}{}.py".format(path, file)
        with open(f_path, "w+") as f:
            f.write(content)
예제 #27
0
def main():
    #   ----------------- Setting initial variables Section -----------------
    logger = get_logger(FLAGS.PoPS_dir + "/PoPS_ITERATIVE")
    logger.info(" ------------- START: -------------")
    logger.info("Setting initial data structures")
    accuracy_vs_size = [[], []]
    logger.info("Loading models")
    teacher = CartPoleDQNTarget(input_size=dense_config.input_size,
                                output_size=dense_config.output_size)
    teacher.load_model(path=FLAGS.teacher_path)  # load teacher
    logger.info("----- evaluating teacher -----")
    print("----- evaluating teacher -----")
    teacher_score = evaluate(agent=teacher, n_epoch=FLAGS.eval_epochs)
    logger.info("----- teacher evaluated with {} ------".format(teacher_score))
    print("----- teacher evaluated with {} -----".format(teacher_score))
    prune_step_path = FLAGS.PoPS_dir + "/prune_step_"
    policy_step_path = FLAGS.PoPS_dir + "/policy_step_"
    initial_path = policy_step_path + "0"
    logger.info(
        "creating policy step 0 model, which is identical in size to the original model"
    )
    copy_weights(
        output_path=initial_path,
        teacher_path=FLAGS.teacher_path)  # inorder to create the initial model
    compressed_agent = StudentCartpole(
        input_size=student_config.input_size,
        output_size=student_config.output_size,
        model_path=initial_path,
        tau=student_config.tau,
        pruning_freq=student_config.pruning_freq,
        sparsity_end=student_config.sparsity_end,
        target_sparsity=student_config.target_sparsity)
    compressed_agent.load_model()
    initial_size = compressed_agent.get_number_of_nnz_params()
    accuracy_vs_size[0].append(initial_size)
    accuracy_vs_size[1].append(teacher_score)
    initial_number_of_params_at_each_layer = compressed_agent.get_number_of_nnz_params_per_layer(
    )
    initial_number_of_nnz = sum(initial_number_of_params_at_each_layer)
    converge = False
    iteration = 0
    convergence_information = deque(maxlen=2)
    convergence_information.append(100)
    precent = 100
    arch_type = 0
    last_measure = initial_size
    while not converge:
        iteration += 1
        print("-----  Pruning Step {} -----".format(iteration))
        logger.info(" -----  Pruning Step {} -----".format(iteration))
        path_to_save_pruned_model = prune_step_path + str(iteration)
        #   ----------------- Pruning Section -----------------
        if arch_type == 2:
            arch_type = 3  # special arch_type for prune-oriented learning rate
        sparsity_vs_accuracy = iterative_pruning_policy_distilliation(
            logger=logger,
            agent=compressed_agent,
            target_agent=teacher,
            iterations=FLAGS.iterations,
            config=student_config,
            best_path=path_to_save_pruned_model,
            arch_type=arch_type,
            lower_bound=student_config.LOWER_BOUND,
            accumulate_experience_fn=accumulate_experience_cartpole,
            evaluate_fn=evaluate,
            objective_score=student_config.OBJECTIVE_SCORE)
        plot_graph(data=sparsity_vs_accuracy,
                   name=FLAGS.PoPS_dir +
                   "/initial size {}%,  Pruning_step number {}".format(
                       precent, iteration),
                   figure_num=iteration)

        # loading model which has reasonable score with the highest sparsity
        compressed_agent.load_model(path_to_save_pruned_model)
        #   ----------------- Measuring redundancy Section -----------------
        # the amount of parameters that are not zero at each layer
        nnz_params_at_each_layer = compressed_agent.get_number_of_nnz_params_per_layer(
        )
        # the amount of parameters that are not zero
        nnz_params = sum(nnz_params_at_each_layer)
        # redundancy is the parameters we dont need, nnz_params / initial is the params we need the opposite
        redundancy = (1 - nnz_params / initial_number_of_nnz) * 100
        print(
            "-----  Pruning Step {} finished, got {}% redundancy in net params -----"
            .format(iteration, redundancy))
        logger.info(
            "-----  Pruning Step {} finished , got {}% redundancy in net params -----"
            .format(iteration, redundancy))
        logger.info(
            "-----  Pruning Step {} finished with {} NNZ params at each layer".
            format(iteration, nnz_params_at_each_layer))
        print(
            " -----  Evaluating redundancy at each layer Step {}-----".format(
                iteration))
        logger.info(
            " -----  Evaluating redundancy at each layer Step {} -----".format(
                iteration))
        redundancy_at_each_layer = calculate_redundancy(
            initial_nnz_params=initial_number_of_params_at_each_layer,
            next_nnz_params=nnz_params_at_each_layer)
        logger.info(
            "----- redundancy for each layer at step {} is {} -----".format(
                iteration, redundancy_at_each_layer))
        if iteration == 1:
            redundancy_at_each_layer = [
                0.83984375, 0.8346405029296875, 0.83795166015625, 0.83984375
            ]
        #   ----------------- Policy distillation Section -----------------
        print(
            " -----  Creating Model with size according to the redundancy at each layer ----- "
        )
        logger.info(
            "----- Creating Model with size according to the redundancy at each layer -----"
        )
        policy_distilled_path = policy_step_path + str(iteration)
        # creating the compact model where every layer size is determined by the redundancy measure
        compressed_agent = StudentCartpole(
            input_size=student_config.input_size,
            output_size=student_config.output_size,
            model_path=policy_distilled_path,
            tau=student_config.tau,
            redundancy=redundancy_at_each_layer,
            pruning_freq=student_config.pruning_freq,
            sparsity_end=student_config.sparsity_end,
            target_sparsity=student_config.target_sparsity,
            last_measure=last_measure)
        nnz_params_at_each_layer = compressed_agent.get_number_of_nnz_params_per_layer(
        )
        logger.info(
            "-----  Step {} ,Created Model with {} NNZ params at each layer".
            format(iteration, nnz_params_at_each_layer))
        iterative_size = compressed_agent.get_number_of_nnz_params()
        last_measure = iterative_size
        precent = (iterative_size / initial_size) * 100
        convergence_information.append(precent)
        print(
            " ----- Step {}, Created Model with size {} which is {}% from original size ----- "
            .format(iteration, iterative_size, precent))
        logger.info(
            "----- Created Model with size {} which is {}% from original size -----"
            .format(iterative_size, precent))
        # scheduling the right learning rate for the size of the model
        if precent > 40:
            arch_type = 0
        elif 10 <= precent <= 40:
            arch_type = 1
        else:
            arch_type = 2
        print(" -----  policy distilling Step {} ----- ".format(iteration))
        logger.info("----- policy distilling Step {} -----".format(iteration))
        fit_supervised(logger=logger,
                       arch_type=arch_type,
                       student=compressed_agent,
                       teacher=teacher,
                       n_epochs=FLAGS.n_epoch,
                       evaluate_fn=evaluate,
                       accumulate_experience_fn=accumulate_experience_cartpole,
                       lower_score_bound=student_config.LOWER_BOUND,
                       objective_score=student_config.OBJECTIVE_SCORE)

        policy_distilled_score = evaluate(agent=compressed_agent,
                                          n_epoch=FLAGS.eval_epochs)
        compressed_agent.reset_global_step()
        print(
            " -----  policy distilling Step {} finished  with score {} ----- ".
            format(iteration, policy_distilled_score))
        logger.info(
            "----- policy distilling Step {} finished with score {}  -----".
            format(iteration, policy_distilled_score))
        # checking convergence
        converge = check_convergence(convergence_information)
        # for debugging purposes
        accuracy_vs_size[0].append(iterative_size)
        accuracy_vs_size[1].append(policy_distilled_score)

    plot_graph(data=accuracy_vs_size,
               name=FLAGS.PoPS_dir + "/accuracy_vs_size",
               figure_num=iteration + 1,
               xaxis='NNZ params',
               yaxis='Accuracy')
예제 #28
0
파일: dsmm_trainer.py 프로젝트: yyht/simnet
def train(config):
    model_config_path = config["model_config_path"]
    FLAGS = namespace_utils.load_namespace(model_config_path)

    os.environ["CUDA_VISIBLE_DEVICES"] = config.get("gpu_id", "")
    train_path = config["train_path"]
    w2v_path = config["w2v_path"]
    vocab_path = config["vocab_path"]
    dev_path = config["dev_path"]
    elmo_w2v_path = config.get("elmo_w2v_path", None)

    model_dir = config["model_dir"]
    model_name = config["model"]

    model_dir = config["model_dir"]
    try:
        model_name = FLAGS["output_folder_name"]
    except:
        model_name = config["model"]
    print(model_name, "====model name====")

    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    if not os.path.exists(os.path.join(model_dir, model_name)):
        os.mkdir(os.path.join(model_dir, model_name))

    if not os.path.exists(os.path.join(model_dir, model_name, "logs")):
        os.mkdir(os.path.join(model_dir, model_name, "logs"))

    if not os.path.exists(os.path.join(model_dir, model_name, "models")):
        os.mkdir(os.path.join(model_dir, model_name, "models"))

    logger = logger_utils.get_logger(os.path.join(model_dir, model_name, "logs","log.info"))
    FLAGS.vocab_path = vocab_path
    json.dump(FLAGS, open(os.path.join(model_dir, model_name, "logs", model_name+".json"), "w"))

    [train_anchor, 
    train_check, 
    train_label, 
    train_anchor_len, 
    train_check_len, 
    embedding_info] = prepare_data(train_path, 
                        w2v_path, vocab_path,
                        make_vocab=True,
                        elmo_w2v_path=elmo_w2v_path,
                        elmo_pca=FLAGS.elmo_pca,
                        data_type=config["data_type"])

    [dev_anchor, 
    dev_check, 
    dev_label, 
    dev_anchor_len, 
    dev_check_len, 
    embedding_info] = prepare_data(dev_path, 
                        w2v_path, vocab_path,
                        make_vocab=False,
                        elmo_w2v_path=elmo_w2v_path,
                        elmo_pca=FLAGS.elmo_pca,
                        data_type=config["data_type"])

    token2id = embedding_info["token2id"]
    id2token = embedding_info["id2token"]
    embedding_mat = embedding_info["embedding_matrix"]
    extral_symbol = embedding_info["extra_symbol"]

    logger.info("==vocab size {}".format(len(token2id)))
    logger.info("vocab path {}".format(vocab_path))

    FLAGS.token_emb_mat = embedding_mat
    FLAGS.char_emb_mat = 0
    FLAGS.vocab_size = embedding_mat.shape[0]
    FLAGS.char_vocab_size = 0
    FLAGS.emb_size = embedding_mat.shape[1]
    FLAGS.extra_symbol = extral_symbol

    if FLAGS.apply_elmo:
        FLAGS.elmo_token_emb_mat = embedding_info["elmo"]
        FLAGS.elmo_vocab_size = embedding_info["elmo"].shape[0]
        FLAGS.elmo_emb_size = embedding_info["elmo"].shape[1]

    if FLAGS.scope == "BCNN":
        model = BCNN()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "ABCNN1":
        model = ABCNN1()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "ABCNN2":
        model = ABCNN2()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "ABCNN3":
        model = ABCNN3()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "MatchPyramid":
        model = MatchPyramid()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "GMatchPyramid":
        model = GMatchPyramid()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "DSSM":
        model = DSSM()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "CDSSM":
        model = CDSSM()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "RDSSM":
        model = RDSSM()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "DecAtt":
        model = DecAtt()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "DSMM_ESIM":
        model = ESIM()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "DSMM":
        model = DSMM()
        total_max_len = FLAGS.max_seq_len_word
    elif FLAGS.scope == "DIIN":
        model = DIIN()
        total_max_len = FLAGS.max_seq_len_word
    else:
        total_max_len = None
    # elif FLAGS.scope == "UniversalTransformer":
    #     model = UniversalTransformer()
    # elif FLAGS.scope == "DRCN":
    #     model = DRCN()
    # elif FLAGS.scope == "RepresentationModel":
    #     model = RepresentationModel()

    model.build_placeholder(FLAGS)
    model.build_op()
    model.init_step()

    print("========begin to train=========")

    best_dev_accuracy, best_dev_loss, best_dev_f1 = 0, 100, 0

    cnt = 0
    toleration_cnt = 0
    toleration = 10
    for epoch in range(FLAGS.max_epochs):
        train_loss, train_accuracy = 0, 0
        train_data = get_batch_data.get_batches(train_anchor, 
            train_check, 
            train_label, FLAGS.batch_size, 
            token2id, is_training=True,
            total_max_len=total_max_len)

        cnt = 0
        train_accuracy_score, train_precision_score, train_recall_score = 0, 0 ,0
        train_label_lst, train_true_lst = [], []
        
        for index, corpus in enumerate(train_data):
            anchor, check, label = corpus
            Q = input_dict_formulation(anchor, check, label)
            # try:
            [loss, _, global_step, 
            accuracy, preds] = model.step(
                                Q, is_training=True, 
                                symmetric=False)

            train_label_lst += np.argmax(preds, axis=-1).tolist()
            train_true_lst += label.tolist()

            train_loss += loss*anchor.shape[0]
            train_accuracy += accuracy*anchor.shape[0]
            cnt += anchor.shape[0]
            # except:
            #     continue

        train_loss /= float(cnt)

        train_accuracy = accuracy_score(train_true_lst, train_label_lst)
        train_recall = recall_score(train_true_lst, train_label_lst, average='macro')
        train_precision = precision_score(train_true_lst, train_label_lst, average='macro')
        train_f1 = f1_score(train_true_lst, train_label_lst, average='macro')

        # [train_precision, 
        # train_recall, 
        # train_f1] = evaluate(train_label_lst, train_true_lst, 1)

        info = OrderedDict()
        info["epoch"] = str(epoch)
        info["train_loss"] = str(train_loss)
        info["train_accuracy"] = str(train_accuracy)
        info["train_f1"] = str(train_f1)

        logger.info("epoch\t{}\ttrain\tloss\t{}\taccuracy\t{}\tf1\t{}".format(epoch, train_loss, 
                                                                train_accuracy, train_f1))

        dev_data = get_batch_data.get_batches(dev_anchor, 
            dev_check, 
            dev_label, FLAGS.batch_size, 
            token2id, is_training=False,
            total_max_len=total_max_len)

        dev_loss, dev_accuracy = 0, 0
        cnt = 0
        dev_label_lst, dev_true_lst = [], []
        for index, corpus in enumerate(dev_data):
            anchor, check, label = corpus
            Q = input_dict_formulation(anchor, check, label)
            try:
                [loss, logits, 
                pred_probs, accuracy] = model.infer(
                                    Q, mode="test",
                                    is_training=False, 
                                    symmetric=False)

                dev_label_lst += np.argmax(pred_probs, axis=-1).tolist()
                dev_true_lst += label.tolist()

                dev_loss += loss*anchor.shape[0]
                dev_accuracy += accuracy*anchor.shape[0]
                cnt += anchor.shape[0]
            except:
                continue
           
        dev_loss /= float(cnt)

        dev_accuracy = accuracy_score(dev_true_lst, dev_label_lst)
        dev_recall = recall_score(dev_true_lst, dev_label_lst, average='macro')
        dev_precision = precision_score(dev_true_lst, dev_label_lst, average='macro')
        dev_f1 = f1_score(dev_true_lst, dev_label_lst, average='macro')

        info["dev_loss"] = str(dev_loss)
        info["dev_accuracy"] = str(dev_accuracy)
        info["dev_f1"] = str(dev_f1)

        logger.info("epoch\t{}\tdev\tloss\t{}\taccuracy\t{}\tf1\t{}".format(epoch, dev_loss, 
                                                        dev_accuracy, dev_f1))

        if dev_f1 > best_dev_f1 or dev_loss < best_dev_loss:
            timestamp = str(int(time.time()))
            model.save_model(os.path.join(model_dir, model_name, "models"), model_name+"_{}_{}_{}".format(timestamp, dev_loss, dev_f1))
            best_dev_f1 = dev_f1
            best_dev_loss = dev_loss

            toleration_cnt = 0

            info["best_dev_loss"] = str(dev_loss)
            info["dev_f1"] = str(dev_f1)

            logger_utils.json_info(os.path.join(model_dir, model_name, "logs", "info.json"), info)
            logger.info("epoch\t{}\tbest_dev\tloss\t{}\tbest_accuracy\t{}\tbest_f1\t{}".format(epoch, dev_loss, 
                                                          dev_accuracy, best_dev_f1))
        else:
            toleration_cnt += 1
            if toleration_cnt == toleration:
                toleration_cnt = 0
예제 #29
0
class CommsUtils(Utils):
    HOST = "http://127.0.0.1"
    PORT = 8001
    ROOT = '/api'
    ROOT_PATH = "{}:{}{}".format(HOST, PORT, ROOT)
    __LOGGER = logger_utils.get_logger(__name__)

    @classmethod
    def get(cls, path):
        """Description: GET method request data from API and returns JSON response

        :param path: API endpoint
        :return: None - if error occurred OR dictionary of the parsed JSON response
        """
        data = None
        r = requests.get("{}{}".format(CommsUtils.ROOT_PATH, path))
        if r.status_code == HTTPStatus.OK:
            data = r.json()
        else:
            CommsUtils.__LOGGER.error(
                "Something went wrong while making request to the API: {}".
                format(r.status_code))
        return data

    @classmethod
    def post(cls, path, data_send):
        """Descriptions: POST method sends data to the API and receives action response back

        :param path: API endpoint
        :param data_send: Dictionary of data to send
        :return: None - if error occurred OR dictionary of the parsed JSON response
        """
        data = None
        r = requests.post("{}{}".format(CommsUtils.ROOT_PATH, path),
                          json=data_send)
        if r.status_code == HTTPStatus.OK:
            data = r.json()
        else:
            CommsUtils.__LOGGER.error(
                "Something went wrong while making request to the API: {}".
                format(r.status_code))

        return data

    @classmethod
    def put(cls, path, data_send):
        """Description: PUT method updates already existing data in the API and receives
        request action response

        :param path: API endpoint
        :param data_send: Dictionary of data to send
        :return: None - if error occurred OR dictionary of the parsed JSON response
        """
        data = None
        r = requests.put("{}{}".format(CommsUtils.ROOT_PATH, path),
                         json=data_send)
        if r.status_code == HTTPStatus.OK:
            data = r.json()
        else:
            CommsUtils.__LOGGER.error(
                "Something went wrong while making request to the API: {}".
                format(r.status_code))
        return data

    @classmethod
    def build_project_model(cls,
                            name,
                            api,
                            characters=None,
                            attributes=None,
                            functions=None,
                            sprites=None):
        """Description: Method build JSON object that is understandable on the API endpoint

        :param name: Project name
        :param api: Project API used
        :param characters: List of characters
        :param attributes: List of attributes
        :param functions: List of functions
        :param sprites: List of sprites
        :return: parsed JSON object
        """
        r = dict()
        r["NAME"] = name
        r["API"] = api
        if characters is None:
            r["CHARACTERS"] = list()
        else:
            ls = list()
            for c in characters:
                ls.append(c.to_dict())
            r["CHARACTERS"] = ls
        if attributes is None:
            r["ATTRIBUTES"] = list()
        else:
            ls = list()
            for a in attributes:
                ls.append(a.to_dict())
            r["ATTRIBUTES"] = ls
        if functions is None:
            r["FUNCTIONS"] = list()
        else:
            ls = list()
            for f in functions:
                ls.append(f.to_dict())
            r["FUNCTIONS"] = ls
        if sprites is None:
            r["SPRITES"] = list()
        else:
            ls = list()
            for s in sprites:
                ls.append(s.to_dict())
            r["SPRITES"] = ls
        return json.dumps(r)

    @classmethod
    def download_project(cls, name):
        compressions = {"Darwin": "", "Windows": "zip", "Linux": "tar"}
        ext = compressions.get(platform.system())
        path = "{}/python/download/{}/{}".format(CommsUtils.ROOT_PATH, name,
                                                 ext)
        resp = requests.get(path, allow_redirects=True)
        if resp.status_code == HTTPStatus.OK:
            out_path = "{}{}\\out".format(ProjectManager.PATH, name)
            zip_path = "{}\\{}.{}".format(out_path, name, ext)
            if not os.path.exists(out_path):
                os.mkdir(out_path)
            else:
                if os.path.exists(zip_path):
                    os.remove(zip_path)

            if os.path.exists("{}{}".format(ProjectManager.PATH, name)):
                with open(zip_path, "wb+") as f:
                    f.write(resp.content)
                zip_file = zipfile.ZipFile(zip_path, "r")
                zip_file.extractall(
                    path="{}{}\\out\\".format(ProjectManager.PATH, name))
                zip_file.close()

                os.remove(zip_path)
            else:
                CommsUtils.__LOGGER.error(
                    "Project [{}] directory does not exists".format(name))
                raise FileNotFoundError("Failed to save generated source code")
        else:
            CommsUtils.__LOGGER.error(
                "Failed to download project [{}] from [{}]".format(name, path))
            raise DownloadError("Failed to download project [{}]".format(name))
        return Status.SUCCESS
예제 #30
0
import os
import sys
import subprocess

from utils import logger_utils
logger = logger_utils.get_logger(__name__)

orchestrator = sys.argv[1]

logger.info(f"Started Executing the orchestrator {orchestrator}.py")
print(f"Started Executing the orchestrator {orchestrator}.py")
# os.system(f"python ./orchestrators/{orchestrator}.py")
orhcestrator_execution = subprocess.Popen(
    ["python", f"./orchestrators/{orchestrator}.py"],
    stdin=subprocess.PIPE,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE)
output, errors = orhcestrator_execution.communicate()
orhcestrator_execution.wait()
print(output)
print(errors)
logger.info(f"Ended Executing the orchestrator {orchestrator}.py")
print(f"Ended Executing the orchestrator {orchestrator}.py")

# print(os.curdir)
# print('getcwd:      ', os.getcwd())
# print('__file__:    ', __file__)