Пример #1
0
    def __call__(self,
                 model_output,
                 inputs,
                 length,
                 i_ex,
                 name=None,
                 targets=None,
                 estimates=None):
        start_ind, end_ind = self.matcher.get_init_inds(model_output)
        if i_ex == 0:
            # Only for the first element
            model_output.tree.compute_matching_dists(
                {},
                matching_fcn=self.matcher,
                left_parents=AttrDict(timesteps=start_ind),
                right_parents=AttrDict(timesteps=end_ind))

        name = 'images' if name is None else name
        estimates = torch.stack([
            node.subgoal[name][i_ex]
            for node in model_output.tree.depth_first_iter()
        ])
        leave = torch.stack([
            node.subgoal.c_n_prime[i_ex]
            for node in model_output.tree.depth_first_iter()
        ]).byte().any(1)
        return estimates[leave], None
Пример #2
0
    def get_all_samples(self,
                        model_output,
                        inputs,
                        length,
                        name=None,
                        targets=None,
                        estimates=None):
        start_ind, end_ind = self.matcher.get_init_inds(model_output)

        # Only for the first element
        model_output.tree.compute_matching_dists(
            {},
            matching_fcn=self.matcher,
            left_parents=AttrDict(timesteps=start_ind),
            right_parents=AttrDict(timesteps=end_ind))

        name = 'images' if name is None else name
        estimates = torch.stack([
            node.subgoal[name]
            for node in model_output.tree.depth_first_iter()
        ])
        leave = torch.stack([
            node.subgoal.c_n_prime
            for node in model_output.tree.depth_first_iter()
        ]).byte().any(-1)

        pruned_seqs = [
            estimates[:, i][leave[:, i]] for i in range(leave.shape[1])
        ]

        return pruned_seqs, None
Пример #3
0
    def forward(self, context=None, x_prime=None, more_context=None, z=None):
        """

        :param x: observation at current step
        :param context: to be fed at each timestep
        :param x_prime: observation at next step
        :param more_context: also to be fed at each timestep.
        :param z: (optional) if not None z is used directly and not sampled
        :return:
        """
        # TODO to get rid of more_context, make an interface that allows context structures
        output = AttrDict()

        output.p_z = self.prior(torch.zeros_like(
            x_prime))  # the input is only used to read the batchsize atm
        if x_prime is not None:
            output.q_z = self.inf(
                self.inf_lstm(x_prime, context, more_context).output)

        if z is None:
            if self._sample_prior:
                z = Gaussian(output.p_z).sample()
            else:
                z = Gaussian(output.q_z).sample()

        pred_input = [z, context, more_context]

        output.x = self.gen_lstm(*pred_input).output
        return output
Пример #4
0
 def forward(self, inputs):
     self.get_timesteps(inputs)
     actions_pred = self.action_pred(inputs.state_t0[:, :, None, None],
                                     inputs.state_t1[:, :, None, None])
     output = AttrDict()
     output.actions = torch.squeeze(actions_pred)
     return output
    def _plan(self, image, goal_image, step):
        print("planning at t{}".format(self.t))
        input_dict = AttrDict(I_0=self._env2planner(image),
                              I_g=self._env2planner(goal_image),
                              start_ind=torch.Tensor([0]).long(),
                              end_ind=torch.Tensor(
                                  [self._hp.params['max_seq_len'] - 1]).long())
        with self.planner.val_mode():
            planner_output = self.planner(input_dict)
            # perform pruning for the balanced tree
            image_plan, _ = self.planner.dense_rec.get_sample_with_len(
                0, self._hp.params['max_seq_len'], planner_output, input_dict,
                'basic')

        # first image is copy of the initial frame -> omit
        self.image_plan = image_plan[1:]
        self.action_plan = planner_output.actions.detach().cpu().numpy(
        )[0] if 'actions' in planner_output else None

        planner_output.dense_rec = AttrDict(images=image_plan[None])
        self.planner_outputs.append((step, planner_output))
        self.current_exec_step = 0

        if self.verbose:
            npy_to_gif(
                self.planner2npy_img(planner_output.dense_rec.images[0]),
                self.log_dir_verb + '/plan_t{}'.format(self.t, step))
