Ejemplo n.º 1
0
    def test_configs_load(self):
        ''' Make sure configs are loadable '''

        cfg_root_path = utils.get_config_root_path()
        files = glob.glob(os.path.join(cfg_root_path, "./**/*.yaml"),
                          recursive=True)
        self.assertGreater(len(files), 0)

        for fn in files:
            print('Loading {}...'.format(fn))
            utils.load_config_from_file(fn)
Ejemplo n.º 2
0
def _test_build_detectors(self, device):
    ''' Make sure models build '''

    cfg_files = get_config_files(None, EXCLUDED_FOLDERS)
    self.assertGreater(len(cfg_files), 0)

    for cfg_file in cfg_files:
        with self.subTest(cfg_file=cfg_file):
            print('Testing {}...'.format(cfg_file))
            cfg = utils.load_config_from_file(cfg_file)
            create_model(cfg, device)
Ejemplo n.º 3
0
def _test_run_selected_detectors(self, cfg_files, device):
    ''' Make sure models build and run '''
    self.assertGreater(len(cfg_files), 0)

    for cfg_file in cfg_files:
        with self.subTest(cfg_file=cfg_file):
            print('Testing {}...'.format(cfg_file))
            cfg = utils.load_config_from_file(cfg_file)
            cfg.MODEL.RPN.POST_NMS_TOP_N_TEST = 10
            cfg.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 10
            model = create_model(cfg, device)
            inputs = create_random_input(cfg, device)
            model.eval()
            output = model(inputs)
            self.assertEqual(len(output), len(inputs.image_sizes))
Ejemplo n.º 4
0
def create(name):
    cfg = util.load_config_from_file(name)
    l=[ obj.Struct(**node.parsing(i.strip())) for i in cfg.split('\n')]
    cfg_new = node.gen_graph(l)
    node.show_node(cfg_new,debug=True)
    device = model.DEVICE()

    for i in cfg_new:
        if i.TYPE==ast.AST.hostname:
            device.hostname = i.RGX[1];
        elif i.TYPE==ast.AST.ip_default_gateway:
            device.default_gateway = i.RGX[1]
        elif i.TYPE==ast.AST.svi:
            mng = i.Node[0]
            device.mng_ip, device.mng_mask = mng.RGX[2], mng.RGX[3]
            device.mng_int_vlan = i.RGX[1]
        elif i.TYPE==ast.AST.interface:
            port = device.getport(int(i.RGX[1]))
            switchport_allowed_tagged, switchport_allowed_untagged, switchport_native, description = None, None, None, None
            for j in i.Node:
                if j.TYPE ==ast.AST.description:
                    port.setdescription(j.RGX[1])
                elif j.TYPE==ast.AST.switchport_allowed_tagged:
                    port.addtagged( j.VLAN )
                elif j.TYPE==ast.AST.switchport_allowed_untagged:
                    port.adduntagged( j.VLAN )
                elif j.TYPE==ast.AST.description:
                    port.setdescription( j.RGX[1] )
                elif j.TYPE==ast.AST.shutdown:
                    port.setdown();
        elif i.TYPE==ast.AST.vlan:
            if i.RGX[1] not in ['1','4093']:
                device.addvlan(model.VLAN(i.RGX[4],i.RGX[1]))
        elif i.TYPE==ast.AST.ip_igmp_snooping:
            if i.RGX[4] not in device.igmp_snooping:
                device.igmp_snooping.append(i.RGX[4]);
        elif i.TYPE==ast.AST.ip_dhcp_snooping:
            if i.RGX[4] not in device.dhcp_snooping:
                device.dhcp_snooping.append(i.RGX[4]);
