コード例 #1
0
class GoogleAnalyticsBase(GoogleOAuth):

    queries = ListProperty(StringType,
                           title="Analytics IDs",
                           default=["ga:########"])
    metrics = ListProperty(StringType,
                           title="Analytics Metrics",
                           default=["ga:hits"])

    def get_google_scope(self):
        """ Required override for GoogleOAuth Block """
        return 'https://www.googleapis.com/auth/analytics.readonly'

    def get_url_suffix(self):
        """ Required override for GoogleOAuth Block """

        # To be implemented in subclass
        return NotImplemented

    def get_url_parameters(self):
        """ Required override for GoogleOAuth Block """

        params = {
            "ids": unquote(self.current_query),
            "metrics": ",".join(self.metrics())
        }

        # Include any additional parameters from the parent block
        params.update(self.get_addl_params())

        self.logger.debug("Accessing Analytics API using {0}".format(params))
        return params
コード例 #2
0
class SpreadsheetLookUp(EnrichSignals, Block):

    source = StringProperty(title='Source File', order=0)
    match = ListProperty(Match, title='Match', order=1)
    output_structure = ListProperty(OutputStructure,
                                    title='Output Structure',
                                    order=2)

    version = VersionProperty('0.1.0')

    def start(self):
        self.map = {}
        rows = self._read_xlsx(xlrd.open_workbook(self.source()))
        for row in rows:
            value = {}
            for item in self.output_structure():
                value[item.key()] = row[item.value()] or None
            self.nested_set(self.map,
                            [row[match.value()] for match in self.match()],
                            value)
        super().start()

    def process_signal(self, signal):
        signal_dict = self.map
        for match in self.match():
            signal_dict = signal_dict[match.key(signal)]
        return self.get_output_signal(signal_dict, signal)

    @staticmethod
    def nested_set(dic, keys, value):
        # https://stackoverflow.com/a/13688108/11653218
        for key in keys[:-1]:
            dic = dic.setdefault(key, {})
        dic[keys[-1]] = value

    @staticmethod
    def _read_xlsx(source, sheet='Sheet1'):
        open_sheet = source.sheet_by_name(sheet)
        rows = []
        labels = [n.value for n in open_sheet.row(0)]
        for i, row in enumerate(open_sheet.get_rows()):
            if not i:
                continue
            cells = []
            # number data is always float, drop tailing zeroes and stringify
            for cell in row:
                if cell.ctype in (2, 3):
                    # data is float
                    if int(cell.value) == cell.value:
                        # data is actually an int stored as a float
                        cells.append(str(int(cell.value)))
                    else:
                        cells.append(str(cell.value))
                else:
                    # data is text
                    cells.append(cell.value.strip())
            row_data = dict(zip(labels, cells))
            rows.append(row_data)
        return rows
コード例 #3
0
class JWTCreate(EnrichSignals, JWTBase):

    version = VersionProperty('0.1.0')
    exp_minutes = Property(title='Valid For Minutes (blank for no exp claim)',
                           order=3,
                           allow_none=True)
    claims = ListProperty(ClaimField, title='Claims', order=4, allow_none=True)

    def process_signal(self, signal, input_id=None):
        _key = self.key(signal)
        _algorithm = self.algorithm(signal)
        _exp_minutes = self.exp_minutes(signal)
        _claims = self.claims(signal)
        _newclaims = {}

        try:
            if isinstance(_exp_minutes, int):
                _newclaims['exp'] = self.set_new_exp_time(_exp_minutes)
            for claim in _claims:
                if claim.name(signal) is not 'exp':
                    _newclaims[claim.name(signal)] = claim.value(signal)

            _token = jwt.encode(_newclaims, _key,
                                algorithm=_algorithm.value).decode('UTF-8')
            return self.get_output_signal({'token': _token}, signal)

        # jwt.encode throws ValueError if key is in wrong format
        except (PyJWTError, ValueError) as e:
            self.notify_signals(
                self.get_output_signal({'message': e.args[0]}, signal),
                'error')
コード例 #4
0
class NetworkConfig(PropertyHolder):
    input_dim = ListProperty(Dimensions,
                             title='Input Tensor Shape',
                             default=[{
                                 'value': -1
                             }, {
                                 'value': 28
                             }, {
                                 'value': 28
                             }, {
                                 'value': 1
                             }])
    learning_rate = FloatProperty(title='Learning Rate', default=0.01)
    loss = SelectProperty(LossFunctions,
                          title='Loss Function',
                          default=LossFunctions.cross_entropy)
    optimizer = SelectProperty(Optimizers,
                               title="Optimizer",
                               default=Optimizers.GradientDescentOptimizer)
    dropout = FloatProperty(title='Dropout Percentage During Training',
                            default=0)
    random_seed = Property(title="Random Seed",
                           default=None,
                           allow_none=True,
                           visible=False)
コード例 #5
0
class MSSQLConditions(object):
    conditions = ListProperty(Conditions,
                              title='Conditions',
                              default=[],
                              order=15)
    combine_condition = SelectProperty(AndOrOperator,
                                       title='Combine Condition',
                                       default="AND",
                                       order=14)

    def _get_where_conditions(self, signal, table, cursor):
        conditions = ""
        combine_condition = self.combine_condition().value
        params = []
        for i, condition in enumerate(self.conditions()):
            if i == 0:
                conditions += ' WHERE '
            else:
                conditions += ' {} '.format(combine_condition)

            column = self.validate_column(condition.column(signal), table,
                                          cursor)
            condition_string = '{} {} ?'.format(
                column,
                condition.operation(signal).value)

            conditions += condition_string
            params.append(condition.value(signal))
        return conditions, params
コード例 #6
0
class When(Block):
    subject = Property(default=None, title='Subject', allow_none=True, order=0)
    cases = ListProperty(Case, title='Cases', default=[], order=1)
    version = VersionProperty('0.1.0')

    def process_signals(self, in_sigs):
        then_signals = []
        else_signals = []

        for signal in in_sigs:

            subject = self.subject(signal)
            for case in self.cases():
                if subject != case.when(signal):
                    continue

                sig = Signal() if case.exclude(signal) else signal

                for attr in case.attributes():
                    title = attr.title(signal)
                    value = attr.formula(signal)
                    setattr(sig, title, value)

                then_signals.append(sig)
                break
