Exemple #1
0
    def get_context_data(self, **kwargs):
        profile_user = self.object
        context = {}

        count_types = OrderedDict()

        logged_user = None
        if self.request.user.is_authenticated():
            logged_user = User.objects.get(username=self.request.user)

        collaborations, count_types_extras = get_collaboration_data(
            logged_user, profile_user)

        collaborations.sort(key=lambda elem: elem.modified, reverse=True)

        count_types.update(count_types_extras)

        context['type_count'] = count_types
        context['results'] = collaborations[:10]

        query = get_visible_threads(logged_user, profile_user)
        context['emails'] = query.order_by('-received_time')[:10]

        count_by = 'thread__mailinglist__name'
        context['list_activity'] = dict(query.values_list(count_by)
                                        .annotate(Count(count_by))
                                        .order_by(count_by))

        context.update(kwargs)
        return super(UserProfileDetailView, self).get_context_data(**context)
def getDbSeqRecord(db_record, id_field, seq_field, meta_fields=None, 
                   delimiter=default_delimiter):
    """
    Parses a database record into a SeqRecord

    Arguments: 
    db_record = a dictionary containing a database record
    id_field = the field containing identifiers
    seq_field = the field containing sequences
    meta_fields = a list of fields to add to sequence annotations
    delimiter = a tuple of delimiters for (fields, values, value lists) 

    Returns: 
    a SeqRecord
    """
    # Return None if ID or sequence fields are empty
    if not db_record[id_field] or not db_record[seq_field]:
        return None
    
    # Create description string
    desc_dict = OrderedDict([('ID', db_record[id_field])])
    if meta_fields is not None:
        desc_dict.update([(f, db_record[f]) for f in meta_fields if f in db_record]) 
    desc_str = flattenAnnotation(desc_dict, delimiter=delimiter)
    
    # Create SeqRecord
    seq_record = SeqRecord(Seq(db_record[seq_field], IUPAC.ambiguous_dna),
                           id=desc_str, name=desc_str, description='')
        
    return seq_record
Exemple #3
0
    def get_git_branches(package_name, fetch=False):
        """
        Get both local and remote branches for the local git repo of
        `package_name`. Ensure that the package has a local repo by calling
        `check_git_repo` first.
        """
        pkg_path = imp.find_module(package_name)[1]
        r = git.Repo(pkg_path, search_parent_directories=True)
        # Start with active branch:
        try:
            branches = OrderedDict(((r.active_branch.name, r.active_branch),))
        except TypeError:
            # active branch is detached
            branches = OrderedDict((("<Detached HEAD>", None),))
        # Add local branches:
        branches.update(((b.name, b) for b in r.heads))

        # Add remote branches:
        for remote in r.remotes:
            try:
                if fetch:
                    remote.fetch()
            except git.GitCommandError:
                continue
            # Don't include PRs in list
            branches.update(((b.name, b) for b in remote.refs if
                             '/pr/' not in b.name))
        return branches
def _merge_fields_and_pk(pk, fields):
    fields_and_pk = OrderedDict()
    fields_and_pk['pk'] = pk
    fields_and_pk[pk.name] = pk
    fields_and_pk.update(fields)

    return fields_and_pk
Exemple #5
0
    def get_updates(self, v):
        # Contrastive divergence
        chain_end, updates_CD = self.CD(self, chain_start=v, cdk=self.CDk)

        # [Expected] negative log-likelihood
        cost = T.mean(self.free_energy(v), axis=0) - T.mean(self.free_energy(chain_end), axis=0)

        #Regularization
        cost += self.regularization

        # Gradients (use automatic differentiation)
        # We must not compute the gradient through the gibbs sampling, i.e. use consider_constant
        gparams = T.grad(cost, self.parameters, consider_constant=[chain_end])
        gradients = dict(zip(self.parameters, gparams))

        # Get learning rates for all params given their gradient.
        lr, updates_lr = self.learning_rate(gradients)

        updates = OrderedDict()
        updates.update(updates_CD)  # Add updates from CD
        updates.update(updates_lr)  # Add updates from learning_rate

        # Updates parameters
        for param, gparam in gradients.items():
            updates[param] = param - lr[param] * gradients[param]

        return updates
Exemple #6
0
    def __new__(mcs, name, bases, attrs):
        # Collect fields from current class.
        current_fields = []
        for key, value in list(attrs.items()):
            if isinstance(value, Field):
                current_fields.append((key, value))
                attrs.pop(key)
        current_fields.sort(key=lambda x: x[1].creation_counter)
        attrs["declared_fields"] = OrderedDict(current_fields)

        new_class = super(DeclarativeFieldsMetaclass, mcs).__new__(mcs, name, bases, attrs)

        # Walk through the MRO.
        declared_fields = OrderedDict()
        for base in reversed(new_class.__mro__):
            # Collect fields from base class.
            if hasattr(base, "declared_fields"):
                declared_fields.update(base.declared_fields)

            # Field shadowing.
            for attr in base.__dict__.keys():
                if attr in declared_fields:
                    declared_fields.pop(attr)

        new_class.base_fields = declared_fields
        new_class.declared_fields = declared_fields

        return new_class
Exemple #7
0
    def format(self, data, keys=None, group_by=None, domain=None):
        rows_dict = OrderedDict()
        tmp_data = OrderedDict()
        sorted_data = []
        value_chains = get_domain_configuration(domain).by_type_hierarchy
        for key, row in data.items():
            to_list = list(key)

            def find_name(list, deep):
                for element in list:
                    if deep == len(key)-3 and key[deep+1] == element.val:
                        return element.text
                    elif key[deep+1] == element.val:
                        return find_name(element.next, deep+1)

            name = find_name(value_chains, 0)
            to_list[2] = name
            tmp_data.update({tuple(to_list): row})
        if tmp_data:
            sorted_data = sorted(tmp_data.items(), key=lambda x: (x[0][0], x[0][2]))

        for row in sorted_data:
            formatted_row = self._format.format_row(row[1])
            if not rows_dict.has_key(formatted_row[0]):
                rows_dict[formatted_row[0]] = []
            rows_dict[formatted_row[0]].append(formatted_row[1])

        min_length = min([len(item[1]) for item in rows_dict.items()])

        for key, row in rows_dict.items():
            total_column = self.calculate_total_column(row)
            res = [key, total_column]
            res.extend(row[0:min_length])
            yield res
Exemple #8
0
    def convert_events_to_dict(self, events):
        """
        events is a list consit of sqlite3.Row object. We need to convert
        it to a ordered dict, which is ordered by timestamp asc. For example:
            events = [
            (7, 'storage, Cassandra', '2014-12-18 11:59:00.000', 120, 'plan'),
            (8, 'storage, Cassandra', '2014-12-19 09:07:00.000', 120, 'plan')
            ]

            The converted events should looks like below:
                {'7': {
                         'services': ['storage', 'Cassandra'],
                         'timestamp': '2014-12-18 11:59:00',
                         'duration': 120,
                         'state': 'plan'
                       },
                 '8': {
                         'services': ['storage", 'Cassandra'],
                         'timestamp': '2014-12-19 09:07:00',
                         'duration': 120,
                         'state': 'plan'
                       }
                }
        """
        if events:
            d = OrderedDict()
            for e in events:
                d.update({e[k]:{k:e[k] for k in e.keys() if k != 'id'}
                          for k in e.keys() if k == 'id'})
            for k in d.keys():
                d[k]['services'] = d[k]['services'].split(', ')
            return d
Exemple #9
0
def qs_sort_by_fields(queryset, field, order_by):
    list_of_id_and_interesting_field, list_of_values, id_of_objects_without_information = [], {}, []
    for k, v in queryset.items():
        list_of_id_and_interesting_field.append([k, v[field]])
         
    for i_id, i_field in list_of_id_and_interesting_field:
        if not(i_field):
            id_of_objects_without_information.append(i_id)
        elif not (list_of_values.has_key(i_field)):
            list_of_values[i_field] = [i_id]
        else:
            list_of_values[i_field].append(i_id)
              
    if int(order_by[field]) < 0:
        list_of_interesting_value_and_ids = [('', id_of_objects_without_information)] + sorted(list_of_values.items(), reverse=True)
    else:
        list_of_interesting_value_and_ids = sorted(list_of_values.items()) + [['', id_of_objects_without_information]]
        
    new_qs = OrderedDict()
    if (field != len(order_by) - 1):
        for i in list_of_interesting_value_and_ids:
            new_queryset = {}
            for id in i[1]:
                new_queryset.update({id: queryset[id]})
            new_qs.update(new_queryset if len(new_queryset.keys()) == 1 else qs_sort_by_fields(new_queryset, field + 1, order_by))
    else:
        for i in list_of_interesting_value_and_ids:
            for id in i[1]:
                new_qs[id] = queryset[id]
    return new_qs
Exemple #10
0
    def create_post(self, path, **kw):
        content = kw.pop('content', None)
        onefile = kw.pop('onefile', False)
        kw.pop('is_page', False)

        metadata = OrderedDict()
        metadata.update(self.default_metadata)
        metadata.update(kw)
        makedirs(os.path.dirname(path))

        with codecs.open(path, "wb+", "utf8") as fd:
            if onefile:
                fd.write("#+BEGIN_COMMENT\n")
                if write_metadata:
                    fd.write(write_metadata(metadata))
                else:
                    for k, v in metadata.items():
                        fd.write('.. {0}: {1}\n'.format(k, v))
                fd.write("#+END_COMMENT\n")
                fd.write("\n\n")

            if content:
                fd.write(content)
            else:
                fd.write('Write your post here.')
 def process_entries(bib_data):
     for key, entry in bib_data.items():
         fields = OrderedDict([('type', entry.original_type)])
         fields.update({k: v.render_as("html").replace("\\textsuperscript ", "") for
                        k,v in entry.rich_fields.items()})
         fields.update(process_person_roles(entry))
         yield key, fields
def load_forecasts(data):
    field_names = load_field_mappings(data)

    for period in get_period_list(data):
        validate_this_period_is_a_day(period)
        datetime_of_day = parse_utc_date(period['value'])
        for forecast in listify(period['Rep']):
            try:
                date_offset = int(forecast.pop('$'))
            except AttributeError:
                print(forecast)
                raise

            interval = 3 * 60

            pre = OrderedDict()
            pre['valid_from'] = calculate_datetime(datetime_of_day,
                                                   date_offset)
            pre['valid_to'] = calculate_datetime(datetime_of_day,
                                                 date_offset + interval)
            pre['supplier'] = 'met_office'
            pre['pressure'] = 0

            for symbolic_name, raw_value in forecast.items():
                field_name = field_names[symbolic_name]
                pre.update(parse_field(field_name, raw_value))

            pre['precipitation'] = pre['precipitation_probability']

            yield pre
Exemple #13
0
def top_diseases(das, request):    
    cohort = get_cohort(das, request)
        
    p = {
    "service"        : "report",
    "report"         : "chronicConditions",
    "reportingBasis" : "ServiceDate",
    "reportingFrom"  : request["reportingFrom"],
    "reportingTo"    : request["reportingTo"],
    "comparisonFrom" : request["comparisonFrom"], 
    "comparisonTo"   : request["comparisonTo"],
    "order"          : "Admits:desc",
    }
    
    if cohort is not None:
        p.update({"cohortId":cohort})
    
    r = das.response(p)
    conditions = OrderedDict()
    
    for i in sorted(r.data["reporting"]["Default"], key=lambda x: r.data["reporting"]["Default"][x]["withCondition"] 
        if x != "memberCount" and x != "memberMonths" else r.data["reporting"]["Default"][x], reverse=True): 
        if i != "memberCount" and i != "memberMonths":
            conditions.update({r.data["reporting"]["Default"][i]["description"] : r.data["reporting"]["Default"][i]})
    return conditions
Exemple #14
0
    def replace_params_with_objects(self, target_node, inline_func, call_object):
        """
        target_node is some AST object, could be the return value of a function we are inlining.
        We need to inspect its parameters and create a dictionary then use ParamReplacer to replace
        all instances of those parameters with the local references to the objects being passed in
        """
        args = inline_func.args
        default_offset = len(args.args) - len(args.defaults)

        arg_mapping = OrderedDict()
        for idx, arg in enumerate(arg for arg in args.args if not arg.id == "self"):
            arg_mapping[arg.id] = None
            if idx >= default_offset:
                arg_mapping[arg.id] = args.defaults[idx - default_offset]

            if len(call_object.args) > idx:
                arg_mapping[arg.id] = call_object.args[idx]

        for keyword in call_object.keywords:
            arg_mapping[keyword.arg] = keyword.value

        if len([arg for arg in args.args if arg.id == "self"]):
            # Ok, get the name of "self" (the instance of the class we are using)
            new_mapping = OrderedDict({"self": call_object.func.value})
            new_mapping.update(arg_mapping)
            arg_mapping = new_mapping

        return ParamReplacer(arg_mapping).visit(target_node)
Exemple #15
0
def subdivision_dict(idCountry):
    """Returns a dictionary containing all countries available.
    If available method will return it from cache otherwise it will
    make the service call.

    :param idCountry: retrieve list of subdivisions for specified country
    :type idCountry: int
    :returns: Dictionary with available countries in format
              idCountry: country
    :rtype: dict
    """
    response = api.membership.subdivision.list(token=admin_session.get_token(),
                                               idCountry=idCountry)
    if response.status_code != 200 or not response.json().get('content'):
        return dict()
    total_records = response.json()['_metadata']['totalRecords']
    subdivisions = OrderedDict()
    subdivisions.update((s['idSubdivision'], s)
                        for s in response.json()['content'])
    if len(subdivisions) < total_records:
        params = {'page': 1}
        while len(subdivisions) < total_records:
            response = api.membership.subdivision.list(
                token=admin_session.get_token(),
                idCountry=idCountry,
                params=params)
            subdivisions.update((s['idSubdivision'], s)
                                for s in response.json()['content'])
            params['page'] += 1
    return subdivisions
Exemple #16
0
 def _as_dct_(self):
     ret = self.meta.copy()
     ## without a target variable, attempt to set start and end dates.
     if self.variable is None:
         ds = nc.Dataset(self.uri,'r')
         try:
             time = ds.variables['time']
             time_bounds = [time[0],time[-1]]
             time_bounds = nc.num2date(time_bounds,time.units,calendar=time.calendar)
             derived = {'Start Date':str(time_bounds[0]),'End Date':str(time_bounds[1])}
         except:
             warn('Time variable not found or improperly attributed. Setting "derived" key to None.')
             derived = None
         finally:
             ds.close()
     ## we can get derived values
     else:
         derived = OrderedDict()
         to_add = self.get_temporal_report() + self.get_spatial_report() + self.get_level_report()
         for row in to_add:
             try:
                 key,value = re.split(' = ',row,maxsplit=1)
             ## here to catch oddities of the returns
             except ValueError:
                 if row == 'No level dimension found.':
                     continue
                 else:
                     raise
             key = key.strip()
             derived.update({key:value})
     ret.update({'derived':derived})
     return(ret)
Exemple #17
0
 def handleSubmit(self, action):
     unsorted_data, errors = self.extractData()
     if errors:
         self.status = self.formErrorsMessage
         return
     unsorted_data = self.updateServerSideData(unsorted_data)
     errors = self.processActions(unsorted_data)
     if errors:
         return self.setErrorsMessage(errors)
     data = OrderedDict(
         [x for x in getFieldsInOrder(self.schema) if x[0] in unsorted_data]
     )
     data.update(unsorted_data)
     thanksPageOverride = self.context.thanksPageOverride
     if thanksPageOverride:
         thanksPageOverrideAction = self.context.thanksPageOverrideAction
         thanksPage = get_expression(self.context, thanksPageOverride)
         if thanksPageOverrideAction == 'redirect_to':
             self.request.response.redirect(thanksPage)
         elif thanksPageOverrideAction == 'traverse_to':
             thanksPage = self.context.restrictedTraverse(
                 thanksPage.encode('utf-8'))
             thanksPage = mapply(
                 thanksPage,
                 self.request.args,
                 self.request
             ).encode('utf-8')
             self.request.response.write(thanksPage)
     else:
         # we come back to the form itself.
         # the thanks page is handled in the __call__ method
         pass
    def construct_yaml_map(self, node):
        data = OrderedDict()
        yield data
        value = self.construct_mapping(node)

        if isinstance(node, yaml.MappingNode):
            self.flatten_mapping(node)
        else:
            raise yaml.constructor.ConstructorError(
                None, None,
                'expected a mapping node, but found %s' % node.id,
                node.start_mark)

        mapping = OrderedDict()
        for key_node, value_node in node.value:
            key = self.construct_object(key_node, deep=False)
            try:
                hash(key)
            except TypeError as exc:
                raise yaml.constructor.ConstructorError(
                    'while constructing a mapping', node.start_mark,
                    'found unacceptable key (%s)' % exc, key_node.start_mark)
            value = self.construct_object(value_node, deep=False)
            mapping[key] = value
        data.update(mapping)
Exemple #19
0
    def construct_yaml_map(self, node):
        data = OrderedDict()
        yield data

        value = self.construct_mapping(node)
        if value is not None:
            data.update(value)
Exemple #20
0
  def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}

    def handle_value(k):
      is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
      if isinstance(k, six.string_types):
        return k
      elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
        return '"[%s]"' % (', '.join(map(str, k)))
      else:
        return k

    if self.keys is None:
      self.keys = sorted(logs.keys())

    if self.model.stop_training:
      # We set NA so that csv parsers do not fail for this last epoch.
      logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])

    if not self.writer:

      class CustomDialect(csv.excel):
        delimiter = self.sep

      self.writer = csv.DictWriter(
          self.csv_file,
          fieldnames=['epoch'] + self.keys,
          dialect=CustomDialect)
      if self.append_header:
        self.writer.writeheader()

    row_dict = OrderedDict({'epoch': epoch})
    row_dict.update((key, handle_value(logs[key])) for key in self.keys)
    self.writer.writerow(row_dict)
    self.csv_file.flush()
Exemple #21
0
def main():
    for cache in all_memoize_caches:
        cache.clear()
    all_values = OrderedDict()
    all_stats = stats.StatsDict()

    asset_sections = get_asset_sections()
    for asset_section in asset_sections:
        print("{}:".format(asset_section))
        asset = get_asset(asset_section)
        values = asset.get_values(all_stats)
        values_with_prefix = OrderedDict(
            [("{} - {}".format(asset_section, key), value) for key, value in values.items()])
        all_values.update(values_with_prefix)

        print()

    for stat in all_stats.get_all_stats_ordered():
        stat.print_stat()
        print()

    print("Total: {:10,.2f}".format(all_stats.get_total()))
    print()

    return all_values
    def format(self, record):
        """Formats a log record and serializes to json"""
        extras = {}
        if isinstance(record.msg, dict):
            extras = record.msg
            record.message = None
        else:
            record.message = record.getMessage()
        # only format time if needed
        if "asctime" in self._required_fields:
            record.asctime = self.formatTime(record, self.datefmt)

        try:
            log_record = OrderedDict()
        except NameError:
            log_record = {}

        for field in self._required_fields:
            log_record[field] = record.__dict__[field]
        if record.exc_info:
            if not record.exc_text:
                record.exc_text = self.formatException(record.exc_info)
            log_record['exc'] = record.exc_text

        log_record.update(extras)
        merge_record_extra(record, log_record, reserved=self._skip_fields)

        return json.dumps(log_record,
                          default=self.json_default,
                          cls=self.json_encoder)
Exemple #23
0
    def parse_xmlelement(self, xmlelement, schema, allow_none=True,
                         context=None):
        """Consume matching xmlelements and call parse() on each"""
        # If this is an empty complexType (<xsd:complexType name="x"/>)
        if not self.attributes and not self.elements:
            return None

        elements = xmlelement.getchildren()
        attributes = copy.copy(xmlelement.attrib)
        if allow_none and len(elements) == 0 and len(attributes) == 0:
            return

        init_kwargs = OrderedDict()

        # Parse elements. These are always indicator elements (all, choice,
        # group, sequence)
        for name, element in self.elements_nested:
            result = element.parse_xmlelements(
                elements, schema, name, context=context)
            if result:
                init_kwargs.update(result)

        # Parse attributes
        for name, attribute in self.attributes:
            if attribute.name:
                if attribute.qname.text in attributes:
                    value = attributes.pop(attribute.qname.text)
                    init_kwargs[name] = attribute.parse(value)
            else:
                init_kwargs[name] = attribute.parse(attributes)

        return self(**init_kwargs)
Exemple #24
0
def fit_results_to_dict(fit_results, min_bound=None, max_bound=None):
    '''Create a JSON-comparible dict from a FitResults object

    Parameters:
        fit_results (FitResults): object containing fit parameters,\
            errors and type
        min_bound: optional min value to add to dictionary if min isn't\
            a fit parameter.
        max_bound: optional max value to add to dictionary if max isn't\
            a fit parameter.

    Returns:
        JSON-compatible dictionary with fit results

    Note:
        Supported fit types: 'norm', 'expon', 'uniform'
    '''

    type_map = {'norm': 'normal', 'expon': 'exponential', 'uniform': 'uniform'}
    param_map = {'uniform': lambda p: [('min', p[0]), ('max', p[0] + p[1])],
                 'norm': lambda p: [('mu', p[0]), ('sigma', p[1])],
                 'expon': lambda p: [('lambda', 1.0 / p[1])]}

    d = OrderedDict({'type': type_map[fit_results.type]})
    d.update(param_map[fit_results.type](fit_results.params))

    if min_bound is not None and 'min' not in d:
        d['min'] = min_bound
    if max_bound is not None and 'max' not in d:
        d['max'] = max_bound

    return d
def rougeS(refSent , sents):
    skip2 = list()
    newDict = OrderedDict()
    refLen = len(refSent)
    noSents = len(sents)
    noComb = (refLen*(refLen - 1)/2)

    
    for i in range(refLen):
        for j in range(i+1 , refLen):
            newDict.update({refSent[i] + '_' + refSent[j] : True})
            
    
    for sent in sents:
        sentLen = len(sent)
        count = 0
        for p in range(sentLen):
            for q in range(p+1 , sentLen):
                if newDict.has_key(sent[p] + '_' + sent[q]):
                    count = count + 1
        skip2.append((float(count)/noComb,sent))
                    
                    
    
    return skip2