Ejemplo n.º 5
0
def get_args():
    # priority: config file > default settings
    if 'config.json' not in os.listdir('./'):
        print(colorama.Fore.LIGHTRED_EX + '未找到配置文件 config.json!')
        input('按 Enter 键退出')
        sys.exit(1)
    if 'course_list.txt' not in os.listdir('./'):
        print(colorama.Fore.LIGHTRED_EX + '未找到配置文件 course_list.txt!')
        input('按 Enter 键退出')
        sys.exit(1)
    config = load_config_from_file()
    logging.debug('Config loaded!')
    reload_course = config.get('reload') if 'reload' in config else True
    usn = config.get('username')
    pwd = config.get('password')
    id_list = config.get('course_id')
    wait = config.get('wait') if 'wait' in config else True

    if not usn or not pwd:
        if not usn:
            print(colorama.Fore.LIGHTRED_EX + '错误: 学号为空')
        if not pwd:
            print(colorama.Fore.LIGHTRED_EX + '错误: 密码为空')
        logging.critical('No username or no password provided, exiting')
        sys.exit(1)
    if id_list is None:
        print(colorama.Fore.LIGHTRED_EX + '错误: 必须输入课程ID列表')
        logging.critical('No course id list provided in batch mode')
        sys.exit(1)
    return {
        'reload_course': reload_course,
        'usn': usn,
        'pwd': pwd,
        'id_list': id_list,
        'wait': wait
    }
Ejemplo n.º 6
0
if __name__ == "__main__":
    args = docopt(__doc__)
    lang_key = args['LANGUAGE-KEY']

    if args['--data-dir'] is None:
        os.environ["DATA_DIR"] = os.path.join(os.environ["DATA_DIR"], lang_key)
    else:
        os.environ["DATA_DIR"] = args['--data-dir']

    if args['--output-dir'] is None:
        os.environ["OUT_DIR"] = os.path.join(os.environ["OUT_DIR"], lang_key)
    else:
        os.environ["OUT_DIR"] = args['--output-dir']

    if args['--config'] is not None:
        config = load_config_from_file(args['--config'])
    else:
        from mc import config

    print("Using configuration", config.__file__)

    if args['--test'] is True:
        test_file = config.filename_test
        eval_type = 'test'
    elif args['--dev'] is True:
        test_file = config.filename_dev
        eval_type = 'dev'
    else:
        raise ValueError('Specify --dev or --test.')

    config_holder = ConfigHolder(config)
Ejemplo n.º 7
0
def create(name):
    device = model.DEVICE()
    cfg = util.load_config_from_file(name)
    l = [obj.Struct(**node.parsing(i.strip())) for i in cfg.split('\n')]
    cfg_new = node.gen_graph(l)
    node.show_node(cfg_new)

    for i in cfg_new:
        if i.TYPE == ast.AST.hostname:
            device.hostname = i.RGX[1]
        elif i.TYPE == ast.AST.interface:
            for j in i.Node:
                if j.TYPE == ast.AST.name:
                    port = device.getport(int(i.RGX[2]))
                    port.setdescription(j.RGX[1])
                if j.TYPE == ast.AST.inactive:
                    port = device.getport(int(i.RGX[2]))
                    port.setdown()
        elif i.TYPE == ast.AST.mvr:
            source_port, receiver_port, name, tagged = None, None, None, None
            vlan = model.VLAN(name, i.RGX[1])
            device.addvlan(vlan)
            for j in i.Node:
                if j.TYPE == ast.AST.source_port:
                    source_port = j.PORT
                elif j.TYPE == ast.AST.receiver_port:
                    receiver_port = j.PORT
                elif j.TYPE == ast.AST.name:
                    name = j.RGX[1]
                elif j.TYPE == ast.AST.tagged:
                    tagged = j.PORT
            device.setmvr(model.MVR(i.RGX[1], name))
            for v in source_port:
                port = device.getport(v)
                port.mvr.tag = i.RGX[1]
                port.mvr.receiver_port = True
            for v in receiver_port:
                port = device.getport(v)
                port.mvr.tag = i.RGX[1]
                port.mvr.source_port = True
            for v in tagged:
                port = device.getport(v)
                port.addtagged(i.RGX[1])
        elif i.TYPE == ast.AST.vlan:
            fixed, forbidden, untagged, name = None, None, None, None
            for j in i.Node:
                if j.TYPE == ast.AST.fixed:
                    fixed = j.PORT
                elif j.TYPE == ast.AST.forbidden:
                    forbidden = j.PORT
                elif j.TYPE == ast.AST.untagged:
                    untagged = j.PORT
                elif j.TYPE == ast.AST.name:
                    name = j.RGX[1]
                elif j.TYPE == ast.AST.ip:
                    if ast.default_management in j.RGX:
                        device.mng_ip, device.mng_mask = j.RGX[3], j.RGX[4]
                        device.mng_int_vlan = i.RGX[1]
                    elif ast.default_gateway in j.RGX:
                        device.default_gateway = j.RGX[3]
            device.addvlan(model.VLAN(name, i.RGX[1]))
            for k, _ in fixed.items():
                port = device.getport(k)
                if k in untagged:
                    port.adduntagged(i.RGX[1])
                else:
                    port.addtagged(i.RGX[1])
    print("*" * 45)
    return device