Пример #6
0
    def forward(self, e0, eg):
        """Returns the logits of a OneHotCategorical distribution."""
        output = AttrDict()
        output.seq_len_logits = remove_spatial(self.p(e0, eg))
        output.seq_len_pred = OneHotCategorical(logits=output.seq_len_logits)

        return output
Пример #7
0
    def act(self,
            t=None,
            i_tr=None,
            images=None,
            state=None,
            goal=None,
            goal_image=None):
        # Note: goal_image provides n (2) images starting from the last images of the trajectory
        self.t = t
        self.i_tr = i_tr
        self.goal_image = goal_image

        if self.policy.has_image_input:
            inputs = AttrDict(I_0=self._preprocess_input(images[t]),
                              I_g=self._preprocess_input(goal_image[-1] if len(
                                  goal_image.shape) > 4 else goal_image),
                              hidden_var=self.hidden_var)
        else:
            current = state[-1:, :2]
            goal = goal[
                -1:, :
                2]  #goal_state = np.concatenate([state[-1:, -2:], state[-1:, 2:]], axis=-1)
            inputs = AttrDict(I_0=current,
                              I_g=goal,
                              hidden_var=self.hidden_var)

        actions, self.hidden_var = self.policy(inputs)

        output = AttrDict()
        output.actions = actions.data.cpu().numpy()[0]
        return output
Пример #8
0
    def forward(self, inputs, phase='train'):
        """
        forward pass at training time
        """
        if not 'enc_traj_seq' in inputs:
            enc_traj_seq, _ = self.encoder(inputs.traj_seq[:, 0]) if self._hp.train_first_action_only \
                                    else batch_apply(self.encoder, inputs.traj_seq)
            if self._hp.train_first_action_only:
                enc_traj_seq = enc_traj_seq[:, None]
            enc_traj_seq = enc_traj_seq.detach(
            ) if self.detach_enc else enc_traj_seq

        enc_goal, _ = self.encoder(inputs.I_g)
        n_dim = len(enc_goal.shape)
        fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat(
            1, enc_traj_seq.shape[1], *([1] * (n_dim - 1)))),
                              dim=2)
        #fused_enc = torch.cat((enc_traj_seq, enc_goal[:, None].repeat(1, enc_traj_seq.shape[1], 1, 1, 1)), dim=2)

        if self._hp.reactive:
            actions_pred = batch_apply(self.policy, fused_enc)
        else:
            policy_output = self.policy(fused_enc)
            actions_pred = policy_output

        # remove last time step to match ground truth if training on full sequence
        actions_pred = actions_pred[:, :
                                    -1] if not self._hp.train_first_action_only else actions_pred

        output = AttrDict()
        output.actions = remove_spatial(actions_pred) if len(
            actions_pred.shape) > 3 else actions_pred
        return output
Пример #9
0
 def loss(self, inputs, outputs, log_error_arr=False):
     losses = AttrDict()
     
     losses.kl = KLDivLoss2(self._hp.kl_weight) \
         (outputs.q_z, outputs.p_z, log_error_arr=log_error_arr)
     
     return losses
Пример #10
0
    def loss(self, outputs, targets, weights, pad_mask, weight, log_sigma):
        predictions = outputs.tree.bf.images
        gt_match_dists = outputs.gt_match_dists

        # Compute likelihood
        loss_val = batch_cdist(predictions, targets, reduction='sum')

        log_sigmas = log_sigma - WeightsHacker.hack_weights(
            torch.ones_like(loss_val)).log()
        n = np.prod(predictions.shape[2:])
        loss_val = 0.5 * loss_val * torch.pow(torch.exp(
            -log_sigmas), 2) + n * (log_sigmas + 0.5 * np.log(2 * np.pi))

        # Weigh by matching probability
        match_weights = gt_match_dists
        match_weights = match_weights * pad_mask[:,
                                                 None]  # Note, this is now unnecessary since both tree models handle it already
        loss_val = loss_val * match_weights * weights

        losses = AttrDict()
        losses.dense_img_rec = PenaltyLoss(weight,
                                           breakdown=2)(loss_val,
                                                        log_error_arr=True,
                                                        reduction=[-1, -2])

        # if self._hp.top_bias > 0.0:
        #     losses.n_top_bias_nodes = PenaltyLoss(
        #         self._hp.supervise_match_weight)(1 - WeightsHacker.get_n_top_bias_nodes(targets, weights))

        return losses