Exemple #26
0
    def _get_schema_attributes(metacls, name, bases, dct):
        fields = OrderedDict()
        for b in bases:
            if not isinstance(b, metacls):
                continue

            field_intersection = set(fields) & set(b._fields)
            if field_intersection:
                metacls._field_dupe_warning(name, field_intersection)
            fields.update(b._fields)

        new_fields = []
        for field_name, field_def in dct.iteritems():
            if isinstance(field_def, Field):
                new_fields.append((field_name, field_def))

        new_fields.sort(key=lambda fd: fd[1]._index)
        for field_name, field_def in new_fields:
            if field_name in fields:
                metacls._field_dupe_warning(name, (field_name,))
            fields[field_name] = field_def

        return {
            "_fields": fields,
            "_schema_name": name,
        }
Exemple #27
0
    def __new__(mcs, name, bases, attrs):
        # Collect transitions from current class.
        current_transitions = []
        for key, value in list(attrs.items()):
            if isinstance(value, Transition):
                current_transitions.append((key, value))
                attrs.pop(key)
        current_transitions.sort(key=lambda x: x[1].creation_counter)
        attrs['declared_transitions'] = OrderedDict(current_transitions)

        new_class = (super(DeclarativeTransitionsMetaclass, mcs).__new__(mcs, name, bases, attrs))

        # Walk through the MRO.
        declared_transitions = OrderedDict()
        for base in reversed(new_class.__mro__):
            # Collect transitions from base class.
            if hasattr(base, 'declared_transitions'):
                declared_transitions.update(base.declared_transitions)

            # Field shadowing.
            for attr, value in base.__dict__.items():
                if value is None and attr in declared_transitions:
                    declared_transitions.pop(attr)

        new_class.declared_transitions = declared_transitions

        return new_class
Exemple #28
0
def get_available_skins(selected=None):
    """Returns a dictionary of skin name --> directory where
    "templates" and "media" subdirectories can be found.

    selected is a name of preferred skin
    if it's None, then information about all skins will be returned
    otherwise, only data about selected and default skins
    will be returned

    selected skin is guaranteed to be the first item in the dictionary
    """
    skins = OrderedDict()
    extra_skins_dir = getattr(django_settings, 'ASKBOT_EXTRA_SKINS_DIR', None)
    if extra_skins_dir:
        skins.update(get_skins_from_dir(extra_skins_dir))

    if 'default' in skins:
        raise ValueError('"default" is not an acceptable name for a custom skin')

    if selected in skins:
        selected_dir = skins[selected]
        skins.clear()
        skins[selected] = selected_dir
    elif selected == 'default':
        skins = OrderedDict()
    elif selected:
        raise ValueError(
            'skin ' + str(selected) + \
            ' not found, please check ASKBOT_EXTRA_SKINS_DIR setting ' + \
            'or in the corresponding directory'
        )

    #insert default as a last item
    skins['default'] = askbot.get_install_directory()
    return skins
Exemple #29
0
class Resource(object):
    """
    Store scanned details for a single resource (file or a directory)
    such as infos and path
    """

    def __init__(self, scan_cache_class, abs_path, base_is_dir, len_base_path):
        self.scan_cache_class = scan_cache_class()
        self.is_cached = False
        self.abs_path = abs_path
        self.base_is_dir = base_is_dir
        posix_path = as_posixpath(abs_path)
        # fix paths: keep the path as relative to the original
        # base_path. This is always Unicode
        self.rel_path = get_relative_path(posix_path, len_base_path, base_is_dir)
        self.infos = OrderedDict()
        self.infos['path'] = self.rel_path

    def put_info(self, infos):
        """
        Cache file info and set `is_cached` to True if already cached or false otherwise.
        """
        self.infos.update(infos)
        self.is_cached = self.scan_cache_class.put_info(self.rel_path, self.infos)

    def get_info(self):
        """
        Retrieve info from cache.
        """
        return self.scan_cache_class.get_info(self.rel_path)
Exemple #30
0
def get_keyboard_codes():
    """
    Return mapping of keycode integer values paired by their curses key-name.

    :rtype: dict

    Returns dictionary of (code, name) pairs for curses keyboard constant
    values and their mnemonic name. Such as key ``260``, with the value of
    its identity, ``u'KEY_LEFT'``.  These are derived from the attributes by
    the same of the curses module, with the following exceptions:

    * ``KEY_DELETE`` in place of ``KEY_DC``
    * ``KEY_INSERT`` in place of ``KEY_IC``
    * ``KEY_PGUP`` in place of ``KEY_PPAGE``
    * ``KEY_PGDOWN`` in place of ``KEY_NPAGE``
    * ``KEY_ESCAPE`` in place of ``KEY_EXIT``
    * ``KEY_SUP`` in place of ``KEY_SR``
    * ``KEY_SDOWN`` in place of ``KEY_SF``

    This function is the inverse of :func:`get_curses_keycodes`.  With the
    given override "mixins" listed above, the keycode for the delete key will
    map to our imaginary ``KEY_DELETE`` mnemonic, effectively erasing the
    phrase ``KEY_DC`` from our code vocabulary for anyone that wishes to use
    the return value to determine the key-name by keycode.
    """
    keycodes = OrderedDict(get_curses_keycodes())
    keycodes.update(CURSES_KEYCODE_OVERRIDE_MIXIN)

    # invert dictionary (key, values) => (values, key), preferring the
    # last-most inserted value ('KEY_DELETE' over 'KEY_DC').
    return dict(zip(keycodes.values(), keycodes.keys()))
Exemple #31
0
def get_comment_authors(page_url,
                        num_authors=10,
                        depth=_utils.SEARCH_DEPTH,
                        post_limit=_utils.POST_LIMIT,
                        authors_ignore=None,
                        driver=None,
                        cookies=None,
                        silent=False):
    if not silent:
        print('START', page_url)
    if driver:
        _utils.selenium_open_new_window(driver, page_url)
    else:
        driver = init(cookies)
        driver.get(page_url)
    authors = OrderedDict()
    if authors_ignore is None:
        authors_ignore = OrderedDict()
    authors_ignore[page_url] = 1

    class PageEndException(Exception):
        pass

    class AuthorsEnoughException(Exception):
        pass

    try:
        labels = set()
        post, prev_page_len = None, -1
        prev_post = None
        for post_no in range(1, post_limit + 1):
            tries = 0
            while True:
                if not silent:
                    print('post #{}...'.format(post_no))
                post = None
                posts = driver.find_elements_by_css_selector(
                    'div[aria-labelledby]')
                if not silent:
                    print(len(posts))
                for post_ in posts:
                    label = post_.get_attribute('aria-labelledby')
                    #print(label, labels)
                    if label not in labels:
                        labels.add(label)
                        post = prev_post = post_
                        tries = 0
                        break
                else:
                    if not silent:
                        print('post #{} is not found'.format(post_no))
                    page_len = _utils.selenium_scroll_to_bottom(driver)
                    if page_len == prev_page_len:
                        if tries >= 2:
                            raise PageEndException()
                        if prev_post:
                            _utils.selenium_scroll_into_view(driver, prev_post)
                        tries += 1
                    else:
                        tries = 0
                    prev_page_len = page_len
                if post:
                    break

            if post:
                comment_elems, author_elems = set(), set()
                pass_no, need_more = 0, True
                while need_more:
                    need_more = False
                    pass_no += 1
                    if not silent:
                        print('post {}, pass {}'.format(post_no, pass_no))
                    for elem in (x for x in post.find_elements_by_tag_name('a')
                                 if x not in author_elems):
                        author_elems.add(elem)
                        author = elem.get_attribute('href')
                        #print('[[[ author =', author, ']]]')
                        if author and author.startswith(ROOT_URL) \
                       and 'comment_id=' in author:
                            if author.startswith(ROOT_URL +
                                                 '/profile.php?id='):
                                pos = author.find('&')
                            else:
                                pos = author.find('?')
                            if pos > 0:
                                author = author[:pos]
                            if author not in authors_ignore and not (
                                    author.endswith('.php')
                                    or author.endswith('/')):
                                #print(author)
                                if author[len(ROOT_URL) + 1:].find('/') < 0:
                                    try:
                                        author_name = \
                                            elem.find_element_by_tag_name(
                                                'span'
                                            ).text
                                        if not silent:
                                            print(author_name, author)
                                        authors_ignore[author] = 1
                                        if depth > 1:
                                            authors.update(
                                                get_comment_authors(
                                                    author,
                                                    num_authors=num_authors
                                                              - len(authors),
                                                    depth=depth - 1,
                                                    post_limit=post_limit,
                                                    authors_ignore=\
                                                        authors_ignore,
                                                    driver=driver,
                                                    #cookies=\
                                                    #    driver.get_cookies()
                                                    silent=silent
                                            ))
                                        else:
                                            authors[author] = author_name
                                        if len(authors) >= num_authors:
                                            raise AuthorsEnoughException()
                                    except NoSuchElementException:
                                        pass

                    for elem in (
                            x for x in post.find_elements_by_tag_name('span')
                            if x not in comment_elems):
                        comment_elems.add(elem)
                        try:
                            text = elem.text
                            #print('[', text, ']')
                            if (text.startswith('View') and
                                ('more comment' in text or 'more repl' in text)
                                ) or ('replied' in text and 'repl' in text):
                                if not silent:
                                    print('    [', text, ']')
                                need_more = True
                                action = webdriver.common.action_chains \
                                                         .ActionChains(driver)
                                action.move_to_element_with_offset(elem, 5, 5)
                                action.perform()
                                tries = 0
                                while True:
                                    try:
                                        elem.click()
                                        WebDriverWait(driver, 10) \
                                            .until(EC.staleness_of(elem))
                                        if tries:
                                            print()
                                        break
                                    except TimeoutException:
                                        print(
                                            '\rWARNING: Comments loading '
                                            'timeout. Retrying...',
                                            end='')
                                        if tries >= 2:
                                            print(
                                                '\rWARNING: Comments loading '
                                                'timeout. Skipped    ')
                                            break
                                        tries += 1
                        except:
                            pass

    except (PageEndException, AuthorsEnoughException):
        pass

    _utils.selenium_close_window(driver)
    if not silent:
        print(authors)
        print(len(authors))
    return list(authors.items())[:num_authors]
class ResourceObj:
    def __init__(self, name: str, uri: str, jsondata: dict, typename: str, context: str, parent=None, isComplex=False, forceType=False):
        self.initiated = False
        self.parent = parent
        self.uri, self.name = uri, name
        self.rtime = 0
        self.status = -1
        self.isRegistry = False
        self.errorIndex = {
        }

        oem = config.get('oemcheck', True)

        # Check if this is a Registry resource
        parent_type = parent.typename if parent is not None and parent is not None else None
        if parent_type is not None and getType(parent_type) == 'MessageRegistryFile':
            traverseLogger.debug('{} is a Registry resource'.format(self.uri))
            self.isRegistry = True
            self.context = None
            context = None

        # Check if we provide a valid json
        self.jsondata = jsondata

        traverseLogger.debug("payload: {}".format(json.dumps(self.jsondata, indent=4, sort_keys=True)))

        if not isinstance(self.jsondata, dict):
            traverseLogger.error("Resource no longer a dictionary...")
            raise ValueError('This Resource is no longer a Dictionary')

        # Check for @odata.id (todo: regex)
        odata_id = self.jsondata.get('@odata.id')
        if odata_id is None and not isComplex:
            if self.isRegistry:
                traverseLogger.debug('{}: @odata.id missing, but not required for Registry resource'
                                     .format(self.uri))
            else:
                traverseLogger.error('{}: Json does not contain @odata.id'.format(self.uri))

        # Get our real type (check for version)
        acquiredtype = typename if forceType else jsondata.get('@odata.type', typename)
        if acquiredtype is None:
            traverseLogger.error(
                '{}:  Json does not contain @odata.type or NavType'.format(uri))
            raise ValueError
        if acquiredtype is not typename and isComplex:
            context = None

        if typename is not None:
            if not oem and 'OemObject' in typename:
                acquiredtype = typename

        if currentService:
            if not oem and 'OemObject' in acquiredtype:
                pass
            else:
                if jsondata.get('@odata.type') is not None:
                    currentService.metadata.add_service_namespace(getNamespace(jsondata.get('@odata.type')))
                if jsondata.get('@odata.context') is not None:
                    # add the namespace to the set of namespaces referenced by this service
                    ns = getNamespace(jsondata.get('@odata.context').split('#')[-1])
                    if '/' not in ns and not ns.endswith('$entity'):
                        currentService.metadata.add_service_namespace(ns)

        # Provide a context for this (todo: regex)
        if context is None:
            context = self.jsondata.get('@odata.context')
            if context is None:
                context = createContext(acquiredtype)
                if self.isRegistry:
                    # If this is a Registry resource, @odata.context is not required; do our best to construct one
                    traverseLogger.debug('{}: @odata.context missing from Registry resource; constructed context {}'
                                         .format(acquiredtype, context))
                elif isComplex:
                    pass
                else:
                    traverseLogger.debug('{}:  Json does not contain @odata.context'.format(uri))

        self.context = context

        # Get Schema object
        self.schemaObj = rfSchema.getSchemaObject(acquiredtype, self.context)

        if self.schemaObj is None:
            traverseLogger.error("ResourceObject creation: No schema XML for {} {} {}".format(typename, acquiredtype, self.context))
            raise ValueError

        # Use string comprehension to get highest type
        if acquiredtype is typename and not forceType:
            acquiredtype = self.schemaObj.getHighestType(typename, parent_type)
            if not isComplex:
                traverseLogger.warning(
                    'No @odata.type present, assuming highest type {} {}'.format(typename, acquiredtype))

        # Check if we provide a valid type (todo: regex)
        self.typename = acquiredtype
        typename = self.typename

        self.initiated = True

        # get our metadata
        metadata = currentService.metadata if currentService else None

        self.typeobj = rfSchema.getTypeObject(typename, self.schemaObj)

        self.propertyList = self.typeobj.getProperties(self.jsondata, topVersion=getNamespace(typename))
        propertyList = [prop.payloadName for prop in self.propertyList]

        # get additional
        self.additionalList = []
        propTypeObj = self.typeobj
        if propTypeObj.propPattern is not None and len(propTypeObj.propPattern) > 0:
            prop_pattern = propTypeObj.propPattern.get('Pattern', '.*')
            prop_type = propTypeObj.propPattern.get('Type', 'Resource.OemObject')

            regex = re.compile(prop_pattern)
            for key in [k for k in self.jsondata if k not in propertyList and regex.fullmatch(k)]:
                val = self.jsondata.get(key)
                value_obj = rfSchema.PropItem(propTypeObj.schemaObj, propTypeObj.fulltype, key, val, customType=prop_type)
                self.additionalList.append(value_obj)

        if config['uricheck'] and self.typeobj.expectedURI is not None:
            my_id = self.jsondata.get('Id')
            self.errorIndex['bad_uri_schema_uri'] = not self.typeobj.compareURI(uri, my_id)
            self.errorIndex['bad_uri_schema_odata'] = not self.typeobj.compareURI(odata_id, my_id)

            if self.errorIndex['bad_uri_schema_uri']:
                traverseLogger.error('{}: URI not in Redfish.Uris: {}'.format(uri, self.typename))
                if my_id != uri.rsplit('/', 1)[-1]:
                    traverseLogger.error('Id {} in payload doesn\'t seem to match URI'.format(my_id))
            else:
                traverseLogger.debug('{} in Redfish.Uris: {}'.format(uri, self.typename))

            if self.errorIndex['bad_uri_schema_odata']:
                traverseLogger.error('{}: odata_id not in Redfish.Uris: {}'.format(odata_id, self.typename))
                if my_id != uri.rsplit('/', 1)[-1]:
                    traverseLogger.error('Id {} in payload doesn\'t seem to match URI'.format(my_id))
            else:
                traverseLogger.debug('{} in Redfish.Uris: {}'.format(odata_id, self.typename))

        # get annotation
        successService, annotationProps = getAnnotations(metadata, self.jsondata)
        if successService:
            self.additionalList.extend(annotationProps)

        # list illegitimate properties together
        self.unknownProperties = [k for k in self.jsondata if k not in propertyList +
                [prop.payloadName for prop in self.additionalList] and '@odata' not in k]

        self.links = OrderedDict()

        sample = config.get('sample')
        linklimits = config.get('linklimits', {})
        self.links.update(self.typeobj.getLinksFromType(self.jsondata, self.context, self.propertyList, oem, linklimits, sample))

        self.links.update(getAllLinks(
            self.jsondata, self.additionalList, self.schemaObj, context=context, linklimits=linklimits,
            sample_size=sample, oemCheck=oem))

    def getResourceProperties(self):
        allprops = self.propertyList + self.additionalList[:min(len(self.additionalList), 100)]
        return allprops
Exemple #33
0
def predict_forward_with_existed_model(db_engine, project_path, model_id,
                                       as_of_date):
    """Predict forward given model_id and as_of_date and store the prediction in database

    Args:
            db_engine (sqlalchemy.db.engine)
            project_storage (catwalk.storage.ProjectStorage)
            model_id (int) The id of a given model in the database
            as_of_date (string) a date string like "YYYY-MM-DD"
    """
    logger.spam("In PREDICT LIST................")
    upgrade_db(db_engine=db_engine)
    project_storage = ProjectStorage(project_path)
    matrix_storage_engine = project_storage.matrix_storage_engine()
    # 1. Get feature and cohort config from database
    (train_matrix_uuid,
     matrix_metadata) = train_matrix_info_from_model_id(db_engine, model_id)
    experiment_config = experiment_config_from_model_id(db_engine, model_id)

    # 2. Generate cohort
    cohort_table_name = f"triage_production.cohort_{experiment_config['cohort_config']['name']}"
    cohort_table_generator = EntityDateTableGenerator(
        db_engine=db_engine,
        query=experiment_config['cohort_config']['query'],
        entity_date_table_name=cohort_table_name)
    cohort_table_generator.generate_entity_date_table(
        as_of_dates=[dt_from_str(as_of_date)])

    # 3. Generate feature aggregations
    feature_generator = FeatureGenerator(
        db_engine=db_engine,
        features_schema_name="triage_production",
        feature_start_time=experiment_config['temporal_config']
        ['feature_start_time'],
    )
    collate_aggregations = feature_generator.aggregations(
        feature_aggregation_config=experiment_config['feature_aggregations'],
        feature_dates=[as_of_date],
        state_table=cohort_table_name)
    feature_generator.process_table_tasks(
        feature_generator.generate_all_table_tasks(collate_aggregations,
                                                   task_type='aggregation'))

    # 4. Reconstruct feature disctionary from feature_names and generate imputation

    reconstructed_feature_dict = FeatureGroup()
    imputation_table_tasks = OrderedDict()

    for aggregation in collate_aggregations:
        feature_group, feature_names = get_feature_names(
            aggregation, matrix_metadata)
        reconstructed_feature_dict[feature_group] = feature_names

        # Make sure that the features imputed in training should also be imputed in production

        features_imputed_in_train = get_feature_needs_imputation_in_train(
            aggregation, feature_names)

        features_imputed_in_production = get_feature_needs_imputation_in_production(
            aggregation, db_engine)

        total_impute_cols = set(features_imputed_in_production) | set(
            features_imputed_in_train)
        total_nonimpute_cols = set(f for f in set(feature_names)
                                   if '_imp' not in f) - total_impute_cols

        task_generator = feature_generator._generate_imp_table_tasks_for

        imputation_table_tasks.update(
            task_generator(aggregation,
                           impute_cols=list(total_impute_cols),
                           nonimpute_cols=list(total_nonimpute_cols)))
    feature_generator.process_table_tasks(imputation_table_tasks)

    # 5. Build matrix
    db_config = {
        "features_schema_name": "triage_production",
        "labels_schema_name": "public",
        "cohort_table_name": cohort_table_name,
    }

    matrix_builder = MatrixBuilder(
        db_config=db_config,
        matrix_storage_engine=matrix_storage_engine,
        engine=db_engine,
        experiment_hash=None,
        replace=True,
    )

    feature_start_time = experiment_config['temporal_config'][
        'feature_start_time']
    label_name = experiment_config['label_config']['name']
    label_type = 'binary'
    cohort_name = experiment_config['cohort_config']['name']
    user_metadata = experiment_config['user_metadata']

    # Use timechop to get the time definition for production
    temporal_config = experiment_config["temporal_config"]
    temporal_config.update(
        temporal_params_from_matrix_metadata(db_engine, model_id))
    timechopper = Timechop(**temporal_config)
    prod_definitions = timechopper.define_test_matrices(
        train_test_split_time=dt_from_str(as_of_date),
        test_duration=temporal_config['test_durations'][0],
        test_label_timespan=temporal_config['test_label_timespans'][0])

    matrix_metadata = Planner.make_metadata(
        prod_definitions[-1],
        reconstructed_feature_dict,
        label_name,
        label_type,
        cohort_name,
        'production',
        feature_start_time,
        user_metadata,
    )

    matrix_metadata['matrix_id'] = str(
        as_of_date) + f'_model_id_{model_id}' + '_risklist'

    matrix_uuid = filename_friendly_hash(matrix_metadata)

    matrix_builder.build_matrix(
        as_of_times=[as_of_date],
        label_name=label_name,
        label_type=label_type,
        feature_dictionary=reconstructed_feature_dict,
        matrix_metadata=matrix_metadata,
        matrix_uuid=matrix_uuid,
        matrix_type="production",
    )

    # 6. Predict the risk score for production
    predictor = Predictor(
        model_storage_engine=project_storage.model_storage_engine(),
        db_engine=db_engine,
        rank_order='best')

    predictor.predict(
        model_id=model_id,
        matrix_store=matrix_storage_engine.get_store(matrix_uuid),
        misc_db_parameters={},
        train_matrix_columns=matrix_storage_engine.get_store(
            train_matrix_uuid).columns())