Ejemplo n.º 8
0
    # accuracy = nb_right_sum * 1.0 / nb_samples

    return average_loss, sp, pe, mse_


if __name__ == "__main__":
    train_data, dev_data, test_data, word2idx = utils.load_snli_data()
    word_emb = get_pretrained_embedding("../SNLI/embedding_matrix.pkl",
                                        voc_size=57323)

    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
    # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    sess = tf.Session()

    config = utils.load_config_from_file("config", main_key="snli")
    config = utils.load_config_from_file("config",
                                         main_key="train_snli",
                                         config=config)
    config["word_emb"] = word_emb

    model = models.WSAN(config_=CONFIG)
    sess.run(tf.global_variables_initializer())

    # best_accuracy = [0.0, 0.0]
    best_valid_corr = []
    best_corr = [0.0, 0.0]
    for i in range(nb_epoches):
        if i >= decay_start:
            new_lr = sess.run(model.learning_rate) * lr_decay
            model.assign_lr(sess, new_lr)
Ejemplo n.º 9
0
def create(name):

    device = model.DEVICE()

    cfg = util.load_config_from_file(name)
    l = [obj.Struct(**node.parsing(i.strip())) for i in cfg.split('\n')]
    cfg = node.gen_graph(l)

    #    node.show_node(cfg,debug=True);

    for i in cfg:
        if i.TYPE == ast.AST.hostname:
            device.hostname = i.RGX[1]
        elif i.TYPE == ast.AST.interface_range_ethernet:
            ports = node.parsing_interface(i.RGX[1]).PORT
            if len(i.Node) > 0:
                for j in i.Node:
                    if j.TYPE != None:
                        if j.TYPE == ast.AST.description:
                            for h in ports:
                                port = device.getport(int(h))
                                port.setdescription(j.RGX[1])
                        elif j.TYPE == ast.AST.switchport_trunk_allowed_vlan_add:
                            for h in ports:
                                port = device.getport(int(h))
                                port.addtagged(j.RGX[1])
                        elif j.TYPE == ast.AST.switchport_access:
                            for h in ports:
                                port = device.getport(int(h))
                                port.adduntagged(j.RGX[2])
                        elif j.TYPE == ast.AST.switchport_trunk_native_vlan:
                            for h in ports:
                                port = device.getport(int(h))
                                port.adduntagged(j.RGX[1])
                        elif j.TYPE == ast.AST.switchport_general_pvid:
                            for h in ports:
                                #                                print("   port",h," switchport_general_pvid",j.RGX);
                                port = device.getport(int(h))
                                port.adduntagged(j.RGX[1])
                                port.addgeneral_pid(j.RGX[1])
                        elif j.TYPE == ast.AST.switchport_general_allowed_vlan_add:
                            for h in ports:
                                #                                print("   port",h," ast.AST.switchport_general",j.RGX);
                                port = device.getport(int(h))
                                if len(j.RGX) == 2:
                                    port.addtagged(j.RGX[1])
                                    port.addgeneral_tag(j.RGX[1])
                                elif len(j.RGX) == 3:
                                    port.adduntagged(j.RGX[1])
                                    port.addgeneral_untag(j.RGX[1])
                                else:
                                    print(
                                        "   port range", h,
                                        " ast.AST.switchport_general_allowed_vlan_add",
                                        j.RGX)
        elif i.TYPE == ast.AST.interface_ethernet:
            if len(i.Node) > 0:
                for j in i.Node:
                    if j.TYPE != None:
                        port = device.getport(int(i.RGX[2]))
                        if j.TYPE == ast.AST.description:
                            port.setdescription(j.RGX[1])
                        elif j.TYPE == ast.AST.switchport_general_allowed_vlan_add:
                            if len(j.RGX) == 2:
                                port = device.getport(int(i.RGX[2]))
                                port.addtagged(j.RGX[1])
                            else:
                                port = device.getport(int(i.RGX[2]))
                                port.addtagged(j.RGX[2])
                        elif j.TYPE == ast.AST.switchport_general_pvid:
                            port.adduntagged(j.RGX[1])
                            port.addgeneral_pid(j.RGX[1])


