def __call__(self, model_output, inputs, length, i_ex, name=None, targets=None, estimates=None): start_ind, end_ind = self.matcher.get_init_inds(model_output) if i_ex == 0: # Only for the first element model_output.tree.compute_matching_dists( {}, matching_fcn=self.matcher, left_parents=AttrDict(timesteps=start_ind), right_parents=AttrDict(timesteps=end_ind)) name = 'images' if name is None else name estimates = torch.stack([ node.subgoal[name][i_ex] for node in model_output.tree.depth_first_iter() ]) leave = torch.stack([ node.subgoal.c_n_prime[i_ex] for node in model_output.tree.depth_first_iter() ]).byte().any(1) return estimates[leave], None
def get_all_samples(self, model_output, inputs, length, name=None, targets=None, estimates=None): start_ind, end_ind = self.matcher.get_init_inds(model_output) # Only for the first element model_output.tree.compute_matching_dists( {}, matching_fcn=self.matcher, left_parents=AttrDict(timesteps=start_ind), right_parents=AttrDict(timesteps=end_ind)) name = 'images' if name is None else name estimates = torch.stack([ node.subgoal[name] for node in model_output.tree.depth_first_iter() ]) leave = torch.stack([ node.subgoal.c_n_prime for node in model_output.tree.depth_first_iter() ]).byte().any(-1) pruned_seqs = [ estimates[:, i][leave[:, i]] for i in range(leave.shape[1]) ] return pruned_seqs, None
def forward(self, context=None, x_prime=None, more_context=None, z=None): """ :param x: observation at current step :param context: to be fed at each timestep :param x_prime: observation at next step :param more_context: also to be fed at each timestep. :param z: (optional) if not None z is used directly and not sampled :return: """ # TODO to get rid of more_context, make an interface that allows context structures output = AttrDict() output.p_z = self.prior(torch.zeros_like( x_prime)) # the input is only used to read the batchsize atm if x_prime is not None: output.q_z = self.inf( self.inf_lstm(x_prime, context, more_context).output) if z is None: if self._sample_prior: z = Gaussian(output.p_z).sample() else: z = Gaussian(output.q_z).sample() pred_input = [z, context, more_context] output.x = self.gen_lstm(*pred_input).output return output
def forward(self, inputs): self.get_timesteps(inputs) actions_pred = self.action_pred(inputs.state_t0[:, :, None, None], inputs.state_t1[:, :, None, None]) output = AttrDict() output.actions = torch.squeeze(actions_pred) return output
def _plan(self, image, goal_image, step): print("planning at t{}".format(self.t)) input_dict = AttrDict(I_0=self._env2planner(image), I_g=self._env2planner(goal_image), start_ind=torch.Tensor([0]).long(), end_ind=torch.Tensor( [self._hp.params['max_seq_len'] - 1]).long()) with self.planner.val_mode(): planner_output = self.planner(input_dict) # perform pruning for the balanced tree image_plan, _ = self.planner.dense_rec.get_sample_with_len( 0, self._hp.params['max_seq_len'], planner_output, input_dict, 'basic') # first image is copy of the initial frame -> omit self.image_plan = image_plan[1:] self.action_plan = planner_output.actions.detach().cpu().numpy( )[0] if 'actions' in planner_output else None planner_output.dense_rec = AttrDict(images=image_plan[None]) self.planner_outputs.append((step, planner_output)) self.current_exec_step = 0 if self.verbose: npy_to_gif( self.planner2npy_img(planner_output.dense_rec.images[0]), self.log_dir_verb + '/plan_t{}'.format(self.t, step))
def forward(self, e0, eg): """Returns the logits of a OneHotCategorical distribution.""" output = AttrDict() output.seq_len_logits = remove_spatial(self.p(e0, eg)) output.seq_len_pred = OneHotCategorical(logits=output.seq_len_logits) return output
def act(self, t=None, i_tr=None, images=None, state=None, goal=None, goal_image=None): # Note: goal_image provides n (2) images starting from the last images of the trajectory self.t = t self.i_tr = i_tr self.goal_image = goal_image if self.policy.has_image_input: inputs = AttrDict(I_0=self._preprocess_input(images[t]), I_g=self._preprocess_input(goal_image[-1] if len( goal_image.shape) > 4 else goal_image), hidden_var=self.hidden_var) else: current = state[-1:, :2] goal = goal[ -1:, : 2] #goal_state = np.concatenate([state[-1:, -2:], state[-1:, 2:]], axis=-1) inputs = AttrDict(I_0=current, I_g=goal, hidden_var=self.hidden_var) actions, self.hidden_var = self.policy(inputs) output = AttrDict() output.actions = actions.data.cpu().numpy()[0] return output
def forward(self, inputs, phase='train'): """ forward pass at training time """ if not 'enc_traj_seq' in inputs: enc_traj_seq, _ = self.encoder(inputs.traj_seq[:, 0]) if self._hp.train_first_action_only \ else batch_apply(self.encoder, inputs.traj_seq) if self._hp.train_first_action_only: enc_traj_seq = enc_traj_seq[:, None] enc_traj_seq = enc_traj_seq.detach( ) if self.detach_enc else enc_traj_seq enc_goal, _ = self.encoder(inputs.I_g) n_dim = len(enc_goal.shape) fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat( 1, enc_traj_seq.shape[1], *([1] * (n_dim - 1)))), dim=2) #fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat(1, enc_traj_seq.shape[1], 1, 1, 1)), dim=2) if self._hp.reactive: actions_pred = batch_apply(self.policy, fused_enc) else: policy_output = self.policy(fused_enc) actions_pred = policy_output # remove last time step to match ground truth if training on full sequence actions_pred = actions_pred[:, : -1] if not self._hp.train_first_action_only else actions_pred output = AttrDict() output.actions = remove_spatial(actions_pred) if len( actions_pred.shape) > 3 else actions_pred return output
def loss(self, inputs, outputs, log_error_arr=False): losses = AttrDict() losses.kl = KLDivLoss2(self._hp.kl_weight) \ (outputs.q_z, outputs.p_z, log_error_arr=log_error_arr) return losses
def loss(self, outputs, targets, weights, pad_mask, weight, log_sigma): predictions = outputs.tree.bf.images gt_match_dists = outputs.gt_match_dists # Compute likelihood loss_val = batch_cdist(predictions, targets, reduction='sum') log_sigmas = log_sigma - WeightsHacker.hack_weights( torch.ones_like(loss_val)).log() n = np.prod(predictions.shape[2:]) loss_val = 0.5 * loss_val * torch.pow(torch.exp( -log_sigmas), 2) + n * (log_sigmas + 0.5 * np.log(2 * np.pi)) # Weigh by matching probability match_weights = gt_match_dists match_weights = match_weights * pad_mask[:, None] # Note, this is now unnecessary since both tree models handle it already loss_val = loss_val * match_weights * weights losses = AttrDict() losses.dense_img_rec = PenaltyLoss(weight, breakdown=2)(loss_val, log_error_arr=True, reduction=[-1, -2]) # if self._hp.top_bias > 0.0: # losses.n_top_bias_nodes = PenaltyLoss( # self._hp.supervise_match_weight)(1 - WeightsHacker.get_n_top_bias_nodes(targets, weights)) return losses
def get_data_config(self, conf_module): # get default data config path = os.path.join( get_dataset_path(conf_module.configuration['dataset_name']), 'dataset_spec.py') data_conf_file = imp.load_source('dataset_spec', path) data_conf = AttrDict() data_conf.dataset_spec = AttrDict(data_conf_file.dataset_spec) # update with custom params if available update_data_conf = {} if hasattr(conf_module, 'data_config'): update_data_conf = conf_module.data_config elif conf_module.configuration.dataset_name is not None: update_data_conf = importlib.import_module( 'gcp.datasets.configs.' + conf_module.configuration.dataset_name).config for key in update_data_conf: if key == "dataset_spec": data_conf.dataset_spec.update(update_data_conf.dataset_spec) else: data_conf[key] = update_data_conf[key] if not 'fps' in data_conf: data_conf.fps = 4 return data_conf
def forward(self, root, inputs): outputs = AttrDict() # TODO implement soft interpolation sg_times, sg_encs = [], [] for segment in root: sg_times.append(segment.subgoal.ind) sg_encs.append(segment.subgoal.e_g_prime) sg_times = torch.stack(sg_times, dim=1) sg_encs = torch.stack(sg_encs, dim=1) # compute time difference weights seq_length = self._hp.max_seq_len target_ind = torch.arange(end=seq_length, dtype=sg_times.dtype) time_diffs = torch.abs(target_ind[None, None, :] - sg_times[:, :, None]) weights = nn.functional.softmax(-time_diffs, dim=-1) # compute weighted sum outputs weighted_sg = weights[:, :, :, None, None, None] * sg_encs.unsqueeze(2).repeat( 1, 1, seq_length, 1, 1, 1) outputs.encodings = torch.sum(weighted_sg, dim=1) outputs.update( self._dense_decode(inputs, outputs.encodings, seq_length)) return outputs
def reset(self, reset_state): super().reset() if reset_state is None: start_pos = self.env.mj2mw( self.state_sampler.sample(self._hp.init_pos)) start_angle = 2 * np.pi * np.random.rand() goal_pos = self.env.mj2mw( self.state_sampler.sample(self._hp.goal_pos)) else: start_pos = reset_state[:2] start_angle = reset_state[2] goal_pos = reset_state[-2:] reset_state = AttrDict(start_pos=start_pos, start_angle=start_angle, goal=goal_pos) img_obs = self.env.reset(reset_state) self.goal_pos = goal_pos qpos_full = np.concatenate((start_pos, np.array([start_angle]))) obs = AttrDict( images=np.expand_dims(img_obs, axis=0), # add camera dimension qpos_full=qpos_full, goal=goal_pos, env_done=False, state=np.concatenate((qpos_full, goal_pos)), topdown_image=self.render_pos_top_down(qpos_full, self.goal_pos)) self._post_step(start_pos) self._initial_shortest_dist = self.comp_shortest_dist( start_pos, goal_pos) return obs, reset_state
def full_seq_forward(self, inputs): if 'model_enc_seq' in inputs: enc_seq_1 = inputs.model_enc_seq[:, 1:] if self._hp.train_im0_enc and 'enc_traj_seq' in inputs: enc_seq_0 = inputs.enc_traj_seq.reshape( inputs.enc_traj_seq.shape[:2] + (self._hp.nz_enc, ))[:, :-1] enc_seq_0 = enc_seq_0[:, :enc_seq_1.shape[1]] else: enc_seq_0 = inputs.model_enc_seq[:, :-1] else: enc_seq = batch_apply(self.encoder, inputs.traj_seq) enc_seq_0, enc_seq_1 = enc_seq[:, :-1], enc_seq[:, 1:] if self.detach_enc: enc_seq_0 = enc_seq_0.detach() enc_seq_1 = enc_seq_1.detach() # TODO quite sure the concatenation is automatic actions_pred = batch_apply(self.action_pred, torch.cat([enc_seq_0, enc_seq_1], dim=2)) output = AttrDict() output.actions = actions_pred #remove_spatial(actions_pred) if 'actions' in inputs: output.action_targets = inputs.actions output.pad_mask = inputs.pad_mask return output
def apply_tree(self, tree, inputs): # recursive_add_dim = make_recursive(lambda x: add_n_dims(x, n=1, dim=1)) start_ind, end_ind = self.get_init_inds(inputs) tree.apply_fn({}, fn=self, left_parents=AttrDict(timesteps=start_ind), right_parents=AttrDict(timesteps=end_ind))
def __call__(self, inputs, subgoal, left_parent, right_parent): out = AttrDict() out.c_n = self.attentive_matching(inputs, subgoal) out.c_n_prime, out.cdf, out.p_n = self.propagate_matching(subgoal, left_parent, right_parent, out.c_n) out.ind = torch.argmax(out.c_n_prime, dim=1) return out
def make_traj(self, agent_data, obs, policy_out): traj = AttrDict() if not self.do_not_save_images: traj.images = obs['images'] traj.states = obs['state'] action_list = [action['actions'] for action in policy_out] traj.actions = np.stack(action_list, 0) traj.pad_mask = get_pad_mask(traj.actions.shape[0], self.max_num_actions) traj = pad_traj_timesteps(traj, self.max_num_actions) if 'robosuite_xml' in obs: traj.robosuite_xml = obs['robosuite_xml'][0] if 'robosuite_env_name' in obs: traj.robosuite_env_name = obs['robosuite_env_name'][0] if 'robosuite_full_state' in obs: traj.robosuite_full_state = obs['robosuite_full_state'] # minimal state that contains all information to position entities in the env if 'regression_state' in obs: traj.regression_state = obs['regression_state'] return traj
def get_default_params(self): params = AttrDict( normalize=True, activation=nn.LeakyReLU(0.2, inplace=True), normalization=self.builder.normalization, normalization_params=AttrDict() ) return params
def loss(self, inputs, outputs): losses = AttrDict() if 'existence_predictor' in outputs: losses.existence_predictor = BCELogitsLoss()( outputs.existence_predictor.existence, outputs.tree.df.match_dist.sum(2).float()) return losses
def reconstruction_loss(self, inputs, outputs, weights): losses = AttrDict() outputs.soft_matched_estimates = self.criterion.get_soft_estimates(outputs.gt_match_dists, outputs.tree.bf.images) losses.update(self.criterion.loss( outputs, inputs.traj_seq, weights, inputs.pad_mask, self._hp.dense_img_rec_weight, self.decoder.log_sigma)) return losses
def forward(self, root, inputs): # TODO implement stopping probability prediction # TODO make the low-level network not predict subgoals batch_size, time = self._hp.batch_size, self._hp.max_seq_len outputs = AttrDict() lstm_inputs = self._get_lstm_inputs(root, inputs) lstm_outputs = self.lstm(lstm_inputs, time) outputs.encodings = torch.stack(lstm_outputs, dim=1) outputs.update(self._dense_decode(inputs, outputs.encodings, time)) return outputs
def forward(self, input, hidden_state, length=None): """ :param input: tensor of shape batch x time x channels :return: """ if length is None: length = input.shape[1] initial_state = AttrDict(hidden_state=hidden_state) outputs = super().forward(AttrDict(cell_input=input), length=length, initial_inputs=initial_state) return outputs
def assert_begin(inputs, initial_inputs, static_inputs): initial_inputs = initial_inputs or AttrDict() static_inputs = static_inputs or AttrDict() assert not (static_inputs.keys() & inputs.keys()), 'Static inputs and inputs overlap' assert not (static_inputs.keys() & initial_inputs.keys() ), 'Static inputs and initial inputs overlap' assert not (inputs.keys() & initial_inputs.keys()), 'Inputs and initial inputs overlap' return initial_inputs, static_inputs
def forward(self, inputs): self.get_timesteps(inputs) enc = self.encoder.forward( torch.cat([inputs.img_t0, inputs.img_t1], dim=1))[0] output = AttrDict() out = self.action_pred(enc) if self._hp.pred_states: output.actions, output.states = torch.split( torch.squeeze(out), [2, 2], 1) else: output.actions = torch.squeeze(out) return output
def loss(self, inputs, model_output): losses = AttrDict() # action prediction loss n_actions = model_output.actions.shape[1] losses.action_reconst = L2Loss(1.0)(model_output.actions, inputs.actions[:, :n_actions], weights=broadcast_final(inputs.pad_mask[:, :n_actions], inputs.actions)) # compute total loss #total_loss = torch.stack([loss[1].value * loss[1].weight for loss in losses.items()]).sum() #losses.total = AttrDict(value=total_loss) # losses.total = total_loss*torch.tensor(np.nan) # for checking if backprop works return losses
def _get_lstm_inputs(self, root, inputs): """ :param root: :return: """ device = inputs.reference_tensor.device batch_size, time = self._hp.batch_size, self._hp.max_seq_len fullseq_shape = [batch_size, time] + list(inputs.enc_e_0.shape[1:]) lstm_inputs = AttrDict() # collect start and end indexes and values of all segments e_0s = torch.zeros(fullseq_shape, dtype=torch.float32, device=device) e_gs = torch.zeros(fullseq_shape, dtype=torch.float32, device=device) start_inds, end_inds = torch.zeros((batch_size, time), dtype=torch.float32, device=device), \ torch.zeros((batch_size, time), dtype=torch.float32, device=device) reset_indicator = torch.zeros((batch_size, time), dtype=torch.uint8, device=device) for segment in root.full_tree( ): # traversing the tree in breadth-first order. if segment.depth == 0: # if leaf-node start_ind = torch.ceil(segment.start_ind).type( torch.LongTensor) end_ind = torch.floor(segment.end_ind).type(torch.LongTensor) batchwise_assign(reset_indicator, start_ind, 1) # TODO iterating over batch must be gone for ex in range(self._hp.batch_size): if start_ind[ex] > end_ind[ex]: continue # happens if start and end floats have no int in between e_0s[ex, start_ind[ex]:end_ind[ex] + 1] = segment.e_0[ex] # +1 for including end_ind frame e_gs[ex, start_ind[ex]:end_ind[ex] + 1] = segment.e_g[ex] start_inds[ex, start_ind[ex]:end_ind[ex] + 1] = segment.start_ind[ex] end_inds[ex, start_ind[ex]:end_ind[ex] + 1] = segment.end_ind[ex] # perform linear interpolation time_steps = torch.arange(time, dtype=torch.float, device=device) inter = (time_steps - start_inds) / (end_inds - start_inds + 1e-7) lstm_inputs.reset_indicator = reset_indicator lstm_inputs.cell_input = (e_gs - e_0s) * broadcast_final(inter, e_gs) + e_0s lstm_inputs.reset_input = torch.cat([e_gs, e_0s], dim=2) return lstm_inputs
def __init__(self, args_in=None, hyperparams=None): parser = argparse.ArgumentParser(description='run parallel data collection') parser.add_argument('experiment', type=str, help='experiment name') parser.add_argument('--nworkers', type=int, help='use multiple threads or not', default=1) parser.add_argument('--gpu_id', type=int, help='the starting gpu_id', default=0) parser.add_argument('--ngpu', type=int, help='the number of gpus to use', default=1) parser.add_argument('--gpu', type=int, help='the gpu to use', default=-1) parser.add_argument('--nsplit', type=int, help='number of splits', default=-1) parser.add_argument('--isplit', type=int, help='split id', default=-1) parser.add_argument('--iex', type=int, help='if different from -1 use only do example', default=-1) parser.add_argument('--data_save_postfix', type=str, help='appends to the data_save_dir path', default='') parser.add_argument('--nstart_goal_pairs', type=int, help='max number of start goal pairs', default=None) parser.add_argument('--resume_from', type=int, help='from which traj idx to continue generating', default=None) args = parser.parse_args(args_in) print("Resume from") print(args.resume_from) if args.gpu != -1: os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) if hyperparams is None: hyperparams_file = args.experiment loader = importlib.machinery.SourceFileLoader('mod_hyper', hyperparams_file) spec = importlib.util.spec_from_loader(loader.name, loader) mod = importlib.util.module_from_spec(spec) loader.exec_module(mod) hyperparams = AttrDict(mod.config) self.args = args self.hyperparams = postprocess_hyperparams(hyperparams, args)
def __call__(self, inputs, subgoal, left_parent, right_parent): super().build_network() timesteps = self.comp_timestep(left_parent.timesteps, right_parent.timesteps, subgoal.fraction) return AttrDict(timesteps=timesteps)
def forward(self, input): """ :param input: tensor of shape batch x time x channels :return: """ return super().forward(AttrDict(cell_input=input), length=input.shape[1]).output
def forward(self, *cell_input, **cell_kwinput): """ at every time-step the input to the dense-reconstruciton LSTM is a tuple of (last_state, e_0, e_g) :param cell_input: :param reset_indicator: :return: """ # TODO allow ConvLSTM if cell_kwinput: cell_input = cell_input + list(zip(*cell_kwinput.items()))[1] if self.hidden is None: self.reset() cell_input = concat_inputs(*cell_input) inp_extra_dim = [] if not self._hp.use_conv_lstm: # TODO put in the embed module inp_extra_dim = list( cell_input.shape[2:] ) # This keeps trailing dimensions (should be all shape 1) cell_input = cell_input.view(-1, self.input_size) embedded = self.embed(cell_input) h_in = embedded for i in range(self.n_layers): self.hidden[i] = self.lstm[i](h_in, self.hidden[i]) h_in = self.hidden[i][0] output = self.output(h_in) return AttrDict(output=output.view(list(output.shape) + inp_extra_dim))