コード例 #7
0
class GetEncodingFromFile(Block):

    image_paths = ListProperty(StringType, title='Image Path', default=[])
    uid = StringProperty(title='User ID', defult='')
    sname = StringProperty(title='Save Name', default='')
    version = VersionProperty("2.1.0")

    def save_encoding(self, file_path, save_name, user_id):
        serialized_encoding = []

        for f in file_path:
            image = face_recognition.load_image_file(f)
            face_encoding = face_recognition.face_encodings(image)[0]
            serialized_encoding.append(
                base64.b64encode(pickle.dumps(face_encoding)).decode())

        entry = {
            'user_id': user_id,
            'name': save_name,
            'encoding': serialized_encoding
        }

        return entry

    def process_signals(self, signals):
        add_face_signals = []
        for signal in signals:
            confirmation = self.save_encoding(self.image_paths(signal),
                                              self.sname(signal),
                                              self.uid(signal))
            add_face_signals.append(Signal(confirmation))

        self.notify_signals(add_face_signals)
class AttributeSelector(Block):
    """
    A block for whitelisting or blacklisting incoming signals and notifying
    the rest.

    Properties:
    mode(select): select either whitelist or blacklist behavior
    attributes(list): list of incoming signal attributes to blacklist
                              or whitelist
    """

    version = VersionProperty("1.1.0")
    mode = SelectProperty(Behavior,
                          title='Selector Mode',
                          default=Behavior.BLACKLIST,
                          order=0)
    attributes = ListProperty(StringType,
                              title='Incoming signal attributes',
                              default=[],
                              order=1)

    def process_signals(self, signals):
        new_sigs = []
        for signal in signals:
            sig_dict = signal.to_dict(include_hidden=True)
            attributes = set(spec for spec in self.attributes(signal))
            keep_attributes = set(sig_dict.keys()).intersection(attributes)

            if self.mode() is Behavior.WHITELIST:
                self.logger.debug('whitelisting...')

                if len(keep_attributes) < len(attributes):
                    self.logger.warning(
                        'specified an attribute that is not in the '
                        'incoming signal: {}'.format(
                            attributes.difference(keep_attributes)))

                new_sig = Signal(
                    {attr: sig_dict[attr]
                     for attr in keep_attributes})

                self.logger.debug(
                    'Allowing incoming attributes: {}'.format(keep_attributes))

            elif self.mode() is Behavior.BLACKLIST:
                self.logger.debug('blacklisting...')

                new_sig = Signal({
                    attr: sig_dict[attr]
                    for attr in sig_dict if attr not in keep_attributes
                })

                self.logger.debug(
                    'Ignoring incoming attributes: {}'.format(keep_attributes))

            new_sigs.append(new_sig)

        self.notify_signals(new_sigs)
コード例 #9
0
class Case(PropertyHolder):
    when = Property(default='', title='When', order=0)
    attributes = ListProperty(SignalField,
                              title="Attributes",
                              default=[],
                              order=2)
    exclude = BoolProperty(default=False,
                           title='Exclude existing attributes?',
                           order=1)
コード例 #10
0
class MSSQLUpdate(MSSQLTabledBase, MSSQLConditions):
    version = VersionProperty("1.0.1")
    column_values = ListProperty(ColumnValue,
                                 title='Column Values',
                                 default=[],
                                 order=2)

    def process_signals(self, signals):
        if self.is_connecting:
            self.logger.error(
                'Connection already in progress. Dropping signals.')
            return

        cursor = self._get_cursor()

        total_rows = 0
        for signal in signals:
            # determine query to execute
            table = self.table()
            column_values, params = \
                self._get_column_values(signal, table, cursor)
            conditions, where_params = \
                self._get_where_conditions(signal, table, cursor)
            params.extend(where_params)
            update = \
                'UPDATE {} SET {}'.format(table, column_values) + conditions
            self.logger.debug('Executing: {} with params {}'.format(
                update, params))

            row_count = cursor.execute(update, params).rowcount
            self.logger.debug('{} rows returned for signal: {}'.format(
                row_count, signal.to_dict()))
            total_rows += row_count

        self.logger.debug('Rows updated: {}'.format(total_rows))

        cursor.commit()
        cursor.close()

        self.notify_signals([Signal({'Rows updated': total_rows})])

    def _get_column_values(self, signal, table, cursor):
        column_values = ""
        params = []
        for i, column_value in enumerate(self.column_values()):
            if i != 0:
                column_values += ', '
            column = self.validate_column(column_value.column(signal), table,
                                          cursor)
            condition_string = '{} = ?'.format(column)
            params.append(column_value.value(signal))
            column_values += condition_string

        return column_values, params
コード例 #11
0
class UnpackBytes(EnrichSignals, Block):

    new_attributes = ListProperty(NewAttributes,
                                  title='New Signal Attributes',
                                  default=[{
                                      'format': 'integer',
                                      'endian': 'big',
                                      'key': '{{ $key }}',
                                      'value': '{{ $value }}'
                                  }])
    version = VersionProperty("0.1.2")

    def process_signals(self, signals):
        outgoing_signals = []
        for signal in signals:
            new_signal_dict = {}
            for attr in self.new_attributes():
                _bytes = attr.value(signal)
                _type = attr.format(signal).value
                _endian = attr.endian(signal).value
                fmt_char = None
                if _type in ['int', 'uint']:
                    if len(_bytes) == 2:
                        fmt_char = 'h'
                    elif len(_bytes) == 4:
                        fmt_char = 'i'
                    elif len(_bytes) == 8:
                        fmt_char = 'q'
                    if _type == 'uint':
                        fmt_char = fmt_char.upper()
                elif _type == 'float':
                    if len(_bytes) == 2:
                        fmt_char = 'e'  # added in python 3.6
                    elif len(_bytes) == 4:
                        fmt_char = 'f'
                    elif len(_bytes) == 8:
                        fmt_char = 'd'
                if fmt_char == None:
                    self.logger.error('cannot unpack {} bytes into {}'.format(
                        len(_bytes), _type))
                else:
                    fmt = _endian + fmt_char
                    try:
                        value = unpack(fmt, _bytes)[0]
                    except error as e:
                        if e.args[-1] == 'bad char in struct format':
                            self.logger.error('Python >= 3.6 is required to '
                                              'unpack 2 bytes into a float')
                        raise e
                    new_signal_dict[attr.key(signal)] = value
            if new_signal_dict:
                new_signal = self.get_output_signal(new_signal_dict, signal)
                outgoing_signals.append(new_signal)
        self.notify_signals(outgoing_signals)