def save_link_grammar(rules, output_grammar, grammar_rules=2, header='', footer=''):
    # rules: [] or {}
    # grammar_rules = kwargs['grammar_rules']: 1 ⇒ connectors, 2+ ⇒ disjuncts

    if type(rules) is dict:
        rules = rules2list(rules, grammar_rules)

    line_list = list()
    clusters = set()
    for rule in rules:
        line = ''
        if len(rule[2]) > 0 and len(rule[3]) > 0:
            line += '{' + ' or '.join(str(x) for x in rule[2]) + '} & {' + ' or '.join(str(y) for y in rule[3]) + '}'
        else:
            if len(rule[2]) > 0:
                line += ' or '.join('(' + str(x) + ')' for x in rule[2])
            elif len(rule[3]) > 0:
                line += ' or '.join('(' + str(x) + ')' for x in rule[3])
        if len(rule[4]) > 0:
            if line != '': line += ' or '
            line += ' or '.join('(' + str(x) + ')' for x in rule[4])

        cluster_number = '% ' + str(rule[0]) + '\n'  # comment line: cluster
        cluster_and_words = ' '.join('"' + word + '"' for word in rule[1]) + ':\n'
        line_list.append(cluster_number + cluster_and_words + line + ';\n')
        clusters.add(rule[0])

    line_list.sort()  # FIXME: overkill?

    if os.path.isfile(output_grammar):
        out_file = output_grammar
    elif os.path.isdir(output_grammar):
        out_file = output_grammar
        if out_file[-1] != '/': out_file += '/'
        out_file += 'dict_'
        out_file = out_file + str(len(clusters)) + 'C_' + str(UTC())[:10] + '_0006.4.0.dict'
    else:
        raise FileNotFoundError('File not found', output_grammar)

    # TODO: Link Grammar 5.4.x ⇒ 5.5.1: delete 'if' statements:
    if header == '':
        if linkgrammar.__version__ == '5.4.4':
            header = '% Grammar Learner v.0.6 ' + str(UTC())
        else:
            header = '% Grammar Learner v.0.7 ' + str(UTC())
    header = header + '\n' + '<dictionary-version-number>: V0v0v6+;\n' + '<dictionary-locale>: EN4us+;'

    if linkgrammar.__version__ == '5.4.4':
        add_rules = 'UNKNOWN-WORD: XXX+;'
    else:
        add_rules = '<UNKNOWN-WORD>: XXX+;'

    if footer == '':
        footer = '% ' + str(len(clusters)) + ' word clusters, ' + str(
            len(rules)) + ' Link Grammar rules.\n' + '% Link Grammar file saved to: ' + out_file
    lg = header + '\n\n' + '\n'.join(line_list) + '\n' + add_rules + '\n\n' + footer
    lg = lg.replace('@', '.')  # 80706 WSD: word@1 ⇒ word.1  FIXME:DEL?
    with open(out_file, 'w') as f:
        f.write(lg)

    response = OrderedDict({'grammar_file': out_file})
    response.update({'grammar_clusters': len(clusters), 'grammar_rules': len(rules)})
    return response
Exemple #35
0
 def serialize_order_fn(self):
     order_dict = OrderedDict()
     for item, _serializer in cls_db_cols_and_serializer:
         order_dict.update(
             {EXPORT_COLUMNS.get(item): _serializer(getattr(self, item))})
     return order_dict
Exemple #36
0
class TRPOMAML(MAMLAlgo):
    """
    Algorithm for TRPO MAML

    Args:
        policy (Policy): policy object
        name (str): tf variable scope
        step_size (int): trust region size for the meta policy optimization through TPRO
        inner_type (str): One of 'log_likelihood', 'likelihood_ratio', 'dice', choose which inner update to use
        exploration (bool): whether to use E-MAML or MAML
        inner_lr (float) : gradient step size used for inner step
        meta_batch_size (int): number of meta-learning tasks
        num_inner_grad_steps (int) : number of gradient updates taken per maml iteration
        trainable_inner_step_size (boolean): whether make the inner step size a trainable variable
    """
    def __init__(self,
                 *args,
                 name="trpo_maml",
                 step_size=0.01,
                 inner_type='likelihood_ratio',
                 exploration=False,
                 **kwargs):
        super(TRPOMAML, self).__init__(*args, **kwargs)

        assert inner_type in ["log_likelihood", "likelihood_ratio", "dice"]
        self.step_size = step_size
        self.inner_type = inner_type
        self.name = name
        self._optimization_keys = [
            'observations', 'actions', 'advantages', 'agent_infos'
        ]

        self.exploration = exploration
        if exploration:  # add adjusted average rewards tp optimization keys
            self._optimization_keys.append('adj_avg_rewards')

        self.optimizer = ConjugateGradientOptimizer()

        self.build_graph()

    def _adapt_objective_sym(self, action_sym, adv_sym, dist_info_old_sym,
                             dist_info_new_sym):
        if self.inner_type == 'likelihood_ratio':
            with tf.variable_scope("likelihood_ratio"):
                likelihood_ratio_adapt = self.policy.distribution.likelihood_ratio_sym(
                    action_sym, dist_info_old_sym, dist_info_new_sym)
            with tf.variable_scope("surrogate_loss"):
                surr_obj_adapt = -tf.reduce_mean(
                    likelihood_ratio_adapt * adv_sym)

        elif self.inner_type == 'log_likelihood':
            with tf.variable_scope("log_likelihood"):
                log_likelihood_adapt = self.policy.distribution.log_likelihood_sym(
                    action_sym, dist_info_new_sym)
            with tf.variable_scope("surrogate_loss"):
                surr_obj_adapt = -tf.reduce_mean(
                    log_likelihood_adapt * adv_sym)

        else:
            raise NotImplementedError

        return surr_obj_adapt

    def build_graph(self):
        """
        Creates the computation graph

        Notes:
            Pseudocode:
            for task in meta_batch_size:
                make_vars
                init_init_dist_sym
            for step in num_inner_grad_steps:
                for task in meta_batch_size:
                    make_vars
                    update_init_dist_sym
            set objectives for optimizer
        """
        """ Create Variables """
        # assert self.num_inner_grad_steps == 1 or not self.exploration, "Not sure if the math is right for more than 1 inner step"

        with tf.variable_scope(self.name):
            self.step_sizes = self._create_step_size_vars()
            """ --- Build inner update graph for adapting the policy and sampling trajectories --- """
            # this graph is only used for adapting the policy and not computing the meta-updates
            self.adapted_policies_params, self.adapt_input_ph_dict = self._build_inner_adaption(
            )
            """ ----- Build graph for the meta-update ----- """
            self.meta_op_phs_dict = OrderedDict()
            obs_phs, action_phs, adv_phs, dist_info_old_phs, all_phs_dict = self._make_input_placeholders(
                'step0')
            self.meta_op_phs_dict.update(all_phs_dict)

            distribution_info_vars, current_policy_params = [], []
            all_surr_objs, all_inner_kls = [], []

        for i in range(self.meta_batch_size):
            dist_info_sym = self.policy.distribution_info_sym(obs_phs[i],
                                                              params=None)
            distribution_info_vars.append(dist_info_sym)  # step 0
            current_policy_params.append(
                self.policy.policy_params
            )  # set to real policy_params (tf.Variable)

        initial_distribution_info_vars = distribution_info_vars
        initial_action_phs = action_phs

        with tf.variable_scope(self.name):
            """ Inner updates"""
            for step_id in range(1, self.num_inner_grad_steps + 1):
                surr_objs, adapted_policy_params = [], []

                # inner adaptation step for each task
                for i in range(self.meta_batch_size):
                    surr_loss = self._adapt_objective_sym(
                        action_phs[i], adv_phs[i], dist_info_old_phs[i],
                        distribution_info_vars[i])

                    adapted_params_var = self._adapt_sym(
                        surr_loss, current_policy_params[i])

                    adapted_policy_params.append(adapted_params_var)
                    surr_objs.append(surr_loss)

                all_surr_objs.append(surr_objs)

                # Create new placeholders for the next step
                obs_phs, action_phs, adv_phs, dist_info_old_phs, all_phs_dict = self._make_input_placeholders(
                    'step%i' % step_id)
                self.meta_op_phs_dict.update(all_phs_dict)

                # dist_info_vars_for_next_step
                distribution_info_vars = [
                    self.policy.distribution_info_sym(
                        obs_phs[i], params=adapted_policy_params[i])
                    for i in range(self.meta_batch_size)
                ]
                current_policy_params = adapted_policy_params
            """ Outer objective """
            surr_objs, outer_kls = [], []

            # Create placeholders
            # meta-objective
            for i in range(self.meta_batch_size):
                likelihood_ratio = self.policy.distribution.likelihood_ratio_sym(
                    action_phs[i], dist_info_old_phs[i],
                    distribution_info_vars[i])
                outer_kl = tf.reduce_mean(
                    self.policy.distribution.kl_sym(dist_info_old_phs[i],
                                                    distribution_info_vars[i]))

                surr_obj = -tf.reduce_mean(likelihood_ratio * adv_phs[i])

                if self.exploration:
                    # add adj_avg_reward placeholder
                    adj_avg_rewards = tf.placeholder(
                        dtype=tf.float32,
                        shape=[None],
                        name='adj_avg_rewards' + '_' +
                        str(self.num_inner_grad_steps) + '_' + str(i))
                    self.meta_op_phs_dict[
                        'step%i_task%i_%s' %
                        (self.num_inner_grad_steps, i,
                         'adj_avg_rewards')] = adj_avg_rewards

                    log_likelihood_inital = self.policy.distribution.log_likelihood_sym(
                        initial_action_phs[i],
                        initial_distribution_info_vars[i])
                    surr_obj += -tf.reduce_mean(
                        adj_avg_rewards) * tf.reduce_mean(
                            log_likelihood_inital)

                surr_objs.append(surr_obj)
                outer_kls.append(outer_kl)

            mean_outer_kl = tf.reduce_mean(tf.stack(outer_kls))
            """ Mean over meta tasks """
            meta_objective = tf.reduce_mean(tf.stack(surr_objs, 0))

            self.optimizer.build_graph(
                loss=meta_objective,
                target=self.policy,
                input_ph_dict=self.meta_op_phs_dict,
                leq_constraint=(mean_outer_kl, self.step_size),
            )

    def optimize_policy(self,
                        all_samples_data,
                        log=True,
                        prefix='',
                        verbose=False):
        """
        Performs MAML outer step

        Args:
            all_samples_data (list) : list of lists of lists of samples (each is a dict) split by gradient update and
             meta task
            log (bool) : whether to log statistics

        Returns:
            None
        """
        meta_op_input_dict = self._extract_input_dict_meta_op(
            all_samples_data, self._optimization_keys)
        if verbose:
            logger.log("Computing KL before")
        mean_kl_before = self.optimizer.constraint_val(meta_op_input_dict)

        if verbose:
            logger.log("Computing loss before")
        loss_before = self.optimizer.loss(meta_op_input_dict)
        if verbose:
            logger.log("Optimizing")
        self.optimizer.optimize(meta_op_input_dict)
        if verbose:
            logger.log("Computing loss after")
        loss_after = self.optimizer.loss(meta_op_input_dict)

        if verbose:
            logger.log("Computing KL after")
        mean_kl = self.optimizer.constraint_val(meta_op_input_dict)
        if log:
            logger.logkv(prefix + 'MeanKLBefore', mean_kl_before)
            logger.logkv(prefix + 'MeanKL', mean_kl)

            logger.logkv(prefix + 'LossBefore', loss_before)
            logger.logkv(prefix + 'LossAfter', loss_after)
            logger.logkv(prefix + 'dLoss', loss_before - loss_after)
Exemple #37
0
            continue

        # Ok
        print b.brickid

    sys.exit(0)

    #allI = set()
    allI = OrderedDict()

    for b in B:
        wcs = wcs_for_brick(b)
        I = ccds_touching_wcs(wcs, T)
        print >> sys.stderr, 'Brick', b, ':', len(I), 'CCDs'
        #allI.update(I)
        allI.update([(i, True) for i in I])
    #print 'Total of', len(allI), 'CCDs touch'
    #T.cut(np.array(list(allI)))

    print >> sys.stderr, len(B), 'bricks,', len(allI), 'CCDs'

    #for i in list(allI):

    # g,r,z full focal planes, 2014-08-18
    #I = np.flatnonzero(T.expnum == 349664)
    #I = np.flatnonzero(T.expnum == 349667)
    #I = np.flatnonzero(T.expnum == 349589)

    #for im in T.cpimage[:10]:
    #    print >>sys.stderr, 'im >>%s<<' % im, im.startswith('CP20140818')
    #I = np.flatnonzero(np.array([im.startswith('CP20140818') for im in T.cpimage]))
def get_newborn_with_low_birth_weight_map(domain,
                                          config,
                                          loc_level,
                                          show_test=False):
    def get_data_for(filters):
        filters['month'] = datetime(*filters['month'])
        queryset = AggChildHealthMonthly.objects.filter(**filters).values(
            '%s_name' % loc_level,
            '%s_map_location_name' % loc_level).annotate(
                low_birth=Sum('low_birth_weight_in_month'),
                in_month=Sum('born_in_month'),
            ).order_by('%s_name' % loc_level,
                       '%s_map_location_name' % loc_level)
        if not show_test:
            queryset = apply_exclude(domain, queryset)
        return queryset

    data_for_map, in_month_total, low_birth_total, average = generate_data_for_map(
        get_data_for(config), loc_level, 'low_birth', 'in_month', 20, 60)

    fills = OrderedDict()
    fills.update({'0%-20%': MapColors.PINK})
    fills.update({'20%-60%': MapColors.ORANGE})
    fills.update({'60%-100%': MapColors.RED})
    fills.update({'defaultFill': MapColors.GREY})

    gender_ignored, age_ignored, chosen_filters = chosen_filters_to_labels(
        config)

    return {
        "slug": "low_birth",
        "label":
        "Percent Newborns with Low Birth Weight{}".format(chosen_filters),
        "fills": fills,
        "rightLegend": {
            "average":
            average,
            "info":
            _(("Percentage of newborns with born with birth weight less than 2500 grams."
               "<br/><br/>"
               "Newborns with Low Birth Weight are closely associated with foetal and neonatal "
               "mortality and morbidity, inhibited growth and cognitive development, and chronic "
               "diseases later in life")),
            "extended_info": [{
                'indicator':
                'Total Number of Newborns born in given month{}:'.format(
                    chosen_filters),
                'value':
                indian_formatted_number(in_month_total)
            }, {
                'indicator':
                'Number of Newborns with LBW in given month{}:'.format(
                    chosen_filters),
                'value':
                indian_formatted_number(low_birth_total)
            }, {
                'indicator':
                '% newborns with LBW in given month{}:'.format(chosen_filters),
                'value':
                '%.2f%%' % (low_birth_total * 100 / float(in_month_total or 1))
            }, {
                'indicator':
                '% Unweighed{}:'.format(chosen_filters),
                'value':
                '%.2f%%' % ((in_month_total - low_birth_total) * 100 /
                            float(in_month_total or 1))
            }]
        },
        "data": dict(data_for_map),
    }
Exemple #39
0
    def output(self, export_path):
        """
        Method to export the Clinv inventory to ods.

        It generates the information needed to fill up a spreadsheet for a
        selected resource.

        Parameters:
            export_path (str): Path to export the inventory.
                (Default: ~/.local/share/clinv/inventory.ods)

        Returns:
            list: First row are the headers of the spreadsheet, followed
            by lines of data.
        """

        book = OrderedDict()
        book.update({"Projects": self._export_projects()})
        book.update({"Services": self._export_services()})
        book.update({"Informations": self._export_informations()})
        book.update({"EC2": self._export_ec2()})
        book.update({"RDS": self._export_rds()})
        book.update({"Route53": self._export_route53()})
        book.update({"S3": self._export_s3()})
        book.update({"People": self._export_people()})

        pyexcel.save_book_as(
            bookdict=book, dest_file_name=os.path.expanduser(export_path),
        )
