Ejemplo n.º 1
0
def predict(config, args):
    plug = Wire(filename=os.path.join(args.data_dir, 'val.csv'),
                train_mode=True)

    loader = (plug
              | CSVLoaderXYZ(
                  name='xyz', prefix=args.data_pre, **config.xy_splitter)
              | ImageLoaderInference('image_loader', True, **config.loader))

    unet = predictors[config.model_name](**config.model)
    predictor = Predictor('unet_predictor', unet, need_setup=True)
    predictor.setup(path=args.model_path, load_weights=True)

    evaluator = Evaluator('evaluator', predictor.predictor, need_setup=False)
    viewer = PolygonViewer('viewer', save=args.save, job_dir=config.job_dir)
    polygonizer = Polygonizer('polygons')
    generator, steps = loader.generator

    while True:
        x, y, z = generator.next()
        batch = Wire(x=x, y=y)
        prediction = batch | predictor | polygonizer
        creator = ShapefileCreator('shapefile')
        for i in range(len(x)):
            shp_path = os.path.join(
                '/Users/nikhilsaraf/Documents',
                os.path.basename(z[i]).replace('image',
                                               'pred').replace('.jpg', '.shp'))
            creator.transform(polygons=prediction.polygons[i],
                              filename=shp_path,
                              transform=get_transform(
                                  os.path.join(args.data_pre, z[i])))
Ejemplo n.º 2
0
def predict(config, args):
    image, transform, crs = open_image(args.file_path)
    adjusted_image = adjust_image(image)
    tiles = tile_image(adjusted_image, 256, 256)

    unet = UNetModel(**config.model)
    predictor = Predictor('unet_predictor', unet, need_setup=True)
    predictor.setup(path=args.model_path, load_weights=True)

    predictions = Wire(x=tiles, batch_size=5) | predictor

    remade_image = untile_image(predictions.predictions, adjusted_image.shape[0], adjusted_image.shape[1], 1)
    polygons = Wire(predictions=remade_image) | Polygonizer('polygons')
    (Wire(filename=args.file_path[:args.file_path.index('.')] + '.shp', transform=make_transform(transform)) + polygons 
        | ShapefileCreator('shapefile', crs=crs.to_dict()))
Ejemplo n.º 3
0
def bot_message():
    data = request.values.to_dict()

    if 'event' in data:
        event = data['event']
        domain = data['auth[domain]'].split('.bitrix24.')[0]
        lang = data['auth[domain]'].split('.bitrix24.')[1]

        auth_info = PortalAuth.query.filter_by(portal=domain).first()
        bot_app = BotApplication(domain, lang, auth_info.access_token,
                                 auth_info.refresh_token)

        if event == 'ONIMBOTMESSAGEADD':
            in_message = data['data[PARAMS][MESSAGE]']
            chat_id = data['data[PARAMS][FROM_USER_ID]']
            check = bot_app.check_message(in_message)

            if check['command'] == '/about':
                bot_app.send_about(chat_id=chat_id)

            elif check['command'] == '/search':
                bot_app.get_company_by_title(check['content'], chat_id=chat_id)

            elif check['command'] == '/cmp':
                index_app = IndexApplication(domain, lang,
                                             auth_info.access_token,
                                             auth_info.refresh_token)
                data = index_app.get_data(check['content'])
                data_frame = DataParser.get_data_frame(**data)
                predict = Predictor(data_frame).make_predict()
                message = ''
                for item in predict:
                    message += 'Вероятность для компании ' + item[
                        'TITLE'] + ' -- ' + str(item['PREDICT'])
                bot_app.send_message(chat_id=chat_id, message=message)

            else:
                bot_app.send_message(chat_id=chat_id,
                                     message='Привет, ' +
                                     data['data[USER][FIRST_NAME]'] + ' ' +
                                     data['data[USER][LAST_NAME]'] + '!')
                bot_app.send_keyboard(chat_id=chat_id)

        else:
            auth_info.event_counter += 1
            db.session.commit()
            if auth_info.event_counter >= 10:
                bot_app.send_message(
                    message=
                    'Произошло много изменений, пора пересчитать предсказания')
                bot_app.send_keyboard()
                auth_info.event_counter = 0
                db.session.commit()

        bot_app.save_auth()

    return 'true'
Ejemplo n.º 4
0
 def __init__(
         self,
         broker_host: str,
         broker_port: int,
         message_channel: str,
         checkpoint_dir: str,
         checkpoint_prefix: str,
         checkpoint_number: int,
         dictionary: str) -> None:
     """Initialize the MessageProcessingRunner."""
     builder = ModelBuilder(17, (64, 64), 500, 20, 1e-4)
     predictor = Predictor(builder, dictionary, 20)
     predictor.restore_checkpoint(
         checkpoint_dir,
         checkpoint_prefix,
         checkpoint_number)
     self._processing_client = ProcessingClient(
         broker_host, broker_port, message_channel, predictor)
Ejemplo n.º 5
0
def load_best_model(model_dir, model_type="predictor"):
    model_file = model_dir + "/best_model.pt"
    print("Loading model from {}".format(model_file))
    model_opt = torch_utils.load_config(model_file)
    if model_type == "predictor":
        predictor = Predictor(model_opt)
        model = Trainer(model_opt, predictor, model_type=model_type)
    else:
        selector = Selector(model_opt)
        model = Trainer(model_opt, selector, model_type=model_type)
    model.load(model_file)
    helper.print_config(model_opt)
    return model
Ejemplo n.º 6
0
def upload():
    file = request.files.get("file")
    if (file == None):
        return {'message': 'file cannot be empty.'}

    predictor = Predictor()
    session.add(predictor)
    session.commit()

    file.save(LOGS_DIR + "/" + str(predictor.id) + ".zip")

    celery.send_task('wsgi.process_log_files',
                     kwargs={"id": str(predictor.id)})

    return {"id": str(predictor.id)}
Ejemplo n.º 7
0
    def __init__(self, db, model_name):
        group_token = os.environ['GROUP_TOKEN']
        group_id = int(os.environ['GROUP_ID'])
        service_token = os.environ['SERVICE_TOKEN']
        app_id = int(os.environ['APP_ID'])
        client_secret = os.environ['CLIENT_SECRET']

        self.visited = set()

        self.admin_pwd = os.environ['ADMIN_PWD']
        self.new_cats = sorted([
            'физика', 'математика', 'лингвистика', 'информатика', 'литература',
            'химия', 'география', "психология", "обществознание", "история",
            "музыка", "астрономия", "маркетинг", "биология", "спорт",
            "искусство", "бизнес"
        ])

        self.predictor = Predictor(model_name)
        self.db = db
        self.db_session = db.create_session()
        self.group_session = vk_api.VkApi(token=group_token,
                                          api_version='5.126')
        self.service_session = vk_api.VkApi(app_id=app_id,
                                            token=service_token,
                                            client_secret=client_secret)
        self.long_poll = VkBotLongPoll(self.group_session, group_id)
        self.group_api = self.group_session.get_api()
        self.service_api = self.service_session.get_api()

        # For dataset filtering
        self.latest_id = self.db_session.query(db.Groups.group_id).order_by(
            db.Groups.group_id.desc()).first()
        if self.latest_id is None:
            self.latest_id = 0
        else:
            self.latest_id = self.latest_id[0]