Пример #11
0
    def get_data_config(self, conf_module):
        # get default data config

        path = os.path.join(
            get_dataset_path(conf_module.configuration['dataset_name']),
            'dataset_spec.py')
        data_conf_file = imp.load_source('dataset_spec', path)
        data_conf = AttrDict()
        data_conf.dataset_spec = AttrDict(data_conf_file.dataset_spec)

        # update with custom params if available
        update_data_conf = {}
        if hasattr(conf_module, 'data_config'):
            update_data_conf = conf_module.data_config
        elif conf_module.configuration.dataset_name is not None:
            update_data_conf = importlib.import_module(
                'gcp.datasets.configs.' +
                conf_module.configuration.dataset_name).config

        for key in update_data_conf:
            if key == "dataset_spec":
                data_conf.dataset_spec.update(update_data_conf.dataset_spec)
            else:
                data_conf[key] = update_data_conf[key]

        if not 'fps' in data_conf:
            data_conf.fps = 4
        return data_conf
Пример #12
0
    def forward(self, root, inputs):
        outputs = AttrDict()
        # TODO implement soft interpolation

        sg_times, sg_encs = [], []
        for segment in root:
            sg_times.append(segment.subgoal.ind)
            sg_encs.append(segment.subgoal.e_g_prime)
        sg_times = torch.stack(sg_times, dim=1)
        sg_encs = torch.stack(sg_encs, dim=1)

        # compute time difference weights
        seq_length = self._hp.max_seq_len
        target_ind = torch.arange(end=seq_length, dtype=sg_times.dtype)
        time_diffs = torch.abs(target_ind[None, None, :] -
                               sg_times[:, :, None])
        weights = nn.functional.softmax(-time_diffs, dim=-1)

        # compute weighted sum outputs
        weighted_sg = weights[:, :, :, None, None,
                              None] * sg_encs.unsqueeze(2).repeat(
                                  1, 1, seq_length, 1, 1, 1)
        outputs.encodings = torch.sum(weighted_sg, dim=1)
        outputs.update(
            self._dense_decode(inputs, outputs.encodings, seq_length))
        return outputs
Пример #13
0
    def reset(self, reset_state):
        super().reset()

        if reset_state is None:
            start_pos = self.env.mj2mw(
                self.state_sampler.sample(self._hp.init_pos))
            start_angle = 2 * np.pi * np.random.rand()
            goal_pos = self.env.mj2mw(
                self.state_sampler.sample(self._hp.goal_pos))
        else:
            start_pos = reset_state[:2]
            start_angle = reset_state[2]
            goal_pos = reset_state[-2:]

        reset_state = AttrDict(start_pos=start_pos,
                               start_angle=start_angle,
                               goal=goal_pos)

        img_obs = self.env.reset(reset_state)
        self.goal_pos = goal_pos
        qpos_full = np.concatenate((start_pos, np.array([start_angle])))

        obs = AttrDict(
            images=np.expand_dims(img_obs, axis=0),  # add camera dimension
            qpos_full=qpos_full,
            goal=goal_pos,
            env_done=False,
            state=np.concatenate((qpos_full, goal_pos)),
            topdown_image=self.render_pos_top_down(qpos_full, self.goal_pos))
        self._post_step(start_pos)
        self._initial_shortest_dist = self.comp_shortest_dist(
            start_pos, goal_pos)
        return obs, reset_state