Exemple #40
0
class Trainer(object):
    """Main class for data parallel training.

    This class supports synchronous distributed data parallel training,
    where multiple workers each have a full model replica and gradients
    are accumulated across workers before each update. We use
    :class:`~torch.nn.parallel.DistributedDataParallel` to handle
    communication of the gradients across workers.
    """
    def __init__(self,
                 args,
                 task,
                 model,
                 criterion,
                 dummy_batch=None,
                 oom_batch=None):
        self.args = args
        self.task = task

        # copy model and criterion to current device
        self._criterion = criterion
        self._model = model
        self.cuda = torch.cuda.is_available() and not args.cpu
        if args.fp16:
            self._criterion = self._criterion.half()
            self._model = self._model.half()
        if self.cuda:
            self._criterion = self._criterion.cuda()
            self._model = self._model.cuda()

        self._dummy_batch = dummy_batch
        self._oom_batch = oom_batch or dummy_batch

        self._lr_scheduler = None
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
        self._prev_grad_norm = None
        self._wrapped_criterion = None
        self._wrapped_model = None

        # Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
        # It is less flexible and syncs only the default stats.
        self._all_reduce_list = [0.0] * 6
        self.fast_stat_sync = args.fast_stat_sync

        self.init_meters(args)

    def init_meters(self, args):
        self.meters = OrderedDict()
        self.meters["train_loss"] = AverageMeter()
        self.meters["train_nll_loss"] = AverageMeter()
        self.meters["valid_loss"] = AverageMeter()
        self.meters["valid_nll_loss"] = AverageMeter()
        self.meters["wps"] = TimeMeter()  # words per second
        self.meters["ups"] = TimeMeter()  # updates per second
        self.meters["wpb"] = AverageMeter()  # words per batch
        self.meters["bsz"] = AverageMeter()  # sentences per batch
        self.meters["gnorm"] = AverageMeter()  # gradient norm
        self.meters["clip"] = AverageMeter()  # % of updates clipped
        self.meters["oom"] = AverageMeter()  # out of memory
        if args.fp16:
            self.meters["loss_scale"] = AverageMeter()  # dynamic loss scale
        self.meters["wall"] = TimeMeter()  # wall time in seconds
        self.meters["train_wall"] = StopwatchMeter(
        )  # train wall time in seconds

    @property
    def criterion(self):
        if self._wrapped_criterion is None:
            if (utils.has_parameters(self._criterion)
                    and self.args.distributed_world_size > 1
                    and not self.args.use_bmuf):
                self._wrapped_criterion = models.DistributedFairseqModel(
                    self.args, self._criterion)
            else:
                self._wrapped_criterion = self._criterion
        return self._wrapped_criterion

    @property
    def model(self):
        if self._wrapped_model is None:
            if self.args.distributed_world_size > 1 and not self.args.use_bmuf:
                self._wrapped_model = models.DistributedFairseqModel(
                    self.args, self._model)
            else:
                self._wrapped_model = self._model
        return self._wrapped_model

    @property
    def optimizer(self):
        if self._optimizer is None:
            self._build_optimizer()
        return self._optimizer

    @property
    def lr_scheduler(self):
        if self._lr_scheduler is None:
            self._build_optimizer()  # this will initialize self._lr_scheduler
        return self._lr_scheduler

    def _build_optimizer(self):
        params = list(
            filter(
                lambda p: p.requires_grad,
                chain(self.model.parameters(), self.criterion.parameters()),
            ))

        if self.args.fp16:
            if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
                print(
                    "| WARNING: your device does NOT support faster training with --fp16, "
                    "please switch to FP32 which is likely to be faster")
            if self.args.memory_efficient_fp16:
                self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
                    self.args, params)
            else:
                self._optimizer = optim.FP16Optimizer.build_optimizer(
                    self.args, params)
        else:
            if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
                print(
                    "| NOTICE: your device may support faster training with --fp16"
                )
            self._optimizer = optim.build_optimizer(self.args, params)

        if self.args.use_bmuf:
            self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)

        # We should initialize the learning rate scheduler immediately after
        # building the optimizer, so that the initial learning rate is set.
        self._lr_scheduler = lr_scheduler.build_lr_scheduler(
            self.args, self.optimizer)
        self._lr_scheduler.step_update(0)

    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
        if distributed_utils.is_master(self.args):  # only save one checkpoint
            extra_state["train_meters"] = self.meters
            checkpoint_utils.save_state(
                filename,
                self.args,
                self.get_model().state_dict(),
                self.get_criterion(),
                self.optimizer,
                self.lr_scheduler,
                self.get_num_updates(),
                self._optim_history,
                extra_state,
            )

    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = None, [], None

        bexists = PathManager.isfile(filename)
        if bexists:
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(state["model"],
                                                 strict=True,
                                                 args=self.args)
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(state["criterion"],
                                                         strict=True)
            except Exception:
                raise Exception(
                    "Cannot load model parameters from checkpoint {}; "
                    "please ensure that the architectures match.".format(
                        filename))

            extra_state = state["extra_state"]
            self._optim_history = state["optimizer_history"]
            last_optim_state = state.get("last_optimizer_state", None)

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert (
                last_optim["criterion_name"] ==
                self.get_criterion().__class__.__name__
            ), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
            assert (
                last_optim["optimizer_name"] ==
                self.optimizer.__class__.__name__
            ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(
                    last_optim["lr_scheduler_state"])
            self.optimizer.load_state_dict(last_optim_state,
                                           optimizer_overrides)

            self.set_num_updates(last_optim["num_updates"])

        if extra_state is not None:
            epoch = extra_state["train_iterator"]["epoch"]
            print("| loaded checkpoint {} (epoch {} @ {} updates)".format(
                filename, epoch, self.get_num_updates()))

            self.lr_step(epoch)

            if "train_meters" in extra_state and not reset_meters:
                self.meters.update(extra_state["train_meters"])
                del extra_state["train_meters"]

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in self.meters.values():
                    if isinstance(meter, TimeMeter):
                        meter.reset()
        else:
            print("| no existing checkpoint found {}".format(filename))

        return extra_state

    def get_train_iterator(
        self,
        epoch,
        combine=True,
        load_dataset=True,
        data_selector=None,
        shard_batch_itr=True,
    ):
        """Return an EpochBatchIterator over the training set for a given epoch."""
        if load_dataset:
            print("| loading train data for epoch {}".format(epoch))
            self.task.load_dataset(
                self.args.train_subset,
                epoch=epoch,
                combine=combine,
                data_selector=data_selector,
            )
        return self.task.get_batch_iterator(
            dataset=self.task.dataset(self.args.train_subset),
            max_tokens=self.args.max_tokens,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(
                self.task.max_positions(),
                self.model.max_positions(),
                self.args.max_tokens,
            ),
            ignore_invalid_inputs=True,
            required_batch_size_multiple=self.args.
            required_batch_size_multiple,
            seed=self.args.seed,
            num_shards=self.args.distributed_world_size
            if shard_batch_itr else 1,
            shard_id=self.args.distributed_rank if shard_batch_itr else 0,
            num_workers=self.args.num_workers,
            epoch=epoch,
        )

    def train_step(self, samples, dummy_batch=False, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters["train_wall"].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args.distributed_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)

                    if self.fast_stat_sync:
                        self._all_reduce_list[0] += sample_size
                        self._all_reduce_list[1] += logging_output.get(
                            "nsentences", 0.0)
                        self._all_reduce_list[2] += logging_output.get(
                            "loss", 0.0)
                        self._all_reduce_list[3] += logging_output.get(
                            "nll_loss", 0.0)
                        self._all_reduce_list[4] += logging_output.get(
                            "ntokens", 0.0)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    print(
                        "| WARNING: attempting to recover from OOM in forward/backward pass",
                        file=sys.stderr,
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

            if self.fast_stat_sync:
                self._all_reduce_list[5] += ooms

        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.fast_stat_sync:
            # rework all_gather_list
            all_reduce_list_tensor = torch.cuda.DoubleTensor(
                self._all_reduce_list)
            if self._sync_stats():
                torch.distributed.all_reduce(all_reduce_list_tensor)
            # Normalize loss and nll_loss by "sample_size"
            # and convert to log base 2
            all_reduce_list_tensor[2:4].div_(
                (all_reduce_list_tensor[0:1] *
                 torch.log(torch.cuda.DoubleTensor([2]))))
            self._all_reduce_list = all_reduce_list_tensor.tolist()
            logging_output = {}
            [
                sample_size,
                logging_output["nsentences"],
                logging_output["loss"],
                logging_output["nll_loss"],
                logging_output["ntokens"],
                ooms,
            ] = self._all_reduce_list
        elif self._sync_stats():
            logging_outputs, sample_sizes, ooms, prev_norms = zip(
                *distributed_utils.all_gather_list(
                    [
                        logging_outputs, sample_sizes, ooms,
                        self._prev_grad_norm
                    ],
                    max_size=getattr(self.args, 'all_gather_list_size', 16384),
                ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)

            if not self.args.use_bmuf:
                assert all(
                    norm == prev_norms[0] for norm in prev_norms
                ) or all(
                    math.isnan(norm) or math.isinf(norm) for norm in prev_norms
                ), "Fatal error: gradients are inconsistent between workers"

        self.meters["oom"].update(ooms, len(samples))
        if ooms == self.args.distributed_world_size * len(samples):
            print("| WARNING: OOM in all workers, skipping update")
            self.zero_grad()
            return None

        if not self.fast_stat_sync:
            # aggregate logging outputs and sample sizes
            logging_output = self.task.aggregate_logging_outputs(
                logging_outputs, self.get_criterion())
            sample_size = self.task.grad_denom(sample_sizes,
                                               self.get_criterion())

        if not all(k in logging_output for k in ["ntokens", "nsentences"]):
            raise Exception(
                ("Please update the {}.aggregate_logging_outputs() method to "
                 "return ntokens and nsentences").format(
                     self.task.__class__.__name__))

        try:
            # normalize grads by sample size
            if sample_size > 0:
                # In DDP: multiply gradients by #GPUs/#sample_size because
                # gradients are accumulated and divided by #GPUs before.
                # In BMUF: during non-sync gradients are divided by #sample_size
                # whereas during sync (while calculating global model): sync accumulate
                # gradients and divided by #GPUs and now multiply by #GPUs/#sample_size
                if self._sync_stats():
                    self.optimizer.multiply_grads(
                        self.args.distributed_world_size / float(sample_size))
                else:
                    self.optimizer.multiply_grads(1 / float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
            self._prev_grad_norm = grad_norm

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # task specific update per step
            self.task.update_step(self._num_updates)

            # update meters
            ntokens = logging_output.get("ntokens", 0)
            nsentences = logging_output.get("nsentences", 0)
            self.meters["wps"].update(ntokens)
            self.meters["ups"].update(1.0)
            self.meters["wpb"].update(ntokens)
            self.meters["bsz"].update(nsentences)
            self.meters["gnorm"].update(grad_norm)
            self.meters["clip"].update(1.0 if grad_norm > self.args.clip_norm
                                       and self.args.clip_norm > 0 else 0.0)
            self.meters["train_loss"].update(logging_output.get("loss", 0),
                                             sample_size)
            if "train_acc" in self.meters:
                self.meters["train_acc"].update(logging_output.get("acc", 0),
                                                sample_size)

            if "nll_loss" in logging_output:
                self.meters["train_nll_loss"].update(
                    logging_output.get("nll_loss", 0), ntokens)

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except OverflowError as e:
            print("| WARNING: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                print("| ERROR: OOM during optimization, irrecoverable")
            raise e

        if self.args.fp16:
            self.meters["loss_scale"].reset()
            self.meters["loss_scale"].update(self.optimizer.scaler.loss_scale)

        self.clear_buffered_stats()
        self.meters["train_wall"].stop()

        return logging_output

    def valid_step(self, sample, raise_oom=False):
        """Do forward pass in evaluation mode."""
        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()

            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                ignore_results = True
            else:
                ignore_results = False

            try:
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if not raise_oom:
                        print(
                            "| WARNING: ran out of memory in validation step, retrying batch"
                        )
                        for p in self.model.parameters():
                            if p.grad is not None:
                                p.grad = None  # free some memory
                        if self.cuda:
                            torch.cuda.empty_cache()
                        return self.valid_step(sample, raise_oom=True)
                raise e

            if ignore_results:
                logging_output, sample_size = {}, 0

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_output, sample_size = zip(
                *distributed_utils.all_gather_list(
                    [logging_output, sample_size],
                    max_size=getattr(self.args, 'all_gather_list_size', 16384),
                ))
            logging_output = list(logging_output)
            sample_size = list(sample_size)
        else:
            logging_output = [logging_output]
            sample_size = [sample_size]

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_output, self.get_criterion())
        sample_size = self.task.grad_denom(sample_size, self.get_criterion())

        # update meters for validation
        ntokens = logging_output.get("ntokens", 0)
        self.meters["valid_loss"].update(logging_output.get("loss", 0),
                                         sample_size)
        if "valid_acc" in self.meters:
            self.meters["valid_acc"].update(logging_output.get("acc", 0),
                                            sample_size)

        if "nll_loss" in logging_output:
            self.meters["valid_nll_loss"].update(
                logging_output.get("nll_loss", 0), ntokens)

        return logging_output

    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
        self.train_step(dummy_batch, dummy_batch=True)
        self.zero_grad()

    def handle_ooms(self, number_of_ooms):
        """
        c10d accumulates/syncs gradients between gpus during backward pass.
        In case of OOMs, gpus may fail to sync, so we manually iterate
        extra to make sure each gpu makes same number of iterations.
        """
        for _ in range(number_of_ooms):
            self.train_step([self._oom_batch], True)

    def zero_grad(self):
        self.optimizer.zero_grad()

    def clear_buffered_stats(self):
        self._all_reduce_list = [0.0] * 6

    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
        self.lr_scheduler.step(epoch, val_loss)
        # prefer updating the LR based on the number of steps
        return self.lr_step_update()

    def lr_step_update(self):
        """Update the learning rate after each update."""
        return self.lr_scheduler.step_update(self.get_num_updates())

    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_model(self):
        """Get the (non-wrapped) model instance."""
        return self._model

    def get_criterion(self):
        """Get the (non-wrapped) criterion instance."""
        return self._criterion

    def get_meter(self, name):
        """Get a specific meter by name."""
        if name not in self.meters:
            return None
        return self.meters[name]

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        self._num_updates = num_updates
        self.lr_step_update()

    def _prepare_sample(self, sample):
        if sample is None or len(sample) == 0:
            return None

        if self.cuda:
            sample = utils.move_to_cuda(sample)

        def apply_half(t):
            if t.dtype is torch.float32:
                return t.half()
            return t

        if self.args.fp16:
            sample = utils.apply_to_sample(apply_half, sample)

        return sample

    def _set_seed(self):
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        if self.cuda:
            torch.cuda.manual_seed(seed)

    def _sync_stats(self):
        # Return True if it's using multiple GPUs and DDP or multiple GPUs with
        # BMUF and it's a bmuf sync with warmup iterations completed before.
        return self.args.distributed_world_size > 1 and (
            (not self.args.use_bmuf) or
            (self.args.use_bmuf and
             (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 and
             (self.get_num_updates() + 1) > self.args.warmup_iterations))

    def _log_oom(self, exc):
        msg = "| OOM: Ran out of memory with exception: {}".format(exc)
        # TODO: print should really go to logger, this print goes
        # to stderr, which is buffered, which in many cases is not
        # printed out if another exception happens.
        # NB(jerry): added a flush to mitigate this
        print(msg, file=sys.stderr)
        if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
            for device_idx in range(torch.cuda.device_count()):
                print(torch.cuda.memory_summary(device=device_idx),
                      file=sys.stderr)
        sys.stderr.flush()
class Network:
    """Generic network abstraction.

    Acts as a convenience wrapper for a parameterized network construction
    function, providing several utility methods and convenient access to
    the inputs/outputs/weights.

    Network objects can be safely pickled and unpickled for long-term
    archival purposes. The pickling works reliably as long as the underlying
    network construction function is defined in a standalone Python module
    that has no side effects or application-specific imports.

    Args:
        name: Network name. Used to select TensorFlow name and variable scopes.
        func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
        static_kwargs: Keyword arguments to be passed in to the network construction function.

    Attributes:
        name: User-specified name, defaults to build func name if None.
        scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
        static_kwargs: Arguments passed to the user-supplied build func.
        components: Container for sub-networks. Passed to the build func, and retained between calls.
        num_inputs: Number of input tensors.
        num_outputs: Number of output tensors.
        input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
        output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
        input_shape: Short-hand for input_shapes[0].
        output_shape: Short-hand for output_shapes[0].
        input_templates: Input placeholders in the template graph.
        output_templates: Output tensors in the template graph.
        input_names: Name string for each input.
        output_names: Name string for each output.
        own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
        vars: All variables (local_name => var).
        trainables: All trainable variables (local_name => var).
        var_global_to_local: Mapping from variable global names to local names.
    """

    def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
        tfutil.assert_tf_initialized()
        assert isinstance(name, str) or name is None
        assert func_name is not None
        assert isinstance(func_name, str) or util.is_top_level_function(func_name)
        assert util.is_pickleable(static_kwargs)

        self._init_fields()
        self.name = name
        self.static_kwargs = util.EasyDict(static_kwargs)

        # Locate the user-specified network build function.
        if util.is_top_level_function(func_name):
            func_name = util.get_top_level_function_name(func_name)
        module, self._build_func_name = util.get_module_from_obj_name(func_name)
        self._build_func = util.get_obj_from_module(module, self._build_func_name)
        assert callable(self._build_func)

        # Dig up source code for the module containing the build function.
        self._build_module_src = _import_module_src.get(module, None)
        if self._build_module_src is None:
            self._build_module_src = inspect.getsource(module)

        # Init TensorFlow graph.
        self._init_graph()
        self.reset_own_vars()

    def _init_fields(self) -> None:
        self.name = None
        self.scope = None
        self.static_kwargs = util.EasyDict()
        self.components = util.EasyDict()
        self.num_inputs = 0
        self.num_outputs = 0
        self.input_shapes = [[]]
        self.output_shapes = [[]]
        self.input_shape = []
        self.output_shape = []
        self.input_templates = []
        self.output_templates = []
        self.input_names = []
        self.output_names = []
        self.own_vars = OrderedDict()
        self.vars = OrderedDict()
        self.trainables = OrderedDict()
        self.var_global_to_local = OrderedDict()

        self._build_func = None  # User-supplied build function that constructs the network.
        self._build_func_name = None  # Name of the build function.
        self._build_module_src = None  # Full source code of the module containing the build function.
        self._run_cache = dict()  # Cached graph data for Network.run().

    def _init_graph(self) -> None:
        # Collect inputs.
        self.input_names = []

        for param in inspect.signature(self._build_func).parameters.values():
            if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
                self.input_names.append(param.name)

        self.num_inputs = len(self.input_names)
        assert self.num_inputs >= 1

        # Choose name and scope.
        if self.name is None:
            self.name = self._build_func_name
        assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
        with tf.name_scope(''):
            self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)

        # Finalize build func kwargs.
        build_kwargs = dict(self.static_kwargs)
        build_kwargs["is_template_graph"] = True
        build_kwargs["components"] = self.components

        # Build template graph.
        with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope):  # ignore surrounding scopes
            assert tf.get_variable_scope().name == self.scope
            assert tf.get_default_graph().get_name_scope() == self.scope
            with tf.control_dependencies(None):  # ignore surrounding control dependencies
                self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
                out_expr = self._build_func(*self.input_templates, **build_kwargs)

        # Collect outputs.
        assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
        self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
        self.num_outputs = len(self.output_templates)
        assert self.num_outputs >= 1
        assert all(tfutil.is_tf_expression(t) for t in self.output_templates)

        # Perform sanity checks.
        if any(t.shape.ndims is None for t in self.input_templates):
            raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
        if any(t.shape.ndims is None for t in self.output_templates):
            raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
        if any(not isinstance(comp, Network) for comp in self.components.values()):
            raise ValueError("Components of a Network must be Networks themselves.")
        if len(self.components) != len(set(comp.name for comp in self.components.values())):
            raise ValueError("Components of a Network must have unique names.")

        # List inputs and outputs.
        self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
        self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
        self.input_shape = self.input_shapes[0]
        self.output_shape = self.output_shapes[0]
        self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]

        # List variables.
        self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
        self.vars = OrderedDict(self.own_vars)
        self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
        self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
        self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())

    def reset_own_vars(self) -> None:
        """Re-initialize all variables of this network, excluding sub-networks."""
        tfutil.run([var.initializer for var in self.own_vars.values()])

    def reset_vars(self) -> None:
        """Re-initialize all variables of this network, including sub-networks."""
        tfutil.run([var.initializer for var in self.vars.values()])

    def reset_trainables(self) -> None:
        """Re-initialize all trainable variables of this network, including sub-networks."""
        tfutil.run([var.initializer for var in self.trainables.values()])

    def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
        """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
        assert len(in_expr) == self.num_inputs
        assert not all(expr is None for expr in in_expr)

        # Finalize build func kwargs.
        build_kwargs = dict(self.static_kwargs)
        build_kwargs.update(dynamic_kwargs)
        build_kwargs["is_template_graph"] = False
        build_kwargs["components"] = self.components

        # Build TensorFlow graph to evaluate the network.
        with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
            assert tf.get_variable_scope().name == self.scope
            valid_inputs = [expr for expr in in_expr if expr is not None]
            final_inputs = []
            for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
                if expr is not None:
                    expr = tf.identity(expr, name=name)
                else:
                    expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
                final_inputs.append(expr)
            out_expr = self._build_func(*final_inputs, **build_kwargs)

        # Propagate input shapes back to the user-specified expressions.
        for expr, final in zip(in_expr, final_inputs):
            if isinstance(expr, tf.Tensor):
                expr.set_shape(final.shape)

        # Express outputs in the desired format.
        assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
        if return_as_list:
            out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
        return out_expr

    def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
        """Get the local name of a given variable, without any surrounding name scopes."""
        assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
        global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
        return self.var_global_to_local[global_name]

    def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
        """Find variable by local or global name."""
        assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
        return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name

    def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
        """Get the value of a given variable as NumPy array.
        Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
        return self.find_var(var_or_local_name).eval()

    def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
        """Set the value of a given variable based on the given NumPy array.
        Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
        tfutil.set_vars({self.find_var(var_or_local_name): new_value})

    def __getstate__(self) -> dict:
        """Pickle export."""
        state = dict()
        state["version"]            = 3
        state["name"]               = self.name
        state["static_kwargs"]      = dict(self.static_kwargs)
        state["components"]         = dict(self.components)
        state["build_module_src"]   = self._build_module_src
        state["build_func_name"]    = self._build_func_name
        state["variables"]          = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
        return state

    def __setstate__(self, state: dict) -> None:
        """Pickle import."""
        # pylint: disable=attribute-defined-outside-init
        tfutil.assert_tf_initialized()
        self._init_fields()

        # Execute custom import handlers.
        for handler in _import_handlers:
            state = handler(state)

        # Set basic fields.
        assert state["version"] in [2, 3]
        self.name = state["name"]
        self.static_kwargs = util.EasyDict(state["static_kwargs"])
        self.components = util.EasyDict(state.get("components", {}))
        self._build_module_src = state["build_module_src"]
        self._build_func_name = state["build_func_name"]

        # Create temporary module from the imported source code.
        module_name = "_tflib_network_import_" + uuid.uuid4().hex
        module = types.ModuleType(module_name)
        sys.modules[module_name] = module
        _import_module_src[module] = self._build_module_src
        exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used

        # Locate network build function in the temporary module.
        self._build_func = util.get_obj_from_module(module, self._build_func_name)
        assert callable(self._build_func)

        # Init TensorFlow graph.
        self._init_graph()
        self.reset_own_vars()
        tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})

    def clone(self, name: str = None, **new_static_kwargs) -> "Network":
        """Create a clone of this network with its own copy of the variables."""
        # pylint: disable=protected-access
        net = object.__new__(Network)
        net._init_fields()
        net.name = name if name is not None else self.name
        net.static_kwargs = util.EasyDict(self.static_kwargs)
        net.static_kwargs.update(new_static_kwargs)
        net._build_module_src = self._build_module_src
        net._build_func_name = self._build_func_name
        net._build_func = self._build_func
        net._init_graph()
        net.copy_vars_from(self)
        return net

    def copy_own_vars_from(self, src_net: "Network") -> None:
        """Copy the values of all variables from the given network, excluding sub-networks."""
        names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
        tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))

    def copy_vars_from(self, src_net: "Network") -> None:
        """Copy the values of all variables from the given network, including sub-networks."""
        names = [name for name in self.vars.keys() if name in src_net.vars]
        tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))

    def copy_trainables_from(self, src_net: "Network") -> None:
        """Copy the values of all trainable variables from the given network, including sub-networks."""
        names = [name for name in self.trainables.keys() if name in src_net.trainables]
        tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))

    def copy_compatible_trainables_from(self, src_net: "Network") -> None:
        """Copy the compatible values of all trainable variables from the given network, including sub-networks"""
        names = []
        for name in self.trainables.keys():
            if name not in src_net.trainables:
                print("Not restoring (not present):     {}".format(name))
            elif self.trainables[name].shape != src_net.trainables[name].shape:
                print("Not restoring (different shape): {}".format(name))

            if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
                names.append(name)

        tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))

    def apply_swa(self, src_net, epoch):
        """Perform stochastic weight averaging on the compatible values of all trainable variables from the given network, including sub-networks"""
        names = []
        for name in self.trainables.keys():
            if name not in src_net.trainables:
                print("Not restoring (not present):     {}".format(name))
            elif self.trainables[name].shape != src_net.trainables[name].shape:
                print("Not restoring (different shape): {}".format(name))

            if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
                names.append(name)

        scale_new_data = 1.0 / (epoch + 1)
        scale_moving_average = (1.0 - scale_new_data)
        tfutil.set_vars(tfutil.run({self.vars[name]: (src_net.vars[name] * scale_new_data + self.vars[name] * scale_moving_average) for name in names}))

    def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
        """Create new network with the given parameters, and copy all variables from this network."""
        if new_name is None:
            new_name = self.name
        static_kwargs = dict(self.static_kwargs)
        static_kwargs.update(new_static_kwargs)
        net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
        net.copy_vars_from(self)
        return net

    def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
        """Construct a TensorFlow op that updates the variables of this network
        to be slightly closer to those of the given network."""
        with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
            ops = []
            for name, var in self.vars.items():
                if name in src_net.vars:
                    cur_beta = beta if name in self.trainables else beta_nontrainable
                    new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
                    ops.append(var.assign(new_value))
            return tf.group(*ops)

    def run(self,
            *in_arrays: Tuple[Union[np.ndarray, None], ...],
            input_transform: dict = None,
            output_transform: dict = None,
            return_as_list: bool = False,
            print_progress: bool = False,
            minibatch_size: int = None,
            num_gpus: int = 1,
            assume_frozen: bool = False,
            custom_inputs=None,
            **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
        """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).

        Args:
            input_transform:    A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
                                The dict must contain a 'func' field that points to a top-level function. The function is called with the input
                                TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
            output_transform:   A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
                                The dict must contain a 'func' field that points to a top-level function. The function is called with the output
                                TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
            return_as_list:     True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
            print_progress:     Print progress to the console? Useful for very large input arrays.
            minibatch_size:     Maximum minibatch size to use, None = disable batching.
            num_gpus:           Number of GPUs to use.
            assume_frozen:      Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
            dynamic_kwargs:     Additional keyword arguments to be passed into the network build function.
            custom_inputs:      Allow to use another Tensor as input instead of default Placeholders
        """
        assert len(in_arrays) == self.num_inputs
        assert not all(arr is None for arr in in_arrays)
        assert input_transform is None or util.is_top_level_function(input_transform["func"])
        assert output_transform is None or util.is_top_level_function(output_transform["func"])
        output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
        num_items = in_arrays[0].shape[0]
        if minibatch_size is None:
            minibatch_size = num_items

        # Construct unique hash key from all arguments that affect the TensorFlow graph.
        key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
        def unwind_key(obj):
            if isinstance(obj, dict):
                return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
            if callable(obj):
                return util.get_top_level_function_name(obj)
            return obj
        key = repr(unwind_key(key))

        # Build graph.
        if key not in self._run_cache:
            with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
                if custom_inputs is not None:
                    with tf.device("/gpu:0"):
                        in_expr = [input_builder(name) for input_builder, name in zip(custom_inputs, self.input_names)]
                        in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
                else:
                    with tf.device("/cpu:0"):
                        in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
                        in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))

                out_split = []
                for gpu in range(num_gpus):
                    with tf.device("/gpu:%d" % gpu):
                        net_gpu = self.clone() if assume_frozen else self
                        in_gpu = in_split[gpu]

                        if input_transform is not None:
                            in_kwargs = dict(input_transform)
                            in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
                            in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)

                        assert len(in_gpu) == self.num_inputs
                        out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)

                        if output_transform is not None:
                            out_kwargs = dict(output_transform)
                            out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
                            out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)

                        assert len(out_gpu) == self.num_outputs
                        out_split.append(out_gpu)

                with tf.device("/cpu:0"):
                    out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
                    self._run_cache[key] = in_expr, out_expr

        # Run minibatches.
        in_expr, out_expr = self._run_cache[key]
        out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]

        for mb_begin in range(0, num_items, minibatch_size):
            if print_progress:
                print("\r%d / %d" % (mb_begin, num_items), end="")

            mb_end = min(mb_begin + minibatch_size, num_items)
            mb_num = mb_end - mb_begin
            mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
            mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))

            for dst, src in zip(out_arrays, mb_out):
                dst[mb_begin: mb_end] = src

        # Done.
        if print_progress:
            print("\r%d / %d" % (num_items, num_items))

        if not return_as_list:
            out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
        return out_arrays

    def list_ops(self) -> List[TfExpression]:
        include_prefix = self.scope + "/"
        exclude_prefix = include_prefix + "_"
        ops = tf.get_default_graph().get_operations()
        ops = [op for op in ops if op.name.startswith(include_prefix)]
        ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
        return ops

    def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
        """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
        individual layers of the network. Mainly intended to be used for reporting."""
        layers = []

        def recurse(scope, parent_ops, parent_vars, level):
            # Ignore specific patterns.
            if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
                return

            # Filter ops and vars by scope.
            global_prefix = scope + "/"
            local_prefix = global_prefix[len(self.scope) + 1:]
            cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
            cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
            if not cur_ops and not cur_vars:
                return

            # Filter out all ops related to variables.
            for var in [op for op in cur_ops if op.type.startswith("Variable")]:
                var_prefix = var.name + "/"
                cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]

            # Scope does not contain ops as immediate children => recurse deeper.
            contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops)
            if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
                visited = set()
                for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
                    token = rel_name.split("/")[0]
                    if token not in visited:
                        recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
                        visited.add(token)
                return

            # Report layer.
            layer_name = scope[len(self.scope) + 1:]
            layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
            layer_trainables = [var for _name, var in cur_vars if var.trainable]
            layers.append((layer_name, layer_output, layer_trainables))

        recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
        return layers

    def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
        """Print a summary table of the network structure."""
        rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
        rows += [["---"] * 4]
        total_params = 0

        for layer_name, layer_output, layer_trainables in self.list_layers():
            num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)
            weights = [var for var in layer_trainables if var.name.endswith("/weight:0") or var.name.endswith("/weight_1:0")]
            weights.sort(key=lambda x: len(x.name))
            if len(weights) == 0 and len(layer_trainables) == 1:
                weights = layer_trainables
            total_params += num_params

            if not hide_layers_with_no_params or num_params != 0:
                num_params_str = str(num_params) if num_params > 0 else "-"
                output_shape_str = str(layer_output.shape)
                weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
                rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]

        rows += [["---"] * 4]
        rows += [["Total", str(total_params), "", ""]]

        widths = [max(len(cell) for cell in column) for column in zip(*rows)]
        print()
        for row in rows:
            print("  ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
        print()

    def setup_weight_histograms(self, title: str = None) -> None:
        """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
        if title is None:
            title = self.name

        with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
            for local_name, var in self.trainables.items():
                if "/" in local_name:
                    p = local_name.split("/")
                    name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
                else:
                    name = title + "_toplevel/" + local_name

                tf.summary.histogram(name, var)
