def genModel(input_symbols, output_symbols): input_symbol_options = input_symbols output_symbol_options = output_symbols nPrims = rangeSample([4,9]).long().item() rules = [] for i in range(nPrims): rule, input_symbol_options, output_symbol_options = generate_prim(i, input_symbol_options, output_symbol_options, blank_prim=True) rules.append(rule) nxrules = rangeSample([3,7]).long().item() for i in range(nxrules): #sample_lhs makes sure not to use words twice for simple defs input_symbol, input_symbol_options = popFromList(input_symbol_options) LHS, used_vars_lhs = sample_LHS(i, input_symbol) RHS = sample_RHS(i, used_vars_lhs, p_stop_rhs=0.35) rules.append(Rule(LHS,RHS)) last_rule = sample_bernoulli(0.5) if last_rule == 1.0: #TODO #scan: rules.append(Rule('u1 u2','[u2] [u1]')) else: rules.append(icon_concat_rule) return Grammar(rules, input_symbols) #todo
def generate_prim(i, input_symbol_options, output_symbol_options, blank_prim=True, obs_prim=None): if obs_prim: LHS_obs = obs_prim.LHS_str RHS_obs = obs_prim.RHS_str is_blank = RHS_obs == '' input_symbol, input_symbol_options = popFromList(input_symbol_options) if blank_prim and sample_bernoulli(1/7.): rule = Rule(input_symbol, '') else: output_symbol, output_symbol_options = popFromList(output_symbol_options) rule = Rule(input_symbol,output_symbol) return rule, input_symbol_options, output_symbol_options
def generate_prims(nprims,input_symbol_options,output_symbol_options, blank_prim=False): # generate the rules for the primitives # input # nprims : number of primitives # _symbol_options : available input and output symbols # return : rules and updated list of input and output symbols #note that the blank thing was hacked in, so the output symbol which was assigned to that one is no longer there ... rules = [] nblank = 0 for i in range(nprims): if blank_prim and random.random() < 1./7: rules.append( Rule(input_symbol_options[i], '') ) nblank += 1 else: rules.append( Rule(input_symbol_options[i],output_symbol_options[i]) ) return rules, input_symbol_options[nprims:], output_symbol_options[nprims:]
def generate_random_scan_rules(nprims, nurules, nxrules, input_symbols, output_symbols): # for synthesis: generate a scan grammar # nprims : number of primitives - can be 4-8? need to add some nonce actions ... # nurules : number of rules with u - must be 2 # nxrules : number of rules with x or u - should be 3-7 i think # also needs to have extended decay rate to hit around rule ... unless there's a cleverer way assert nurules == 0 #### #random shuffling input_symbol_options = input_symbols.copy() output_symbol_options = output_symbols.copy() random.shuffle(input_symbol_options) random.shuffle(output_symbol_options) rules, input_symbol_options, output_symbol_options = generate_prims(nprims, input_symbol_options, output_symbol_options, blank_prim=True) # if random.random() < 0.5: # u_rules, input_symbol_options, output_symbol_options = generate_opposite_around_rules(nurules, input_symbol_options, output_symbol_options) # rules.extend(u_rules) for i in range(nxrules): #sample_lhs makes sure not to use words twice for simple defs input_symbol = input_symbol_options[i] LHS, used_vars_lhs = sample_LHS(input_symbol, only_x=False) RHS = sample_RHS(used_vars_lhs, p_stop_rhs=0.35) rules.append(Rule(LHS,RHS)) #rules.extend(u_rules) #also need to ensure that both u and x are used, or else there will be no good compositionality ... #can do this by having 6 prims, 10 fns with u, and 4 with x or x1 x2!!! if random.random() < 0.5: #scan: rules.append(Rule('u1 u2','[u2] [u1]')) else: rules.append(icon_concat_rule) return Grammar(rules, input_symbols) #todo
def parse_rules(rules, input_symbols=None): assert input_symbols Rules = [] for rule in rules: #split into two on arrow if '->' in rule: idx = rule.index('->') else: raise ParseError lhs = rule[:idx] rhs = rule[idx + 1:] lhs = ' '.join(lhs) rhs = ' '.join(rhs) Rules.append(Rule(lhs, rhs)) #create grammar from Rules: #list_prims = ['dax', 'lug', 'fep', 'blicket', 'kiki', 'tufa', 'gazzer', 'zup', 'wif']#input_lang.symbols #TODO this is a major hack return Grammar(Rules, input_symbols)
def generate_random_rules(nprims,nrules,input_symbols,output_symbols, sort_prims=False): # nprims : number of primitives # nrules : number of rules assert(nprims+nrules <= len(input_symbols)) input_symbol_options = np.copy(np.array(input_symbols)) output_symbol_options = np.copy(np.array(output_symbols)) np.random.shuffle(input_symbol_options) np.random.shuffle(output_symbol_options) rules, input_symbol_options, output_symbol_options = generate_prims(nprims, input_symbol_options, output_symbol_options) if sort_prims: rules = sorted(rules, key=lambda rule: rule.LHS_str ) primitive_input_symbols = list(set(input_symbols) - set(input_symbol_options)) for i in range(nrules): input_symbol = input_symbol_options[i] LHS,used_vars_lhs = sample_LHS(input_symbol) RHS = sample_RHS(used_vars_lhs) rules.append(Rule(LHS,RHS)) rules.append(icon_concat_rule) return Grammar(rules,input_symbols)
import random #AHHH """ nxrules = random.choice((3,4,5,6,7)) nurules = 0 nprims = random.choice(range(4,9)) input_symbols = input_lang.symbols output_symbols = output_lang.symbols """ #helper funs: p_lhs_onearg = 0.4 # probability that we have a single argument. Otherwise, two arguments p_stop_rhs = 0.6 # prob. of stopping on the right hand side #0.4 for longtail vars_input = ['u1','u2','x1','x2'] vars_output = ['['+v+']' for v in vars_input] icon_concat_rule = Rule('u1 x1','[u1] [x1]') def sample_bernoulli(p): return pyprob.sample( pyprob.distributions.Categorical(torch.tensor([1-p, p]))) def getNPrims(g): nPrims = 0 for rule in g.rules: if len(rule.LHS_list) == 1: nPrims += 1 return nPrims def getNHOrules(g): nHOrules = 0 for rule in g.rules: if len(rule.LHS_list) > 1:
def exact_perm_rules(): return [ Rule('run', 'RUN'), Rule('jump', 'JUMP'), Rule('walk', 'WALK'), Rule('look', 'LOOK'), Rule('turn right', 'RTURN'), Rule('u1 right', 'RTURN [u1]'), Rule('turn opposite right', 'RTURN RTURN'), Rule('u1 opposite right', 'RTURN RTURN [u1]'), Rule('turn around right', 'RTURN RTURN RTURN RTURN'), Rule('u1 around right', 'RTURN [u1] RTURN [u1] RTURN [u1] RTURN [u1]'), Rule('turn left', 'LTURN'), Rule('u1 left', 'LTURN [u1]'), Rule('turn opposite left', 'LTURN LTURN'), Rule('u1 opposite left', 'LTURN LTURN [u1]'), Rule('turn around left', 'LTURN LTURN LTURN LTURN'), Rule('u1 around left', 'LTURN [u1] LTURN [u1] LTURN [u1] LTURN [u1]'), Rule('x1 and x2', '[x1] [x2]'), Rule('x2 after x1', '[x1] [x2]'), Rule('x1 twice', '[x1] [x1]'), Rule('x1 thrice', '[x1] [x1] [x1]') ]
def exact_perm_doubled_rules(): return [ Rule('run', 'RUN'), Rule('jump', 'JUMP'), Rule('walk', 'WALK'), Rule('look', 'LOOK'), Rule('turn', ''), Rule('right', 'RTURN'), Rule('left', 'LTURN'), Rule('u1 opposite u2', '[u2] [u2] [u1]'), Rule('u1 around u2', '[u2] [u1] [u2] [u1] [u2] [u1] [u2] [u1]'), Rule('x1 and x2', '[x1] [x2]'), Rule('x2 after x1', '[x1] [x2]'), Rule('x1 twice', '[x1] [x1]'), Rule('x1 thrice', '[x1] [x1] [x1]'), Rule('u1 u2','[u2] [u1]') ]