Пример #14
0
    def full_seq_forward(self, inputs):
        if 'model_enc_seq' in inputs:
            enc_seq_1 = inputs.model_enc_seq[:, 1:]
            if self._hp.train_im0_enc and 'enc_traj_seq' in inputs:
                enc_seq_0 = inputs.enc_traj_seq.reshape(
                    inputs.enc_traj_seq.shape[:2] +
                    (self._hp.nz_enc, ))[:, :-1]
                enc_seq_0 = enc_seq_0[:, :enc_seq_1.shape[1]]
            else:
                enc_seq_0 = inputs.model_enc_seq[:, :-1]
        else:
            enc_seq = batch_apply(self.encoder, inputs.traj_seq)
            enc_seq_0, enc_seq_1 = enc_seq[:, :-1], enc_seq[:, 1:]

        if self.detach_enc:
            enc_seq_0 = enc_seq_0.detach()
            enc_seq_1 = enc_seq_1.detach()

        # TODO quite sure the concatenation is automatic
        actions_pred = batch_apply(self.action_pred,
                                   torch.cat([enc_seq_0, enc_seq_1], dim=2))

        output = AttrDict()
        output.actions = actions_pred  #remove_spatial(actions_pred)
        if 'actions' in inputs:
            output.action_targets = inputs.actions
            output.pad_mask = inputs.pad_mask
        return output
Пример #15
0
 def apply_tree(self, tree, inputs):
     # recursive_add_dim = make_recursive(lambda x: add_n_dims(x, n=1, dim=1))
     start_ind, end_ind = self.get_init_inds(inputs)
     tree.apply_fn({},
                   fn=self,
                   left_parents=AttrDict(timesteps=start_ind),
                   right_parents=AttrDict(timesteps=end_ind))
Пример #16
0
    def __call__(self, inputs, subgoal, left_parent, right_parent):
        out = AttrDict()
        out.c_n = self.attentive_matching(inputs, subgoal)
        out.c_n_prime, out.cdf, out.p_n = self.propagate_matching(subgoal, left_parent, right_parent, out.c_n)
        out.ind = torch.argmax(out.c_n_prime, dim=1)

        return out
Пример #17
0
    def make_traj(self, agent_data, obs, policy_out):
        traj = AttrDict()

        if not self.do_not_save_images:
            traj.images = obs['images']
        traj.states = obs['state']
        
        action_list = [action['actions'] for action in policy_out]
        traj.actions = np.stack(action_list, 0)
        
        traj.pad_mask = get_pad_mask(traj.actions.shape[0], self.max_num_actions)
        traj = pad_traj_timesteps(traj, self.max_num_actions)

        if 'robosuite_xml' in obs:
            traj.robosuite_xml = obs['robosuite_xml'][0]
        if 'robosuite_env_name' in obs:
            traj.robosuite_env_name = obs['robosuite_env_name'][0]
        if 'robosuite_full_state' in obs:
            traj.robosuite_full_state = obs['robosuite_full_state']

        # minimal state that contains all information to position entities in the env
        if 'regression_state' in obs:
            traj.regression_state = obs['regression_state']

        return traj
Пример #18
0
 def get_default_params(self):
     params = AttrDict(
         normalize=True,
         activation=nn.LeakyReLU(0.2, inplace=True),
         normalization=self.builder.normalization,
         normalization_params=AttrDict()
     )
     return params
Пример #19
0
    def loss(self, inputs, outputs):
        losses = AttrDict()

        if 'existence_predictor' in outputs:
            losses.existence_predictor = BCELogitsLoss()(
                outputs.existence_predictor.existence,
                outputs.tree.df.match_dist.sum(2).float())
        return losses
Пример #20
0
 def reconstruction_loss(self, inputs, outputs, weights):
     losses = AttrDict()
 
     outputs.soft_matched_estimates = self.criterion.get_soft_estimates(outputs.gt_match_dists,
                                                                        outputs.tree.bf.images)
     losses.update(self.criterion.loss(
         outputs, inputs.traj_seq, weights, inputs.pad_mask, self._hp.dense_img_rec_weight, self.decoder.log_sigma))
 
     return losses