Ejemplo n.º 8
0
def get_result():
    index_app = IndexApplication(**get_post())
    cmp_list = request.values.getlist('companies')
    data = index_app.get_data(cmp_list)
    data_frame = DataParser.get_data_frame(**data)
    predict = Predictor(data_frame).make_predict()

    if request.values.get('return_type') == 'frame':
        return render_template('model_predict.html', data=predict)

    if request.values.get('return_type') == 'chat':
        for item in predict:
            message = 'Вероятность для компании ' + item[
                'TITLE'] + ' -- ' + str(item['PREDICT'])
            index_app.send_message(title=item['TITLE'], content=message)
        return render_template('index.html')
Ejemplo n.º 9
0
def register_callbacks(app: Dash):
    '''Register callbacks for the given app argument'''

    predictor = Predictor()

    @app.callback(
        Output("manual-output", "children"),
        [Input("submit-button", "n_clicks")],
        [State("input-calculus1", "value"),
         State("input-discrete_math", "value"),
         State("input-intro_to_it", "value"),
         State("input-prolog", "value"),
         State("input-calculus2", "value"),
         State("input-data_structure", "value"),
         State("input-java", "value"),
         State("input-mis", "value"),
         State("input-stats", "value"),
         State("input-algo_analysis", "value"),
         State("input-computer_arch", "value"),
         State("input-database", "value"),
         State("input-linear_algebra", "value"),
         State("input-oop", "value")]
    )
    def process_data(n_clicks: int,
                     calculus1: int, discrete_math: int, intro_to_it: int, prolog: int,
                     calculus2: int, data_structure: int, java: int, mis: int, stats: int,
                     algo_analysis: int, computer_arch: int, database: int, linear_algebra: int, oop: int,
                     ) -> str:
        '''Process the data from the provided arguments'''
        values = [
            calculus1, discrete_math, intro_to_it, prolog,
            calculus2, data_structure, java, mis, stats,
            algo_analysis, computer_arch, database, linear_algebra, oop
        ]

        if not n_clicks:
            return json.dumps({})

        if any(x == -1 for x in values):
            result = {
                'error': 'All course scores must be non-empty'
            }
            return json.dumps(result)

        scores = [4.0, 3.7, 3.3, 3.0, 2.7, 2.3, 2.0, 1.7, 1.3, 1.0]
        data = predictor.predict(
            calculus1=scores[calculus1],
            discrete_math=scores[discrete_math],
            intro_to_it=scores[intro_to_it],
            prolog=scores[prolog],
            calculus2=scores[calculus2],
            data_structure=scores[data_structure],
            java=scores[java],
            mis=scores[mis],
            stats=scores[stats],
            algo_analysis=scores[algo_analysis],
            computer_arch=scores[computer_arch],
            database=scores[database],
            linear_algebra=scores[linear_algebra],
            oop=scores[oop],
        )
        return json.dumps({'data': data})

    @app.callback(
        [
            Output('concentration-data', 'children'),
            Output('manual-welcome', 'className'),
            Output('manual-error', 'className'),
            Output('manual-error', 'children'),
            Output('manual-deck', 'className'),
        ],
        [Input('manual-output', 'children')]
    )
    def render_output(data: str) -> (str, str, str, str, str):
        payload: dict = json.loads(data)
        error: str = payload.get('error', None)
        data: dict = payload.get('data', None)

        if error:
            return "", "d-none", "d-block", error, "d-none"

        if not data:
            return "", "d-block", "d-none", "", "d-none"

        return json.dumps(data), "d-none", "d-none", "", "d-block"

    @app.callback(
        [
            Output("imdd-recommended", "children"),
            Output("imdd-performance", "children"),
            Output("imdd-header", "color"),
            Output("imdd-header", "className"),
        ],
        [
            Input("concentration-data", "children")
        ]
    )
    def render_imdd_title(data: str) -> (str, str, str, str):
        """
        render_imdd_title is used for rendering the GUI element of the imdd concentration
        """
        if not data:
            return "", "", "", ""

        data = json.loads(data)
        recommended = sorted(data.items(), key=lambda x: x[1], reverse=True)[
            0][0] == "imdd"
        recommended_str = "Recommended" if recommended else ""
        performance = "%.2f %%" % data["imdd"]
        color = "primary" if recommended else "secondary"
        textColor = "text-white" if recommended else "text-primary"
        return recommended_str, performance, color, textColor

    @app.callback(
        [
            Output("mi-recommended", "children"),
            Output("mi-performance", "children"),
            Output("mi-header", "color"),
            Output("mi-header", "className"),
        ],
        [
            Input("concentration-data", "children")
        ]
    )
    def render_mi_title(data: str) -> (str, str, str, str):
        """
        render_mi_title is used for rendering the GUI element of the MI concentration
        """
        if not data:
            return "", "", "", ""

        data = json.loads(data)
        recommended = sorted(data.items(), key=lambda x: x[1], reverse=True)[
            0][0] == "mi"
        recommended_str = "Recommended" if recommended else ""
        performance = "%.2f %%" % data["mi"]
        color = "primary" if recommended else "secondary"
        textColor = "text-white" if recommended else "text-primary"
        return recommended_str, performance, color, textColor

    @app.callback(
        [
            Output("se-recommended", "children"),
            Output("se-performance", "children"),
            Output("se-header", "color"),
            Output("se-header", "className"),
        ],
        [
            Input("concentration-data", "children")
        ]
    )
    def render_se_title(data: str) -> (str, str, str, str):
        """
        render_se_title is used for rendering the GUI element of the SE concentration
        """
        if not data:
            return "", "", "", ""

        data = json.loads(data)
        recommended = sorted(data.items(), key=lambda x: x[1], reverse=True)[
            0][0] == "se"
        recommended_str = "Recommended" if recommended else ""
        performance = "%.2f %%" % data["se"]
        color = "primary" if recommended else "secondary"
        textColor = "text-white" if recommended else "text-primary"
        return recommended_str, performance, color, textColor

    @app.callback(
        [
            Output("manual-input-container", "className"),
            Output("manual-output-container", "className"),
            Output("batch-input-container", "className"),
            Output("batch-output-container", "className"),
        ],
        [Input("tabs", "active_tab")]
    )
    def render_tab_content(tab: str):
        if tab == "playground":
            return "d-block", "d-block", "d-none", "d-none"
        if tab == "batch":
            return "d-none", "d-none", "d-block", "d-block"

        return "d-none", "d-none", "d-none", "d-none"

    @app.callback(
        Output("batch-preview", "children"),
        [Input('upload-data', 'contents')],
        [State('upload-data', 'filename'),
         State('upload-data', 'last_modified')]
    )
    def preview_data(contents, filename, last_modified):
        if not filename:
            return html.P("The preview of your data will appear here")

        try:
            decoded = base64.b64decode(contents.split("base64,")[1])
            df = pd.read_csv(io.StringIO(decoded.decode('utf-8')))
            df = df.head()
            return [
                html.Small("Only the first 5 rows are shown"),
                BatchTable.render(df),
                dbc.Button("Process data",
                            outline=True, color="primary", id="process-button", className="mt-3 mb-5")
            ]
        except Exception as e:
            return dbc.Alert("There is an error in uploading your data", color="danger")


    @app.callback(
        Output("batch-results-container", "children"),
        [Input("process-button", "n_clicks")],
        [State("upload-data", "contents")],
    )
    def process_batch(n_clicks, contents):
        if not n_clicks or n_clicks == 0:
            return html.P("Predicted results for your data will appear here")
        
        try:
            decoded = base64.b64decode(contents.split("base64,")[1])
            df = pd.read_csv(io.StringIO(decoded.decode('utf-8')))
            result_df = predictor.predict_batch(df)
            csv_string = result_df.to_csv(index=False, encoding='utf-8')
            csv_string = "data:text/csv;charset=utf-8," + urllib.parse.quote(csv_string)
            return [
                html.A(
                    "Download results",
                    href=csv_string,
                    download="results.csv",
                    target="_blank",
                    className="btn btn-primary text-white mt-3",
                ),
                BatchTable.render(result_df),
            ]
        except Exception as e:
            return dbc.Alert("There is an error in processing your data: " + str(e), color="danger")