Exemple #42
0
class ExpectedSAC(SoftActorCritic):
    """
    Compute

    E_{a \sim \pi(. | s)}[Q(s, a) - \log \pi(a | s)]

    in closed form
    """
    def __init__(self,
                 *args,
                 expected_qf_estim_strategy='exact',
                 expected_log_pi_estim_strategy='exact',
                 **kwargs):
        """

        :param args:
        :param expected_qf_estim_strategy: String describing how to estimate
            E[Q(s, A)]:
                'exact': estimate exactly by convolving Q with Gaussian
                'mean_action': estimate with Q(s, E[A])
                'sample': estimate with one sample of Q(s, A)
        :param expected_log_pi_estim_strategy: String describing how to
            estimate E[log \pi(A | s)]
                'exact': compute in closed form
                'mean_action': estimate with log \pi(E[A] | s)
                'sample': estimate with one sample of log \pi(A | s)
        :param kwargs:
        """
        super().__init__(*args, **kwargs)
        assert expected_qf_estim_strategy in [EXACT, MEAN_ACTION, SAMPLE]
        assert expected_log_pi_estim_strategy in [EXACT, MEAN_ACTION, SAMPLE]
        self.expected_qf_estim_strategy = expected_qf_estim_strategy
        self.expected_log_pi_estim_strategy = expected_log_pi_estim_strategy

    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        q_pred = self.qf(obs, actions)
        v_pred = self.vf(obs)
        # Make sure policy accounts for squashing functions like tanh correctly!
        (new_actions, policy_mean, policy_log_std, log_pi, entropy,
         policy_stds, log_pi_mean) = self.policy(
             obs,
             return_log_prob=True,
             return_entropy=(self.expected_log_pi_estim_strategy == EXACT),
             return_log_prob_of_mean=(
                 self.expected_log_pi_estim_strategy == MEAN_ACTION),
         )
        expected_log_pi = -entropy
        """
        QF Loss
        """
        target_v_values = self.target_vf(next_obs)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_v_values
        qf_loss = self.qf_criterion(q_pred, q_target.detach())
        """
        VF Loss
        """
        q_new_actions = self.qf(obs, new_actions)
        if self.expected_qf_estim_strategy == EXACT:
            expected_q = self.qf(obs, policy_mean, action_stds=policy_stds)
        elif self.expected_qf_estim_strategy == MEAN_ACTION:
            expected_q = self.qf(obs, policy_mean)
        elif self.expected_qf_estim_strategy == SAMPLE:
            expected_q = q_new_actions
        else:
            raise TypeError(
                "Invalid E[Q(s, a)] estimation strategy: {}".format(
                    self.expected_qf_estim_strategy))
        if self.expected_log_pi_estim_strategy == EXACT:
            expected_log_pi_target = expected_log_pi
        elif self.expected_log_pi_estim_strategy == MEAN_ACTION:
            expected_log_pi_target = log_pi_mean
        elif self.expected_log_pi_estim_strategy == SAMPLE:
            expected_log_pi_target = log_pi
        else:
            raise TypeError(
                "Invalid E[log pi(a|s)] estimation strategy: {}".format(
                    self.expected_log_pi_estim_strategy))
        v_target = expected_q - expected_log_pi_target
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        """
        Policy Loss
        """
        # paper says to do + but Tuomas said that's a typo. Do Q - V.
        log_policy_target = q_new_actions - v_pred
        policy_loss = (log_pi * (log_pi - log_policy_target).detach()).mean()
        policy_reg_loss = self.policy_reg_weight * ((policy_mean**2).mean() +
                                                    (policy_log_std**2).mean())
        policy_loss = policy_loss + policy_reg_loss
        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self._update_target_network()
        """
        Save some statistics for eval
        """
        self.eval_statistics = OrderedDict()
        self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
        self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
        self.eval_statistics['Policy Loss'] = np.mean(
            ptu.get_numpy(policy_loss))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'Q Predictions',
                ptu.get_numpy(q_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'V Predictions',
                ptu.get_numpy(v_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
Exemple #43
0
def _process_signature(xsd_type, args, kwargs):
    """Return a dict with the args/kwargs mapped to the field name.

    Special handling is done for Choice elements since we need to record which
    element the user intends to use.

    :param fields: List of tuples (name, element)
    :type fields: list
    :param args: arg tuples
    :type args: tuple
    :param kwargs: kwargs
    :type kwargs: dict


    """
    result = OrderedDict()
    # Process the positional arguments. args is currently still modified
    # in-place here
    if args:
        args = list(args)
        num_args = len(args)

        for element_name, element in xsd_type.elements_nested:
            values, args = element.parse_args(args)
            if not values:
                break
            result.update(values)

    if args:
        for attribute_name, attribute in xsd_type.attributes:
            result[attribute_name] = args.pop(0)

    if args:
        raise TypeError(
            "__init__() takes at most %s positional arguments (%s given)" % (
                len(result), num_args))

    # Process the named arguments (sequence/group/all/choice). The
    # available_kwargs set is modified in-place.
    available_kwargs = set(kwargs.keys())
    for element_name, element in xsd_type.elements_nested:
        if element.accepts_multiple:
            values = element.parse_kwargs(kwargs, element_name, available_kwargs)
        else:
            values = element.parse_kwargs(kwargs, None, available_kwargs)

        if values is not None:
            for key, value in values.items():
                if key not in result:
                    result[key] = value

    # Process the named arguments for attributes
    if available_kwargs:
        for attribute_name, attribute in xsd_type.attributes:
            if attribute_name in available_kwargs:
                available_kwargs.remove(attribute_name)
                result[attribute_name] = kwargs[attribute_name]

    if available_kwargs:
        raise TypeError((
            "%s() got an unexpected keyword argument %r. " +
            "Signature: (%s)"
        ) % (
            xsd_type.qname or 'ComplexType',
            next(iter(available_kwargs)),
            xsd_type.signature()))

    return result
Exemple #44
0
    def _classic_approx(k, n, m, negative_evals=False):
        """Approximate arcsin(1/x) for controlled-rotation.

        This method calculates the binning of arcsin(1/x) function using k
        bits fixed point numbers and n bit accuracy.

        Args:
            k (int): register length
            n (int): num bits following most-significant qubit taken into account
            m (int): length of sub string of n-qubit pattern
            negative_evals (bool): flag for using first qubit as sign bit

        Returns:
            dict: Dictionary containing values of approximated and binned values.
        """
        def bin_to_num(binary):
            """Convert to numeric"""
            num = np.sum([
                2**-(n + 1) for n, i in enumerate(reversed(binary)) if i == '1'
            ])
            return num

        def get_est_lamb(pattern, fo, n, k):
            """Estimate the bin mid point and return the float value"""
            if fo - n > 0:
                remainder = sum(
                    [2**-i for i in range(k - (fo - n - 1), k + 1)])
                return bin_to_num(pattern) + remainder / 2
            return bin_to_num(pattern)

        # pylint: disable=import-outside-toplevel
        from collections import OrderedDict
        output = OrderedDict()
        fo = None
        for fo in range(k - 1, n - 1, -1):
            # skip first bit if negative ev are used
            if negative_evals and fo == k - 1:
                continue
            # init bit string
            vec = ['0'] * k
            # set most significant bit
            vec[fo] = '1'
            # iterate over all 2^m combinations = sub string in n-bit pattern
            for pattern_ in itertools.product('10', repeat=m):
                app_pattern_array = []
                lambda_array = []
                fo_array = []
                # iterate over all 2^(n-m) combinations
                for appendpat in itertools.product('10', repeat=n - m):
                    # combine both generated patterns
                    pattern = pattern_ + appendpat
                    vec[fo - n:fo] = pattern
                    # estimate bin mid point
                    e_l = get_est_lamb(vec.copy(), fo, n, k)
                    lambda_array.append(e_l)
                    fo_array.append(fo)
                    app_pattern_array.append(list(reversed(appendpat)))

                # rewrite first-one to correct index in QuantumRegister
                fo_pos = k - fo - 1
                if fo_pos in list(output.keys()):
                    prev_res = output[fo_pos]
                else:
                    prev_res = []

                output.update({
                    fo_pos:
                    prev_res + [(list(
                        reversed(pattern_)), app_pattern_array, lambda_array)]
                })

        # last iterations, only last n bits != 0
        last_fo = fo
        vec = ['0'] * k
        for pattern_ in itertools.product('10', repeat=m):
            app_pattern_array = []
            lambda_array = []
            fo_array = []
            for appendpat in itertools.product('10', repeat=n - m):
                pattern = list(pattern_ + appendpat).copy()
                if '1' not in pattern and (not negative_evals):
                    continue
                if '1' not in pattern and negative_evals:
                    e_l = 0.5
                else:
                    vec[last_fo - n:last_fo] = list(pattern)
                    e_l = get_est_lamb(vec.copy(), last_fo, n, k)
                lambda_array.append(e_l)
                fo_array.append(last_fo - 1)
                app_pattern_array.append(list(reversed(appendpat)))

            fo_pos = k - last_fo
            if fo_pos in list(output.keys()):
                prev_res = output[fo_pos]
            else:
                prev_res = []

            output.update({
                fo_pos:
                prev_res +
                [(list(reversed(pattern_)), app_pattern_array, lambda_array)]
            })

        return output
def run(config, args):

    data = run_query('try_commit_messages', config, args)['data']

    count = defaultdict(int)
    count['total'] = len(data['message'])

    users = defaultdict(set)
    users['total'] = set(data['user'])

    # Order is important as the search stops after the first successful test.
    d = OrderedDict()
    d.update(subcommand('syntax'))
    d['vanilla'] = {
        'test': 'try:',
        'method': 'vanilla try syntax',
    }
    d.update(subcommand('fuzzy'))
    d.update(subcommand('again'))
    d.update(subcommand('empty'))
    d.update(subcommand('release'))
    d.update(subcommand('coverage'))
    d.update(subcommand('chooser'))
    d['other'] = {
        'test': '',
        'method': 'other',
    }
    d['total'] = {
        'test': None,
        'method': 'total',
    }

    data = zip(data['user'], data['message'])
    for user, message in data:
        for k, v in d.items():
            if v['test'] in message:
                count[k] += 1
                users[k].add(user)
                break

    def fmt(key):
        percent = round(float(count[key]) / count['total'] * 100, 1)
        return [d[key]['method'], count[key], percent, len(users[key]), round(float(count[key]) / len(users[key]), 2)]  # noqa

    data = [['Method', 'Pushes', 'Percent', 'Users', 'Push / User']]
    for k, v in sorted(count.items(), key=lambda t: t[1], reverse=True):
        data.append(fmt(k))
    return data
Exemple #46
0
def dict_representer(dumper, data):
    ordered = OrderedDict()
    if "label" in data:
        ordered["label"] = data["label"]
    ordered.update(sorted(data.items()))
    return dumper.represent_dict(DictRenderWrapper(ordered))
class CnnEvaluator:
  """
  A multimodal long short-term memory (LSTM) generator
  """
# ========================================================================================
  def __init__(self, params,Wemb = None):

    self.word_encoding_size = params.get('word_encoding_size', 512)
    image_feat_size = params.get('image_feat_size', 512)
    aux_inp_size = params.get('aux_inp_size', -1)

    self.n_fmaps_psz = params.get('n_fmaps_psz', 100)
    self.filter_hs = params.get('filter_hs', [])

    # Used for dropout.
    self.use_noise = theano.shared(numpy_floatX(0.))


    vocabulary_size = params.get('vocabulary_size',-1)
    self.sent_enc_size = params.get('sent_encoding_size',-1)# size of CNN vectors hardcoded here

    model = OrderedDict()
    # Recurrent weights: take x_t, h_{t-1}, and bias unit
    # and produce the 3 gates and the input to cell signal
    if Wemb == None:
        model['Wemb'] = initwTh(vocabulary_size-1, self.word_encoding_size) # word encoder
    model['WIemb'] = initwTh(image_feat_size, self.sent_enc_size,inittype='xavier') # image encoder
    #model['b_Img'] = np.zeros((self.sent_enc_size)).astype(config.floatX)


    model['Wfc_sent'] = initwTh(self.n_fmaps_psz * len(self.filter_hs), self.sent_enc_size,inittype='xavier') # word encoder
    #model['bfc_sent'] = np.zeros((self.sent_enc_size)).astype(config.floatX)

    #if params['advers_gen']:
        # Add a merging layer
        #model['Wm_sent'] = initwTh(self.sent_enc_size, params.get('merge_dim',50),inittype='xavier') # word encoder
        #model['Wm_img'] = initwTh(self.sent_enc_size, params.get('merge_dim',50),inittype='xavier') # word encoder
        #model['b_m'] = np.zeros((params.get('merge_dim',50))).astype(config.floatX)
        ## Final output weights
        #model['W_out'] = initwTh(params.get('merge_dim',50),1, 1.0) # word encoder

    # Decoder weights (e.g. mapping to vocabulary)

    update_list = ['Wemb','Wfc_sent','WIemb']
    self.regularize = ['Wemb','Wfc_sent','WIemb']

    if params.get('en_aux_inp',0) and not params['advers_gen']:
        model['WIemb_aux'] = initwTh(aux_inp_size, self.sent_enc_size) # image encoder
        model['b_Img_aux'] = np.zeros((self.sent_enc_size)).astype(config.floatX)

    self.model_th = self.init_tparams(model)

    # Share the Word embeddings with the generator model
    if Wemb != None:
        self.model_th['Wemb'] = Wemb
    self.updateP = OrderedDict()
    for vname in update_list:
        self.updateP[vname] = self.model_th[vname]

    # Instantiate a conv layer already so we don't end up creating new weights
    if params['advers_gen']:
        filter_w = self.word_encoding_size
        self.conv_layers = []
        max_sent_len = params.get('maxlen',0)
        for filter_h in self.filter_hs:
            filter_shape = (self.n_fmaps_psz, params['n_gen_samples'], filter_h, filter_w)
            pool_size = (max_sent_len-filter_h+1, self.word_encoding_size-filter_w+1)
            conv_layer = batch2DConvPoolLayer(filter_shape=filter_shape,
                                              poolsize=pool_size,
                                              non_linear=params['conv_non_linear'])
            # flatten all the filter outputs to a single vector
            self.conv_layers.append(conv_layer)
            self.updateP.update(conv_layer.params)
            self.regularize.extend(conv_layer.regularize)
            self.model_th.update(conv_layer.params)


# ========================================================================================
  def init_tparams(self,params):
    tparams = OrderedDict()
    for kk, pp in params.iteritems():
        tparams[kk] = theano.shared(params[kk], name=kk)
    return tparams

# ========================================================================================
 # BUILD CNN evaluator forward propogation model
  def build_model(self, tparams, options, xI=None, prior_inp_list = []):
    trng = RandomStreams()
    rng = np.random.RandomState()

    # Used for dropout.
    use_noise = theano.shared(numpy_floatX(0.))

    xWi = T.matrix('xW', dtype='int64')
    # Now input is transposed compared to the generator!!
    xW = xWi.T
    n_samples = xW.shape[0]
    n_words= xW.shape[1]

    Words = T.concatenate([tparams['Wemb'], T.alloc(numpy_floatX(0.),1,self.word_encoding_size)],axis=0)
    embW = Words[xW.flatten()].reshape([options['batch_size'], 1, n_words, self.word_encoding_size])

    if options.get('use_dropout',0):
        embW = dropout_layer(embW, use_noise, trng, options['drop_prob_encoder'], shp = embW.shape)

    sent_emb, cnn_out , tparams = self.sent_conv_layer(tparams, options, embW, options['batch_size'], use_noise, trng)

    if xI == None:
        xI = T.matrix('xI', dtype=config.floatX)
        xI_is_inp = True
    else:
        xI_is_inp = False


    if options.get('mode','batchtrain') != 'batchtrain':
        posSamp = T.ivector('posSamp')

    if xI_is_inp:
        embImg = T.dot(xI, tparams['WIemb']) + tparams['b_Img']
    else:
        embImg = xI + tparams['b_Img']

    if options.get('use_dropout',0):
        embImg = dropout_layer(embImg, use_noise, trng, options['drop_prob_encoder'], shp = embImg.shape)


    #-------------------------------------------------------------------------------------------------------------#
    # Curr prob is computed by applying softmax over (I0,c0), (I0,c1),... (I0,cn-1) pairs
    # It could also be computed with (I0,c0), (I1,c0),... (In,c0) pairs, but will lead to different discrimination
    # Maybe even sum of the two could be used
    #-------------------------------------------------------------------------------------------------------------#
    probMatchImg, sim_score = multimodal_cosine_sim_softmax(embImg, sent_emb, tparams, options.get('sim_smooth_factor',1.0))
    inp_list = [xWi]
    if xI_is_inp:
        inp_list.append(xI)

    if options.get('en_aux_inp',0):
        xAux = T.matrix('xAux', dtype=config.floatX)
        embAux = T.dot(xAux, tparams['WIemb_aux']) + tparams['b_Img_aux']
        xAuxEmb = dropout_layer(embAux, use_noise, trng, options['drop_prob_aux'], shp = embAux.shape)
        inp_list.append(xAux)
        probMatchAux, sim_scoreAux = multimodal_cosine_sim_softmax(embAux, sent_emb, tparams, options.get('sim_smooth_factor',1.0))
    else:
        probMatchAux = T.alloc(numpy_floatX(0.),1,1)

    probMatch = (probMatchImg + probMatchAux) / 2.

    sortedProb = T.argsort(probMatch,axis=1)

    batch_idces = T.arange(probMatch.shape[0])
    opponents = T.switch(T.eq(sortedProb[:,-1], batch_idces), sortedProb[:,-2], sortedProb[:,-1])

    violator_mask = (probMatch.diagonal() - probMatch[batch_idces,opponents]) < (options.get('cost_margin',0.02))

    n_violators = violator_mask.sum()

    if options.get('mode','batchtrain') == 'batchtrain':
        cost = [-((T.log(probMatch.diagonal())* (1+2.0*violator_mask)).sum())/probMatch.shape[0]]
    else:
        cost = [-(T.log(probMatch[0,posSamp]).sum())/posSamp.shape[0]]

    cost.append(n_violators)
    cost.append((probMatch.diagonal() - probMatch[batch_idces,opponents]))

    f_pred_sim_prob = theano.function(prior_inp_list + inp_list, [probMatchImg, probMatchAux, probMatch, opponents], name='f_pred_sim_prob')
    f_pred_sim_scr = theano.function(prior_inp_list + inp_list[:2], sim_score, name='f_pred_sim_scr')
    f_sent_emb = theano.function(inp_list[:1], cnn_out, name='f_sent_emb')

    if options.get('mode','batchtrain') != 'batchtrain':
        inp_list.append(posSamp)

    return use_noise, inp_list, [f_pred_sim_prob, f_pred_sim_scr, f_sent_emb], cost, sim_score, tparams

# ========================================================================================
 # BUILD CNN evaluator forward propogation model with taking direct inputs from lstm gen
  def build_advers_eval(self, tparams, options, gen_inp_list=None, gen_out=None, genUpdates = None, genLens = None):
    trng = RandomStreams()

    #n_words= xWRef.shape[1]

    zero_guy = T.alloc(numpy_floatX(0.),1,self.word_encoding_size)
    Word_Vecs = T.concatenate([zero_guy, tparams['Wemb']],axis=0)
    #Word_Vecs = tparams['Wemb']

    #Word_Vecs = tparams['Wemb']

    # These are of dimensions B x n_samp x time x Vocab
    if gen_out == None:
        discrim_inp = T.tensor4(name='disc_inp')
        inp_list = [discrim_inp]
        n_ref_samps = discrim_inp.shape[0]
    else:
        refData_inp = tensor.tensor4(name='disc_ref_inp')
        n_ref_samps = refData_inp.shape[0]
        n_words = refData_inp.shape[2]
        n_gen_words = gen_out.shape[2]
        z_shape = list(gen_out.shape)
        z_shape[2] = n_words - n_gen_words
        gen_pad = ifelse(tensor.gt(n_words, n_gen_words), tensor.concatenate([gen_out,
                                            tensor.zeros(z_shape)], axis=2), gen_out)
        discrim_inp = tensor.concatenate([refData_inp, gen_pad], axis=0)
        inp_list = [refData_inp]

    # Embed this input into size B x n_samp x time x word_vec_dim
    embW = T.dot(discrim_inp,Word_Vecs)

    #embGen = ifelse(tensor.gt(n_words, n_gen_words),tensor.concatenate([gen_out,theano.tensor.alloc(numpy_floatX(0.),n_words-n_gen_words,self.word_encoding_size)], axis=0),gen_out)
    #embGen = tensor.shape_padleft(embGen, n_ones=2)

    #embWRef = Words[xWRef.flatten()].reshape([options['eval_batch_size'], 1, n_words, self.word_encoding_size])
    #embW = tensor.concatenate([embWRef, embGen], axis=0)

    max_sent_len = options.get('maxlen',0)
    layer1_inputs = []
    for i,filter_h in enumerate(self.filter_hs):
        pool_size = (max_sent_len-filter_h+1,1)
        self.conv_layers[i].build(embW, poolsize = pool_size)
        # flatten all the filter outputs to a single vector
        cout = self.conv_layers[i].output.flatten(2)
        layer1_inputs.append(cout)

    layer1_input = T.concatenate(layer1_inputs,axis=1)

    # Now apply dropout on the cnn ouptut
    if options.get('use_dropout',0):
        cnn_out = dropout_layer(layer1_input, self.use_noise, trng, options['drop_prob_eval'],layer1_input.shape)
    else:
        cnn_out = layer1_input

    # Now transform this into a sent embedding
    sent_emb = T.dot(cnn_out, tparams['Wfc_sent'])# + tparams['bfc_sent']
    # Add a nonlinearity here
    #sent_emb = nonLinLayer(sent_emb, layer_type=options['conv_non_linear'])

    # Now to embed the image feature vector and calculate a similarity score
    if gen_out == None:
        xImg = T.matrix('xI', dtype=config.floatX)
    else:
        xImg = gen_inp_list[0]

    #Compute Image embedding:
    embImg = T.dot(xImg, tparams['WIemb'])# + tparams['b_Img']
    # Add a nonlinearity here
    #embImg = nonLinLayer(embImg, layer_type=options['conv_non_linear'])

    #if options.get('use_dropout',0):
    #    embImg = dropout_layer(embImg, self.use_noise, trng, options['drop_prob_eval'],embImg.shape)
    #else:
    #    embImg = embImg

    #m_img = l2norm(tensor.dot(embImg, tparams['Wm_img']))
    #m_sent = l2norm(tensor.dot(sent_emb, tparams['Wm_sent']))
    m_img = l2norm(embImg)
    m_sent = l2norm(sent_emb)

    #Now time to merge them
    #merge_out = m_img * m_sent + tparams['b_m']
    #merge_out = nonLinLayer(merge_out, layer_type=options['conv_non_linear'])

    scores = T.dot(m_img, m_sent.T)
    #merge_out = nonLinLayer(merge_out, layer_type='sigm')

    # Final output layer
    #p_out = nonLinLayer(tensor.dot(merge_out, tparams['W_out']), layer_type='sigm')
    p_out = (scores.diagonal())
    if gen_out !=None:
        p_out = T.concatenate([p_out, 0.5*(scores[:,n_ref_samps:].diagonal()+1.0)])
    #p_out = nonLinLayer(5.0*scores.diagonal(), layer_type='sigm').flatten()

    if gen_out !=None:
        for inp in gen_inp_list:
          if inp not in inp_list:
              inp_list.append(inp)
        print inp_list
    else:
        inp_list.append(xImg)

    xTarg = T.fvector('targ')
    inp_list.append(xTarg)
    #import pdb;pdb.set_trace()
    if options.get('eval_loss','contrastive')=='contrastive':
        #costEval, ic_s, ic_i = self.contrastive_loss(m_img, m_sent)
        probMatch = T.nnet.softmax(scores*2.0)
        costEval = -((T.log(probMatch[:,:n_ref_samps].diagonal())*xTarg).sum())
        if gen_out !=None:
            costGen = -((T.log(probMatch[:,n_ref_samps:].diagonal())).sum())
            # Also minimize the probability assigned to the generated fake samples
            #costEval += ((T.log(probMatch[:,n_ref_samps:].diagonal())).sum())
        else:
            costGen = []
        ic_s = probMatch
        ic_i = probMatch
    elif options.get('eval_loss','contrastive')=='wass':
        costEval = (scores[:,:n_ref_samps].diagonal()*xTarg).mean() - (scores[:,:n_ref_samps].diagonal()*(1.-xTarg)).mean()
        if gen_out !=None:
            costGen = -(scores[:,n_ref_samps:].diagonal()).mean()
            costEval += costGen
        costEval = -costEval
        ic_s = costEval
        ic_i = costEval

    #regularize
    if options.get('regc',0.) > 0.:
        self.reg_cost = theano.shared(numpy_floatX(0.), name='reg_c')
        reg_c = T.as_tensor_variable(numpy_floatX(options['regc']), name='reg_c')
        for p in self.regularize:
          self.reg_cost = self.reg_cost+(self.model_th[p] ** 2).sum()
          self.reg_cost *= 0.5 * reg_c
        costEval += (self.reg_cost /options['batch_size'])


    f_pred_cost = theano.function(inp_list, costEval, name='f_pred_sim_scr', updates=genUpdates)

    f_pred_sim_prob = theano.function(inp_list[:-1], [p_out], name='f_pred_sim_prob', updates=genUpdates)
    #f_pred_sim_prob = theano.function(inp_list, [p_out, sent_emb, m_img, m_sent, embW, ic_s, ic_i, self.reg_cost], name='f_pred_sim_prob')
    f_sent_emb = theano.function(inp_list[:-1], [m_sent, m_img, scores], name='f_sent_emb', updates=genUpdates)


    return inp_list, [f_pred_sim_prob, f_pred_cost, f_sent_emb], [costEval, costGen], p_out, tparams

  def contrastive_loss(self, im, s, margin=0.1):
      """
      Compute contrastive loss
      """
      # compute image-sentence score matrix
      scores = T.dot(im, s.T)
      diagonal = scores.diagonal()

      # compare every diagonal score to scores in its column (i.e, all contrastive images for each sentence)
      cost_s = T.maximum(0, margin - diagonal + scores)
      # compare every diagonal score to scores in its row (i.e, all contrastive sentences for each image)
      cost_im = T.maximum(0, margin - diagonal.reshape((-1, 1)) + scores)

      # clear diagonals
      cost_s = fill_diagonal(cost_s, 0.)
      cost_im = fill_diagonal(cost_im, 0.)

      return cost_s.sum() + cost_im.sum(), cost_s, cost_im

# ========================================================================================
  ####################################################################################
  # Defines the convolution layer on sentences.
  # -- Input is word embeddings stacked as a n_word * enc_size "image"
  # -- Filters are all of width equal to enc_size, height varies (3,4,5 grams etc.)
  # -- Also pooling is taking max over entire filter output, i.e each filter output
  #    is converted to a single number!
  # -- Output is stacking all the filter outputs to a single vector,
  #    sz = (batch-size,  n_filters)
  ####################################################################################
  def sent_conv_layer(self, tparams, options, embW, batch_size, use_noise, trng, n_samp=1):
    # Used for dropout.
    rng = np.random.RandomState()
    max_sent_len = options.get('maxlen',0)
    filter_shapes = []
    self.conv_layers = []
    pool_sizes = []
    filter_w = self.word_encoding_size
    layer1_inputs = []
    for filter_h in self.filter_hs:
        filter_shapes.append((self.n_fmaps_psz, n_samp, filter_h, filter_w))
        if max_sent_len > 0:
            image_shape = [batch_size, n_samp, max_sent_len, self.word_encoding_size]
        else:
            image_shape = None
        pool_sizes.append((max_sent_len-filter_h+1, self.word_encoding_size-filter_w+1))
        conv_layer = LeNetConvPoolLayer(rng, input= embW, image_shape= image_shape, filter_shape=filter_shapes[-1],
                                poolsize=pool_sizes[-1], non_linear=options['conv_non_linear'])
        # flatten all the filter outputs to a single vector
        cout = conv_layer.output.flatten(2)
        self.conv_layers.append(conv_layer)
        layer1_inputs.append(cout)
        self.updateP.update(conv_layer.params)
        self.regularize.extend(conv_layer.regularize)
        tparams.update(conv_layer.params)

    layer1_input = T.concatenate(layer1_inputs,axis=1)

    # Now apply dropout on the cnn ouptut
    if options.get('use_dropout',0):
        cnn_out = dropout_layer(layer1_input, use_noise, trng, options['drop_prob_cnn'],layer1_input.shape)
    else:
        cnn_out = layer1_input

    # Now transform this into a sent embedding
    sent_emb = T.dot(cnn_out,tparams['Wfc_sent']) + tparams['bfc_sent']

    return sent_emb, cnn_out, tparams
Exemple #48
0
def get_solutes(totals, ks_constants, ptargets):
    targets = OrderedDict((k, 10.0**-v) for k, v in ptargets.items())
    solutes = OrderedDict()
    totals = totals.copy()
    solutes.update(totals)
    solutes["H"] = h = targets["H"]  # H must always be a target (for now)
    if "F" in targets:
        solutes["F"] = f = targets["F"]
    else:
        if "F" in totals:
            solutes["F"] = f = get_F(h, totals, ks_constants)
        else:
            f = 0.0
    if "CO3" in targets:
        solutes["CO3"] = co3 = targets["CO3"]
    else:
        if "CO2" in totals:
            solutes["CO3"] = co3 = get_CO3(h, totals, ks_constants)
        else:
            co3 = 0.0
    if "PO4" in targets:
        solutes["PO4"] = po4 = targets["PO4"]
    else:
        if "PO4" in totals:
            solutes["PO4"] = po4 = get_PO4(h, totals, ks_constants)
        else:
            po4 = 0.0
    if "H2O" in ks_constants:
        solutes["OH"] = get_OH(h, ks_constants)
    if "SO4" in totals and "HSO4" in ks_constants:
        solutes["HSO4"] = get_HSO4(h, totals, ks_constants)
        solutes["SO4"] = get_SO4(h, totals, ks_constants)
    if "H2S" in totals and "H2S" in ks_constants:
        solutes["H2S"] = get_H2S(h, totals, ks_constants)
        solutes["HS"] = get_HS(h, totals, ks_constants)
    if "BOH3" in totals and "BOH3" in ks_constants:
        solutes["BOH3"] = get_BOH3(h, totals, ks_constants)
        solutes["BOH4"] = get_BOH4(h, totals, ks_constants)
    if "NH3" in totals and "NH4" in ks_constants:
        solutes["NH3"] = get_NH3(h, totals, ks_constants)
        solutes["NH4"] = get_NH4(h, totals, ks_constants)
    if "H4SiO4" in totals and "H4SiO4" in ks_constants:
        solutes["H3SiO4"] = get_H3SiO4(h, totals, ks_constants)
        solutes["H4SiO4"] = get_H4SiO4(h, totals, ks_constants)
    if "NO2" in totals and "HNO2" in ks_constants:
        solutes["HNO2"] = get_HNO2(h, totals, ks_constants)
        solutes["NO2"] = get_NO2(h, totals, ks_constants)
    if "F" in targets and "F" in totals:
        if "Ca" in totals and "CaF" in ks_constants:
            solutes["CaF"] = get_CaF(h, f, co3, po4, totals, ks_constants)
        if "Mg" in totals and "MgF" in ks_constants:
            solutes["MgF"] = get_MgF(h, f, co3, po4, totals, ks_constants)
    if "F" in totals and "HF" in ks_constants:
        solutes["HF"] = get_HF(h, f, ks_constants)
    if "CO3" in targets:
        if "CO2" in totals:
            if "Mg" in totals and "MgCO3" in ks_constants:
                solutes["MgCO3"] = get_MgCO3(h, f, co3, po4, totals,
                                             ks_constants)
            if "Ca" in totals and "CaCO3" in ks_constants:
                solutes["CaCO3"] = get_CaCO3(h, f, co3, po4, totals,
                                             ks_constants)
            if "Sr" in totals and "SrCO3" in ks_constants:
                solutes["Sr"] = get_Sr(co3, totals, ks_constants)
                solutes["SrCO3"] = get_SrCO3(co3, totals, ks_constants)
    if "CO2" in totals and "HCO3" in ks_constants:
        solutes["HCO3"] = get_HCO3(h, co3, ks_constants)
        if "H2CO3" in ks_constants:
            solutes["CO2"] = get_CO2(h, co3, ks_constants)
    if "PO4" in totals and "HPO4" in ks_constants:
        solutes["HPO4"] = get_HPO4(h, po4, ks_constants)
        if "H2PO4" in ks_constants:
            solutes["H2PO4"] = get_H2PO4(h, po4, ks_constants)
            if "H3PO4" in ks_constants:
                solutes["H3PO4"] = get_H3PO4(h, po4, ks_constants)
    if "PO4" in targets and "PO4" in totals:
        if "Mg" in totals:
            if "MgH2PO4" in ks_constants:
                solutes["MgH2PO4"] = get_MgH2PO4(h, f, co3, po4, totals,
                                                 ks_constants)
            if "MgHPO4" in ks_constants:
                solutes["MgHPO4"] = get_MgHPO4(h, f, co3, po4, totals,
                                               ks_constants)
            if "MgPO4" in ks_constants:
                solutes["MgPO4"] = get_MgPO4(h, f, co3, po4, totals,
                                             ks_constants)
        if "Ca" in totals:
            if "CaH2PO4" in ks_constants:
                solutes["CaH2PO4"] = get_CaH2PO4(h, f, co3, po4, totals,
                                                 ks_constants)
            if "CaHPO4" in ks_constants:
                solutes["CaHPO4"] = get_CaHPO4(h, f, co3, po4, totals,
                                               ks_constants)
            if "CaPO4" in ks_constants:
                solutes["CaPO4"] = get_CaPO4(h, f, co3, po4, totals,
                                             ks_constants)
    if "Mg" in totals:
        solutes["Mg"] = get_Mg(h, f, co3, po4, totals, ks_constants)
        if "MgOH" in ks_constants:
            solutes["MgOH"] = get_MgOH(h, f, co3, po4, totals, ks_constants)
    if "Ca" in totals:
        solutes["Ca"] = get_Ca(h, f, co3, po4, totals, ks_constants)
    return solutes
Exemple #49
0
    def encodeInvoker(self, invoker):
        '''
        Create an encode exploit for the invoker.
        
        @param invoker: Invoker
            The invoker to create a parameters encoder for.
        @return: callable(**data)
            The exploit that provides the invoker encoding.
        '''
        assert isinstance(invoker, Invoker), 'Invalid invoker %s' % invoker

        children, ordered = OrderedDict(), OrderedDict()
        for inp in invoker.inputs:
            assert isinstance(inp, Input)
            typeInp = inp.type
            assert isinstance(typeInp, Type)

            if typeInp.isPrimitive:
                children[inp.name] = self.encodePrimitive(
                    typeInp, getterOnDict(inp.name))

            elif isinstance(typeInp, TypeQuery):
                assert isinstance(typeInp, TypeQuery)

                childrenQuery, orderedQuery, getterQuery = OrderedDict(
                ), OrderedDict(), getterOnDict(inp.name)
                for nameEntry, classCriteria in typeInp.query.criterias.items(
                ):

                    getter = getterChain(
                        getterQuery,
                        getterOnObjIfIn(
                            nameEntry,
                            typeInp.criteriaEntryTypeFor(nameEntry)))
                    childrenQuery[nameEntry] = self.encodeCriteria(
                        typeFor(classCriteria), getter)

                    if issubclass(classCriteria, AsOrdered):
                        orderedQuery[nameEntry] = self.encodeGetOrder(
                            typeInp.criteriaEntryTypeFor(nameEntry),
                            getterQuery)

                isUpdated = False
                if invoker.output.isOf(typeInp.owner):
                    # If the query is a main query and also there is no name conflict then add the query children to
                    # the main children
                    if set(childrenQuery.keys()).isdisjoint(children.keys(
                    )) and set(orderedQuery).isdisjoint(ordered):
                        isUpdated = True
                        children.update(childrenQuery)
                        ordered.update(orderedQuery)

                if not isUpdated:
                    children[inp.name] = self.encodePath(childrenQuery)
                    ordered[inp.name] = self.encodePath(orderedQuery)

        exploitOrder = None
        if ordered:
            if self.nameOrderAsc in children:
                log.error('Name conflict for \'%s\' in %s', self.nameOrderAsc,
                          invoker)
            elif self.nameOrderDesc in children:
                log.error('Name conflict for \'%s\' in %s', self.nameOrderDesc,
                          invoker)
            else:
                exploitOrder = self.encodeOrder(self.encodePath(ordered))
        exploitPath = self.encodePath(children)

        def exploit(**data):
            target = deque()
            data.update(target=target)
            exploitPath(**data)
            if exploitOrder: exploitOrder(**data)
            return target

        return exploit
Exemple #50
0
class TD3Trainer(TorchTrainer):
    """
    Twin Delayed Deep Deterministic policy gradients
    """

    def __init__(
            self,
            policy,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            target_policy,
            target_policy_noise=0.2,
            target_policy_noise_clip=0.5,

            discount=0.99,
            reward_scale=1.0,

            policy_learning_rate=1e-3,
            qf_learning_rate=1e-3,
            policy_and_target_update_period=2,
            tau=0.005,
            qf_criterion=None,
            optimizer_class=optim.Adam,
    ):
        if qf_criterion is None:
            qf_criterion = nn.MSELoss()
        self.qf1 = qf1
        self.qf2 = qf2
        self.policy = policy
        self.target_policy = target_policy
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.target_policy_noise = target_policy_noise
        self.target_policy_noise_clip = target_policy_noise_clip

        self.discount = discount
        self.reward_scale = reward_scale

        self.policy_and_target_update_period = policy_and_target_update_period
        self.tau = tau
        self.qf_criterion = qf_criterion

        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_learning_rate,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_learning_rate,
        )
        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_learning_rate,
        )

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

    def train(self, np_batch):
        batch = np_to_pytorch_batch(np_batch)
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = ptu.randn(next_actions.shape) * self.target_policy_noise
        noise = torch.clamp(
            noise,
            -self.target_policy_noise_clip,
            self.target_policy_noise_clip
        )
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target) ** 2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target) ** 2
        qf2_loss = bellman_errors_2.mean()

        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = - q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = - q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 1',
                ptu.get_numpy(bellman_errors_1),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 2',
                ptu.get_numpy(bellman_errors_2),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy Action',
                ptu.get_numpy(policy_actions),
            ))
        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.policy,
            self.qf1,
            self.qf2,
            self.target_policy,
            self.target_qf1,
            self.target_qf2,
        ]

    def get_snapshot(self):
        return dict(
            qf1=self.qf1,
            qf2=self.qf2,
            trained_policy=self.policy,
            target_policy=self.target_policy,
        )
    def evaluate(self,
                 step: int,
                 summary_writer,
                 prefix='validation ',
                 num_cached=5,
                 num_threads=3,
                 rnd_gen=None,
                 plotter=None,
                 model=None):
        # Reset streaming measures
        self.reset_tensors()

        # Get tensors to evaluate for plotting
        if plotter is not None:
            plot_tensors = plotter.get_tensors()

        # Set up progress bar
        _pbw = ['Evaluating on {}set:'.format(prefix), progressbar.ETA()]
        progress = progressbar.ProgressBar(widgets=_pbw,
                                           maxval=self.dataset.n_mbs -
                                           1).start()

        #
        # Iterate over dataset minibatches
        #
        mb_validation = self.dataset.batch_loader(num_cached=num_cached,
                                                  num_threads=num_threads,
                                                  rnd_gen=rnd_gen)
        with Timer(verbose=True, name="Evaluate on {}set".format(prefix)):
            summary_values_filled = None

            for mb_i, mb in enumerate(mb_validation):

                # Abort if indicated by file
                check_kill_file(self.workspace)

                if mb.get('pixel_weights', None) is None:
                    feed_dict = {self.model.X: mb['X'], self.model.y_: mb['y']}
                else:
                    feed_dict = {
                        self.model.X: mb['X'],
                        self.model.y_: mb['y'],
                        self.model.pixel_weights: mb['pixel_weights']
                    }

                if plotter is not None:
                    evaluated_tensors = self.session.run([
                        *self.summary_ops, *self.summary_tensors, *plot_tensors
                    ],
                                                         feed_dict=feed_dict)
                else:
                    evaluated_tensors = self.session.run(
                        [*self.summary_ops, *self.summary_tensors],
                        feed_dict=feed_dict)

                # Discard return values from summary_ops (=update operations)
                evaluated_tensors = evaluated_tensors[len(self.summary_ops):]
                summary_values = evaluated_tensors[:len(self.summary_tensors)]

                # Perform plotting
                if plotter is not None:
                    plotter.set_tensor_values(
                        evaluated_tensors[len(self.summary_tensors
                                              ):len(self.plot_tensors) +
                                          len(plot_tensors)])
                    plotter.plot(evaluate_tensors=False)

                # Re-associate returned tensorflow values to keys and incorporate new minibatch values
                if summary_values_filled is None:
                    # Fill summary_values_filled for the first time
                    summary_values_filled = OrderedDict(
                        zip(list(self.summary_tensor_dict.keys()),
                            summary_values))
                    for key_i, key in enumerate(summary_values_filled.keys()):
                        if not self.summary_tensor_is_op[key_i]:
                            if isinstance(summary_values_filled[key],
                                          np.ndarray):
                                summary_values_filled[key] = [
                                    summary_values_filled[key]
                                ]
                            elif np.isfinite(summary_values_filled[key]):
                                summary_values_filled[key] = [
                                    summary_values_filled[key]
                                ]
                            else:
                                summary_values_filled[key] = []
                else:
                    for key_i, key in enumerate(summary_values_filled.keys()):
                        if not self.summary_tensor_is_op[key_i]:
                            if isinstance(summary_values[key_i], np.ndarray):
                                summary_values_filled[key].append(
                                    summary_values[key_i])
                            elif np.isfinite(summary_values[key_i]):
                                summary_values_filled[key].append(
                                    summary_values[key_i])
                        else:
                            summary_values_filled[key] = summary_values[key_i]

                # Update progress bar and clear minibatch
                progress.update(mb_i)
                mb.clear()
                del mb

        progress.finish()

        #
        # Divide sums by number of samples for tensors that do not have an update function
        #
        if len(summary_values_filled):
            for key_i, key in enumerate(summary_values_filled.keys()):
                if not self.summary_tensor_is_op[key_i]:
                    if len(summary_values_filled[key]):
                        if not isinstance(summary_values_filled[key][0],
                                          np.ndarray):
                            summary_values_filled[key] = np.mean(
                                summary_values_filled[key])
                        else:
                            summary_values_filled[key] = np.concatenate(
                                summary_values_filled[key])
                    else:
                        summary_values_filled[key] = np.nan

        #
        # Go through values to use as summaries, create histograms if values are not scalars
        #
        values_to_print = OrderedDict()
        if len(summary_values_filled):
            for key_i, key in enumerate(summary_values_filled.keys()):
                if not isinstance(summary_values_filled[key], np.ndarray):
                    values_to_print.update({key: summary_values_filled[key]})
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag=prefix + key,
                                         simple_value=float(
                                             summary_values_filled[key]))
                    ])
                else:
                    hist = custom_tensorflow_histogram(
                        summary_values_filled[key], bins=100)
                    summary = tf.Summary(
                        value=[tf.Summary.Value(tag=prefix + key, histo=hist)])

                summary_writer.add_summary(summary, step)

        print("{}scores:\n\tstep {}, {}".format(prefix, step, values_to_print))
        summary_writer.flush()
        sys.stdout.flush()