#                            print("   port",i.RGX[2]," switchport_general_pvid",j.RGX);
                        elif j.TYPE == ast.AST.switchport_general:
                            port = device.getport(int(i.RGX[2]))
                            #                            print("   port",i.RGX[2]," ast.AST.switchport_general",j.RGX);
                            if len(j.RGX) == 2:
                                port = device.getport(int(i.RGX[2]))
                                port.addtagged(j.RGX[1])
                                port.addgeneral_tag(j.RGX[1])
                            elif len(j.RGX) == 3:
                                port.adduntagged(j.RGX[1])
                                port.addgeneral_tag(j.RGX[1])
                            else:
                                print("   port", i.RGX[2],
                                      " ast.AST.switchport_general", j.RGX)
        elif i.TYPE == ast.AST.vlan_database:
            if len(i.Node) > 0:
                for j in i.Node:
                    if j.TYPE == ast.AST.vlan:
                        p = node.parsing_port(j.RGX[1]).VLAN
                        for _, i in p.items():
                            device.addvlan(model.VLAN("", i))
        elif i.TYPE == ast.AST.interface_vlan:
            if len(i.Node) > 0:
                for j in i.Node:
                    if j.TYPE != None:
                        if j.TYPE == ast.AST.name:
                            device.setvlanname(i.RGX[1], j.RGX[1])
                        elif j.TYPE == ast.AST.ip_address:
                            device.mng_ip, device.mng_mask = j.RGX[1], j.RGX[2]
                            device.mng_int_vlan = i.RGX[1]
                        elif j.TYPE == ast.AST.ip_igmp_snooping:
                            if i.RGX[1] not in device.igmp_snooping:
                                device.igmp_snooping.append(i.RGX[1])
                        else:
                            print("   ", j.__dict__)
        elif i.TYPE == ast.AST.ip_dhcp_snooping:
            if i.RGX[2] not in device.dhcp_snooping:
                device.dhcp_snooping.append(i.RGX[2])
        elif i.TYPE == ast.AST.ip_default_gateway:
            device.default_gateway = i.RGX[1]

    return device
Ejemplo n.º 10
0
        :param mask: (None, sentence_length)
        :param keep_prob:
        :return:
        '''
        dropouted_mask = tf.nn.dropout(mask, keep_prob=keep_prob)
        outputs = tf.expand_dims(dropouted_mask, axis=-1) * inputs
        
        return outputs
  

    def assign_lr(self, session, new_lr):
        '''
        调整模型的学习速率
        :param session: 所在的会话
        :param new_lr: 新的学习速率
        :return:
        '''
        session.run(self.assign_lr_op, feed_dict={self.new_lr: new_lr})

    def assign_batch_size(self, session, new_batch_size):
        session.run(self.assign_bs_op, feed_dict={self.new_batch_size: new_batch_size})


if __name__ == "__main__":
    config = utils.load_config_from_file("config", "snli")
    logger.info("build model ...")
    wsan = WSAN(config)
    logger.info("build model successfully")