Ejemplo n.º 10
0
    # ====================== #
    # Begin Train on Predictor
    # ====================== #
    print("Training on iteration #%d for dualRE Predictor..." % num_iter)
    opt["model_save_dir"] = opt["p_dir"]
    opt["dropout"] = opt["p_dropout"]

    # save config
    helper.save_config(opt,
                       opt["model_save_dir"] + "/config.json",
                       verbose=True)
    helper.print_config(opt)

    # prediction module
    predictor = Predictor(opt, emb_matrix=TOKEN.vocab.vectors)
    model = Trainer(opt, predictor, model_type="predictor")
    model.train(dataset_train, dataset_dev)

    # Evaluate
    best_model_p = load_best_model(opt["model_save_dir"],
                                   model_type="predictor")
    print("Final evaluation #%d on train set..." % num_iter)
    evaluate(best_model_p, dataset_train, verbose=True)
    print("Final evaluation #%d on dev set..." % num_iter)
    dev_f1 = evaluate(best_model_p, dataset_dev, verbose=True)[2]
    print("Final evaluation #%d on test set..." % num_iter)
    test_f1 = evaluate(best_model_p, dataset_test, verbose=True)[2]
    dev_f1_iter.append(dev_f1)
    test_f1_iter.append(test_f1)
    best_model_p = load_best_model(opt["p_dir"], model_type="predictor")
Ejemplo n.º 11
0
    crit_lm = nn.CrossEntropyLoss()
    optimizer = None

    if args.gpu >= 0:
        if_cuda = True
        torch.cuda.set_device(args.gpu)
        ner_model.cuda()
        packer = CRFRepack_WC(len(tag2idx), True)
    else:
        if_cuda = False
        packer = CRFRepack_WC(len(tag2idx), False)


    # init the predtor and evaltor
    # predictor 
    predictor = Predictor(tag2idx, packer, label_seq = True, batch_size = 50)
    
    # evaluator       
    evaluator = Evaluator(predictor, packer, tag2idx, args.eva_matrix, args.pred_method)

    agent = Trainer(ner_model, packer, crit_ner, crit_lm, optimizer, evaluator, crf2corpus)
    
    # perform the evalution for dev and test set of training corpus
    if args.local_eval:
        # assert len(train_args['dev_file']) == len(train_args['test_file'])
        num_corpus = len(train_args['dev_file'])


        # construct the pred and eval dataloader
        dev_tokens = []
        dev_labels = []
