예제 #1
0
        def wrapper(*args):
            # TODO support **kwargs
            # TODO make cuda/cpu an option for decorator or discover it automatically

            found_tensor = find_element(args)
            if isinstance(found_tensor, torch.Tensor):
                convert = False
            elif isinstance(found_tensor, np.ndarray):
                convert = True
            else:
                raise NotImplementedError

            if convert:
                convert_to = lambda el: ar2ten(el, 'cpu') if isinstance(
                    el, np.ndarray) else el
                args = rmap(convert_to, args)

            result = fn(*args)

            if convert:
                convert_fro = lambda el: ten2ar(el) if isinstance(
                    el, torch.Tensor) else el
                result = rmap(convert_fro, result)

            return result
예제 #2
0
        def wrapper(*args, **kwargs):
            # TODO support conversion of **kwargs

            found_tensor = find_element(args)
            if isinstance(found_tensor, torch.Tensor):
                convert = True
            elif isinstance(found_tensor, np.ndarray):
                convert = False
            else:
                raise NotImplementedError

            if convert:
                convert_to = lambda el: ten2ar(el) if isinstance(
                    el, torch.Tensor) else el
                args = rmap(convert_to, args)

            result = fn(*args, **kwargs)

            if convert:
                convert_fro = lambda el: ar2ten(el, found_tensor.device
                                                ) if isinstance(
                                                    el, np.ndarray) else el
                result = rmap(convert_fro, result)

            return result
예제 #3
0
 def forward(self, x, output_length, conditioning_length, context=None):
     """
     
     :param x: the modelled sequence, batch x time x  x_dim
     :param length: the desired length of the output sequence. Note, this includes all conditioning frames except 1
     :param conditioning_length: the length on which the prediction will be conditioned. Ground truth data are observed
     for this length
     :param context: a context sequence. Prediction is conditioned on all context up to and including this moment
     :return:
     """
     lstm_inputs = AttrDict()
     outputs = AttrDict()
     if context is not None:
         lstm_inputs.more_context = context
 
     if not self._sample_prior:
         outputs.q_z = self.inference(x, context)
         lstm_inputs.z = Gaussian(outputs.q_z).sample()
 
     outputs.update(self.generator(inputs=lstm_inputs,
                                   length=output_length + conditioning_length,
                                   static_inputs=AttrDict(batch_size=x.shape[0])))
     # The way the conditioning works now is by zeroing out the loss on the KL divergence and returning less frames
     # That way the network can pass the info directly through z. I can also implement conditioning by feeding
     # the frames directly into predictor. that would require passing previous frames to the VRNNCell and
     # using a fake frame to condition the 0th frame on.
     outputs = rmap(lambda ten: ten[:, conditioning_length:], outputs)
     outputs.conditioning_length = conditioning_length
     return outputs
예제 #4
0
파일: tree.py 프로젝트: orybkin/video-gcp
    def predict_sequence(self, inputs, outputs, start_ind, end_ind, phase):
        layerwise_inputs = self.filter_layerwise_inputs(inputs)
        start_node, end_node = self._create_initial_nodes(inputs)

        outputs.tree = root = SubgoalTreeLayer()
        tree_inputs = [
            layerwise_inputs, start_ind, end_ind, start_node, end_node
        ]
        tree_inputs = rmap(lambda x: add_n_dims(x, n=1, dim=1), tree_inputs)
        tree_inputs = [inputs] + tree_inputs

        root.produce_tree(*tree_inputs, self.tree_module,
                          self._hp.hierarchy_levels)
        outputs.dense_rec = self.dense_rec(root, inputs)

        if 'traj_seq' in inputs and phase == 'train':
            # compute matching between nodes & frames of input sequence, needed for loss and inv mdl etc.
            self.tree_module.compute_matching(inputs, outputs)

        # TODO the binding has to move to this class
        outputs.pruned_prediction = self.tree_module.binding.prune_sequence(
            inputs, outputs)

        # add pruned reconstruction if necessary
        if not outputs.dense_rec and self._hp.matching_type == 'balanced':
            # TODO this has to be unified with the balanced tree case
            outputs.pruned_prediction = self.dense_rec.get_all_samples_with_len(
                outputs.end_ind, outputs, inputs, pruning_scheme='basic')[0]

        return outputs
