コード例 #1
0
def genModel(input_symbols, output_symbols):
	input_symbol_options = input_symbols
	output_symbol_options = output_symbols

	nPrims = rangeSample([4,9]).long().item()
	rules = []
	for i in range(nPrims):
		rule, input_symbol_options, output_symbol_options = generate_prim(i, input_symbol_options, 
																			output_symbol_options,
																			blank_prim=True)
		rules.append(rule)

	nxrules = rangeSample([3,7]).long().item()

	for i in range(nxrules):
		#sample_lhs makes sure not to use words twice for simple defs
		input_symbol, input_symbol_options = popFromList(input_symbol_options)

		LHS, used_vars_lhs = sample_LHS(i, input_symbol)
		RHS = sample_RHS(i, used_vars_lhs, p_stop_rhs=0.35)
		
		rules.append(Rule(LHS,RHS))
	
	last_rule = sample_bernoulli(0.5)
	if last_rule == 1.0: #TODO
		#scan:
		rules.append(Rule('u1 u2','[u2] [u1]'))
	else:
		rules.append(icon_concat_rule)

	return Grammar(rules, input_symbols) #todo
コード例 #2
0
def generate_prim(i, input_symbol_options, output_symbol_options, blank_prim=True, obs_prim=None):
	if obs_prim:
		LHS_obs = obs_prim.LHS_str
		RHS_obs = obs_prim.RHS_str
		is_blank = RHS_obs == ''


	input_symbol, input_symbol_options = popFromList(input_symbol_options)

	if blank_prim and sample_bernoulli(1/7.):
		rule = Rule(input_symbol, '') 
	else:
		output_symbol, output_symbol_options = popFromList(output_symbol_options)
		rule = Rule(input_symbol,output_symbol)

	return rule, input_symbol_options, output_symbol_options
コード例 #3
0
def generate_prims(nprims,input_symbol_options,output_symbol_options, blank_prim=False):
	# generate the rules for the primitives
	# input
	#   nprims : number of primitives
	#   _symbol_options : available input and output symbols
	# return :  rules and updated list of input and output symbols

	#note that the blank thing was hacked in, so the output symbol which was assigned to that one is no longer there ...

	rules = []
	nblank = 0
	for i in range(nprims):
		if blank_prim and random.random() < 1./7:
			rules.append( Rule(input_symbol_options[i], '') )
			nblank += 1
		else:
			rules.append( Rule(input_symbol_options[i],output_symbol_options[i]) )

	return rules, input_symbol_options[nprims:], output_symbol_options[nprims:]
コード例 #4
0
def generate_random_scan_rules(nprims, nurules, nxrules, input_symbols, output_symbols):

	# for synthesis: generate a scan grammar
	# nprims : number of primitives - can be 4-8? need to add some nonce actions ...
	# nurules : number of rules with u - must be 2
	# nxrules : number of rules with x or u - should be 3-7 i think
	# also needs to have extended decay rate to hit around rule ... unless there's a cleverer way 

	assert nurules == 0
	####
	#random shuffling
	input_symbol_options = input_symbols.copy()
	output_symbol_options = output_symbols.copy()
	random.shuffle(input_symbol_options)
	random.shuffle(output_symbol_options)

	rules, input_symbol_options, output_symbol_options = generate_prims(nprims, input_symbol_options, output_symbol_options, blank_prim=True)

	# if random.random() < 0.5:
	# 	u_rules, input_symbol_options, output_symbol_options = generate_opposite_around_rules(nurules, input_symbol_options, output_symbol_options)
	# 	rules.extend(u_rules)

	for i in range(nxrules):
		#sample_lhs makes sure not to use words twice for simple defs
		input_symbol = input_symbol_options[i]
		LHS, used_vars_lhs = sample_LHS(input_symbol, only_x=False)
		RHS = sample_RHS(used_vars_lhs, p_stop_rhs=0.35)
		rules.append(Rule(LHS,RHS))

	#rules.extend(u_rules)
	#also need to ensure that both u and x are used, or else there will be no good compositionality ...
	#can do this by having 6 prims, 10 fns with u, and 4 with x or x1 x2!!!

	if random.random() < 0.5:
		#scan:
		rules.append(Rule('u1 u2','[u2] [u1]'))
	else:
		rules.append(icon_concat_rule)

	return Grammar(rules, input_symbols) #todo