Ejemplo n.º 12
0
class Bot:
    def __init__(self, db, model_name):
        group_token = os.environ['GROUP_TOKEN']
        group_id = int(os.environ['GROUP_ID'])
        service_token = os.environ['SERVICE_TOKEN']
        app_id = int(os.environ['APP_ID'])
        client_secret = os.environ['CLIENT_SECRET']

        self.visited = set()

        self.admin_pwd = os.environ['ADMIN_PWD']
        self.new_cats = sorted([
            'физика', 'математика', 'лингвистика', 'информатика', 'литература',
            'химия', 'география', "психология", "обществознание", "история",
            "музыка", "астрономия", "маркетинг", "биология", "спорт",
            "искусство", "бизнес"
        ])

        self.predictor = Predictor(model_name)
        self.db = db
        self.db_session = db.create_session()
        self.group_session = vk_api.VkApi(token=group_token,
                                          api_version='5.126')
        self.service_session = vk_api.VkApi(app_id=app_id,
                                            token=service_token,
                                            client_secret=client_secret)
        self.long_poll = VkBotLongPoll(self.group_session, group_id)
        self.group_api = self.group_session.get_api()
        self.service_api = self.service_session.get_api()

        # For dataset filtering
        self.latest_id = self.db_session.query(db.Groups.group_id).order_by(
            db.Groups.group_id.desc()).first()
        if self.latest_id is None:
            self.latest_id = 0
        else:
            self.latest_id = self.latest_id[0]

    def send_message(self,
                     user_id: int,
                     message: str,
                     keyboard: str = None) -> None:
        """
        sends a message to user using method messages.send
        (https://vk.com/dev/messages.send)

        :param user_id: recipient user ID
        :param message: message text
        :param keyboard: json describing keyboard attached with message
        :return: None
        """
        self.group_api.messages.send(user_id=user_id,
                                     random_id=get_random_id(),
                                     message=message,
                                     keyboard=keyboard)
        print(f'<-- message {message[:30]}{"..." if len(message) > 30 else ""}'
              f' to {user_id} has been sent')

    def get_posts(self,
                  owner_id: int,
                  count: int = 1) -> Union[List[dict], dict]:
        """
        gets posts from user's or group's wall using method wall.get
        (https://vk.com/dev/wall.get)

        :param owner_id: wall's owner ID
        :param count: count of posts
        :return: list of dictionaries of dictionary, describing post
        """
        posts = self.service_api.wall.get(owner_id=owner_id, count=count)
        print(f'group {owner_id} posts received')
        try:
            if len(posts['items']) > 1:
                return posts['items']
            else:
                return posts['items'][0]
        except IndexError:
            print(f'error: {owner_id} {posts}')

    def get_subscriptions(self, user_id: int, count=100) -> List[int]:
        """
        gets user's subscriptions using method users.getSubscriptions
        (https://vk.com/dev/users.getSubscriptions)

        :param user_id: user ID
        :param count: get random {count} groups
        :return: list of numbers defining user IDs
        """
        subscriptions = self.service_api.users.getSubscriptions(
            user_id=user_id, extended=1)
        print(f'received subscriptions from '
              f'{"user" if user_id > 0 else "group"} {abs(user_id)}')
        ids = [
            i['id'] for i in subscriptions['items']
            if not i['is_closed'] and 'type' in i and 'deactivated' not in i
        ]
        return ids if len(ids) <= count else sample(ids, count)

    def get_group_info(
        self, group_id: int
    ) -> Union[Dict[str, Union[str, int]], List[Dict[str, Union[str, int]]]]:
        """
        gets information about one or more groups using method groups.getById
        (https://vk.com/dev/groups.getById)

        :param group_id: group ID
        :return: list of dictionaries of dictionary, describing information
        about group
        """
        info = self.service_api.groups.getById(group_id=group_id)
        print(f'received info from {group_id}')
        if len(info) == 1:
            return info[0]
        else:
            return info

    def listen(self) -> None:
        """
        gets updates from server and handling them
        :return: None
        """
        for event in self.long_poll.listen():
            if event.type == VkBotEventType.MESSAGE_NEW:
                self.process_new_message(event)

    def process_new_message(self, event):
        from_id = event.object['message']['from_id']
        cmd = event.object['message']['text']
        print(f'--> {from_id} sent "{cmd}"')

        payload = json.loads(event.object['message'].get('payload', '{}'))

        if payload.get('button') == 'start_analysis':
            self.command_start_analysis(from_id)
        elif ('button' in payload
              and 'show_recommendation' in payload['button']):
            self.command_show_recommendation(from_id, payload)
        elif ''.join(filter(str.isalpha, cmd.lower())) == self.admin_pwd:
            self.command_admin(from_id)
        elif ('button' in payload and 'dataset_filter' in payload['button']):
            self.command_dataset_filter(from_id, payload)
        else:
            self.command_start(from_id)

    def command_start(self, from_id):

        keyboard = VkKeyboard(one_time=True)
        keyboard.add_button('Начать анализ',
                            color=VkKeyboardColor.POSITIVE,
                            payload=json.dumps({'button': 'start_analysis'}))
        msg = ('Здравствуйте, я - Виталя, бот-рекомендатор. Я помогу вам '
               'определить ваши интересы и подскажу, где найти ещё больше '
               'полезных групп ВКонтакте. Начнём анализ?')
        user = self.db_session.query(self.db.UserStatuses).filter(
            self.db.UserStatuses.user_id == from_id).first()
        if user and user.subjects:
            keyboard.add_button('Перейти к рекомендациям',
                                color=VkKeyboardColor.SECONDARY,
                                payload=json.dumps(
                                    {'button': 'show_recommendation_1'}))
            msg = ('С возвращением! Желаете провести анализ снова или '
                   'посмотреть, что я рекомендовал вам в прошлый раз?'
                   if from_id in self.visited else 'Нужно нажать на кнопку')
            self.visited.add(from_id)
        self.send_message(from_id, msg, keyboard.get_keyboard())
        user_status = self.db_session.query(self.db.UserStatuses).filter(
            self.db.UserStatuses.user_id == from_id).first()
        if user_status:
            user_status.status = 'started'
        else:
            self.db_session.add(
                self.db.UserStatuses(user_id=from_id, status='started'))
            print(f'=== user {from_id} added')
        self.db_session.commit()

    def command_start_analysis(self, from_id):
        texts = []

        try:
            group_ids = self.get_subscriptions(from_id)
        except vk_api.exceptions.ApiError:
            message = 'Ваш профиль закрыт, я не могу увидеть подписки'
            keyboard = VkKeyboard(one_time=True)
            keyboard.add_button('Теперь профиль открыт, начать анализ',
                                color=VkKeyboardColor.POSITIVE,
                                payload=json.dumps(
                                    {'button': 'start_analysis'}))
            self.send_message(from_id, message, keyboard.get_keyboard())
            return

        message = ('Анализ может занять несколько минут. Пожалуйста, '
                   'подождите.')
        self.send_message(from_id, message)
        for _id in group_ids:
            try:
                posts = map(
                    itemgetter('text'),
                    filter(lambda x: not x['marked_as_ads'],
                           self.get_posts(-_id, 10)))
                texts.append('\n'.join(posts))
            except TypeError:
                continue

        prediction = list(map(itemgetter(0),
                              self.predictor.predict(texts)[:3]))

        user_status = self.db_session.query(self.db.UserStatuses).filter(
            self.db.UserStatuses.user_id == from_id).first()
        user_status.subjects = '&'.join(prediction)
        user_status.status = 'show_page'
        user_status.page = 1
        self.db_session.commit()

        message = 'В ходе анализа было выявлено, что вас ' \
                  'интересуют следующие категории групп:\n'
        message += '\n'.join([
            f'{i}. {category.capitalize()}'
            for i, category in enumerate(prediction, 1)
        ])

        self.send_message(from_id, message)

        keyboard = VkKeyboard(one_time=True)
        keyboard.add_button('Начать анализ повторно',
                            color=VkKeyboardColor.SECONDARY,
                            payload=json.dumps({'button': 'start_analysis'}))

        group_ids = self.db_session.query(self.db.GroupsIds).filter(
            or_(self.db.GroupsIds.subject == prediction[0],
                self.db.GroupsIds.subject == prediction[1],
                self.db.GroupsIds.subject == prediction[2])).all()

        if len(group_ids) > 0:
            show_groups = group_ids[:10]
            message = 'Страница 1:\n'
            message += '\n'.join([
                f'{i + 1}. {show_groups[i].name} -- '
                f'https://vk.com/club{show_groups[i].group_id} '
                for i in range(len(show_groups))
            ])
            page_number = len(group_ids) // 10 + 1

            keyboard.add_line()
            keyboard.add_button(f'Страница {page_number}',
                                color=VkKeyboardColor.PRIMARY,
                                payload=json.dumps({
                                    'button':
                                    f'show_recommendation_{page_number}'
                                }))
            keyboard.add_button(f'Страница 2',
                                color=VkKeyboardColor.PRIMARY,
                                payload=json.dumps(
                                    {'button': f'show_recommendation_2'}))
        else:
            message = "Проанализировать ещё раз?"
        self.send_message(from_id, message, keyboard.get_keyboard())

    def command_show_recommendation(self, from_id, payload):
        page = int(payload['button'].split('_')[2])
        recommendation = self.db_session.query(self.db.UserStatuses).filter(
            self.db.UserStatuses.user_id == from_id).first()
        recommendation = recommendation.subjects.split('&')
        group_ids = self.db_session.query(self.db.GroupsIds).filter(
            or_(self.db.GroupsIds.subject == recommendation[0],
                self.db.GroupsIds.subject == recommendation[1],
                self.db.GroupsIds.subject == recommendation[2])).all()
        show_groups = group_ids[(page - 1) * 10:page * 10]
        message = f'Страница {page}:\n'
        message += '\n'.join([
            f'{i + 1}. {show_groups[i].name} -- '
            f'https://vk.com/club{show_groups[i].group_id}'
            for i in range(len(show_groups))
        ])
        keyboard = VkKeyboard(one_time=True)
        keyboard.add_button('Начать анализ повторно',
                            color=VkKeyboardColor.SECONDARY,
                            payload=json.dumps({'button': 'start_analysis'}))
        keyboard.add_line()
        page_number = page - 1 if page > 1 else len(group_ids) // 10 + 1
        keyboard.add_button(f'Страница {page_number}',
                            color=VkKeyboardColor.PRIMARY,
                            payload=json.dumps({
                                'button':
                                f'show_recommendation_{page_number}'
                            }))
        page_number = (page + 1) % (len(group_ids) // 10 + 1)
        page_number = page_number or len(group_ids) // 10 + 1
        keyboard.add_button(f'Страница {page_number}',
                            color=VkKeyboardColor.PRIMARY,
                            payload=json.dumps({
                                'button':
                                f'show_recommendation_{page_number}'
                            }))
        self.send_message(from_id, message, keyboard.get_keyboard())
        user_status = self.db_session.query(self.db.UserStatuses).filter(
            self.db.UserStatuses.user_id == from_id).first()
        user_status.status = 'show_page'
        user_status.page = page
        self.db_session.commit()

    def command_admin(self, from_id):
        print(f'*** {from_id} entered admin panel')

        keyboard = VkKeyboard(one_time=True)
        keyboard.add_button('Фильтровать датасет',
                            color=VkKeyboardColor.PRIMARY,
                            payload=json.dumps({'button': 'dataset_filter'}))
        keyboard.add_button('Выйти',
                            color=VkKeyboardColor.NEGATIVE,
                            payload=json.dumps({'command': 'start'}))
        msg = 'Вы вошли в панель администратора'
        self.send_message(from_id, msg, keyboard.get_keyboard())

        user_status = self.db_session.query(self.db.UserStatuses).filter(
            self.db.UserStatuses.user_id == from_id).first()
        if user_status:
            user_status.status = 'admin'
        else:
            self.db_session.add(
                self.db.UserStatuses(user_id=from_id, status='admin'))
        self.db_session.commit()

    def command_dataset_filter(self, from_id, payload):
        user_status = self.db_session.query(self.db.UserStatuses).filter(
            self.db.UserStatuses.user_id == from_id).first()
        if user_status.status == 'admin':
            if '#' in payload['button']:
                _, gr_id, cat = payload['button'].split('#')
                gr_id = int(gr_id)
                if gr_id > self.latest_id:
                    self.latest_id = gr_id
                    cat = self.new_cats[int(cat)] if cat != '-1' else 'other'
                    old_group = self.db_session.query(self.db.GroupsIds).get(
                        self.latest_id)
                    self.db_session.add(
                        self.db.Groups(group_id=self.latest_id,
                                       name=old_group.name,
                                       subject=cat,
                                       link=old_group.link))
                    msg = (f"{old_group.name} теперь относится к группе "
                           f"{cat.capitalize()}")
                else:
                    msg = f'Группа {gr_id} уже была добавлена'
                self.send_message(from_id, msg)

            group = self.db_session.query(self.db.GroupsIds).order_by(
                self.db.GroupsIds.group_id.asc()).filter(
                    self.db.GroupsIds.group_id > self.latest_id).first()

            keyboard = VkKeyboard(one_time=True)
            msg = ('К какой категории относится эта группа?\n'
                   f'https://vk.com/club{group.group_id}\n\n')
            for i, cat in enumerate(self.new_cats):
                keyboard.add_button(
                    cat.capitalize(),
                    color=VkKeyboardColor.SECONDARY,
                    payload=json.dumps({
                        'button':
                        f'dataset_filter#{group.group_id}#{self.new_cats.index(cat)}'
                    }))
                if (i + 1) % 3 == 0:
                    keyboard.add_line()
            if (i + 1) % 3 != 0:
                keyboard.add_line()
            keyboard.add_button('Ни к одной',
                                color=VkKeyboardColor.NEGATIVE,
                                payload=json.dumps({
                                    'button':
                                    f'dataset_filter#{group.group_id}#-1'
                                }))
            keyboard.add_button('Завершить',
                                color=VkKeyboardColor.NEGATIVE,
                                payload=json.dumps({'command': 'start'}))
            self.send_message(from_id, msg, keyboard.get_keyboard())
        else:
            keyboard = VkKeyboard(one_time=True)
            keyboard.add_button('Начать анализ',
                                color=VkKeyboardColor.POSITIVE,
                                payload=json.dumps(
                                    {'button': 'start_analysis'}))
            msg = 'Начнём анализ?'
            self.send_message(from_id, msg, keyboard.get_keyboard())
Ejemplo n.º 13
0
import numpy as np
from app import app, db
from models import Drawings
from flask import render_template, request, jsonify
from model.predictor import Predictor
from helpers.bucket_helper import BucketHelper

predictor = Predictor(path='model/')
bucket_helper = BucketHelper()


@app.route('/', methods=["POST", "GET", "OPTIONS"])
def main_page():
    return render_template('index.html')


@app.route('/about', methods=["GET"])
def about():
    return render_template('about.html')


@app.route('/prediction_page', methods=["POST"])
def predict_img():
    predictor.decode_image(request)
    predictor.process_image()
    predicted_label, confidence, message = predictor.predict_image()

    response = {
        'message': message,
        'predicted_label': predicted_label,
        'confidence': float(confidence),
def planner(trajectory):
    map_size = 12
    step_resolution = 0.25
    step_size = 10
    obstacle_width = 10
    obstacle_thickness = 5
    agent_step = step_resolution * step_size
    plt.figure()
    currentAxis = plt.gca()
    plt.axis([-map_size, map_size, -map_size, map_size])
    plt.title("RRG Route Map for the Maze Problem")
    plt.xlabel('X')
    plt.ylabel('Y')
    Node = namedtuple('Node', ['x', 'y'])
    task_test = tasks[105]
    predictor = Predictor(task=task_test,
                          checkpoint_path='../data/checkpoints/checkpoint_' +
                          task_test.task_name + '.pt')
    _, obstacle_pos_prediction = predictor.predict(trajectory)

    def is_valid_move(start_node, dir_x, dir_y):
        global obstacle
        end_x = start_node.x
        end_y = start_node.y
        start_step = calculate_step(start_node)

        for s in range(step_size):
            if (end_x < -map_size) or (end_x > map_size) or (
                    end_y < -map_size) or (end_y > map_size):
                return False

            for obstacle in walls:
                if obstacle.interfere_node(end_x, end_y):
                    return False

            if obstacle_range.interfere_node(end_x, end_y):
                cur_step = start_step + s

                if cur_step >= 100:
                    return False

                obstacle = Area(obstacle_pos_prediction[0, cur_step, 0].item(),
                                0, obstacle_width, obstacle_thickness)

                if obstacle.interfere_node(end_x, end_y):
                    return False

            end_x += dir_x * step_resolution
            end_y += dir_y * step_resolution

        return True

    def bfs(end_node, plot=False):
        parents = {}
        finished = set()
        queue = deque()
        queue.append(origin)

        while queue:
            parent = queue.popleft()
            finished.add(parent)

            if plot:
                plt.plot([
                    parent.x,
                ], [
                    parent.y,
                ],
                         'o',
                         color='gray',
                         markersize=4)

            for child in route_dict[parent]:
                if child not in finished:
                    if plot:
                        plt.plot([parent.x, child.x], [parent.y, child.y],
                                 color='gray')

                    if child not in queue:
                        queue.append(child)
                        parents[child] = parent

        cur_node = end_node
        path = [end_node]

        while cur_node != origin:
            cur_node = parents[cur_node]
            path.append(cur_node)

        path.reverse()

        return path

    def calculate_step(end_node):
        cur_path = bfs(end_node=end_node, plot=False)

        return (len(cur_path) - 1) * step_size

    obstacle = Area(0, 0, obstacle_width, obstacle_thickness)
    wall0 = Area(-8, 7.5, 12, 5)
    wall1 = Area(8, -7.5, 12, 5)
    walls = [wall0, wall1]
    target = Area(-8, 11, 8, 2)
    obstacle_range = Area(0, 0, map_size * 2 + 4, 5)
    origin = Node(2, -11)
    route_dict = {origin: set()}

    done = False
    last_node = None

    while not done:
        pos_rand_x = random.uniform(-map_size, map_size)
        pos_rand_y = random.uniform(-map_size, map_size)

        min_distance = float('inf')
        n_nearest = None

        for node in route_dict.keys():
            distance = sqrt(
                abs(node.x - pos_rand_x) + abs(node.y - pos_rand_y))

            if min_distance > distance:
                min_distance = distance
                n_nearest = node

        if min_distance > agent_step:
            continue

        direction_x, direction_y = choose_direction(n_nearest.x, n_nearest.y,
                                                    pos_rand_x, pos_rand_y)

        if not is_valid_move(n_nearest, direction_x, direction_y):
            continue

        new_node_x = n_nearest.x + direction_x * agent_step
        new_node_y = n_nearest.y + direction_y * agent_step

        new_node = Node(new_node_x, new_node_y)

        if new_node in route_dict:
            continue

        route_dict[new_node] = set()
        route_dict[n_nearest].add(new_node)
        route_dict[new_node].add(n_nearest)

        for p in route_dict.keys():
            direction_x, direction_y = choose_direction(
                p.x, p.y, new_node.x, new_node.y)

            if abs(p.x - new_node.x) == agent_step and abs(p.y - new_node.y) == agent_step \
                    and is_valid_move(p, direction_x, direction_y):
                route_dict[p].add(new_node)

        if target.interfere_node(new_node_x, new_node_y):
            last_node = new_node
            done = True

    path = bfs(end_node=last_node, plot=True)

    for milestone in path:
        if obstacle_range.interfere_node(milestone.x, milestone.y):
            danger_step = calculate_step(milestone)

            if obstacle.interfere_node(milestone.x, milestone.y):
                return None

            obstacle = Area(obstacle_pos_prediction[0, danger_step, 0].item(),
                            0, obstacle_width, obstacle_thickness)
            currentAxis.add_patch(
                Rectangle((obstacle.min_x, obstacle.min_y),
                          obstacle.x_width,
                          obstacle.y_width,
                          fill=True,
                          facecolor="blue",
                          alpha=1))

    policy = []

    for i in range(len(path) - 1):
        if path[i + 1].x < path[i].x:
            policy_unit_x = -1
        elif path[i + 1].x > path[i].x:
            policy_unit_x = 1
        else:
            policy_unit_x = 0

        if path[i + 1].y < path[i].y:
            policy_unit_y = -1
        elif path[i + 1].y > path[i].y:
            policy_unit_y = 1
        else:
            policy_unit_y = 0

        policy.append([policy_unit_x, policy_unit_y])

    currentAxis.add_patch(
        Rectangle((wall0.min_x, wall0.min_y),
                  wall0.x_width,
                  wall0.y_width,
                  fill=True,
                  facecolor="green",
                  alpha=1))
    currentAxis.add_patch(
        Rectangle((wall1.min_x, wall1.min_y),
                  wall1.x_width,
                  wall1.y_width,
                  fill=True,
                  facecolor="green",
                  alpha=1))
    currentAxis.add_patch(
        Rectangle((obstacle_range.min_x, obstacle_range.min_y),
                  obstacle_range.x_width,
                  obstacle_range.y_width,
                  fill=True,
                  facecolor="blue",
                  alpha=0.2))
    currentAxis.add_patch(
        Rectangle((target.min_x, target.min_y),
                  target.x_width,
                  target.y_width,
                  fill=True,
                  facecolor="purple",
                  alpha=0.5))

    for i in range(1, len(path)):
        plt.plot([
            path[i].x,
        ], [
            path[i].y,
        ], ' o', color='orange')
        plt.plot([path[i - 1].x, path[i].x], [path[i - 1].y, path[i].y],
                 color='orange',
                 lw=3)

    plt.plot([
        origin.x,
    ], [
        origin.y,
    ], 'rs', markersize=10)

    plt.show()

    return policy
Ejemplo n.º 15
0
    for i in range(pos_emb.shape[0]):
        pos_emb[i, i] = 1.0
    
    if args.eval_on_dev:
 
        ner_model = BiLSTM(word_emb, pos_emb, args.word_hid_dim, max(list(label2idx.values()))+1, args.dropout, args.batch, trainable_emb = args.trainable_emb)
        ner_model.rand_init()
        criterion = nn.NLLLoss()
        if args.cuda:
            ner_model.cuda()
        if args.opt == 'adam':
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, ner_model.parameters()), lr=args.lr)
        else:
            optimizer = optim.SGD(filter(lambda p: p.requires_grad, ner_model.parameters()), lr=args.lr, momentum=args.momentum)
        
        predictor = Predictor()
        evaluator = Evaluator(predictor, label2idx, word2idx, pos2idx, args)

        best_scores = []
        best_dev_f1_sum = 0
        patience = 0

        print('\n'*2)
        print('='*10 + 'Phase1, train on train_data, epoch=args.epochs' + '='*10)
        print('\n'*2)
        for epoch in range(1, args.epochs+1):
            
            loss = train_epoch(train_data, ner_model, optimizer, criterion, args)
            print("*"*10 + "epoch:{}, loss:{}".format(epoch, loss) + "*"*10)
            eval_result_train = evaluator.evaluate(train_data, ner_model, args, cuda = args.cuda)
            print("On train_data: ")