Пример #21
0
    def forward(self, root, inputs):
        # TODO implement stopping probability prediction
        # TODO make the low-level network not predict subgoals
        batch_size, time = self._hp.batch_size, self._hp.max_seq_len
        outputs = AttrDict()

        lstm_inputs = self._get_lstm_inputs(root, inputs)
        lstm_outputs = self.lstm(lstm_inputs, time)
        outputs.encodings = torch.stack(lstm_outputs, dim=1)
        outputs.update(self._dense_decode(inputs, outputs.encodings, time))
        return outputs
Пример #22
0
 def forward(self, input, hidden_state, length=None):
     """
     :param input: tensor of shape batch x time x channels
     :return:
     """
     if length is None: length = input.shape[1]
     initial_state = AttrDict(hidden_state=hidden_state)
     outputs = super().forward(AttrDict(cell_input=input),
                               length=length,
                               initial_inputs=initial_state)
     return outputs
Пример #23
0
    def assert_begin(inputs, initial_inputs, static_inputs):
        initial_inputs = initial_inputs or AttrDict()
        static_inputs = static_inputs or AttrDict()
        assert not (static_inputs.keys()
                    & inputs.keys()), 'Static inputs and inputs overlap'
        assert not (static_inputs.keys() & initial_inputs.keys()
                    ), 'Static inputs and initial inputs overlap'
        assert not (inputs.keys() &
                    initial_inputs.keys()), 'Inputs and initial inputs overlap'

        return initial_inputs, static_inputs
Пример #24
0
 def forward(self, inputs):
     self.get_timesteps(inputs)
     enc = self.encoder.forward(
         torch.cat([inputs.img_t0, inputs.img_t1], dim=1))[0]
     output = AttrDict()
     out = self.action_pred(enc)
     if self._hp.pred_states:
         output.actions, output.states = torch.split(
             torch.squeeze(out), [2, 2], 1)
     else:
         output.actions = torch.squeeze(out)
     return output
Пример #25
0
    def loss(self, inputs, model_output):
        losses = AttrDict()

        # action prediction loss
        n_actions = model_output.actions.shape[1]
        losses.action_reconst = L2Loss(1.0)(model_output.actions, inputs.actions[:, :n_actions],
                                            weights=broadcast_final(inputs.pad_mask[:, :n_actions], inputs.actions))

        # compute total loss
        #total_loss = torch.stack([loss[1].value * loss[1].weight for loss in losses.items()]).sum()
        #losses.total = AttrDict(value=total_loss)
        # losses.total = total_loss*torch.tensor(np.nan)   # for checking if backprop works
        return losses
Пример #26
0
    def _get_lstm_inputs(self, root, inputs):
        """
        :param root:
        :return:
        """
        device = inputs.reference_tensor.device
        batch_size, time = self._hp.batch_size, self._hp.max_seq_len
        fullseq_shape = [batch_size, time] + list(inputs.enc_e_0.shape[1:])
        lstm_inputs = AttrDict()

        # collect start and end indexes and values of all segments
        e_0s = torch.zeros(fullseq_shape, dtype=torch.float32, device=device)
        e_gs = torch.zeros(fullseq_shape, dtype=torch.float32, device=device)
        start_inds, end_inds = torch.zeros((batch_size, time), dtype=torch.float32, device=device), \
                               torch.zeros((batch_size, time), dtype=torch.float32, device=device)
        reset_indicator = torch.zeros((batch_size, time),
                                      dtype=torch.uint8,
                                      device=device)
        for segment in root.full_tree(
        ):  # traversing the tree in breadth-first order.
            if segment.depth == 0:  # if leaf-node
                start_ind = torch.ceil(segment.start_ind).type(
                    torch.LongTensor)
                end_ind = torch.floor(segment.end_ind).type(torch.LongTensor)
                batchwise_assign(reset_indicator, start_ind, 1)

                # TODO iterating over batch must be gone
                for ex in range(self._hp.batch_size):
                    if start_ind[ex] > end_ind[ex]:
                        continue  # happens if start and end floats have no int in between
                    e_0s[ex, start_ind[ex]:end_ind[ex] +
                         1] = segment.e_0[ex]  # +1 for including end_ind frame
                    e_gs[ex, start_ind[ex]:end_ind[ex] + 1] = segment.e_g[ex]
                    start_inds[ex, start_ind[ex]:end_ind[ex] +
                               1] = segment.start_ind[ex]
                    end_inds[ex, start_ind[ex]:end_ind[ex] +
                             1] = segment.end_ind[ex]

        # perform linear interpolation
        time_steps = torch.arange(time, dtype=torch.float, device=device)
        inter = (time_steps - start_inds) / (end_inds - start_inds + 1e-7)

        lstm_inputs.reset_indicator = reset_indicator
        lstm_inputs.cell_input = (e_gs - e_0s) * broadcast_final(inter,
                                                                 e_gs) + e_0s
        lstm_inputs.reset_input = torch.cat([e_gs, e_0s], dim=2)

        return lstm_inputs