コード例 #5
0
ファイル: agent.py プロジェクト: pengking/rulesynthesis
def parse_rules(rules, input_symbols=None):
    assert input_symbols
    Rules = []
    for rule in rules:
        #split into two on arrow
        if '->' in rule:
            idx = rule.index('->')
        else:
            raise ParseError
        lhs = rule[:idx]
        rhs = rule[idx + 1:]

        lhs = ' '.join(lhs)
        rhs = ' '.join(rhs)
        Rules.append(Rule(lhs, rhs))

    #create grammar from Rules:
    #list_prims = ['dax', 'lug', 'fep', 'blicket', 'kiki', 'tufa', 'gazzer', 'zup', 'wif']#input_lang.symbols #TODO this is a major hack
    return Grammar(Rules, input_symbols)
コード例 #6
0
def generate_random_rules(nprims,nrules,input_symbols,output_symbols, sort_prims=False):
	# nprims : number of primitives
	# nrules : number of rules
	assert(nprims+nrules <= len(input_symbols))
	input_symbol_options = np.copy(np.array(input_symbols))
	output_symbol_options = np.copy(np.array(output_symbols))
	np.random.shuffle(input_symbol_options)
	np.random.shuffle(output_symbol_options)
	rules, input_symbol_options, output_symbol_options = generate_prims(nprims, input_symbol_options, output_symbol_options)
	if sort_prims:
		rules = sorted(rules, key=lambda rule: rule.LHS_str )
	primitive_input_symbols = list(set(input_symbols) - set(input_symbol_options))
	for i in range(nrules):
		input_symbol = input_symbol_options[i]
		LHS,used_vars_lhs = sample_LHS(input_symbol)
		RHS = sample_RHS(used_vars_lhs)
		rules.append(Rule(LHS,RHS))
	rules.append(icon_concat_rule)
	return Grammar(rules,input_symbols)
コード例 #7
0
import random #AHHH
"""
nxrules = random.choice((3,4,5,6,7))
nurules = 0
nprims = random.choice(range(4,9))

	input_symbols = input_lang.symbols
	output_symbols = output_lang.symbols
"""

#helper funs:
p_lhs_onearg = 0.4 # probability that we have a single argument. Otherwise, two arguments
p_stop_rhs = 0.6 # prob. of stopping on the right hand side #0.4 for longtail
vars_input = ['u1','u2','x1','x2']
vars_output = ['['+v+']' for v in vars_input]
icon_concat_rule = Rule('u1 x1','[u1] [x1]')

def sample_bernoulli(p):
	return pyprob.sample( pyprob.distributions.Categorical(torch.tensor([1-p, p])))

def getNPrims(g):
	nPrims = 0
	for rule in g.rules:
		if len(rule.LHS_list) == 1:
			nPrims += 1
	return nPrims

def getNHOrules(g):
	nHOrules = 0
	for rule in g.rules:
		if len(rule.LHS_list) > 1:
コード例 #8
0
def exact_perm_rules():
	return [
		Rule('run', 'RUN'),
		Rule('jump', 'JUMP'), 
		Rule('walk', 'WALK'),
		Rule('look', 'LOOK'),
		Rule('turn right', 'RTURN'),
		Rule('u1 right', 'RTURN [u1]'),
		Rule('turn opposite right', 'RTURN RTURN'),
		Rule('u1 opposite right', 'RTURN RTURN [u1]'),
		Rule('turn around right', 'RTURN RTURN RTURN RTURN'),
		Rule('u1 around right', 'RTURN [u1] RTURN [u1] RTURN [u1] RTURN [u1]'),
		Rule('turn left', 'LTURN'),
		Rule('u1 left', 'LTURN [u1]'),
		Rule('turn opposite left', 'LTURN LTURN'),
		Rule('u1 opposite left', 'LTURN LTURN [u1]'),
		Rule('turn around left', 'LTURN LTURN LTURN LTURN'),
		Rule('u1 around left', 'LTURN [u1] LTURN [u1] LTURN [u1] LTURN [u1]'),
		Rule('x1 and x2', '[x1] [x2]'),
		Rule('x2 after x1', '[x1] [x2]'),
		Rule('x1 twice', '[x1] [x1]'),
		Rule('x1 thrice', '[x1] [x1] [x1]')
	]
コード例 #9
0
def exact_perm_doubled_rules():
	return [
		Rule('run', 'RUN'),
		Rule('jump', 'JUMP'), 
		Rule('walk', 'WALK'),
		Rule('look', 'LOOK'),
		Rule('turn', ''),

		Rule('right', 'RTURN'),
		Rule('left', 'LTURN'),

		Rule('u1 opposite u2', '[u2] [u2] [u1]'),
		Rule('u1 around u2', '[u2] [u1] [u2] [u1] [u2] [u1] [u2] [u1]'),
		Rule('x1 and x2', '[x1] [x2]'),
		Rule('x2 after x1', '[x1] [x2]'),
		Rule('x1 twice', '[x1] [x1]'),
		Rule('x1 thrice', '[x1] [x1] [x1]'),
		Rule('u1 u2','[u2] [u1]')
	]