コード例 #12
0
class MSSQLRawQuery(EnrichSignals, MSSQLBase):

    version = VersionProperty("1.0.1")
    query = StringProperty(
        title='Parameterized Query (use ? for any user-supplied values)',
        default='SELECT * FROM table where id=?',
        order=1)
    params = ListProperty(ParamField,
                          title='Substitution Parameters (In Order)',
                          default=[],
                          order=2)

    def process_signals(self, signals, **kwargs):
        if self.is_connecting:
            self.logger.error(
                'Connection already in progress. Dropping signals.')
            return

        cursor = self._get_cursor()

        output_signals = []

        for signal in signals:
            _query = self.query()
            _params = list(
                param.parameter(signal) for param in self.params(signal))

            result = cursor.execute(
                _query,
                _params) if len(_params) > 0 else cursor.execute(_query)

            try:
                rows = result.fetchall()
                for row in rows:
                    hashed_row = zip([r[0] for r in cursor.description], row)
                    signal_dict = {a: b for a, b in hashed_row}
                    output_signals.append(
                        self.get_output_signal(signal_dict, signal))

            except Exception:
                cursor.commit()
                output_signals.append(
                    self.get_output_signal({'affected_rows': result.rowcount},
                                           signal))

        cursor.close()

        if len(output_signals) > 0:
            self.notify_signals(output_signals, output_id='results')
        else:
            output_signals.append(
                self.get_output_signal({'results': None}, signals[0]))
            self.notify_signals(output_signals, output_id='no_results')
コード例 #13
0
class XeroManualJournals(Block):
    manual_journal_entries = ListProperty(ManualJournals,
                                          title='Manual Journal Entries',
                                          default=[])
    version = VersionProperty("0.1.3")
    consumer_key = StringProperty(title='Xero Consumer Key',
                                  default='[[XERO_CONSUMER_KEY]]',
                                  allow_none=False)

    def __init__(self):
        self.xero = None
        self.credentials = None
        super().__init__()

    def configure(self, context):
        super().configure(context)

        con_key = self.consumer_key()
        with open('blocks/xero/keys/privatekey.pem') as keyfile:
            rsa_private_key = keyfile.read()

        self.credentials = PrivateCredentials(con_key, rsa_private_key)
        self.xero = Xero(self.credentials)

    def start(self):
        super().start()

    def process_signals(self, signals):
        response_signal = []
        for signal in signals:

            for man_jour in self.manual_journal_entries():
                line_items_list = []
                for jour_line in man_jour.journal_lines():
                    line_items_list.append({
                        'Description':
                        jour_line.line_description(),
                        'LineAmount':
                        jour_line.line_amount(signal),
                        'AccountCode':
                        jour_line.account_code()
                    })
                response_signal.append(
                    Signal(
                        self.xero.manualjournals.put({
                            'Narration':
                            man_jour.narration(signal),
                            'Status':
                            'POSTED',
                            'JournalLines':
                            line_items_list
                        })[0]))
コード例 #14
0
class GoogleAnalyticsRealtime(GoogleAnalyticsBase):

    # Overridden for default property name
    metrics = ListProperty(StringType,
                           title="Analytics Metrics",
                           default=["rt:activeUsers"])
    dimensions = ListProperty(StringType,
                              title="Analytics Dimensions",
                              default=["rt:city"])
    version = VersionProperty("1.0.1")

    def get_url_suffix(self):
        """ Required override for GoogleOAuth Block """

        return 'analytics/v3/data/realtime'

    def get_url_parameters(self):
        params = super().get_url_parameters()

        params["dimensions"] = ",".join(self.dimensions())

        return params
コード例 #15
0
class TwilioSMS(TerminatorBlock):

    recipients = ListProperty(Recipient, title='Recipients', default=[])
    creds = ObjectProperty(TwilioCreds, title='Credentials')
    from_ = StringProperty(title='From', default='[[TWILIO_NUMBER]]')
    message = Property(title='Message', default='')
    version = VersionProperty("1.0.0")

    def __init__(self):
        super().__init__()
        self._client = None

    def configure(self, context):
        super().configure(context)
        self._client = TwilioRestClient(self.creds().sid(),
                                        self.creds().token())

    def process_signals(self, signals):
        for s in signals:
            self._send_sms(s)

    def _send_sms(self, signal):
        try:
            message = self.message(signal)

            for rcp in self.recipients():
                Thread(target=self._broadcast_msg, args=(rcp, message)).start()

        except Exception as e:
            self.logger.error("Message evaluation failed: {0}: {1}".format(
                type(e).__name__, str(e)))

    def _broadcast_msg(self, recipient, message, retry=False):
        body = "%s: %s" % (recipient.name(), message)
        try:
            # Twilio sends back some useless XML. Don't care.
            response = self._client.messages.create(to=recipient.number(),
                                                    from_=self.from_(),
                                                    body=body)
        except TwilioRestException as e:
            self.logger.error("Status %d" % e.status)
            if not retry:
                self.logger.debug("Retrying failed request")
                self._broadcast_msg(recipient, message, True)
            else:
                self.logger.error("Retry request failed")
        except Exception as e:
            self.logger.error("Error sending SMS to %s (%s): %s" %
                              (recipient.name(), recipient.number(), e))
コード例 #16
0
class MongoDBAggregation(MongoDBBase):
    """ A block for finding and grouping multiple documents together."""
    pipeline = ListProperty(AggregationPipe,
                            title="Aggregation Pipeline",
                            default=[AggregationPipe()])

    def execute_query(self, collection, signal):
        pipes = []
        for pipe in self.pipeline():
            pipes.append(self.evaluate_expression(pipe.pipe, signal))

        self.logger.debug("Searching aggregation {}".format(pipes))

        cursor = collection.aggregate(pipes, **(self.query_args()))
        return cursor