Exemple #52
0
class CQLTrainer(TorchTrainer):
    def __init__(
        self,
        env,
        policy,
        qf1,
        qf2,
        target_qf1,
        target_qf2,
        discount=0.99,
        reward_scale=1.0,
        policy_lr=1e-3,
        qf_lr=1e-3,
        optimizer_class=optim.Adam,
        soft_target_tau=1e-2,
        plotter=None,
        render_eval_paths=False,
        use_automatic_entropy_tuning=True,
        target_entropy=None,
        policy_eval_start=0,
        num_qs=2,

        # CQL
        min_q_version=3,
        temp=1.0,
        min_q_weight=1.0,

        ## sort of backup
        max_q_backup=False,
        deterministic_backup=True,
        num_random=10,
        with_lagrange=False,
        lagrange_thresh=0.0,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.soft_target_tau = soft_target_tau

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(
                    self.env.action_space.shape).item()
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.with_lagrange = with_lagrange
        if self.with_lagrange:
            self.target_action_gap = lagrange_thresh
            self.log_alpha_prime = ptu.zeros(1, requires_grad=True)
            self.alpha_prime_optimizer = optimizer_class(
                [self.log_alpha_prime],
                lr=qf_lr,
            )

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.policy_eval_start = policy_eval_start

        self._current_epoch = 0
        self._policy_update_ctr = 0
        self._num_q_update_steps = 0
        self._num_policy_update_steps = 0
        self._num_policy_steps = 1

        self.num_qs = num_qs

        ## min Q
        self.temp = temp
        self.min_q_version = min_q_version
        self.min_q_weight = min_q_weight

        self.softmax = torch.nn.Softmax(dim=1)
        self.softplus = torch.nn.Softplus(beta=self.temp, threshold=20)

        self.max_q_backup = max_q_backup
        self.deterministic_backup = deterministic_backup
        self.num_random = num_random

        # For implementation on the
        self.discrete = False

    def _get_tensor_values(self, obs, actions, network=None):
        action_shape = actions.shape[0]
        obs_shape = obs.shape[0]
        num_repeat = int(action_shape / obs_shape)
        obs_temp = obs.unsqueeze(1).repeat(1, num_repeat,
                                           1).view(obs.shape[0] * num_repeat,
                                                   obs.shape[1])
        preds = network(obs_temp, actions)
        preds = preds.view(obs.shape[0], num_repeat, 1)
        return preds

    def _get_policy_actions(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions,
                                           1).view(obs.shape[0] * num_actions,
                                                   obs.shape[1])
        new_obs_actions, _, _, new_obs_log_pi, *_ = network(
            obs_temp,
            reparameterize=True,
            return_log_prob=True,
        )
        if not self.discrete:
            return new_obs_actions, new_obs_log_pi.view(
                obs.shape[0], num_actions, 1)
        else:
            return new_obs_actions

    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        if self.num_qs == 1:
            q_new_actions = self.qf1(obs, new_obs_actions)
        else:
            q_new_actions = torch.min(
                self.qf1(obs, new_obs_actions),
                self.qf2(obs, new_obs_actions),
            )

        policy_loss = (alpha * log_pi - q_new_actions).mean()

        if self._current_epoch < self.policy_eval_start:
            """
            For the initial few epochs, try doing behaivoral cloning, if needed
            conventionally, there's not much difference in performance with having 20k 
            gradient steps here, or not having it
            """
            policy_log_prob = self.policy.log_prob(obs, actions)
            policy_loss = (alpha * log_pi - policy_log_prob).mean()
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        if self.num_qs > 1:
            q2_pred = self.qf2(obs, actions)

        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs,
            reparameterize=True,
            return_log_prob=True,
        )
        new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )

        if not self.max_q_backup:
            if self.num_qs == 1:
                target_q_values = self.target_qf1(next_obs, new_next_actions)
            else:
                target_q_values = torch.min(
                    self.target_qf1(next_obs, new_next_actions),
                    self.target_qf2(next_obs, new_next_actions),
                )

            if not self.deterministic_backup:
                target_q_values = target_q_values - alpha * new_log_pi

        if self.max_q_backup:
            """when using max q backup"""
            next_actions_temp, _ = self._get_policy_actions(
                next_obs, num_actions=10, network=self.policy)
            target_qf1_values = self._get_tensor_values(
                next_obs, next_actions_temp,
                network=self.target_qf1).max(1)[0].view(-1, 1)
            target_qf2_values = self._get_tensor_values(
                next_obs, next_actions_temp,
                network=self.target_qf2).max(1)[0].view(-1, 1)
            target_q_values = torch.min(target_qf1_values, target_qf2_values)

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        qf1_loss = self.qf_criterion(q1_pred, q_target)
        if self.num_qs > 1:
            qf2_loss = self.qf_criterion(q2_pred, q_target)

        ## add CQL
        random_actions_tensor = ptu.uniform(
            (q2_pred.shape[0] * self.num_random, actions.shape[-1]))
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(
            obs, num_actions=self.num_random, network=self.policy)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(
            next_obs, num_actions=self.num_random, network=self.policy)
        q1_rand = self._get_tensor_values(obs,
                                          random_actions_tensor,
                                          network=self.qf1)
        q2_rand = self._get_tensor_values(obs,
                                          random_actions_tensor,
                                          network=self.qf2)
        q1_curr_actions = self._get_tensor_values(obs,
                                                  curr_actions_tensor,
                                                  network=self.qf1)
        q2_curr_actions = self._get_tensor_values(obs,
                                                  curr_actions_tensor,
                                                  network=self.qf2)
        q1_next_actions = self._get_tensor_values(obs,
                                                  new_curr_actions_tensor,
                                                  network=self.qf1)
        q2_next_actions = self._get_tensor_values(obs,
                                                  new_curr_actions_tensor,
                                                  network=self.qf2)

        cat_q1 = torch.cat(
            [q1_rand,
             q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1)
        cat_q2 = torch.cat(
            [q2_rand,
             q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1)
        std_q1 = torch.std(cat_q1, dim=1)
        std_q2 = torch.std(cat_q2, dim=1)

        if self.min_q_version == 3:
            # importance sammpled version
            random_density = np.log(0.5**curr_actions_tensor.shape[-1])
            cat_q1 = torch.cat([
                q1_rand - random_density, q1_next_actions -
                new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach()
            ], 1)
            cat_q2 = torch.cat([
                q2_rand - random_density, q2_next_actions -
                new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach()
            ], 1)

        min_qf1_loss = torch.logsumexp(
            cat_q1 / self.temp,
            dim=1,
        ).mean() * self.min_q_weight * self.temp
        min_qf2_loss = torch.logsumexp(
            cat_q2 / self.temp,
            dim=1,
        ).mean() * self.min_q_weight * self.temp
        """Subtract the log likelihood of data"""
        min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.min_q_weight
        min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.min_q_weight

        if self.with_lagrange:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(),
                                      min=0.0,
                                      max=1000000.0)
            min_qf1_loss = alpha_prime * (min_qf1_loss -
                                          self.target_action_gap)
            min_qf2_loss = alpha_prime * (min_qf2_loss -
                                          self.target_action_gap)

            self.alpha_prime_optimizer.zero_grad()
            alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_optimizer.step()

        qf1_loss = qf1_loss + min_qf1_loss
        qf2_loss = qf2_loss + min_qf2_loss
        """
        Update networks
        """
        # Update the Q-functions iff
        self._num_q_update_steps += 1
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.qf1_optimizer.step()

        if self.num_qs > 1:
            self.qf2_optimizer.zero_grad()
            qf2_loss.backward(retain_graph=True)
            self.qf2_optimizer.step()

        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                self.soft_target_tau)
        if self.num_qs > 1:
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['min QF1 Loss'] = np.mean(
                ptu.get_numpy(min_qf1_loss))
            if self.num_qs > 1:
                self.eval_statistics['QF2 Loss'] = np.mean(
                    ptu.get_numpy(qf2_loss))
                self.eval_statistics['min QF2 Loss'] = np.mean(
                    ptu.get_numpy(min_qf2_loss))

            if not self.discrete:
                self.eval_statistics['Std QF1 values'] = np.mean(
                    ptu.get_numpy(std_q1))
                self.eval_statistics['Std QF2 values'] = np.mean(
                    ptu.get_numpy(std_q2))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 in-distribution values',
                        ptu.get_numpy(q1_curr_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 in-distribution values',
                        ptu.get_numpy(q2_curr_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 random values',
                        ptu.get_numpy(q1_rand),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 random values',
                        ptu.get_numpy(q2_rand),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 next_actions values',
                        ptu.get_numpy(q1_next_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 next_actions values',
                        ptu.get_numpy(q2_next_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict('actions',
                                              ptu.get_numpy(actions)))
                self.eval_statistics.update(
                    create_stats_ordered_dict('rewards',
                                              ptu.get_numpy(rewards)))

            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics[
                'Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            if self.num_qs > 1:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Q2 Predictions',
                        ptu.get_numpy(q2_pred),
                    ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            if not self.discrete:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy mu',
                        ptu.get_numpy(policy_mean),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy log std',
                        ptu.get_numpy(policy_log_std),
                    ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.with_lagrange:
                self.eval_statistics['Alpha_prime'] = alpha_prime.item()
                self.eval_statistics['min_q1_loss'] = ptu.get_numpy(
                    min_qf1_loss).mean()
                self.eval_statistics['min_q2_loss'] = ptu.get_numpy(
                    min_qf2_loss).mean()
                self.eval_statistics[
                    'threshold action gap'] = self.target_action_gap
                self.eval_statistics[
                    'alpha prime loss'] = alpha_prime_loss.item()

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        base_list = [
            self.policy,
            self.qf1,
            self.qf2,
            self.target_qf1,
            self.target_qf2,
        ]
        return base_list

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            qf1=self.qf1,
            qf2=self.qf2,
            target_qf1=self.target_qf1,
            target_qf2=self.target_qf2,
        )
    def perform_logging(self, itr, paths, eval_policy, train_video_paths,
                        loss):

        # collect eval trajectories, for logging
        print("\nCollecting data for eval...")
        eval_paths, eval_envsteps_this_batch = sample_trajectories(
            self.env, eval_policy, self.params['eval_batch_size'],
            self.params['ep_len'])

        # save eval rollouts as videos in tensorboard event file
        if self.log_video and train_video_paths != None:
            print('\nCollecting video rollouts eval')
            eval_video_paths = sample_n_trajectories(self.env, eval_policy,
                                                     MAX_NVIDEO, MAX_VIDEO_LEN,
                                                     True)

            #save train/eval videos
            print('\nSaving train rollouts as videos...')
            self.logger.log_paths_as_videos(train_video_paths,
                                            itr,
                                            fps=self.fps,
                                            max_videos_to_save=MAX_NVIDEO,
                                            video_title='train_rollouts')
            self.logger.log_paths_as_videos(eval_video_paths,
                                            itr,
                                            fps=self.fps,
                                            max_videos_to_save=MAX_NVIDEO,
                                            video_title='eval_rollouts')

        # save eval metrics
        if self.log_metrics:
            # returns, for logging
            train_returns = [path["reward"].sum() for path in paths]
            eval_returns = [
                eval_path["reward"].sum() for eval_path in eval_paths
            ]

            # episode lengths, for logging
            train_ep_lens = [len(path["reward"]) for path in paths]
            eval_ep_lens = [
                len(eval_path["reward"]) for eval_path in eval_paths
            ]

            # decide what to log
            logs = OrderedDict()
            logs["Eval_AverageReturn"] = np.mean(eval_returns)
            logs["Eval_StdReturn"] = np.std(eval_returns)
            logs["Eval_MaxReturn"] = np.max(eval_returns)
            logs["Eval_MinReturn"] = np.min(eval_returns)
            logs["Eval_AverageEpLen"] = np.mean(eval_ep_lens)

            logs["Train_AverageReturn"] = np.mean(train_returns)
            logs["Train_StdReturn"] = np.std(train_returns)
            logs["Train_MaxReturn"] = np.max(train_returns)
            logs["Train_MinReturn"] = np.min(train_returns)
            logs["Train_AverageEpLen"] = np.mean(train_ep_lens)

            logs["Train_EnvstepsSoFar"] = self.total_envsteps
            logs["TimeSinceStart"] = time.time() - self.start_time
            if isinstance(loss, dict):
                logs.update(loss)
            else:
                logs["Training loss"] = loss

            if itr == 0:
                self.initial_return = np.mean(train_returns)
            logs["Initial_DataCollection_AverageReturn"] = self.initial_return

            # perform the logging
            for key, value in logs.items():
                print('{} : {}'.format(key, value))
                self.logger.log_scalar(value, key, itr)
            print('Done logging...\n\n')

            # dumping
            metric_file = os.path.join(self.params['logdir'],
                                       f'metrics_{itr}.json')
            for k, v in logs.items():
                logs[k] = float(v)
            with open(metric_file, 'w') as fw:
                import json
                json.dump(logs, fw)
            self.logger.flush()
class VltValidation(object):
    def __init__(self):
        self.module = AnsibleModule(argument_spec=self.get_fields())
        self.show_vlt = self.module.params['show_vlt']
        self.show_system_network_summary = self.module.params[
            'show_system_network_summary']
        self.intended_vlt_pairs = self.module.params['intended_vlt_pairs']
        self.exit_msg = OrderedDict()

    def get_fields(self):
        spec_fields = {
            'show_vlt': {
                'type': 'list',
                'required': True
            },
            'show_system_network_summary': {
                'type': 'list',
                'required': True
            },
            'intended_vlt_pairs': {
                'type': 'list',
                'required': True
            }
        }
        return spec_fields

    # get switch inv name from mac
    def get_switch_inv_name_from_mac(self, mac):
        inv_name = None
        for show_system in self.show_system_network_summary:
            if (str.lower(show_system["node-mac"])) == (str.lower(mac)):
                inv_name = show_system.get("inv_name")
                break
        return inv_name

    def validate_vlt_pairs(self, actual_vlt_dict):
        final_out = list()
        intended_vlt_list = self.intended_vlt_pairs
        for intended_vlt in intended_vlt_list:
            intended_primary = intended_vlt.get("primary")
            intended_secondary = intended_vlt.get("secondary")
            actual_vlt = actual_vlt_dict.get(intended_primary)
            temp_dict = {}
            if actual_vlt is not None:
                actual_secondary = actual_vlt.get("secondary")
                secondary_status = actual_vlt.get("secondary_status")
                if actual_secondary is not None and intended_secondary != actual_secondary:
                    temp_dict["error_type"] = "secondary_mismatch"
                    temp_dict["intended_primary"] = intended_primary
                    temp_dict["intended_secondary"] = intended_secondary
                    temp_dict["secondary"] = actual_secondary
                    reason = "config mismatch as {0} is expected, but the actual secondary is {1} ".format(
                        intended_secondary, actual_secondary)
                    temp_dict["possible_reason"] = reason
                    final_out.append(temp_dict)
                else:
                    if actual_secondary is None:
                        temp_dict["intended_primary"] = intended_primary
                        temp_dict["intended_secondary"] = intended_secondary
                        temp_dict["error_type"] = "peer_missing"
                        reason = "peer info is not configured or peer interface is down"
                        temp_dict["possible_reason"] = reason
                        final_out.append(temp_dict)
                    elif intended_secondary == actual_secondary and secondary_status != "up":
                        temp_dict["intended_primary"] = intended_primary
                        temp_dict["intended_secondary"] = intended_secondary
                        temp_dict["secondary"] = actual_secondary
                        temp_dict["error_type"] = "peer_down"
                        reason = "peer interface is down"
                        temp_dict["possible_reason"] = reason
                        final_out.append(temp_dict)
            else:
                temp_dict["intended_primary"] = intended_primary
                temp_dict["intended_secondary"] = intended_secondary
                temp_dict["error_type"] = "vlt_config_missing"
                temp_dict["possible_reason"] = "vlt is not configured"
                final_out.append(temp_dict)
        return final_out

    def parse_vlt_output(self):
        show_vlt_dict = {}
        for show_list in self.show_vlt:
            source_switch = None
            item = show_list.get("item")
            if item is not None:
                inv_info = item.get("inv_name")
                source_switch = inv_info.get("inv_name")
            msg = show_list.get("msg")
            if msg is not None:
                result = msg.get("result")
                for sub_result in result:
                    vlt_dict = {}
                    rpc_reply = sub_result.get("rpc-reply")
                    data = rpc_reply.get("data")
                    if data is not None:
                        topo_oper_data = data.get("topology-oper-data")
                        if topo_oper_data is not None:
                            vlt_domain = topo_oper_data.get("vlt-domain")
                            if vlt_domain is not None:
                                local_info = vlt_domain.get("local-info")
                                if local_info is not None:
                                    local_role = local_info.get("role")
                                    vlt_dict[local_role] = source_switch
                                    local_mac = local_info.get("system-mac")
                                    vlt_dict[local_role + "_mac"] = local_mac
                                peer_info = vlt_domain.get("peer-info")
                                if peer_info is not None:
                                    peer_mac = peer_info.get("system-mac")
                                    peer_switch = self.get_switch_inv_name_from_mac(
                                        peer_mac)
                                    peer_role = peer_info.get("role")
                                    vlt_dict[peer_role] = peer_switch
                                    vlt_dict[peer_role + "_mac"] = peer_mac
                                    peer_status = peer_info.get("peer-status")
                                    vlt_dict[peer_role +
                                             "_status"] = peer_status
                        if bool(vlt_dict):
                            primary_switch = vlt_dict.get("primary")
                            vlt_data = show_vlt_dict.get(primary_switch)
                            if vlt_data is None:
                                # update database specific to primary, it helps
                                # to avoid to skip duplicate data
                                show_vlt_dict[primary_switch] = vlt_dict
        return show_vlt_dict

    def perform_action(self):
        try:
            actual_vlt_dict = self.parse_vlt_output()
            final_out = self.validate_vlt_pairs(actual_vlt_dict)
            self.exit_msg.update({"results": final_out})
            self.module.exit_json(changed=False, msg=self.exit_msg)
        except Exception as e:
            self.module.fail_json(msg=to_native(e),
                                  exception=traceback.format_exc())
 def _initialize_current_observations(self):
     local_observations = OrderedDict({})
     for entity in self.observations_by_entity:
         local_observations.update(self._record_entity_observations(entity))
     return local_observations
Exemple #56
0
    def _canonical_key(cls, args, kwargs):
        extra_dims = cls.extra_dims
        dimensions_set = set(extra_dims)
        if not set(kwargs) <= dimensions_set:
            extra = sorted(set(kwargs) - dimensions_set)
            raise TypeError(
                "%s does not have the following %s: %s\n"
                "Valid dimensions are: %s" % (
                    cls.__name__,
                    s("dimension", extra),
                    ", ".join(extra),
                    ", ".join(extra_dims),
                ), )

        if len(args) > len(extra_dims):
            raise TypeError(
                "%s has %d extra %s but %d %s given" % (
                    cls.__name__,
                    len(extra_dims),
                    s("dimension", extra_dims),
                    len(args),
                    plural("was", "were", args),
                ), )

        missing = object()
        coords = OrderedDict(zip(extra_dims, repeat(missing)))
        to_add = dict(zip(extra_dims, args))
        coords.update(to_add)
        added = set(to_add)

        for key, value in kwargs.items():
            if key in added:
                raise TypeError(
                    "%s got multiple values for dimension %r" % (
                        cls.__name__,
                        coords,
                    ), )
            coords[key] = value
            added.add(key)

        missing = {k for k, v in coords.items() if v is missing}
        if missing:
            missing = sorted(missing)
            raise TypeError(
                "no coordinate provided to %s for the following %s: %s" % (
                    cls.__name__,
                    s("dimension", missing),
                    ", ".join(missing),
                ), )

        # validate that all of the provided values exist along their given
        # dimensions
        for key, value in coords.items():
            if value not in cls.extra_dims[key]:
                raise ValueError(
                    "%r is not a value along the %s dimension of %s" % (
                        value,
                        key,
                        cls.__name__,
                    ), )

        return coords, tuple(coords.items())
Exemple #57
0
async def async_process_ha_core_config(hass: HomeAssistant, config: Dict) -> None:
    """Process the [homeassistant] section from the configuration.

    This method is a coroutine.
    """
    config = CORE_CONFIG_SCHEMA(config)

    # Only load auth during startup.
    if not hasattr(hass, "auth"):
        auth_conf = config.get(CONF_AUTH_PROVIDERS)

        if auth_conf is None:
            auth_conf = [{"type": "homeassistant"}]

        mfa_conf = config.get(
            CONF_AUTH_MFA_MODULES,
            [{"type": "totp", "id": "totp", "name": "Authenticator app"}],
        )

        setattr(
            hass, "auth", await auth.auth_manager_from_config(hass, auth_conf, mfa_conf)
        )

    await hass.config.async_load()

    hac = hass.config

    if any(
        k in config
        for k in [
            CONF_LATITUDE,
            CONF_LONGITUDE,
            CONF_NAME,
            CONF_ELEVATION,
            CONF_TIME_ZONE,
            CONF_UNIT_SYSTEM,
        ]
    ):
        hac.config_source = SOURCE_YAML

    for key, attr in (
        (CONF_LATITUDE, "latitude"),
        (CONF_LONGITUDE, "longitude"),
        (CONF_NAME, "location_name"),
        (CONF_ELEVATION, "elevation"),
    ):
        if key in config:
            setattr(hac, attr, config[key])

    if CONF_TIME_ZONE in config:
        hac.set_time_zone(config[CONF_TIME_ZONE])

    # Init whitelist external dir
    hac.whitelist_external_dirs = {hass.config.path("www")}
    if CONF_WHITELIST_EXTERNAL_DIRS in config:
        hac.whitelist_external_dirs.update(set(config[CONF_WHITELIST_EXTERNAL_DIRS]))

    # Customize
    cust_exact = dict(config[CONF_CUSTOMIZE])
    cust_domain = dict(config[CONF_CUSTOMIZE_DOMAIN])
    cust_glob = OrderedDict(config[CONF_CUSTOMIZE_GLOB])

    for name, pkg in config[CONF_PACKAGES].items():
        pkg_cust = pkg.get(CONF_CORE)

        if pkg_cust is None:
            continue

        try:
            pkg_cust = CUSTOMIZE_CONFIG_SCHEMA(pkg_cust)
        except vol.Invalid:
            _LOGGER.warning("Package %s contains invalid customize", name)
            continue

        cust_exact.update(pkg_cust[CONF_CUSTOMIZE])
        cust_domain.update(pkg_cust[CONF_CUSTOMIZE_DOMAIN])
        cust_glob.update(pkg_cust[CONF_CUSTOMIZE_GLOB])

    hass.data[DATA_CUSTOMIZE] = EntityValues(cust_exact, cust_domain, cust_glob)

    if CONF_UNIT_SYSTEM in config:
        if config[CONF_UNIT_SYSTEM] == CONF_UNIT_SYSTEM_IMPERIAL:
            hac.units = IMPERIAL_SYSTEM
        else:
            hac.units = METRIC_SYSTEM
    elif CONF_TEMPERATURE_UNIT in config:
        unit = config[CONF_TEMPERATURE_UNIT]
        if unit == TEMP_CELSIUS:
            hac.units = METRIC_SYSTEM
        else:
            hac.units = IMPERIAL_SYSTEM
        _LOGGER.warning(
            "Found deprecated temperature unit in core "
            "configuration expected unit system. Replace '%s: %s' "
            "with '%s: %s'",
            CONF_TEMPERATURE_UNIT,
            unit,
            CONF_UNIT_SYSTEM,
            hac.units.name,
        )
Exemple #58
0
class Forall(object):
    """
    Models a single instance of forall quantification.
    Defines a method called Forall that can be used multiple times to construct
    a sequence of quantifiers.
    This class should always hold the current quantification structure.
    """
    def __init__(self, is_first=True, bind_variables=None, **bind_variable):
        """
        Given a bind variable (bind_variable_name is the variable name,
        and bind_variable_value is either StaticState or StaticTransition),
        check that, if is_first is true, the bind variable is independent.
        """

        # note: this is a quick fix, but needs to be modified
        # since dictionaries don't guarantee order
        bind_variable_name = list(bind_variable.keys())[0]
        #bind_variable_obj = bind_variable.values()[0]
        bind_variable_obj = bind_variable[bind_variable_name]

        self.bind_variables = bind_variables
        """
        resolve the bind variable on which this one depends
        this consists of using the current bind variable name
        to reference the actual bind variable value in the bind_variables dictionary
        """
        if not (is_first):
            bind_variable_obj._required_binding = \
                self.bind_variables[bind_variable_obj._required_binding]

        bind_variable_final = bind_variable_obj.complete_instantiation(
            bind_variable_name)
        if self.bind_variables is None:
            self.bind_variables = OrderedDict(
                {bind_variable_name: bind_variable_final})
        else:
            self.bind_variables.update(
                {bind_variable_name: bind_variable_final})

        self._bind_variables = self.bind_variables.values()

        # defined by calling Formula
        self._formula = None

        # this is set to True when self.Formula() is called
        # it's used to decide whether arithmetic operations should be
        # added to atoms' stacks or not, to prevent double evaluation
        self._instantiation_complete = False

    def __repr__(self):
        if self._formula is None:
            return "Forall(%s)" % self.bind_variables
        else:
            return "Forall(%s).Check(%s)" % \
                   (self.bind_variables, self.get_formula_instance())

    def Forall(self, **bind_variable):
        # return an instance
        return Forall(is_first=False,
                      bind_variables=self.bind_variables,
                      **bind_variable)

    def get(self, key):
        return self.bind_variables[key]

    # syntactic sugar
    def Check(self, formula_lambda):
        return self.Formula(formula_lambda)

    def Formula(self, formula_lambda):
        """
        Store the formula lambda, which itself returns a formula when given
        bind variables, for later use.
        """
        self._formula = formula_lambda
        # generate instantiated formula to compute its atoms
        self._formula_atoms = \
            formula_tree.get_positive_formula_alphabet(self.get_formula_instance(
                first_time=True
            ))
        self._instantiation_complete = True
        return self

    def get_formula_instance(self, first_time=False):
        """
        Instantiate the formula using the lambda stored.
        """
        # use the arguments of the lambda function
        argument_names = inspect.getargspec(self._formula).args
        bind_variables = list(
            map(lambda arg_name: self.bind_variables[arg_name],
                argument_names))
        if first_time:
            # enable "_arithmetic_build" flag in bind variables
            # so arithmetic operations are added
            for bind_variable in bind_variables:
                bind_variable._arithmetic_build = True
        formula = self._formula(*bind_variables)
        # switch off arithmetic build flags
        if first_time:
            for bind_variable in bind_variables:
                bind_variable._arithmetic_build = False
        return formula
class DotMap(MutableMapping, OrderedDict):
    def __init__(self, *args, **kwargs):
        self._map = OrderedDict()
        self._dynamic = True
        if kwargs:
            if '_dynamic' in kwargs:
                self._dynamic = kwargs['_dynamic']
        if args:
            d = args[0]
            # for recursive assignment handling
            trackedIDs = {id(d): self}
            if isinstance(d, dict):
                for k, v in self.__call_items(d):
                    if isinstance(v, dict):
                        if id(v) in trackedIDs:
                            v = trackedIDs[id(v)]
                        else:
                            v = self.__class__(v, _dynamic=self._dynamic)
                            trackedIDs[id(v)] = v
                    if type(v) is list:
                        l = []
                        for i in v:
                            n = i
                            if isinstance(i, dict):
                                n = self.__class__(i, _dynamic=self._dynamic)
                            l.append(n)
                        v = l
                    self._map[k] = v
        if kwargs:
            for k, v in self.__call_items(kwargs):
                if k is not '_dynamic':
                    self._map[k] = v

    def __call_items(self, obj):
        if hasattr(obj, 'iteritems') and ismethod(getattr(obj, 'iteritems')):
            return obj.iteritems()
        else:
            return obj.items()

    def items(self):
        return self.iteritems()

    def iteritems(self):
        return self.__call_items(self._map)

    def __iter__(self):
        return self._map.__iter__()

    def next(self):
        return self._map.next()

    def __setitem__(self, k, v):
        self._map[k] = v

    def __getitem__(self, k):
        if k not in self._map and self._dynamic and k != '_ipython_canary_method_should_not_exist_':
            # automatically extend to new DotMap
            self[k] = self.__class__()
        return self._map[k]

    def __setattr__(self, k, v):
        if k in {
                '_map', '_dynamic', '_ipython_canary_method_should_not_exist_'
        }:
            super(DotMap, self).__setattr__(k, v)
        else:
            self[k] = v

    def __getattr__(self, k):
        if k in {
                '_map', '_dynamic', '_ipython_canary_method_should_not_exist_'
        }:
            return super(DotMap, self).__getattr__(k)

        try:
            v = super(self.__class__, self).__getattribute__(k)
            return v
        except AttributeError:
            pass

        return self[k]

    def __delattr__(self, key):
        return self._map.__delitem__(key)

    def __contains__(self, k):
        return self._map.__contains__(k)

    def __add__(self, other):
        if self.empty():
            return other
        else:
            self_type = type(self).__name__
            other_type = type(other).__name__
            msg = "unsupported operand type(s) for +: '{}' and '{}'"
            raise TypeError(msg.format(self_type, other_type))

    def __str__(self):
        items = []
        for k, v in self.__call_items(self._map):
            # recursive assignment case
            if id(v) == id(self):
                items.append('{0}={1}(...)'.format(k, self.__class__.__name__))
            else:
                items.append('{0}={1}'.format(k, repr(v)))
        joined = ', '.join(items)
        out = '{0}({1})'.format(self.__class__.__name__, joined)
        return out

    def __repr__(self):
        return str(self)

    def toDict(self):
        d = {}
        for k, v in self.items():
            if issubclass(type(v), DotMap):
                # bizarre recursive assignment support
                if id(v) == id(self):
                    v = d
                else:
                    v = v.toDict()
            elif type(v) in (list, tuple):
                l = []
                for i in v:
                    n = i
                    if issubclass(type(i), DotMap):
                        n = i.toDict()
                    l.append(n)
                if type(v) is tuple:
                    v = tuple(l)
                else:
                    v = l
            d[k] = v
        return d

    def pprint(self, pformat='dict'):
        if pformat == 'json':
            print(dumps(self.toDict(), indent=4, sort_keys=True))
        else:
            pprint(self.toDict())

    def empty(self):
        return (not any(self))

    # proper dict subclassing
    def values(self):
        return self._map.values()

    # ipython support
    def __dir__(self):
        return self.keys()

    @classmethod
    def parseOther(self, other):
        if issubclass(type(other), DotMap):
            return other._map
        else:
            return other

    def __cmp__(self, other):
        other = DotMap.parseOther(other)
        return self._map.__cmp__(other)

    def __eq__(self, other):
        other = DotMap.parseOther(other)
        if not isinstance(other, dict):
            return False
        return self._map.__eq__(other)

    def __ge__(self, other):
        other = DotMap.parseOther(other)
        return self._map.__ge__(other)

    def __gt__(self, other):
        other = DotMap.parseOther(other)
        return self._map.__gt__(other)

    def __le__(self, other):
        other = DotMap.parseOther(other)
        return self._map.__le__(other)

    def __lt__(self, other):
        other = DotMap.parseOther(other)
        return self._map.__lt__(other)

    def __ne__(self, other):
        other = DotMap.parseOther(other)
        return self._map.__ne__(other)

    def __delitem__(self, key):
        return self._map.__delitem__(key)

    def __len__(self):
        return self._map.__len__()

    def clear(self):
        self._map.clear()

    def copy(self):
        return self.__class__(self)

    def __copy__(self):
        return self.copy()

    def __deepcopy__(self, memo=None):
        return self.copy()

    def get(self, key, default=None):
        return self._map.get(key, default)

    def has_key(self, key):
        return key in self._map

    def iterkeys(self):
        return self._map.iterkeys()

    def itervalues(self):
        return self._map.itervalues()

    def keys(self):
        return self._map.keys()

    def pop(self, key, default=None):
        return self._map.pop(key, default)

    def popitem(self):
        return self._map.popitem()

    def setdefault(self, key, default=None):
        return self._map.setdefault(key, default)

    def update(self, *args, **kwargs):
        if len(args) != 0:
            self._map.update(*args)
        self._map.update(kwargs)

    def viewitems(self):
        return self._map.viewitems()

    def viewkeys(self):
        return self._map.viewkeys()

    def viewvalues(self):
        return self._map.viewvalues()

    @classmethod
    def fromkeys(cls, seq, value=None):
        d = cls()
        d._map = OrderedDict.fromkeys(seq, value)
        return d

    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, d):
        self.__dict__.update(d)

    # bannerStr
    def _getListStr(self, items):
        out = '['
        mid = ''
        for i in items:
            mid += '  {}\n'.format(i)
        if mid != '':
            mid = '\n' + mid
        out += mid
        out += ']'
        return out

    def _getValueStr(self, k, v):
        outV = v
        multiLine = len(str(v).split('\n')) > 1
        if multiLine:
            # push to next line
            outV = '\n' + v
        if type(v) is list:
            outV = self._getListStr(v)
        out = '{} {}'.format(k, outV)
        return out

    def _getSubMapDotList(self, pre, name, subMap):
        outList = []
        if pre == '':
            pre = name
        else:
            pre = '{}.{}'.format(pre, name)

        def stamp(pre, k, v):
            valStr = self._getValueStr(k, v)
            return '{}.{}'.format(pre, valStr)

        for k, v in subMap.items():
            if isinstance(v, DotMap) and v != DotMap():
                subList = self._getSubMapDotList(pre, k, v)
                outList.extend(subList)
            else:
                outList.append(stamp(pre, k, v))
        return outList

    def _getSubMapStr(self, name, subMap):
        outList = ['== {} =='.format(name)]
        for k, v in subMap.items():
            if isinstance(v, self.__class__) and v != self.__class__():
                # break down to dots
                subList = self._getSubMapDotList('', k, v)
                # add the divit
                # subList = ['> {}'.format(i) for i in subList]
                outList.extend(subList)
            else:
                out = self._getValueStr(k, v)
                # out = '> {}'.format(out)
                out = '{}'.format(out)
                outList.append(out)
        finalOut = '\n'.join(outList)
        return finalOut

    def bannerStr(self):
        lines = []
        previous = None
        for k, v in self.items():
            if previous == self.__class__.__name__:
                lines.append('-')
            out = ''
            if isinstance(v, self.__class__):
                name = k
                subMap = v
                out = self._getSubMapStr(name, subMap)
                lines.append(out)
                previous = self.__class__.__name__
            else:
                out = self._getValueStr(k, v)
                lines.append(out)
                previous = 'other'
        lines.append('--')
        s = '\n'.join(lines)
        return s
class BayesianBinarySensor(BinarySensorDevice):
    """Representation of a Bayesian sensor."""
    def __init__(self, name, prior, observations, probability_threshold,
                 device_class):
        """Initialize the Bayesian sensor."""
        self._name = name
        self._observations = observations
        self._probability_threshold = probability_threshold
        self._device_class = device_class
        self._deviation = False
        self.prior = prior
        self.probability = prior

        self.current_observations = OrderedDict({})

        self.observations_by_entity = self._build_observations_by_entity()

        self.observation_handlers = {
            "numeric_state": self._process_numeric_state,
            "state": self._process_state,
            "template": self._process_template,
        }

    async def async_added_to_hass(self):
        """
        Call when entity about to be added.

        All relevant update logic for instance attributes occurs within this closure.
        Other methods in this class are designed to avoid directly modifying instance
        attributes, by instead focusing on returning relevant data back to this method.

        The goal of this method is to ensure that `self.current_observations` and `self.probability`
        are set on a best-effort basis when this entity is register with hass.

        In addition, this method must register the state listener defined within, which
        will be called any time a relevant entity changes its state.
        """
        @callback
        def async_threshold_sensor_state_listener(entity, _old_state,
                                                  new_state):
            """
            Handle sensor state changes.

            When a state changes, we must update our list of current observations,
            then calculate the new probability.
            """
            if new_state.state == STATE_UNKNOWN:
                return

            self.current_observations.update(
                self._record_entity_observations(entity))
            self.probability = self._calculate_new_probability()

            self.hass.async_add_job(self.async_update_ha_state, True)

        self.current_observations.update(
            self._initialize_current_observations())
        self.probability = self._calculate_new_probability()
        async_track_state_change(
            self.hass,
            self.observations_by_entity,
            async_threshold_sensor_state_listener,
        )

    def _initialize_current_observations(self):
        local_observations = OrderedDict({})
        for entity in self.observations_by_entity:
            local_observations.update(self._record_entity_observations(entity))
        return local_observations

    def _record_entity_observations(self, entity):
        local_observations = OrderedDict({})
        entity_obs_list = self.observations_by_entity[entity]

        for entity_obs in entity_obs_list:
            platform = entity_obs["platform"]

            should_trigger = self.observation_handlers[platform](entity_obs)

            if should_trigger:
                obs_entry = {"entity_id": entity, **entity_obs}
            else:
                obs_entry = None

            local_observations[entity_obs["id"]] = obs_entry

        return local_observations

    def _calculate_new_probability(self):
        prior = self.prior

        for obs in self.current_observations.values():
            if obs is not None:
                prior = update_probability(
                    prior,
                    obs["prob_given_true"],
                    obs.get("prob_given_false", 1 - obs["prob_given_true"]),
                )

        return prior

    def _build_observations_by_entity(self):
        """
        Build and return data structure of the form below.

        {
            "sensor.sensor1": [{"id": 0, ...}, {"id": 1, ...}],
            "sensor.sensor2": [{"id": 2, ...}],
            ...
        }

        Each "observation" must be recognized uniquely, and it should be possible
        for all relevant observations to be looked up via their `entity_id`.
        """

        observations_by_entity = {}
        for ind, obs in enumerate(self._observations):
            obs["id"] = ind

            if "entity_id" in obs:
                entity_ids = [obs["entity_id"]]
            elif "value_template" in obs:
                entity_ids = obs.get(CONF_VALUE_TEMPLATE).extract_entities()

            for e_id in entity_ids:
                obs_list = observations_by_entity.get(e_id, [])
                obs_list.append(obs)
                observations_by_entity[e_id] = obs_list

        return observations_by_entity

    def _process_numeric_state(self, entity_observation):
        """Return True if numeric condition is met."""
        entity = entity_observation["entity_id"]

        should_trigger = condition.async_numeric_state(
            self.hass,
            entity,
            entity_observation.get("below"),
            entity_observation.get("above"),
            None,
            entity_observation,
        )
        return should_trigger

    def _process_state(self, entity_observation):
        """Return True if state conditions are met."""
        entity = entity_observation["entity_id"]

        should_trigger = condition.state(self.hass, entity,
                                         entity_observation.get("to_state"))

        return should_trigger

    def _process_template(self, entity_observation):
        """Return True if template condition is True."""
        template = entity_observation.get(CONF_VALUE_TEMPLATE)
        template.hass = self.hass
        should_trigger = condition.async_template(self.hass, template,
                                                  entity_observation)
        return should_trigger

    @property
    def name(self):
        """Return the name of the sensor."""
        return self._name

    @property
    def is_on(self):
        """Return true if sensor is on."""
        return self._deviation

    @property
    def should_poll(self):
        """No polling needed."""
        return False

    @property
    def device_class(self):
        """Return the sensor class of the sensor."""
        return self._device_class

    @property
    def device_state_attributes(self):
        """Return the state attributes of the sensor."""
        return {
            ATTR_OBSERVATIONS:
            list(self.current_observations.values()),
            ATTR_OCCURRED_OBSERVATION_ENTITIES:
            list(
                set(
                    obs.get("entity_id")
                    for obs in self.current_observations.values()
                    if obs is not None)),
            ATTR_PROBABILITY:
            round(self.probability, 2),
            ATTR_PROBABILITY_THRESHOLD:
            self._probability_threshold,
        }

    async def async_update(self):
        """Get the latest data and update the states."""
        self._deviation = bool(self.probability >= self._probability_threshold)