Ejemplo n.º 16
0
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available())
parser.add_argument('--cpu', action='store_true')
args = parser.parse_args()

torch.manual_seed(args.seed)
random.seed(1234)
if args.cpu:
    args.cuda = False
elif args.cuda:
    torch.cuda.manual_seed(args.seed)

# load opt
model_file = args.model_dir + '/' + args.model
print("Loading model from {}".format(model_file))
opt = torch_utils.load_config(model_file)
predictor = Predictor(opt)
model = PredictorTrainer(opt, predictor)
model.load(model_file)

# load vocab
TOKEN = data.Field(sequential=True,
                   batch_first=True,
                   lower=True,
                   include_lengths=True)
RELATION = data.Field(sequential=False, unk_token=None, pad_token=None)
POS = data.Field(sequential=True, batch_first=True)
NER = data.Field(sequential=True, batch_first=True)
PST = data.Field(sequential=True, batch_first=True)

fields = {
    'tokens': ('token', TOKEN),
Ejemplo n.º 17
0
    def run(self):
        """
        The entry point to the main node functionality - measuring default Configuration.
        When the default Configuration finishes its evaluation, the first set of Configurations will be
        sampled for evaluation (respectively, the queues for Configuration measurement results initialize).
        """
        self._state = self.State.RUNNING
        self.logger.info("Starting BRISE")
        self.sub.send('log', 'info', message="Starting BRISE")

        if not self.experiment_setup:
            # Check if main.py running with a specified experiment description file path
            if len(argv) > 1:
                exp_desc_file_path = argv[1]
            else:
                exp_desc_file_path = './Resources/EnergyExperiment/EnergyExperiment.json'
                log_msg = f"The Experiment Setup was not provided and the path to an experiment file was not specified." \
                          f" The default one will be executed: {exp_desc_file_path}"
                self.logger.warning(log_msg)
                self.sub.send('log', 'warning', message=log_msg)
            experiment_description, search_space = load_experiment_setup(
                exp_desc_file_path)
        else:
            experiment_description = self.experiment_setup[
                "experiment_description"]
            search_space = self.experiment_setup["search_space"]

        validate_experiment_description(experiment_description)
        os.makedirs(experiment_description["General"]["results_storage"],
                    exist_ok=True)

        # Initializing instance of Experiment - main data holder.
        self.experiment = Experiment(experiment_description, search_space)
        search_space.experiment_id = self.experiment.unique_id
        Configuration.set_task_config(
            self.experiment.description["TaskConfiguration"])

        # initialize connection to rabbitmq service
        self.connection = pika.BlockingConnection(
            pika.ConnectionParameters(
                os.getenv("BRISE_EVENT_SERVICE_HOST"),
                int(os.getenv("BRISE_EVENT_SERVICE_AMQP_PORT"))))
        self.consume_channel = self.connection.channel()

        # initialize connection to the database
        self.database = MongoDB(os.getenv("BRISE_DATABASE_HOST"),
                                int(os.getenv("BRISE_DATABASE_PORT")),
                                os.getenv("BRISE_DATABASE_NAME"),
                                os.getenv("BRISE_DATABASE_USER"),
                                os.getenv("BRISE_DATABASE_PASS"))

        # write initial settings to the database
        self.database.write_one_record(
            "Experiment_description",
            self.experiment.get_experiment_description_record())
        self.database.write_one_record(
            "Search_space",
            get_search_space_record(self.experiment.search_space,
                                    self.experiment.unique_id))
        self.experiment.send_state_to_db()

        self.sub.send(
            'experiment',
            'description',
            global_config=self.experiment.description["General"],
            experiment_description=self.experiment.description,
            searchspace_description=self.experiment.search_space.serialize(
                True))
        self.logger.debug(
            "Experiment description and global configuration sent to the API.")

        # Create and launch Stop Condition services in separate threads.
        launch_stop_condition_threads(self.experiment.unique_id)

        # Instantiate client for Worker Service, establish connection.
        self.wsc_client = WSClient(
            self.experiment.description["TaskConfiguration"],
            os.getenv("BRISE_EVENT_SERVICE_HOST"),
            int(os.getenv("BRISE_EVENT_SERVICE_AMQP_PORT")))

        # Initialize Repeater - encapsulate Configuration evaluation process to avoid results fluctuations.
        # (achieved by multiple Configuration evaluations on Workers - Tasks)
        RepeaterOrchestration(self.experiment)

        self.predictor: Predictor = Predictor(self.experiment.unique_id,
                                              self.experiment.description,
                                              self.experiment.search_space)

        self.consume_channel.basic_consume(
            queue='default_configuration_results_queue',
            auto_ack=True,
            on_message_callback=self.get_default_configurations_results)
        self.consume_channel.basic_consume(
            queue='configurations_results_queue',
            auto_ack=True,
            on_message_callback=self.get_configurations_results)
        self.consume_channel.basic_consume(queue='stop_experiment_queue',
                                           auto_ack=True,
                                           on_message_callback=self.stop)
        self.consume_channel.basic_consume(
            queue="get_new_configuration_queue",
            auto_ack=True,
            on_message_callback=self.send_new_configurations_to_measure)

        self.default_config_handler = get_default_config_handler(
            self.experiment)
        temp_msg = "Measuring default Configuration."
        self.logger.info(temp_msg)
        self.sub.send('log', 'info', message=temp_msg)
        default_parameters = self.experiment.search_space.generate_default()
        default_configuration = Configuration(default_parameters,
                                              Configuration.Type.DEFAULT,
                                              self.experiment.unique_id)
        default_configuration.experiment_id = self.experiment.unique_id
        dictionary_dump = {"configuration": default_configuration.to_json()}
        body = json.dumps(dictionary_dump)

        self.consume_channel.basic_publish(
            exchange='',
            routing_key='measure_new_configuration_queue',
            body=body)
        # listen all queues with responses until the _is_interrupted flag is False
        try:
            while not self._is_interrupted:
                self.consume_channel.connection.process_data_events(
                    time_limit=1)  # 1 second
        finally:
            if self.connection.is_open:
                self.connection.close()
def register_callbacks(app: Dash):
    '''Register callbacks for the given app argument'''

    predictor = Predictor()

    @app.callback(Output("output", "children"),
                  [Input("submit-button", "n_clicks")], [
                      State("input-score-eng", "value"),
                      State("input-score-math", "value"),
                      State("input-score-bio", "value"),
                      State("input-score-chem", "value"),
                      State("input-score-phy", "value"),
                      State("input-score-econ", "value"),
                      State("input-score-geo", "value"),
                      State("input-score-soc", "value"),
                      State("input-score-fin", "value"),
                      State("input-major-first", "value"),
                      State("input-major-second", "value"),
                      State("input-major-third", "value")
                  ])
    def process_data(n_clicks: int, eng: float, math: float, bio: float,
                     chem: float, phy: float, econ: float, geo: float,
                     soc: float, fin: float, major_first: int,
                     major_second: int, major_third: int) -> str:
        '''Process the data from the provided arguments'''
        values = [major_first, major_second, major_third]

        if not n_clicks:
            return json.dumps({})

        if any(x == -1 for x in values):
            result = {'error': 'Please fill all the data'}
            return json.dumps(result)

        data1 = predictor.predict(eng=eng,
                                  math=math,
                                  bio=bio,
                                  chem=chem,
                                  phy=phy,
                                  econ=econ,
                                  geo=geo,
                                  soc=soc,
                                  fin=fin,
                                  major_name=major_first)

        data2 = predictor.predict(eng=eng,
                                  math=math,
                                  bio=bio,
                                  chem=chem,
                                  phy=phy,
                                  econ=econ,
                                  geo=geo,
                                  soc=soc,
                                  fin=fin,
                                  major_name=major_second)

        data3 = predictor.predict(eng=eng,
                                  math=math,
                                  bio=bio,
                                  chem=chem,
                                  phy=phy,
                                  econ=econ,
                                  geo=geo,
                                  soc=soc,
                                  fin=fin,
                                  major_name=major_third)

        data = {'first': data1, 'second': data2, 'third': data3}

        return json.dumps({'data': data})

    @app.callback([
        Output('major-data', 'children'),
        Output('output-welcome', 'className'),
        Output('output-error', 'className'),
        Output('output-error', 'children'),
        Output('output-deck', 'className'),
    ], [Input('output', 'children')])
    def render_output(data: str) -> (str, str, str, str, str):
        payload: dict = json.loads(data)
        error: str = payload.get('error', None)
        data: dict = payload.get('data', None)

        if error:
            return "", "d-none", "d-block", error, "d-none"

        if not data:
            return "", "d-block", "d-none", "", "d-none"

        return json.dumps(data), "d-none", "d-none", "", "d-block"

    @app.callback([
        Output("first-recommended", "children"),
        Output("first-proba", "children"),
        Output("first-header", "color"),
        Output("first-header", "className"),
    ], [Input("major-data", "children")])
    def render_first_title(data: str) -> (str, str, str, str):
        """
        Used for rendering the GUI element of the first major choice
        """
        if not data:
            return "", "", "", ""

        data = json.loads(data)
        recommended = sorted(data.items(), key=lambda x: x[1],
                             reverse=True)[0][0] == "first"
        recommended_str = "Recommended" if recommended else ""
        performance = "%.2f %%" % data["first"]
        color = "primary" if recommended else "secondary"
        textColor = "text-white" if recommended else "text-primary"
        return recommended_str, performance, color, textColor

    @app.callback([
        Output("second-recommended", "children"),
        Output("second-proba", "children"),
        Output("second-header", "color"),
        Output("second-header", "className"),
    ], [Input("major-data", "children")])
    def render_second_title(data: str) -> (str, str, str, str):
        """
        Used for rendering the GUI element of the second major choice
        """
        if not data:
            return "", "", "", ""

        data = json.loads(data)
        recommended = sorted(data.items(), key=lambda x: x[1],
                             reverse=True)[0][0] == "second"
        recommended_str = "Recommended" if recommended else ""
        performance = "%.2f %%" % data["second"]
        color = "primary" if recommended else "secondary"
        textColor = "text-white" if recommended else "text-primary"
        return recommended_str, performance, color, textColor

    @app.callback([
        Output("third-recommended", "children"),
        Output("third-proba", "children"),
        Output("third-header", "color"),
        Output("third-header", "className"),
    ], [Input("major-data", "children")])
    def render_third_title(data: str) -> (str, str, str, str):
        """
        Used for rendering the GUI element of the third major choice
        """
        if not data:
            return "", "", "", ""

        data = json.loads(data)
        recommended = sorted(data.items(), key=lambda x: x[1],
                             reverse=True)[0][0] == "third"
        recommended_str = "Recommended" if recommended else ""
        performance = "%.2f %%" % data["third"]
        color = "primary" if recommended else "secondary"
        textColor = "text-white" if recommended else "text-primary"
        return recommended_str, performance, color, textColor
Ejemplo n.º 19
0
        crit_ner.cuda()
        crit_lm.cuda()
        ner_model.cuda()
        packer = CRFRepack_WC(len(tag2idx), True)
    else:
        packer = CRFRepack_WC(len(tag2idx), False)

    if args.start_epoch != 0:
        args.start_epoch += 1
        args.epoch = args.start_epoch + args.epoch
        epoch_list = range(args.start_epoch, args.epoch)
    else:
        args.epoch += 1
        epoch_list = range(1, args.epoch)

    predictor = Predictor(tag2idx, packer, label_seq=True, batch_size=50)
    evaluator = Evaluator(predictor, packer, tag2idx, args.eva_matrix,
                          args.pred_method)

    trainer = Trainer(ner_model, packer, crit_ner, crit_lm, optimizer,
                      evaluator, crf2corpus, args.plateau)
    trainer.train(crf2train_dataloader, crf2dev_dataloader, dev_dataset_loader,
                  epoch_list, args)

    trainer.eval_batch_corpus(dev_dataset_loader, args.dev_file,
                              args.corpus2crf)

    try:
        print("Load from PICKLE")
        single_testset = pickle.load(
            open(args.pickle + "/temp_single_test.p", "rb"))
Ejemplo n.º 20
0
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available())
parser.add_argument('--cpu', action='store_true')
args = parser.parse_args()