コード例 #17
0
class ConditionalModifier(Block):
    """ Conditional Modifier block.

    Adds a new new field, *title*, to input signals. The
    value of the attribute is determined by the *lookup*
    parameter. *lookup* is a list of formula/value pairs.
    In order, the *formula* of *lookup* are evaluated and
    when an evaluation is *True*, the *value* is assigned
    to the signal attribute *title*. If multiple formulas
    match, the first value is the one that is assigned
    to the signal.

    """

    fields = ListProperty(SignalField, title='Fields', default=[], order=0)
    exclude = BoolProperty(default=False, title='Exclude existing fields?')
    version = VersionProperty("1.1.0")

    def process_signals(self, signals):
        fresh_signals = []

        for signal in signals:

            # if we are including only the specified fields, create
            # a new, empty signal object
            tmp = Signal() if self.exclude() else signal

            # iterate over the specified fields, evaluating the formula
            # in the context of the original signal
            for field in self.fields():
                value = self._evaluate_lookup(field.lookup(), signal)
                setattr(tmp, field.title(), value)
            # only rebuild the signal list if we're using new objects
            if self.exclude:
                fresh_signals.append(tmp)

        if self.exclude:
            signals = fresh_signals

        self.notify_signals(signals)

    def _evaluate_lookup(self, lookup, signal):
        for lu in lookup:
            value = lu.formula(signal)
            if value:
                return lu.value(signal)
コード例 #18
0
class Modifier(Block):
    """ A nio block for enriching signals.

    By default, the modifier block adds attributes to
    existing signals as specified. If the 'exclude' flag is
    set, the block instantiates new (generic) signals and
    passes them along with *only* the specified fields.

    Properties:
        - fields(list): List of attribute names and corresponding values to add
                        to the incoming signals.
        - exclude(bool): If `True`, output signals only contain the attributes
                   specified by `fields`.
    """

    exclude = BoolProperty(default=False,
                           title='Exclude existing fields?',
                           order=0)
    fields = ListProperty(SignalField, title='Fields', default=[], order=1)
    version = VersionProperty("1.1.0")

    def process_signals(self, signals):
        fresh_signals = []

        for signal in signals:

            # if we are including only the specified fields, create
            # a new, empty signal object
            tmp = Signal() if self.exclude() else signal

            # iterate over the specified fields, evaluating the formula
            # in the context of the original signal
            for field in self.fields():
                value = field.formula(signal)
                title = field.title(signal)
                setattr(tmp, title, value)

            # only rebuild the signal list if we're using new objects
            if self.exclude:
                fresh_signals.append(tmp)

        if self.exclude():
            signals = fresh_signals

        self.notify_signals(signals)
コード例 #19
0
class Sortable():
    """ A Mongo block mixin that allows you to sort results """

    sort = ListProperty(Sort, title='Sort', default=[])

    def __init__(self):
        super().__init__()
        self._sort = []

    def configure(self, context):
        super().configure(context)

        self._sort = [(s.key(), s.direction().value) for s in self.sort()]

    def query_args(self):
        existing_args = super().query_args()
        existing_args['sort'] = self._sort
        return existing_args
コード例 #20
0
class TextClassifier(Block):

    training_set = ListProperty(TrainingSetDataPoint,
                                title="Training Set",
                                default=[])
    version = VersionProperty('0.1.0')

    def __init__(self):
        super().__init__()
        self._classifier = Pipeline([('vect', CountVectorizer()),
                                     ('tfidf', TfidfTransformer()),
                                     ('clf',
                                      SGDClassifier(loss='hinge',
                                                    penalty='l2',
                                                    alpha=1e-3,
                                                    n_iter=5,
                                                    random_state=42))])

    def configure(self, context):
        super().configure(context)
        try:
            training_data = []
            training_targets = []
            for data_point in self.training_set():
                training_data.append(data_point.data())
                training_targets.append(data_point.target())
            self._classifier.fit(training_data, training_targets)
        except:
            self.logger.warning("No training data available during configure")

    def process_signals(self, signals):
        predicted_signals = []
        for signal in signals:
            try:
                predicted = self._classifier.predict([signal.sample])
                signal.target = predicted[0]
                predicted_signals.append(signal)
            except:
                self.logger.warning("Classifier does not have training data",
                                    exc_info=True)
        self.notify_signals(predicted_signals)
コード例 #21
0
class BufferStatusUpdate(TerminatorBlock):

    text = Property(default='{{$text}}', title='Status Update Text')
    profile_ids = ListProperty(ProfileID,
                               default=[ProfileID()],
                               title='Profile IDs')
    access_token = StringProperty(default='[[BUFFER_ACCESS_TOKEN]]',
                                  title='Access Token')
    version = VersionProperty("1.0.0")

    def process_signals(self, signals):
        for s in signals:
            try:
                text = self.text(s)
            except Exception as e:
                self.logger.error("Text evaluation failed: {0}: {1}".format(
                    type(e).__name__, str(e)))
                continue
            data = {
                'access_token': self.access_token(),
                'text': text,
                'profile_ids[]':
                [pid.profile_id() for pid in self.profile_ids()]
            }
            self._status_update(data)

    def _status_update(self, payload):
        response = requests.post(POST_URL, data=payload)
        status = response.status_code
        if status != 200:
            self.logger.error("Buffer Status Update to {} with text '{}' "
                              "failed with status {}".format(
                                  payload['profile_ids[]'], payload['text'],
                                  status))
        else:
            self.logger.debug(
                "Buffer Status Update to {} with text '{}'".format(
                    payload['profile_ids[]'], payload['text']))
コード例 #22
0
ファイル: stock_block.py プロジェクト: nio-blocks/stock
class Stock(RESTPolling):

    _yql_base_url = ("https://query.yahooapis.com/v1/public/yql?"
                     "q=select%20*%20from%20yahoo.finance.quote%20where"
                     "%20symbol%20in%20({0})&format=json"
                     "&env=store%3A%2F%2Fdatatables.org%2Falltableswithkeys")

    queries = ListProperty(StringType, title='Symbols/Tickers', default=[])
    version = VersionProperty("1.0.2")

    def _prepare_url(self, paging=False):
        sym_str = ",".join(['"' + sym + '"' for sym in self.queries()])
        self._url = self._yql_base_url.format(sym_str)

        return {"Content-Type": "application/json"}

    def _process_response(self, resp):
        body = resp.json()

        try:
            results = body['query']['results']
        except Exception:
            self.logger.error("Invalid resposnse format: {0}".format(body))
            results = None

        if not results or not isinstance(results, dict):
            self.logger.warning("No results found: {0}".format(results))
            return [], None

        # If only one query, then quote is a dict instead of a list
        quote = results.get('quote', [])
        if not isinstance(quote, list):
            quote = [quote]

        # No paging
        return [Signal(s) for s in quote], None
