def forward(ctx, inputs, first_idxs, max_size): """ Args: ctx: Context object used to calculate gradients. inputs: FloatTensor of shape (F, D), representing the packed batch tensor. e.g. areas for faces in a batch of meshes. first_idxs: LongTensor of shape (N,) where N is the number of elements in the batch and `first_idxs[i] = f` means that the inputs for batch element i begin at `inputs[f]`. max_size: Max length of an element in the batch. Returns: inputs_padded: FloatTensor of shape (N, max_size, D) where max_size is max of `sizes`. The values for batch element i which start at `inputs[first_idxs[i]]` will be copied to `inputs_padded[i, :]`, with zeros padding out the extra inputs. """ if not (inputs.dim() == 2): raise ValueError("input can only be 2-dimensional.") if not (first_idxs.dim() == 1): raise ValueError("first_idxs can only be 1-dimensional.") if not (inputs.dtype == torch.float32): raise ValueError("input has to be of type torch.float32.") if not (first_idxs.dtype == torch.int64): raise ValueError("first_idxs has to be of type torch.int64.") if not isinstance(max_size, int): raise ValueError("max_size has to be int.") ctx.save_for_backward(first_idxs) ctx.num_inputs = int(inputs.shape[0]) inputs, first_idxs = inputs.contiguous(), first_idxs.contiguous() inputs_padded = _C.packed_to_padded(inputs, first_idxs, max_size) return inputs_padded
def backward(ctx, grad_output): grad_output = grad_output.contiguous() first_idxs = ctx.saved_tensors[0] max_size = ctx.max_size grad_input = _C.packed_to_padded(grad_output, first_idxs, max_size) return grad_input, None, None