Beispiel #1
0
    def parseExpression(self, expr, name=None, expandedTerm=None):
        """ This function handles the convertion from str to TensorObject of
        lagrangian expressions written by the user in the model file.
        As much as possible, the user input is validated and error messages
        are printed if needed."""

        originalExpr = expr
        errorExpr = (name +
                     ' : ' if name is not None else '') + str(originalExpr)

        ##########
        # Case 1 : expr is a representation matrix
        ##########

        if expr[:2].lower() == 't(':
            args = expr[2:-1]
            gp = args.split(',')[0]

            if gp + ',' not in args:
                loggingCritical(
                    f"\nError : representation matrix {expr} should have exactly two arguments : group and rep"
                )
                exit()
            rep = eval(args.replace(gp + ',', ''))

            if gp in self.model.gaugeGroups:
                gp = self.model.gaugeGroups[gp]
            else:
                for gName, g in self.model.gaugeGroups.items():
                    if g.type == gp:
                        gp = g
                        break
                if type(gp) == str:
                    loggingCritical(
                        f"\nError in 'Definitions': gauge group '{gp}' is unknown."
                    )
                    exit()

            # DimR -> Dynkin labels
            if isinstance(rep, int):
                rep = self.idb.get(gp.type, 'dynkinLabels', rep)

            repMats = gp.repMat(tuple(rep))

            shape = tuple([len(repMats), *repMats[0].shape])
            dic = {}
            for i, mat in enumerate(repMats):
                for k, v in mat._smat.items():
                    dic[(i, *k)] = v

            # This is for latex output
            expr = Function('t')(Symbol(gp.type), Symbol(str(rep)))

            return TensorObject(copy=(name, shape, dic),
                                fromDef=name,
                                expr=expr)

        ##########
        # Case 2 : expr is a CGC
        ##########

        if expr[:4].lower() == 'cgc(':
            args = expr[4:-1].split(',')

            # Read the gauge group
            gp = args[0]
            args = args[1:]

            # Collect lists together
            i = 0
            while i < len(args):
                o = args[i].count('[') + args[i].count('(') + args[i].count(
                    '{')
                c = args[i].count(']') + args[i].count(')') + args[i].count(
                    '}')
                if o > c:
                    args[i] = args[i] + ', ' + args[i + 1]
                    args.pop(i + 1)
                else:
                    i += 1

            # Read the fields
            fields = []
            for i, el in enumerate(args):
                if el.isnumeric() or ('(' in el and ')' in el):
                    # Stop after encountering an int or a tuple
                    i -= 1
                    break
                fields.append(el)
            args = args[i + 1:]

            # Determine which syntax was used
            if gp in self.model.gaugeGroups and all(
                [el in self.model.Particles for el in fields]):
                fieldNames = fields
            elif gp in [gp.type for gp in self.model.gaugeGroupsList] and all(
                [el not in self.model.Particles for el in fields]):
                fieldNames = []
            else:
                loggingCritical(
                    "\nError : CGC syntax is 'cgc(groupName, field1, field2 [, field3 [, field4, [CGC number]]])' or "
                    +
                    "cgc(group, dynkins1, dynkins2 [, dynkins3 [, dynkins4, [CGC number]]]). The group and particles must be defined above."
                )
                loggingCritical(
                    f"Please rewrite the term '{name}: {expr}' accordingly.")
                exit()

            N = 0
            # The CGC call contains a pos
            if args != []:
                if len(args) == 1:
                    N = int(args[0]) - 1
                else:
                    loggingCritical(
                        f"\nError in {name}: {expr} ; too much arguments to cgc() function."
                    )
                    exit()
                if N < 0:
                    loggingCritical(
                        f"\nError in {name}: {expr} ; the position argument must be a non-zero positive integer."
                    )
                    exit()

            if not isinstance(N, int):
                loggingCritical(
                    f"\nError in CGC '{name}: {expr}' : position argument must be an integer."
                )
                exit()

            if fieldNames != []:
                gpName, gType = gp, self.model.gaugeGroups[gp].type
                reps = [self.model.Particles[p].Qnb[gpName] for p in fields]
            else:
                gType, reps = gp, [eval(labels) for labels in fields]

            cgc = self.idb.get(gType,
                               'invariants',
                               reps,
                               pyrateNormalization=True,
                               realBasis=GaugeGroup.realBasis)

            if len(cgc) == 0:
                loggingCritical(
                    f"Error: no invariant can be formed from the reps provided in '{name}'."
                )
                exit()
            if N > len(cgc) - 1:
                loggingCritical(
                    f"\nError in {name}: {expr} ; the position argument cannot be larger than {len(cgc)} here."
                )
                exit()

            result = cgc[N]

            shape = result.dim[:result.rank]
            dic = {}
            for k, v in result.dic.items():
                dic[k[:result.rank]] = v

            # This is for latex output
            expr = Function('cgc')(Symbol(gType),
                                   *([Symbol(el) for el in fields] + [N]))

            return TensorObject(copy=(name, shape, dic),
                                fromDef=name,
                                expr=expr,
                                fields=fieldNames)

        ##########
        # Case 3 : an expression involving the already defined quantities
        ##########
        localDict = {}

        count = 0
        expr = expr.replace('sqrt', '#').replace('Sqrt', '#')
        for k, v in sorted(self.definitions.items(), key=lambda x: -len(x[0])):
            expr = expr.replace(k, f'@_{count}_')
            localDict[f'symb_{count}_'] = v.symbol
            count += 1
        expr = expr.replace('@', 'symb')
        expr = expr.replace('#', 'sqrt')

        def sympyParse(expr):
            if '^' in expr:
                loggingCritical(
                    f"\nError in expression '{errorExpr}' : powers must be written using the '**' operator"
                )
                exit()
            return parse_expr(expr,
                              local_dict=localDict,
                              transformations=standard_transformations[1:] +
                              (implicit_multiplication, ),
                              evaluate=False)

        # A) Replacements to format the expr string
        expr = expr.replace(']', '] ').strip()
        expr = expr.replace(' +',
                            '+').replace(' -',
                                         '-').replace(' *',
                                                      '*').replace(' /', '/')
        expr = expr.replace(' )', ')')
        expr = expr.replace('] ', ']*')

        for k, v in localDict.items():
            if isinstance(v, Symbol):
                expr = expr.replace(k, k + ' ')

        # B) Parse the string
        try:
            expr = sympyParse(expr)
        except:
            loggingCritical(f"\nError while parsing the term '{errorExpr}'.")
            exit()

        rep = {}
        if expr.find(Pow) != set():
            # Now we want to expand the term, keeping any form (a*b)**2 unexpanded
            a, b, c = [Wild(x, exclude=(1, )) for x in 'abc']
            rep = expr.replace((a * b)**c,
                               lambda a, b, c: (a * b)**Symbol('n_' + str(c)),
                               map=True)

        if rep == {} or rep[1] == {}:
            termList = expand(expr).as_coeff_add()[1]
        else:
            termList = expand(rep[0], power_base=False).as_coeff_add()[1]

        # C) Parse the left hand side of the definition (if given)
        Linds = []
        if name is not None:
            if '[' in name and ']' in name:
                Lbase = name[:name.find('[')]
                Linds = name[name.find('[') + 1:name.find(']')].split(',')
                Linds = [Symbol(el) for el in Linds]

        # D) Validate and compute the expression
        rhsResult = 0
        commonFreeInds = None

        for term in termList:
            split = splitPow(term)

            rationalFactors = [el for el in split if el.is_number]
            terms = tuple([el for el in split if el not in rationalFactors])
            coeff = Mul(*rationalFactors)

            # Handle expr**N now
            newTerms = []
            for subTerm in terms:
                if isinstance(subTerm, Pow):
                    base, exp = subTerm.base, subTerm.exp
                    if isinstance(exp, Symbol):
                        exp = int(exp.name[2:])
                    indexed = base.find(Indexed)

                    if indexed != set():
                        indices = flatten([el.indices for el in indexed])

                        indCopies = {}
                        for i in indices:
                            if i not in indCopies:
                                indCopies[i] = [
                                    Symbol(str(i) + f'_{p}')
                                    for p in range(1, exp)
                                ]

                    else:
                        indCopies = {}
                    newTerms.append(base)
                    for p in range(0, exp - 1):
                        sub = {i: copy[p] for i, copy in indCopies.items()}
                        newTerms.append(base.subs(sub))
                else:
                    newTerms.append(subTerm)

            terms = []
            for subTerm in newTerms:
                if isinstance(subTerm, Mul) or isinstance(subTerm, Pow):
                    tmp = splitPow(subTerm)
                    for el in tmp:
                        if not el.is_number:
                            terms.append(el)
                        else:
                            coeff *= el
                else:
                    if not subTerm.is_number:
                        terms.append(subTerm)
                    else:
                        coeff *= subTerm

            if expandedTerm is not None:
                if expandedTerm == []:
                    expandedTerm.append(Mul(coeff, *terms, evaluate=False))
                else:
                    expandedTerm[0] += Mul(coeff, *terms, evaluate=False)

            inds = []
            indRanges = {}
            for i, field in enumerate(terms):
                if isinstance(field, Symbol):
                    continue
                try:
                    fieldInds = field.indices
                except AttributeError:
                    loggingCritical(
                        f"\nError (in term '{expr}') while reading the quantity '{field}'. It seems that indices are missing."
                    )
                    exit()

                fieldDef = self.definitions[str(field.base)]
                if fieldDef.dim is not None and len(fieldInds) != fieldDef.dim:
                    loggingCritical(
                        f"\nError (in term '{expr}'): the quantity {field.base} should carry exactly {fieldDef.dim} indices"
                    )
                    exit()

                inds += list(fieldInds)
                for p, ind in enumerate(field.indices):
                    indRanges[ind] = (fieldDef, p)

            freeInds = []
            for ind in set(inds):
                count = inds.count(ind)
                if count == 1:
                    freeInds.append(ind)
                if count > 2:
                    loggingCritical(
                        f"\nError: in term '{term}', the index '{ind}' appears more than twice."
                    )
                    exit()

            if commonFreeInds is None:
                commonFreeInds = freeInds
            elif freeInds != commonFreeInds:
                loggingCritical(
                    f"\nError : each term of the sum '{expr}' must contain the same free indices."
                )
                exit()
            if name is not None and set(freeInds) != set(Linds):
                loggingCritical(
                    f"\nError in term {term}: there should be {len(set(Linds))} free indices"
                    + (' -> ' +
                       str(tuple(set(Linds))) if set(Linds) != set() else ''))
                exit()

            # Now that the term is validated, construct the resulting tensor object
            contractArgs = []
            for field in terms:
                if not isinstance(field, Symbol):
                    base, inds = field.base, field.indices
                else:
                    base, inds = field, []

                tens = self.definitions[str(base)]
                tens.update(len(inds))

                inds = [Wild(str(el)) for el in inds]
                contractArgs.append(tens(*inds))

            freeDummies = [Wild(str(el)) for el in Linds]
            tmp = tensorContract(*contractArgs,
                                 value=coeff,
                                 freeDummies=freeDummies,
                                 doit=True)

            if not isinstance(tmp, dict):
                tmp = expand(tmp)

            if rhsResult == 0:
                rhsResult = tmp
            else:
                if not isinstance(rhsResult, dict):
                    rhsResult += tmp
                else:
                    for k, v in tmp.items():
                        v = expand(v)
                        if k in rhsResult:
                            rhsResult[k] += v
                        else:
                            rhsResult = k

                        if rhsResult[k] == 0:
                            del rhsResult[k]

        if not isinstance(rhsResult, dict):
            return TensorObject(copy=('' if name is None else name, (), {
                (): rhsResult
            }),
                                fromDef=name,
                                expr=expr)

        ranges = []

        for freeInd in Linds:
            iRange = indRanges[freeInd]
            iRange = iRange[0].range[iRange[1]]
            ranges.append(iRange)

        try:
            return TensorObject(copy=(Lbase, ranges, rhsResult),
                                fromDef=name,
                                expr=expr)
        except:
            loggingCritical(
                f"\nError while parsing the term '{errorExpr}': please check the consistency of contracted indices."
            )
            exit()