コード例 #23
0
ファイル: voice_block.py プロジェクト: niolabs/demo_mlir
class TwilioVoice(TerminatorBlock):

    recipients = ListProperty(Recipient, title='Recipients', default=[])
    creds = ObjectProperty(TwilioCreds, title='Credentials')
    from_ = StringProperty(default='[[TWILIO_NUMBER]]', title='From')
    url = StringProperty(default='', title='Callback URL')
    message = Property(
        default='An empty voice message',
        title='Message')
    port = IntProperty(title='Port', default=8184)
    host = StringProperty(title='Host', default='[[NIOHOST]]')
    endpoint = StringProperty(title='Endpoint', default='')
    version = VersionProperty("1.0.0")

    def __init__(self):
        super().__init__()
        self._client = None
        self._messages = {}
        self._server = None

    def configure(self, context):
        super().configure(context)
        self._client = TwilioRestClient(self.creds().sid,
                                        self.creds().token)
        conf = {
            'host': self.host(),
            'port': self.port()
        }
        self.configure_server(conf, Speak(self.endpoint(), self))

    def start(self):
        super().start()
        # Start Web Server
        self.start_server()

    def stop(self):
        super().stop()
        # Stop Web Server
        self.stop_server()

    def process_signals(self, signals):
        for s in signals:
            self._place_calls(s)

    def _place_calls(self, signal):
        try:
            msg = self.message(signal)
            msg_id = uuid4().hex
            self._messages[msg_id] = msg
            for rcp in self.recipients():
                spawn(target=self._call, recipient=rcp, message_id=msg_id)
        except Exception as e:
            self.logger.error(
                "Message evaluation failed: {0}: {1}".format(
                    type(e).__name__, str(e))
            )

    def _call(self, recipient, message_id, retry=False):
        try:
            # Twilio sends back some useless XML. Don't care.
            to = recipient.number,
            from_ = self.from_(),
            url = "%s?msg_id=%s" % (self.url(), message_id)
            self.logger.debug("Making call to {}, from {}, with callback url"
                              " {}".format(to, from_, url))
            self._client.calls.create(
                to=to,
                from_=from_,
                url=url
            )
        except TwilioRestException as e:
            self.logger.error("Status %d" % e.status)
            if not retry:
                self.logger.debug("Retrying failed request")
                self._call(recipient, message_id, True)
            else:
                self.logger.error("Retry request failed")
        except Exception as e:
            self.logger.error("Error sending voice {}: {}".format(
                recipient, e
            ))
コード例 #24
0
class SignalField(PropertyHolder):
    title = StringProperty(default='', title="Attribute Name", order=0)
    lookup = ListProperty(LookupProperty, title='Lookup', default=[], order=1)
コード例 #25
0
class PlotlyDash(TerminatorBlock):

    version = VersionProperty("0.1.1")
    graph_series = ListProperty(Series, title='Data Series', default=[])
    x_axis = Property(
        title='Independent Variable',
        default='{{ datetime.datetime.utcnow() }}',
        allow_none=False
    )
    title = StringProperty(
        title='Title', default='Plotly Title', allow_none=False)
    num_data_points = IntProperty(
        title='How many points to display', default=20, allow_none=True)
    port = IntProperty(title='Port', default=8050)
    update_interval = IntProperty(title='Update Interval (seconds)', default=1)

    def __init__(self):
        self._main_thread = None
        self.app = dash.Dash()
        # self.app.config.supress_callback_exceptions=True
        self.data_dict = {}
        self.data = []
        super().__init__()

    def configure(self, context):
        super().configure(context)
        self.data_dict = {
            s.name(): {'x': [], 'y': [], 'name': s.name()}
            for s in self.graph_series()
        }

    def start(self):
        self._main_thread = spawn(self._server)
        self.logger.debug('server started on localhost:{}'.format(self.port()))

        self.data = self.data_dict_to_data_list(self.data_dict)
        figure = {'data': self.data, 'layout': {'title': self.title()}}
        app_layout = [
            dcc.Graph(id=self.title(), figure=figure),
            dcc.Interval(id='interval-component', interval=self.update_interval() * 1000)
        ]

        self.app.layout = html.Div(app_layout)

        @self.app.callback(Output(self.title(), 'figure'),
                           events=[Event('interval-component', 'interval')])
        def update_graph_live():
            return {'data': self.data, 'layout': {'title': self.title()}}

        @self.app.server.route('/shutdown', methods=['GET'])
        def shutdown():
            shutdown_server()
            return 'OK'

        def shutdown_server():
            func = request.environ.get('werkzeug.server.shutdown')
            if func is None:
                self.logger.warning('Not running with the Werkzeug Server')
            func()
        super().start()


    def stop(self):
        # http://flask.pocoo.org/snippets/67/
        try:
            r = requests.get('http://localhost:{}/shutdown'.format(
                self.port()))
            self.logger.debug('shutting down server ...')
        except:
            self.logger.warning('shutdown_server callback failed')
        try:
            self._main_thread.join()
            self.logger.debug('_main_thread joined')
        except:
            self.logger.warning('_main_thread exited before join() call')
        if self._main_thread.is_alive():
            self.logger.warning('_main_thread did not exit')
        super().stop()

    def process_signals(self, signals):
        # process_signals just needs to update self.data list
        # append new signal data to the proper dict key
        for signal in signals:
            for series in self.graph_series():
                if not isinstance(self.x_axis(signal), list):
                    if len(self.data_dict[series.name()]['y']) \
                            < self.num_data_points():
                        self.data_dict[series.name()]['x'].append(
                            self.x_axis(signal))
                        self.data_dict[series.name()]['y'].append(
                            series.y_axis(signal))
                        for arg in series.kwargs():
                            self.data_dict[series.name()][arg] = \
                            series.kwargs()[arg]
                    else:
                        self.data_dict[series.name()]['x'].append(
                            self.x_axis(signal))
                        self.data_dict[series.name()]['y'].append(
                            series.y_axis(signal))
                        self.data_dict[series.name()]['x'] = \
                            self.data_dict[series.name()]['x'][1:]
                        self.data_dict[series.name()]['y'] = \
                            self.data_dict[series.name()]['y'][1:]
                        for arg in series.kwargs():
                            self.data_dict[series.name()][arg] = \
                            series.kwargs()[arg]
                else:
                    self.data_dict[series.name()]['x'] = self.x_axis(signal)
                    self.data_dict[series.name()]['y'] = series.y_axis(signal)
                    for arg in series.kwargs():
                        self.data_dict[series.name()][arg] = series.kwargs()[arg]

        self.data = self.data_dict_to_data_list(self.data_dict)

    @staticmethod
    def data_dict_to_data_list(dict):
        return [v for d,v in dict.items()]

    def _server(self):
        self.app.layout = html.Div()
        # if debug isn't passed the server breaks silently
        self.app.run_server(debug=False, port=self.port(), host='0.0.0.0')
