def fit(self, data): """ Args: data (DictTree) """ step_idxs = [[] for _ in range(self.max_cnt)] for step_idx, iput in enumerate(data.iput): step_idxs[int(iput[self.cnt_idx])].append(step_idx) for cnt in range(self.max_cnt): if step_idxs[cnt]: iput = data.iput[step_idxs[cnt]] sub = data.oput.sub[step_idxs[cnt]] arg = data.oput.arg[step_idxs[cnt]] else: iput = np.zeros((1, len(data.iput[0]))) sub = np.zeros(1, np.int32) arg = np.zeros((1, len(data.oput.arg[0]))) self.models[cnt].fit(DictTree( iput=iput, oput=DictTree( sub=sub, arg=arg, ), ))
def step(self, obs): ret_name = self.last_act_name ret_val = obs steps = [] while self.stack: top = self.stack[-1] sub_name, sub_arg = self.skill_set[top.name].step(top.arg, top.cnt, ret_name, ret_val) steps.append(DictTree( name=top.name, arg=top.arg, cnt=top.cnt, ret_name=ret_name, ret_val=ret_val, sub_name=sub_name, sub_arg=sub_arg, )) print("{}({}, {}, {}, {}) -> {}({})".format( top.name, top.arg, top.cnt, ret_name, ret_val, sub_name, sub_arg)) if sub_name is None: self.stack.pop() ret_name = top.name ret_val = sub_arg elif sub_name in self.skill_set: top.cnt += 1 self.stack.append(DictTree(name=sub_name, arg=sub_arg, cnt=0)) ret_name = None ret_val = None else: top.cnt += 1 self.last_act_name = sub_name return sub_name, sub_arg, DictTree(steps=steps) self.last_act_name = None return None, None, DictTree(steps=steps)
def get_skillset(self): # print(self.skill_set) # print(self.actions) to_return = DictTree({ skill: DictTree( step=None, model_name=getattr(skill_class, 'model_name', "log_poly2"), arg_in_len=skill_class.arg_in_len, max_cnt=getattr(skill_class, 'max_cnt', None), sub_skill_names=getattr(skill_class, 'sub_skill_names', []), ret_out_len=skill_class.ret_out_len, min_valid_data=getattr(skill_class, 'min_valid_data', None), sub_arg_accuracy=getattr(skill_class, 'sub_arg_accuracy', None), ) for skill, skill_class in list(self.skill_set.items()) + list(self.actions.items()) }) for skill in to_return.values(): if skill.sub_skill_names: skill.ret_in_len = max( to_return[sub_skill_name].ret_out_len for sub_skill_name in skill.sub_skill_names) skill.arg_out_len = max( skill.ret_out_len, max(to_return[sub_skill_name].arg_in_len for sub_skill_name in skill.sub_skill_names)) self.stack = None self.last_act_name = None return to_return
def get_loss(self, batch): loss = self._get_loss(batch) stats = DictTree( loss=loss, per_step=DictTree(loss=loss), ) return stats
def pre_action_callback(act_name, metadata): nonlocal memory, ret_name obs = env.observe() top = memory.stack[-1] trace.data.steps.append( DictTree( mem_in=memory, ret_name=ret_name, ret_val=empty_tensor, obs=self._get_value(obs, teacher=False), mem_out=DictTree(steps=[ DictTree( name=top.name, arg=self._get_value(top.arg, teacher=False), cnt=top.cnt, ret_name=ret_name, ret_val=empty_tensor, sub_name=act_name, sub_arg=empty_tensor, ) ], stack=memory.stack[:-1] + [DictTree(top, cnt=top.cnt + 1)]), act_name=act_name, act_arg=empty_tensor, )) memory = trace.data.steps[-1].mem_out ret_name = act_name
def _process(agent_skill, data): # TODO: this could be more efficient sub_skill_names = [None] + [ sub_skill.skill_name for sub_skill in agent_skill.sub_skills ] iput = np.asarray([ utils.pad(step.arg, agent_skill.skill_model.arg_in_len) + [step.cnt] + utils.one_hot(sub_skill_names.index(step.ret_name), agent_skill.skill_model.num_sub) + utils.pad(step.ret_val, agent_skill.skill_model.ret_in_len) for step in data ]) sub = np.asarray([sub_skill_names.index(step.sub_name) for step in data]) arg = np.asarray([ utils.pad(step.sub_arg, agent_skill.skill_model.arg_out_len) for step in data ]) return DictTree( len=len(data), iput=iput, oput=DictTree( sub=sub, arg=arg, ), )
def _php_configs(self, env): act_names = list(env.actions.keys()) call_graph = collections.defaultdict(list) # Fill in call graph for all non-primitive actions for level in range(self.config.levels): for branches in itertools.product(range(self.config.branch), repeat=level + 1): parent = ''.join(str(x) for x in branches[:-1]) child = ''.join(str(x) for x in branches) call_graph[f'P{parent}'].append(f'P{child}') # Add primitive actions to call graph for level in range(self.config.levels + 1): for branches in itertools.product(range(self.config.branch), repeat=level): parent = ''.join(str(x) for x in branches) call_graph[f'P{parent}'] += act_names return { k: DictTree( teacher=self.make_teacher(env), sub_names=v, model=DictTree( name='mlp', arg_in_size=0, ret_out_size=0, sub_layers=self.mlp_layers, arg_layers=self.mlp_layers, ), ) for k, v in call_graph.items() }
def catalog(config): """ Args: config (DictTree) """ if '|' in config.name: return ModelSelector(config) elif config.name.startswith('t_'): return TimeDependentModel(DictTree( name=config.name[len('t_'):], cnt_idx=config.arg_in_len, max_cnt=config.max_cnt, num_sub=config.num_sub, )) elif config.name.startswith('log_lin'): return LogisticLinearModel(DictTree( num_sub=config.num_sub, )) elif config.name.startswith('log_poly'): return LogisticPolynomialModel(DictTree( num_sub=config.num_sub, degree=int(config.name[len('log_poly'):]), )) elif config.name.startswith('log_mlp'): return LogisticMLPModel(DictTree( num_sub=config.num_sub, degree=int(config.name[len('log_mlp'):config.name.index('[')]), hidden_sizes=[int(s) for s in config.name[config.name.index('[') + 1:-1].split(', ')], )) elif config.name == 'table': return TableModel(config) else: raise NotImplementedError(config.name)
def __init__(self, config): super().__init__() self.pre = mnist.MNISTModule.Module(config.pre | DictTree(oput_size=config.features_in_size)) self.lstm = torch.nn.LSTM( input_size=config.features_in_size, bidirectional=False, batch_first=config.batch_first, num_layers=config.rnn_layers, hidden_size=config.features_out_size) self.post = mlp.make_mlp(config.post | DictTree(iput_size=config.features_out_size))
def evaluate(self, traces): self.eval() # set to evaluation mode with torch.no_grad(): error = asyncio.get_event_loop().run_until_complete( asyncio.gather(*[self._error(trace) for trace in traces])) self.train() # set to training mode return DictTree(per_trace=DictTree(error=sum(error)))
def __init__(self, config): super(HierarchicalAgent, self).__init__(config) self.skillset = DictTree({ skill.__name__: DictTree( step=getattr(skill, 'step', None) if config.rollable else None, model_name=getattr(skill, 'model_name', self.default_model_name), arg_in_len=skill.arg_in_len, max_cnt=getattr(skill, 'max_cnt', None), sub_skill_names=getattr(skill, 'sub_skill_names', []), ret_out_len=skill.ret_out_len, min_valid_data=getattr(skill, 'min_valid_data', None), sub_arg_accuracy=getattr(skill, 'sub_arg_accuracy', None), ) for skill in self.skills + self.actions }) for skill in self.skillset.values(): if skill.sub_skill_names: skill.ret_in_len = max( self.skillset[sub_skill_name].ret_out_len for sub_skill_name in skill.sub_skill_names) skill.arg_out_len = max( skill.ret_out_len, max(self.skillset[sub_skill_name].arg_in_len for sub_skill_name in skill.sub_skill_names)) if config.rollable and not config.teacher: for skill_name, skill in self.skillset.items(): if skill.sub_skill_names: skill.step = load_skill(config.model_dirname, skill_name, skill) self.stack = None self.last_act_name = None
def batchify(batch): batch_args = DictTree() for k, v in batch[0].args.allitems(): if isinstance(v, torch.Tensor): batch_args[k] = torch.cat([job.args[k] for job in batch]) else: batch_args[k] = v return [batch_args[f'arg{i}'] for i in range(batch[0].meta.num_args)], batch_args.get('kwargs', {})
def finalize(self, phps, actions, config=None): module_config = DictTree( p_img_starts=self.config.p_img_starts, q_img_starts=self.config.q_img_starts, ) super().finalize(phps, actions, (config or DictTree()) | DictTree( sub=module_config | DictTree(module_cls=CategoricalMNISTModule), arg=module_config | DictTree(module_cls=CategoricalMNISTModule), ))
def __init__(self, env, config): super().__init__(env, DEFAULT_CONFIG | config) if not self.config.teacher: rnn_config = DictTree( pre=DictTree(iput_size=env.obs_size), post=DictTree(oput_size=len(self.act_names)), ) self.act_logits = rnn.catalog(rnn_config | self.config.rnn) self._opt = None
def _make_module(config): cfg = DictTree( layers=config.layers, oput_size=config.oput_size, ) if config.posterior: cfg.iput_size = config.q_iput_size else: cfg.iput_size = config.p_iput_size return make_mlp(cfg)
def train(config): all_agents = [ agents.catalog( DictTree(evaluation=config.eval, model_dirname=config.model, domain_name=config.domain, task_name=task_name, hardware_name=config.hardware, rollable=False, teacher=False)) for task_name in config.tasks ] data_dirname = "{}/{}".format(config.data, config.domain) for agent in all_agents[:-1]: client.delete(agent) _train( data_dirname, agent, [], DictTree(modes=['independent'], batch_size=None, validate=False, model_dirname="model/{}/{}".format( config.domain, agent.task_name))) results = DictTree() if config.independent: modes_list = [['independent']] else: modes_list = itertools.product(['validation', ''], ['training', ''], ['independent', '']) for modes in modes_list: modes = [mode for mode in modes if mode] if not modes: continue print("Training with modes: {}".format(', '.join(modes))) client.delete(all_agents[-1]) results['+'.join(modes)] = _train( data_dirname, all_agents[-1], all_agents[:-1], DictTree(modes=modes, batch_size=(None if config.full_batch else 1), validate=True, model_dirname="model/{}/{}_{}".format( config.domain, ".".join(config.tasks), "+".join(modes)))) try: os.makedirs("results/{}/{}".format(config.domain, ".".join(config.tasks))) except OSError: pass time_stamp = time.strftime("%Y-%m-%d %H-%M-%S", time.gmtime()) pickle.dump(results, open( "results/{}/{}/{}.{}.pkl".format(config.domain, ".".join(config.tasks), all_agents[-1].task_name, time_stamp), 'wb'), protocol=2)
def finalize(self, phps, actions, config=None): super().finalize(phps, actions, (config or DictTree()) | DictTree( sub=DictTree( module_cls=CategoricalMLPModule, layers=self.config.sub_layers, ), arg=DictTree( module_cls=CategoricalMLPModule, layers=self.config.arg_layers, ), ))
def _split_data(data, idxs): """ Args: data (DictTree) """ return DictTree( iput=data.iput[idxs], oput=DictTree( sub=data.oput.sub[idxs], arg=data.oput.arg[idxs], ), )
def get_loss(self, batch): packed_batch = rnn_utils.pack_sequence([trace.data.all for trace in batch]) packed_ctx, _ = self.ctx(packed_batch.to(self.device)) padded_ctx, _ = rnn_utils.pad_packed_sequence(packed_ctx.cpu(), batch_first=True) get_loss = [self._get_loss(trace | DictTree(ctx=ctx)) for trace, ctx in zip(batch, padded_ctx)] all_stats = asyncio.get_event_loop().run_until_complete(asyncio.gather(*get_loss)) stats = DictTree( loss=torch.stack([s.loss for s in all_stats]).sum(), per_step=DictTree(), ) with torch.no_grad(): for k, v in all_stats[0].per_step.allitems(): stats.per_step[k] = torch.stack([s.per_step[k] for s in all_stats]).sum() return stats
async def forward(self, iput): """ iput: p = (p_iput, p_mask) q = (q_iput, q_mask) [optional] oput_size [optional] true_oput [optional] eval_oput [optional] modes: no q + no true_oput + no eval_oput = rollout: sample p(oput | iput, mask) no q + no true_oput + eval_oput = evaluate: sample p(oput | iput, mask), and compute error(oput, eval_oput) no q + true_oput + no eval_oput = get_loss (act_arg): compute -log p(oput | iput, mask) q + true_oput + no eval_oput = get_loss (annotated): compute -log p(oput | iput, mask) - log q(oput | iput, mask) q + no true_oput + no eval_oput = get_loss (unannotated): sample q(oput | iput, mask), and compute D[q(. | iput, mask) || p(. | iput, mask)] and log q(oput | iput, mask) res: oput error [in evaluate] loss [in get_loss] log_p [in get_loss] log_q [in get_loss] """ p_log_prob = self._get_log_prob(self.p_module, iput.p_iput, iput.get('p_mask'), iput.get('oput_size')) if 'q_iput' in iput: q_log_prob = self._get_log_prob(self.q_module, iput.q_iput, iput.get('q_mask'), iput.get('oput_size')) oput = iput.get('true_oput', self._sample(q_log_prob)) log_p = self._log_prob(p_log_prob, oput) log_q = self._log_prob(q_log_prob, oput) if 'true_oput' in iput: loss = -log_p - log_q else: loss = self._dkl(q_log_prob, p_log_prob) res = DictTree( oput=oput, # TODO: have configurable entropy_weight loss=loss + ENTROPY_WEIGHT * (self._neg_entropy(p_log_prob) + self._neg_entropy(q_log_prob)), log_p=log_p, log_q=log_q, ) else: oput = iput.get('true_oput', self._sample(p_log_prob)) res = DictTree(oput=oput, ) if 'true_oput' in iput: log_p = self._log_prob(p_log_prob, iput.true_oput) res.loss = -log_p + ENTROPY_WEIGHT * self._neg_entropy( p_log_prob) res.log_p = log_p res.log_q = torch.zeros_like(log_p) if 'eval_oput' in iput: res.error = self._error(oput, iput.eval_oput) return res
def validate(model, data, sub_arg_accuracy=None): """ Args: model data (DictTree) sub_arg_accuracy """ pred = model.predict(data.iput) sub_corr = (pred.sub == data.oput.sub) # type: np.ndarray sub_corr = sub_corr.all() arg_mse = ((pred.arg - data.oput.arg) ** 2).sum(0) if DEBUG: arg_rmse = (arg_mse / len(data.iput)) ** .5 # arg_corr = (arg_rmse <= sub_arg_accuracy).all() print(data.iput) print(data.oput.sub) print(pred.sub) print(data.oput.arg) print(pred.arg) print(sub_corr) print(arg_rmse) if sub_arg_accuracy is None: return DictTree(data_len=len(data.iput), sub_corr=sub_corr, arg_mse=arg_mse) else: arg_corr = ((arg_mse / len(data.iput)) ** .5 <= sub_arg_accuracy).all() return sub_corr and arg_corr
def __init__(self, config): super().__init__(config) # print(f"HardwareAgent: {self.skill_set}") hw = hws.catalog(DictTree(hardware_name=config.hardware_name)) self.skill_set.update(hw.skill_set) #self.domain_name = config.domain_name #self.task_name = config. self.skillset = dict( list(self.skill_set.items()) + list(self.actions.items())) for skill in list(self.skill_set.values()): if skill.sub_skill_names: skill.ret_in_len = max( self.skillset[sub_skill_name].ret_out_len for sub_skill_name in skill.sub_skill_names) skill.arg_out_len = max( skill.ret_out_len, max(self.skillset[sub_skill_name].arg_in_len for sub_skill_name in skill.sub_skill_names)) if not (config.model_dirname == 'None') and ( config.evaluation == 'True'): #and not config.teacher: # print("Here") for skill_name, skill in self.skill_set.items(): # print(f"model: {config.model_dirname}") # print(f"skill_name: {skill_name}") # print(f"skill: {skill}") if skill.sub_skill_names: skill.step = self.load_skill(config.model_dirname, skill_name, skill)
def _validate(agent_skill, shared_data, validate=True, model_dirname=None): model = models.catalog( DictTree( name=agent_skill.skill_model.name, arg_in_len=agent_skill.skill_model.arg_in_len, max_cnt=agent_skill.skill_model.max_cnt, num_sub=agent_skill.skill_model.num_sub, sub_arg_accuracy=agent_skill.sub_arg_accuracy, )) model.fit(shared_data) if validate: valid_data = _process(agent_skill, agent_skill.data) validated = models.validate(model, valid_data, agent_skill.sub_arg_accuracy) else: validated = True if validated: agent_skill.skill_model.model = model if model_dirname is not None: try: os.makedirs(model_dirname) except OSError: pass model_fn = "{}/{}.pkl".format(model_dirname, agent_skill.skill_name) pickle.dump(model, open(model_fn, 'wb'), protocol=2) return validated
def _train(agent_skill, shared_data, validate=True, model_dirname=None): model = models.catalog(DictTree( name=agent_skill.skill_model.name, arg_in_len=agent_skill.skill_model.arg_in_len, max_cnt=agent_skill.skill_model.max_cnt, num_sub=agent_skill.skill_model.num_sub, sub_arg_accuracy=agent_skill.sub_arg_accuracy, )) if validate: num_folds = min(len(agent_skill.data), NUM_FOLDS) kf = ms.KFold(num_folds, True) validation = [] for new_train_idxs, valid_idxs in kf.split(agent_skill.data): train_data = _process(agent_skill, [agent_skill.data[idx] for idx in new_train_idxs] + shared_data) valid_data = _process(agent_skill, [agent_skill.data[idx] for idx in valid_idxs]) model.fit(train_data) validation.append(models.validate(model, valid_data)) validated = models.total_validation(validation, agent_skill.sub_arg_accuracy) else: validated = True if validated: all_data = agent_skill.data if shared_data is not None: all_data += shared_data all_data = _process(agent_skill, all_data) model.fit(all_data) agent_skill.skill_model.model = model if model_dirname is not None: try: os.makedirs(model_dirname) except OSError: pass model_fn = "{}/{}.pkl".format(model_dirname, agent_skill.skill_name) pickle.dump(model, open(model_fn, 'wb'), protocol=2) return validated
def observe(self, expert=False): """ Observe environment. :return: one-hot current floor """ obs = torch.cat(self.last_obs) return DictTree(value=obs, expert_value=obs)
def __init__(self, env, config): super().__init__(env, config) if not self.config.teacher: observable_size = len(self.act_names) + env.ret_out_size + env.obs_size + len( self.act_names) + env.arg_in_size self.ctx = rnn.catalog( DictTree(features_in_size=observable_size, batch_first=True) | self.posterior_rnn_config) self._opt = None
def delete(domain_name, agent_name): if DEBUG: print("delete({}, {})".format(domain_name, agent_name)) SubSkill.query.filter(SubSkill.domain_name == domain_name).filter(SubSkill.agent_name == agent_name).delete() AgentSkill.query.filter(AgentSkill.domain_name == domain_name).filter(AgentSkill.agent_name == agent_name).delete() SkillModel.query.filter(~SkillModel.agent_skills.any()).delete(synchronize_session='fetch') db.session.commit() return DictTree(deleted=agent_name)
async def _get_loss(self, trace): # TODO: time stats loss = [] log_p = [] log_q = [] if trace.metadata.annotated: memory = trace.data.steps[0].mem_in else: memory = self.reset(DictTree(value=trace.metadata.init_arg)) for step, ctx in zip(trace.data.steps, trace.ctx): iput = DictTree( mem_in=memory, ret_name=step.ret_name, ret_val=step.ret_val, obs=step.obs, ctx=ctx, act_name=step.act_name, act_arg=step.act_arg, ) if trace.metadata.annotated: iput.mem_out = step.mem_out oput = await self(iput) loss.extend(oput.loss) log_p.extend(oput.log_p) log_q.extend(oput.log_q) if not trace.metadata.annotated: step.mem_in = memory step.mem_out = oput.mem_out memory = oput.mem_out if trace.metadata.annotated: loss = torch.stack(loss).sum() with torch.no_grad(): return DictTree( per_step=DictTree( score=loss, log_p=torch.stack(log_p).sum(), log_q=torch.stack(log_q).sum(), ), loss=loss, ) else: # score-function trick loss = torch.stack(loss) score = loss.sum() log_q = torch.stack(log_q) loss = score + (loss.detach()[1:] * log_q[:-1].cumsum(0)).sum() with torch.no_grad(): return DictTree( per_step=DictTree( score=score, log_p=torch.stack(log_p).sum(), log_q=log_q.sum(), ), loss=loss, )
def train(agent, config): #for k, v in config.items(): #print("Key is ", k) #print("Value is ", v) data = json.dumps(config, cls=MyEncoder) return DictTree( req.put('{}/agent/{}/{}/'.format(SERVER_URL, agent.domain_name, agent.task_name), data=data).json())
def __init__(self, config): """ Args: config (DictTree) """ self.sub_arg_accuracy = config.sub_arg_accuracy self.models = [catalog(config | DictTree(name=name)) for name in config.name.split('|')] self.selector = None