def concatenate_as(tensor_list, tensor_as, dim, mode="bilinear"): means = [resize2D_as(x[0], tensor_as[0], mode=mode) for x in tensor_list] variances = [ resize2D_as(x[1], tensor_as[0], mode=mode) for x in tensor_list ] means = torch.cat(means, dim=dim) variances = torch.cat(variances, dim=dim) return means, variances
def concatenate_as(tensor_list, tensor_as, dim, mode="bilinear"): tensor_list = [resize2D_as(x, tensor_as, mode=mode) for x in tensor_list] return torch.cat(tensor_list, dim=dim)