コード例 #26
0
class Data(PropertyHolder):
    params = ListProperty(Param, title="Parameters", default=[])
    form_encode_data = BoolProperty(default=False, title="Form-Encode Data?")
コード例 #27
0
class PlotlyDash(Block):

    version = VersionProperty('0.1.0')
    graph_series = ListProperty(Series, title='Data Series', default=[])
    x_axis = Property(
        title='Independent Variable',
        default='{{ datetime.datetime.utcnow() }}',
        allow_none=False
    )
    title = StringProperty(
        title='Title', default='Plotly Title', allow_none=False)
    num_data_points = IntProperty(
        title='How many points to display', default=20, allow_none=True)

    def __init__(self):
        self._main_thread = None
        self.app = dash.Dash()
        self.app.config.supress_callback_exceptions=True
        self.data_dict = {}
        self.data = []
        super().__init__()

    def start(self):
        self._main_thread = spawn(self._server)
        self.logger.debug('server started on localhost:8050')
        super().start()

        self.data_dict = {
            s.name(): {'x': [], 'y': [], 'name': s.name()}
            for s in self.graph_series()
        }
        self.data = self.data_dict_to_data_list(self.data_dict)
        figure = {'data': self.data, 'layout': {'title': self.title()}}
        app_layout = [
            dcc.Graph(id=self.title(), figure=figure),
            dcc.Interval(id='interval-component', interval=1 * 1000)
        ]

        self.app.layout = html.Div(app_layout)

        @self.app.callback(Output(self.title(), 'figure'),
                           events=[Event('interval-component', 'interval')])
        def update_graph_live():
            return {'data': self.data, 'layout': {'title': self.title()}}

    def stop(self):
        try:
            self._main_thread.join(1)
            self.logger.debug('server stopped')
        except:
            self.logger.warning('main thread exited before join()')
        super().stop()

    def process_signals(self, signals):
        # process_signals just needs to update self.data list
        # append new signal data to the proper dict key
        for signal in signals:
            for series in self.graph_series():
                if len(self.data_dict[series.name()]['y']) \
                        < self.num_data_points():
                    self.data_dict[series.name()]['x'].append(
                        self.x_axis(signal))
                    self.data_dict[series.name()]['y'].append(
                        series.y_axis(signal))
                else:
                    self.data_dict[series.name()]['x'].append(
                        self.x_axis(signal))
                    self.data_dict[series.name()]['y'].append(
                        series.y_axis(signal))
                    self.data_dict[series.name()]['x'] = \
                        self.data_dict[series.name()]['x'][1:]
                    self.data_dict[series.name()]['y'] = \
                        self.data_dict[series.name()]['y'][1:]
        self.data = self.data_dict_to_data_list(self.data_dict)

    @staticmethod
    def data_dict_to_data_list(dict):
        return [v for d,v in dict.items()]

    def _server(self):
        self.app.layout = html.Div()
        self.app.run_server(debug=False)
コード例 #28
0
ファイル: email_block.py プロジェクト: niolabs/demo_mlir
class Email(TerminatorBlock):
    """ A block for sending email.

    Properties:
        to (list(Identity)): A list of recipient identities (name/email).
        server (SMTPConfig): host, port, account, etc. for SMTP server.
        message (Message): The message contents and sender name.

    """
    version = VersionProperty("0.1.0")
    to = ListProperty(Identity, title='Receiver', default=[])
    server = ObjectProperty(SMTPConfig, title='Server', allow_none=False)
    message = ObjectProperty(Message, title='Message', allow_none=True)

    def __init__(self):
        super().__init__()
        self._retry_conn = None

    def process_signals(self, signals):
        """ For each signal object, build the configured message and send
        it to each recipient.

        Note that this method does not return until all of the messages are
        successfully sent (i.e. all the sendmail threads have exited). This
        avoids dropped messages in the event that the disconnect thread gets
        scheduled before all sendmail threads are complete.

        Args:
            signals (list(Signal)): The signals to process.

        Returns:
            None

        """
        # make a new connection to the SMTP server each time we get a new
        # batch of signals.
        smtp_conn = SMTPConnection(self.server(), self.logger)
        try:
            smtp_conn.connect()
        except:
            self.logger.error('Aborting sending emails. '
                              '{} signals discarded'.format(len(signals)))
            return

        # handle each incoming signal
        for signal in signals:
            try:
                subject = self.message().subject(signal)
            except Exception as e:
                subject = self.get_defaults()['message'].subject()
                self.logger.error(
                    "Email subject evaluation failed: {0}: {1}".format(
                        type(e).__name__, str(e)))

            try:
                body = self.message().body(signal)
            except Exception as e:
                body = self.get_defaults()['message'].body()
                self.logger.error(
                    "Email body evaluation failed: {0}: {1}".format(
                        type(e).__name__, str(e)))

            self._send_to_all(smtp_conn, subject, body, signal)

        # drop the SMTP connection after each round of signals
        smtp_conn.disconnect()

    def _send_to_all(self, conn, subject, body, signal):
        """ Build a message based on the provided content and send it to
        each of the configured recipients.

        Args:
            conn (SMTPConnection): The connection over which to send
                the message.
            subject (str): The desired subject line of the message.
            body (str): The desired message body.

        Returns:
            None

        """
        sender = self.message().sender()
        msg = self._construct_msg(subject, body)
        for rcp in self.to():
            # customize the message to each recipient
            msg['To'] = rcp.name(signal)
            try:
                conn.sendmail(sender, rcp.email(signal), msg.as_string())
                self.logger.debug("Sent mail to: {}".format(rcp.email(signal)))
            except Exception as e:
                self.logger.error("Failed to send mail: {}".format(e))

    def _construct_msg(self, subject, body):
        """ Construct the multipart message. Mail clients unable to
        render HTML will default to plaintext.

        Args:
            subject (str): The subject line.
            body (str): The message body.

        Returns:
            msg (MIMEMultipart): A message containing generic
                headers, and HTML version, and a plaintext version.

        """
        msg = MIMEMultipart('alternative')
        msg['Subject'] = subject
        msg['From'] = self.message().sender()

        plain_part = MIMEText(body, 'plain')
        msg.attach(plain_part)

        html_part = MIMEText(HTML_MSG_FORMAT.format(body), 'html')
        msg.attach(html_part)

        return msg