Пример #27
0
    def __init__(self, args_in=None, hyperparams=None):
        parser = argparse.ArgumentParser(description='run parallel data collection')
        parser.add_argument('experiment', type=str, help='experiment name')
        parser.add_argument('--nworkers', type=int, help='use multiple threads or not', default=1)
        parser.add_argument('--gpu_id', type=int, help='the starting gpu_id', default=0)
        parser.add_argument('--ngpu', type=int, help='the number of gpus to use', default=1)
        parser.add_argument('--gpu', type=int, help='the gpu to use', default=-1)
        parser.add_argument('--nsplit', type=int, help='number of splits', default=-1)
        parser.add_argument('--isplit', type=int, help='split id', default=-1)
        parser.add_argument('--iex', type=int, help='if different from -1 use only do example', default=-1)
        parser.add_argument('--data_save_postfix', type=str, help='appends to the data_save_dir path', default='')
        parser.add_argument('--nstart_goal_pairs', type=int, help='max number of start goal pairs', default=None)
        parser.add_argument('--resume_from', type=int, help='from which traj idx to continue generating', default=None)

        args = parser.parse_args(args_in)

        print("Resume from")
        print(args.resume_from)

        if args.gpu != -1:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

        if hyperparams is None:
            hyperparams_file = args.experiment
            loader = importlib.machinery.SourceFileLoader('mod_hyper', hyperparams_file)
            spec = importlib.util.spec_from_loader(loader.name, loader)
            mod = importlib.util.module_from_spec(spec)
            loader.exec_module(mod)
            hyperparams = AttrDict(mod.config)

        self.args = args
        self.hyperparams = postprocess_hyperparams(hyperparams, args)
Пример #28
0
    def __call__(self, inputs, subgoal, left_parent, right_parent):
        super().build_network()

        timesteps = self.comp_timestep(left_parent.timesteps,
                                       right_parent.timesteps,
                                       subgoal.fraction)
        return AttrDict(timesteps=timesteps)
Пример #29
0
 def forward(self, input):
     """
     :param input: tensor of shape batch x time x channels
     :return:
     """
     return super().forward(AttrDict(cell_input=input),
                            length=input.shape[1]).output
Пример #30
0
    def forward(self, *cell_input, **cell_kwinput):
        """
        at every time-step the input to the dense-reconstruciton LSTM is a tuple of (last_state, e_0, e_g)
        :param cell_input:
        :param reset_indicator:
        :return:
        """
        # TODO allow ConvLSTM
        if cell_kwinput:
            cell_input = cell_input + list(zip(*cell_kwinput.items()))[1]

        if self.hidden is None:
            self.reset()

        cell_input = concat_inputs(*cell_input)
        inp_extra_dim = []

        if not self._hp.use_conv_lstm:
            # TODO put in the embed module
            inp_extra_dim = list(
                cell_input.shape[2:]
            )  # This keeps trailing dimensions (should be all shape 1)
            cell_input = cell_input.view(-1, self.input_size)

        embedded = self.embed(cell_input)
        h_in = embedded
        for i in range(self.n_layers):
            self.hidden[i] = self.lstm[i](h_in, self.hidden[i])
            h_in = self.hidden[i][0]
        output = self.output(h_in)
        return AttrDict(output=output.view(list(output.shape) + inp_extra_dim))