예제 #5
0
    def get_matched_sequence(self, tree, key):
        latents = tree.bf[key]
        indices = tree.bf.match_dist.argmax(1)
        # Two-dimensional indexing
        matched_sequence = rmap(lambda x: batchwise_index(x, indices), latents)

        return matched_sequence
예제 #6
0
 def update(self, val):
     self.val = val
     if self.sum is None:
         self.sum = val
     else:
         self.sum = rmap_list(lambda x, y: x + y, [self.sum, val])
     self.count += 1
     self.avg = rmap(lambda x: x / self.count, self.sum)
예제 #7
0
    def decode_seq(self, inputs, encodings):
        """ Decodes a sequence of images given the encodings

        :param inputs:
        :param encodings:
        :param seq_len:
        :return:
        """

        # TODO skip from the goal as well
        extend_to_seq = lambda x: x[:, None][:, [0] * encodings.shape[1]
                                             ].contiguous()
        seq_skips = rmap(extend_to_seq, inputs.skips)
        pixel_source = rmap(extend_to_seq, [inputs.I_0, inputs.I_g])

        return batch_apply(self,
                           input=encodings,
                           skips=seq_skips,
                           pixel_source=pixel_source)
예제 #8
0
    def predict_sequence(self, inputs, outputs, start_ind, end_ind, phase):
        filtered_inputs = self.one_step_planner._filter_inputs_for_model(
            inputs, phase)
        layerwise_inputs = self.filter_layerwise_inputs(inputs)
        start_node, end_node = self._create_initial_nodes(inputs)

        outputs.tree = tree = root = SubgoalTreeLayer()
        tree_inputs = [
            layerwise_inputs, start_ind, end_ind, start_node, end_node
        ]
        tree_inputs = rmap(lambda x: add_n_dims(x, n=1, dim=1), tree_inputs)
        tree_inputs = [filtered_inputs] + tree_inputs

        self.produce_tree(root, tree, tree_inputs, inputs, outputs)

        outputs.dense_rec = self.dense_rec(root, inputs)

        # add pruned reconstruction if necessary
        if not outputs.dense_rec and self._hp.matching_type == 'balanced':
            outputs.pruned_prediction = self.dense_rec.get_all_samples_with_len(
                outputs.end_ind, outputs, inputs, pruning_scheme='basic')[0]

        return outputs
예제 #9
0
파일: vrnn.py 프로젝트: orybkin/blox-nn
    def forward(self, x, output_length, conditioning_length, context=None):
        """
        
        :param x: the modelled sequence, batch x time x  x_dim
        :param length: the desired length of the output sequence. Note, this includes all conditioning frames except 1
        :param conditioning_length: the length on which the prediction will be conditioned. Ground truth data are observed
        for this length
        :param context: a context sequence. Prediction is conditioned on all context up to and including this moment
        :return:
        """
        lstm_inputs = AttrDict(x_prime=x[:, 1:])
        if context is not None:
            context = pad(context, pad_front=1, dim=1)
            lstm_inputs.update(more_context=context[:, 1:])

        initial_inputs = AttrDict(x=x[:, :conditioning_length])

        self.lstm.cell.init_state(x[:, 0], more_context=context)
        outputs = self.lstm(inputs=lstm_inputs,
                            initial_seq_inputs=initial_inputs,
                            length=output_length + conditioning_length - 1)
        outputs = rmap(lambda ten: ten[:, conditioning_length - 1:], outputs)
        return outputs
예제 #10
0
파일: struct.py 프로젝트: orybkin/blox-nn
def save(struct, path):
    # TODO remove this once you can save structures with torch.save
    os.makedirs(os.path.dirname(path), exist_ok=True)
    save_inputs = rmap(AttrDict, struct, target_class=Struct, only_target=True)
    torch.save(save_inputs, path)
예제 #11
0
파일: struct.py 프로젝트: orybkin/blox-nn
 def contiguous(self):
     return rmap(lambda x: x.contiguous(), self)
예제 #12
0
파일: struct.py 프로젝트: orybkin/blox-nn
 def __getitem__(self, *args, **kwargs):
     return rmap(lambda x: x.__getitem__(*args, **kwargs), self)