コード例 #29
0
ファイル: twitter_block.py プロジェクト: nio-blocks/twitter
class Twitter(TwitterStreamBlock):
    """ A block for communicating with the Twitter Streaming API.
    Reads Tweets in real time, notifying other blocks via NIO's signal
    interface at a configurable interval.

    Properties:
        phrases (list(str)): The list of phrases to track.
        follow (list(str)): The list of users to track.
        fields (list(str)): Outgoing signals will pull these fields
            from incoming tweets. When empty/unset, all fields are
            included.
        language (list(str)): Only get tweets of the specifed language.
        filter_level (FilterLevel): Minimum value of the filter_level Tweet
            attribute.
        locations (list(Location)): A comma-separated list of longitude,
            latitude pairs specifying a set of bounding boxes to filter
            Tweets by.
        notify_freq (timedelta): The interval between signal notifications.
        creds: Twitter app credentials, see above. Defaults to global settings.
        rc_interval (timedelta): Time to wait between receipts (either tweets
            or hearbeats) before attempting to reconnect to Twitter Streaming.

    """

    version = VersionProperty("2.0.0")
    phrases = ListProperty(StringType, default=[], title='Query Phrases')
    follow = ListProperty(StringType, default=[], title='Follow Users')
    fields = ListProperty(StringType, default=[], title='Included Fields')
    language = ListProperty(StringType, default=['en'], title='Language')
    filter_level = SelectProperty(FilterLevel,
                                  default=FilterLevel.none,
                                  title='Filter Level')
    locations = ListProperty(Location, default=[], title='Locations')

    streaming_host = 'stream.twitter.com'
    streaming_endpoint = '1.1/statuses/filter.json'
    users_endpoint = 'https://api.twitter.com/1.1/users/lookup.json'

    def __init__(self):
        super().__init__()
        self._user_ids = []

    def _start(self):
        self._set_user_ids()

    def _set_user_ids(self):
        if len(self.follow()) == 0:
            return
        auth = OAuth1(self.creds().consumer_key(),
                      self.creds().app_secret(),
                      self.creds().oauth_token(),
                      self.creds().oauth_token_secret())
        # user ids can be grabbed 100 at a time.
        for i in range(0, len(self.follow()), 100):
            data = {"screen_name": ','.join(self.follow()[i:i + 100])}
            resp = requests.post(self.users_endpoint, data=data, auth=auth)
            if resp.status_code == 200:
                for user in resp.json():
                    id = user.get('id_str')
                    if id is not None:
                        self._user_ids.append(id)
        self.logger.debug("Following {} users".format(len(self._user_ids)))

    def get_params(self):
        params = {
            'stall_warnings': 'true',
            'delimited': 'length',
            'track': ','.join(self.phrases()),
            'follow': ','.join(self._user_ids),
            'filter_level': self.filter_level().name
        }
        if self.language():
            params['language'] = ','.join(self.language())
        if self.locations():
            locations = []
            for location in self.locations():
                locations.append(str(location.southwest().longitude()))
                locations.append(str(location.southwest().latitude()))
                locations.append(str(location.northeast().longitude()))
                locations.append(str(location.northeast().latitude()))
            params['locations'] = ','.join(locations)
        return params

    def get_request_method(self):
        return "POST"

    def filter_results(self, data):
        """ Filters incoming tweet objects to include only the configured
        fields (or all of them, if self.fields is empty).

        """
        # If they did not specify which fields, just give them everything
        if not self.fields() or len(self.fields()) == 0:
            return data

        result = {}
        for f in self.fields():
            try:
                result[f] = data[f]
            except:
                self.logger.error("Invalid Twitter field: %s" % f)

        return result

    def create_signal(self, data):
        for msg in PUB_STREAM_MSGS:
            if data and msg in data:

                # Log something about the message
                report = "{} notice".format(PUB_STREAM_MSGS[msg])
                if msg == "disconnect":
                    error_idx = int(data['disconnect']['code']) - 1
                    report += ": {}".format(DISCONNECT_REASONS[error_idx])
                elif msg == "warning":
                    report += ": {}".format(data['message'])
                self.logger.debug(report)

                # Calculate total limit for limit signals
                if msg == "limit":
                    # lock when calculating limit
                    with self._get_result_lock('limit'):
                        self._calculate_limit(data)

                # Anything that is not 'limit' or 'tweet' is considered 'other'
                if msg != "limit":
                    msg = "other"

                # Add a signal to the appropriate list
                with self._get_result_lock(msg):
                    self._result_signals[msg].append(Signal(data))

                return

        # If we didn't return yet, the message is a regular tweet.
        self.logger.debug("It's a tweet!")
        data = self.filter_results(data)
        if data:
            with self._get_result_lock('tweets'):
                self._result_signals['tweets'].append(Signal(data))

    def _calculate_limit(self, data):
        """ Calculate total limit count for limit signals """
        track = data.get('limit', {}).get('track', 0)
        if track > self._limit_count:
            limit = track - self._limit_count
            self._limit_count = track
        else:
            limit = 0
        data['count'] = limit
        data['cumulative_count'] = track
