Exemplo n.º 1
0
def fit_process(uid):
    # create a new temporary model file
    fd, path = tempfile.mkstemp()

    # close the temporary model file descriptor as we don't need it
    os.close(fd)

    # give this process a dedicated session
    session = Session()
    try:
        ftclassifier.fit(session, uid, path)
        # sgdclassifier.fit(session, uid, path)

        # persist the model to the database
        with open(path, 'rb') as f:
            classifier = f.read()
            dal.update_classifier(session, uid, classifier)

        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()
        Session.remove()

    # delete the temporary model file
    os.unlink(path)
Exemplo n.º 2
0
Arquivo: app.py Projeto: cjbvt/SpareME
def populate():
    """
    Populates the database for the given user with sample data.
    """
    try:
        id_token = request.form['id_token']
        uid = verify_id_token(id_token)
    except KeyError:
        return "id_token required", status.HTTP_400_BAD_REQUEST
    except ValueError:
        return "id_token unrecognized", status.HTTP_400_BAD_REQUEST
    except auth.AuthError as exc:
        if exc.code == 'ID_TOKEN_REVOKED':
            return "id_token revoked", status.HTTP_400_BAD_REQUEST
        else:
            return "id_token invalid", status.HTTP_400_BAD_REQUEST
    session = Session()
    try:
        dal.populate(session, uid)
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()
        Session.remove()
    classifier.fit(uid)
    return "Sample data added for user", status.HTTP_202_ACCEPTED
Exemplo n.º 3
0
Arquivo: app.py Projeto: cjbvt/SpareME
def stats():
    """
    Get a list of all the given user's stats.
    """
    try:
        id_token = request.form['id_token']
        uid = verify_id_token(id_token)
    except KeyError:
        return "id_token required", status.HTTP_400_BAD_REQUEST
    except ValueError:
        return "id_token unrecognized", status.HTTP_400_BAD_REQUEST
    except auth.AuthError as exc:
        if exc.code == 'ID_TOKEN_REVOKED':
            return "id_token revoked", status.HTTP_400_BAD_REQUEST
        else:
            return "id_token invalid", status.HTTP_400_BAD_REQUEST
    session = Session()
    try:
        stats = dal.get_stats(session, uid)
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()
        Session.remove()
    return json.dumps(stats), status.HTTP_200_OK
Exemplo n.º 4
0
Arquivo: app.py Projeto: cjbvt/SpareME
def reset():
    """
    Deletes all of the user's data from the database.
    """
    try:
        id_token = request.form['id_token']
        uid = verify_id_token(id_token)
    except KeyError:
        return "id_token required", status.HTTP_400_BAD_REQUEST
    except ValueError:
        return "id_token unrecognized", status.HTTP_400_BAD_REQUEST
    except auth.AuthError as exc:
        if exc.code == 'ID_TOKEN_REVOKED':
            return "id_token revoked", status.HTTP_400_BAD_REQUEST
        else:
            return "id_token invalid", status.HTTP_400_BAD_REQUEST
    session = Session()
    try:
        dal.delete(session, uid)
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()
        Session.remove()
    return "User data deleted", status.HTTP_202_ACCEPTED
Exemplo n.º 5
0
Arquivo: app.py Projeto: cjbvt/SpareME
def predict():
    """
    Predicts the text label of every value in the given list of unlabeled text.
    """
    try:
        id_token = request.form['id_token']
        uid = verify_id_token(id_token)
    except KeyError:
        return "id_token required", status.HTTP_400_BAD_REQUEST
    except ValueError:
        return "id_token unrecognized", status.HTTP_400_BAD_REQUEST
    except auth.AuthError as exc:
        if exc.code == 'ID_TOKEN_REVOKED':
            return "id_token revoked", status.HTTP_400_BAD_REQUEST
        else:
            return "id_token invalid", status.HTTP_400_BAD_REQUEST
    try:
        unlabeled_text = json.loads(request.form['unlabeled_text'])
    except KeyError:
        return "unlabeled_text required", status.HTTP_400_BAD_REQUEST
    except ValueError:
        return "unlabeled_text unrecognized", status.HTTP_400_BAD_REQUEST
    session = Session()
    try:
        predicted_labels = classifier.predict(session, uid,
                                              list(unlabeled_text.values()))
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()
        Session.remove()
    predictions = dict(zip(unlabeled_text.keys(), predicted_labels))
    return json.dumps(predictions), status.HTTP_200_OK