torch.manual_seed(args.seed)
random.seed(1234)
if args.cpu:
    args.cuda = False
elif args.cuda:
    torch.cuda.manual_seed(args.seed)

# load opt
model_p_file = args.p_dir + '/' + args.p_model
print("Loading predictor from {}".format(model_p_file))
opt_p = torch_utils.load_config(model_p_file)
predictor = Predictor(opt_p)
model_p = PredictorTrainer(opt_p, predictor)
model_p.load(model_p_file)

model_s_file = args.s_dir + '/' + args.s_model
print("Loading selector from {}".format(model_s_file))
opt_s = torch_utils.load_config(model_s_file)
selector = Predictor(opt_s)
model_s = SelectorTrainer(opt_s, selector)
model_s.load(model_s_file)

# load vocab
TOKEN = data.Field(sequential=True,
                   batch_first=True,
                   lower=True,
                   include_lengths=True)
Ejemplo n.º 21
0
X_train, X_test, y_train, y_test = read_train_test_dir(data_dir)
stats = load_dict_output(data_dir, "stats.json")
print("Dataset read complete...")



n_test = get_test_sample_size(X_test.shape[0], k=TEST_BATCH_SIZE)
X_test = X_test[:n_test, :]
y_test = y_test[:n_test, :]

users_test = X_test[:, 0].reshape(-1,1)
items_test = X_test[:, 1].reshape(-1,1)
y_test = y_test.reshape(-1,1)


predictor = Predictor(model=model, batch_size=TEST_BATCH_SIZE, users=users_test, items=items_test, y=y_test,
                      use_cuda=args.cuda, n_items=stats["n_items"])



preds = predictor.predict().reshape(-1,1)


output = pd.DataFrame(np.concatenate((users_test, preds, y_test), axis=1),
                      columns = ['user_id', 'pred', 'y_true'])


if args.task == "choice":

    output, hit_ratio, ndcg = get_choice_eval_metrics(output, at_k=EVAL_K)

    print("hit ratio: {:.4f}".format(hit_ratio))