コード例 #30
0
class TensorFlow(EnrichSignals, Block):

    layers = ListProperty(Layers,
                          title='Network Layers',
                          default=[{
                              'count': 10,
                              'activation': 'softmax',
                              'initial_weights': 'random',
                              'bias': True
                          }])
    network_config = ObjectProperty(NetworkConfig,
                                    title='ANN Configuration',
                                    defaul=NetworkConfig())
    models = ObjectProperty(ModelManagement,
                            title='Model Management',
                            default=ModelManagement())
    version = VersionProperty("0.5.0")

    def __init__(self):
        super().__init__()
        self.X = None
        self.XX = None
        self.Y_ = None
        self.prob_keep = None
        self.train_step = None
        self.correct_prediction = None
        self.prediction = None
        self.sess = None
        self.loss_function = None
        self.saver = None
        self.iter = 0
        self.summaries = None
        self.summary_writer = None

    def configure(self, context):
        super().configure(context)
        if self.network_config().random_seed() != None:
            tf.set_random_seed(self.network_config().random_seed())
        # input tensor shape
        shape = []
        for dim in self.network_config().input_dim():
            if dim.value.value == -1:
                shape.append(None)
            else:
                shape.append(dim.value.value)
        self.X = tf.placeholder(tf.float32, shape=shape, name='INPUT')
        # specify desired output (labels)
        shape = [None, self.layers()[-1].count()]
        self.Y_ = tf.placeholder(tf.float32, shape=shape, name='LABELS')
        self.prob_keep = tf.placeholder(tf.float32, name='PROB_KEEP')
        layers_logits = {}
        prev_layer = self.X
        for i, layer in enumerate(self.layers()):
            name = 'layer{}'.format(i)
            with tf.name_scope(name):
                if layer.activation().value != 'dropout':
                    flattened = 1
                    for dim in prev_layer.shape:
                        if dim.value != None:
                            flattened *= dim.value
                    # TODO: Flatten only if not convolutional layer
                    XX = tf.reshape(prev_layer, [-1, flattened])
                    W = tf.Variable(getattr(tf,
                                            layer.initial_weights().value)([
                                                XX.shape[-1].value,
                                                layer.count()
                                            ]),
                                    name='{}_WEIGHTS'.format(name))
                    b = tf.Variable(getattr(
                        tf,
                        layer.initial_weights().value)([layer.count()]),
                                    name='{}_BIASES'.format(name))
                    if self.models().tensorboard_int():
                        with tf.name_scope('weights'):
                            tf.summary.histogram('weights', W)
                        with tf.name_scope('biases'):
                            tf.summary.histogram('biases', b)
                    if i == (len(self.layers()) - 1):
                        # calculate logits separately for use by loss function
                        if layer.bias.value:
                            layers_logits[name + '_logits'] = \
                                tf.matmul(XX, W) + b
                        else:
                            layers_logits[name + '_logits'] = \
                                tf.matmul(XX, W)
                        layers_logits[name] = getattr(
                            tf.nn,
                            layer.activation().value)(layers_logits[name +
                                                                    '_logits'])
                    else:
                        if layer.bias.value:
                            logits = tf.matmul(XX, W) + b
                        else:
                            logits = tf.matmul(XX, W)
                        layers_logits[name] = \
                            getattr(tf.nn, layer.activation().value)(logits)
                else:
                    name = 'layer{}_d'.format(i)
                    layers_logits[name] = tf.nn.dropout(
                        prev_layer, self.prob_keep)
                prev_layer = layers_logits[name]
        output_layer_num = len(self.layers()) - 1
        Y = layers_logits['layer{}'.format(output_layer_num)]
        Y_logits = layers_logits['layer{}_logits'.format(output_layer_num)]
        if self.network_config().loss().value == 'cross_entropy':
            self.loss_function = tf.reduce_mean(abs(self.Y_ * tf.log(Y)))
        if self.network_config().loss().value == \
                'softmax_cross_entropy_with_logits':
            self.loss_function = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=Y_logits,
                                                        labels=self.Y_))
        if self.network_config().loss().value == 'mean_absolute_error':
            self.loss_function = tf.reduce_mean(abs(self.Y_ - Y))
        if self.models().tensorboard_int():
            with tf.name_scope('loss'):
                tf.summary.scalar(self.network_config().loss().value,
                                  self.loss_function)
        self.train_step = getattr(
            tf.train,
            self.network_config().optimizer().value)(
                self.network_config().learning_rate()).minimize(
                    self.loss_function)
        self.prediction = Y
        if self.models().load_file() or self.models().save_file():
            self.saver = tf.train.Saver(max_to_keep=None)
        self.sess = tf.Session()
        if self.models().tensorboard_int():
            label = self.models().tensorboard_tag()
            self.summaries = tf.summary.merge_all()
            self.summary_writer = tf.summary.FileWriter(
                '{}/{}'.format(self.models().tensorboard_dir(), label),
                self.sess.graph)
            self.logger.debug('TensorBoard summary label: {}'.format(label))
        if self.models().load_file():
            self.saver.restore(self.sess, self.models().load_file())
        else:
            self.sess.run(tf.global_variables_initializer())

    def process_signals(self, signals, input_id=None):
        new_signals = []
        for signal in signals:
            if input_id == 'train':
                if self.models().tensorboard_int():
                    if self.iter % self.models().tensorboard_int() == 0:
                        summary, _, loss, predict = self._train(signal)
                        self.summary_writer.add_summary(summary, self.iter)
                    else:
                        _, loss, predict = self._train(signal)
                    self.iter += 1
                else:
                    _, loss, predict = self._train(signal)
                output = {
                    'input_id': input_id,
                    'loss': loss,
                    'prediction': predict
                }
                new_signals.append(self.get_output_signal(output, signal))
            elif input_id == 'test':
                loss, predict = self._test(signal)
                output = {
                    'input_id': input_id,
                    'loss': loss,
                    'prediction': predict
                }
                new_signals.append(self.get_output_signal(output, signal))
            else:
                predict = self._predict(signal)
                output = {
                    'input_id': input_id,
                    'loss': None,
                    'prediction': predict
                }
                new_signals.append(self.get_output_signal(output, signal))
        self.notify_signals(new_signals)

    def stop(self):
        if self.models().save_file():
            self.logger.debug('saving model to {}'.format(
                self.models().save_file()))
            self.saver.save(self.sess, self.models().save_file())
        if self.models().tensorboard_int():
            self.summary_writer.close()
        self.sess.close()
        super().stop()

    def _train(self, signal):
        batch_X = signal.batch
        batch_Y_ = signal.labels
        fetches = [self.train_step, self.loss_function, self.prediction]
        dropout_rate = 1 - self.network_config().dropout()
        if self.models().tensorboard_int():
            if self.iter % self.models().tensorboard_int() == 0:
                fetches = [self.summaries] + fetches
        return self.sess.run(fetches,
                             feed_dict={
                                 self.X: batch_X,
                                 self.Y_: batch_Y_,
                                 self.prob_keep: dropout_rate
                             })

    def _test(self, signal):
        batch_X = signal.batch
        batch_Y_ = signal.labels
        fetches = [self.loss_function, self.prediction]
        return self.sess.run(fetches,
                             feed_dict={
                                 self.X: batch_X,
                                 self.Y_: batch_Y_,
                                 self.prob_keep: 1
                             })

    def _predict(self, signal):
        batch_X = signal.batch
        fetches = self.prediction
        return self.sess.run(fetches,
                             feed_dict={
                                 self.X: batch_X,
                                 self.prob_keep: 1
                             })