Exemplo n.º 6
0
Arquivo: app.py Projeto: cjbvt/SpareME
def add():
    """
    Adds the given text to the database for a user, labeled with the given
    label text, and re-fits their classifier.
    """
    try:
        id_token = request.form['id_token']
        uid = verify_id_token(id_token)
    except KeyError:
        return "id_token required", status.HTTP_400_BAD_REQUEST
    except ValueError:
        return "id_token unrecognized", status.HTTP_400_BAD_REQUEST
    except auth.AuthError as exc:
        if exc.code == 'ID_TOKEN_REVOKED':
            return "id_token revoked", status.HTTP_400_BAD_REQUEST
        else:
            return "id_token invalid", status.HTTP_400_BAD_REQUEST
    try:
        label = request.form['label']
    except KeyError:
        return "label required", status.HTTP_400_BAD_REQUEST
    try:
        text = request.form['text']
    except KeyError:
        return "text required", status.HTTP_400_BAD_REQUEST
    session = Session()
    try:
        dal.add_labeled_text(session, uid, label, text)
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()
        Session.remove()
    classifier.fit(uid)
    return "Labeled text added for user", status.HTTP_202_ACCEPTED
Exemplo n.º 7
0
    def execute_batches(batches, quiet = False):
        def run_print(string):
            if quiet: return
            click.echo(string)

        loop = asyncio.get_event_loop()
        def handler_interrupt():
            run_print('Interrupt Received! ')
            run_print('Cancelling the jobs...')
            for task in asyncio.Task.all_tasks(loop = loop):
                task.cancel()
        loop.add_signal_handler(signal.SIGINT, handler_interrupt)

        max_ssh = int(os.environ.setdefault('ADMINWARE_MAX_SSH', '100'))
        start_delay = float(os.environ.setdefault('ADMINWARE_START_DELAY', '0.2'))
        pool = concurrent.futures.ThreadPoolExecutor(max_workers = max_ssh)

        async def start_tasks(tasks):
            active_tasks = []
            def remove_done_tasks():
                for active_task in active_tasks:
                    if active_task.finished():
                        active_tasks.remove(active_task)
                        break

            async def add_tasks():
                for task in tasks:
                    while len(active_tasks) > max_ssh:
                        remove_done_tasks()
                        await asyncio.sleep(0.01)
                    asyncio.ensure_future(task, loop = loop)
                    active_tasks.append(task)
                    run_print('Starting Job: {}'.format(task.node))
                    await(asyncio.sleep(start_delay))

            async def await_finished():
                while len(active_tasks) > 0:
                    remove_done_tasks()
                    await asyncio.sleep(0.01)

            try:
                await add_tasks()
                run_print('Waiting for jobs to finish...')
            finally: await await_finished()

        session = Session()
        try:
            for batch in batches:
                session.add(batch)
                session.commit()
                # Ensure the models are loaded from the db
                batch.jobs
                batch.shell_variables
                run_print('Executing: {}'.format(batch.name()))
                tasks = map(lambda j: j.task(thread_pool = pool), batch.jobs)
                loop.run_until_complete(start_tasks(tasks))
        except concurrent.futures.CancelledError: pass
        finally:
            run_print('Cleaning up...')
            pool.shutdown(wait = True)
            run_print('Saving...')
            session.commit()
            Session.remove()
            run_print('Done')
Exemplo n.º 8
0
def remove_session(sender, response, **extra):
    Session.remove()