目录

目录

PermutoSDF: Fast Multi-View Reconstruction with Implicit Surfaces using Permutohedral Lattices


编者语:主要是把 NGP 的思想推广到了一种对高维输入更友好的格点 (lattice) 上

背景:permutohedral lattice (排列多面体点阵)

背景数学知识详见我的 这篇笔记

Motivation

续上述笔记之言,Instant-NGP 其实就是采用 $\mathbb{Z}^d$ 的 regular grid 的 lattice,而这种 lattice 的 Delaunay cell 的顶点个数是 $2^d$,随着维度 d 提升需要计算的顶点个数(i.e. 权重的个数)指数上升,因此其计算效率对高维较为不友好。

而本篇的思想就是使用 Permutohedral lattice $A^\ast_d$ 来代替 $\mathbb{Z}^d$,因为 $A^\ast_d$ 的 Delaunay cell 的顶点个数是 $(d+1)$ ,是所有 lattice 中 Delaunay cell 顶点个数随维度 d 上升最慢的,因此针对高维输入的计算效率将大大提高。

permutohedral_encoding 底层代码解读

维度补全 & scale 定义

EncodingFixedParams() 结构体创建时传入的参数 sigmas_list 即为 [self@PermutoEncoding].scale_per_level 参数,该参数是 PermutoEncoding 传入的,该参数的一个典型例子如 scale_list

coarsest_scale=1.0 
finest_scale=0.0001 
scale_list=np.geomspace(coarsest_scale, finest_scale, num=nr_levels)

这个正常的逐层的空间缩放参数在 EncodingFixedParams::compute_scale_factor_tensor 被计算转存到 m_scale_factor 中,计算方式如下:

torch::Tensor compute_scale_factor_tensor(const std::vector<float> sigmas_list, const int pos_dim){
    int nr_resolutions=sigmas_list.size();

    torch::Tensor scale_factor_tensor=torch::zeros({ nr_resolutions, pos_dim }, torch::dtype(torch::kFloat32).device(torch::kCUDA, 0)  );
    double invStdDev = 1.0;
    for(int res_idx=0; res_idx<nr_resolutions; res_idx++){
        for (int i = 0; i < pos_dim; i++) {
            // 1 除以 (i+1)(i+2) 的和开根号, i 是当前第几个维度 (从0开始索引)
            scale_factor_tensor[res_idx][i] =  1.0 / (std::sqrt((double) (i + 1) * (i + 2))) * invStdDev;

            // 配置时的从 1.0 到 0.0001 几何平均的 scale_list -> 传入 python module 的 scale_per_level -> 在这里除到了分母上
            scale_factor_tensor[res_idx][i]=scale_factor_tensor[res_idx][i]/ sigmas_list[res_idx];
        }
    }

    return scale_factor_tensor;
}

m_scale_factor 的利用方式如下(在 forward_gpu , backward_gpu 中均如此计算):

// 下面这段把 d 维向量升到 (d+1) 维的齐次坐标 (坐标和为0)
// random_shift 主要是不同层、不同维度分量施加不同的随机位置偏移,最大可能避免 hash collision; 
//     受 python module 的 appply_random_shift_per_level: bool 控制,如果为 False 这里就是零tensor
// scale 的定义 和 elevated 计算做了特殊处理; 最终能够使得: 
//     1. elevated 之和仍然可以保证为0; 
//     2. elevated 每个分量的变化幅度接近,即使最大维度非常大,变化幅度的量级也仍在个位数; 
float sm = 0;
#pragma unroll
for (int i = pos_dim; i > 0; i--) {
    // float cf = (pos[i-1] +random_shift_constant[level*pos_dim + i-1]  ) * scale_factor_constant[level*pos_dim + i-1];
    float cf = (pos[i-1] +random_shift_monolithic[level][i-1]  ) * scale_factor[level][i-1];
    elevated[i] = sm - i * cf; // 第i个 elevated 放置的是 第i-1 个 pos 元素;
    sm += cf;
}
elevated[0] = sm; // elevated 的第0个放的就是补全成齐次坐标的分量维度, 让整体坐标和为0

其中的 (std::sqrt((double) (i + 1) * (i + 2))) (i.e. $\sqrt{(i+1)(i+2)}, i\in{0,1,\dots,\text{pos_dim}-1}$) 部分比较奇怪,尝试推导:

假定对某层,其 sigmas_list 中的值的倒数为 $s$,假定当前 pos_dim=4 (4维输入):

  • i=4:
    • $\text{cf}\leftarrow\frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack$
    • $\text{elevated}\lbrack 4\rbrack \leftarrow -\frac{4s}{\sqrt{4\times 5}} \text{pos} \lbrack 3 \rbrack$
    • $\text{sm}=\frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack$
  • i=3:
    • $\text{cf}\leftarrow\frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack$
    • $\text{elevated}\lbrack 3\rbrack \leftarrow = \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack -\frac{3s}{\sqrt{3\times 4}} \text{pos} \lbrack 2 \rbrack$
    • $\text{sm}= \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack$
  • i=2:
    • $\text{cf}\leftarrow\frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack$
    • $\text{elevated}\lbrack 2\rbrack \leftarrow \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack-\frac{2s}{\sqrt{2\times 3}} \text{pos} \lbrack 1 \rbrack$
    • $\text{sm}=\frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack + \frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack$
  • i=1:
    • $\text{cf}\leftarrow\frac{s}{\sqrt{1\times2}} \text{pos}\lbrack 0\rbrack$
    • $\text{elevated}\lbrack 1\rbrack \leftarrow \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack + \frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack -\frac{s}{\sqrt{1\times2}} \text{pos} \lbrack 0 \rbrack$
    • $\text{sm}=\frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack + \frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack + \frac{s}{\sqrt{1\times2}} \text{pos}\lbrack 0\rbrack$
  • 最后:
    • $\text{elevated} \lbrack 0 \rbrack \leftarrow \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack + \frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack + \frac{s}{\sqrt{1\times2}} \text{pos}\lbrack 0\rbrack$

验证 $\text{elevated}$ 之和为0,没有问题:

$\left( 4\times \frac{s}{\sqrt{4\times5}} - \frac{4s}{\sqrt{4\times5}} \right) \text{pos}\lbrack 3\rbrack + \left( 3 \times \frac{s}{\sqrt{3\times 4}} - \frac{3s}{\sqrt{3\times 4}} \right) \text{pos} \lbrack 2 \rbrack + \left( 2\times \frac{s}{\sqrt{2\times3}} - \frac{2s}{\sqrt{2\times 3}} \right) \text{pos} \lbrack 1 \rbrack + \left( 1\times \frac{s}{\sqrt{1\times2}} - \frac{s}{\sqrt{1\times2}} \right) \text{pos} \lbrack 0 \rbrack = 0$

猜测:这样操作主要是为了最后补全的分量也仍然在和其他几个坐标相同的量级上
  • 首先:补全齐次坐标其实没什么限制,只要能从原始的 $d$ 维坐标出发,定义任何一种 可逆的 连续的 到 $(d+1)$ 维齐次坐标的映射都可以。那么既然如此,我们可以挑选出满足我们要求的、性质比较优良的补全方式。
  • 然后:最简单暴力的补全方式 $\text{elevated} \lbrack 0 \rbrack = 1-\sum_{k=1}^{d} \text{elevated} \lbrack k \rbrack$ 的主要问题是补全这个维度的量级可以变得很大,比如当我们控制其他维度坐标在 $\lbrack 0,1 \rbrack$之间的话,他们是有可能全部取到 1 的,也就是说 $\text{elevated} \lbrack 0 \rbrack$ 的量级将是 $\lbrack -d+1,1 \rbrack$,这其实是不希望看到的;我们还是比较希望不同维度的量级都在一个比较接近的程度的
  • 按照上述方式定义后,假定原始 输入 pos 各个分量的取值范围均在 [0,1] 之间,那么 $\text{elevated}$ 各个分量的量级:
i最小值最大值
00$\left( \frac{1}{\sqrt{4\times 5}} + \frac{1}{\sqrt{3\times 4}} + \frac{1}{\sqrt{2\times 3}} + \frac{1}{\sqrt{1\times 2}} \right) s \approx 1.628s$
通项形如 $\sqrt{\frac{n+1}{n}}-\sqrt{\frac{n}{n+1}}$,无法相互抵消;$n\rightarrow \inf$ 时和极限为无穷;
和值随项数增大变化缓慢,项数=1~20,和值为 (0.7, 1.1, 1.4, 1.6, 1.8, 1.9, 2.0, 2.2, …, 3.1)
1$-\frac{\sqrt{1}}{\sqrt{2}}s\approx -0.707s$$\left(\frac{1}{\sqrt{4\times5}} + \frac{1}{\sqrt{3\times4}} + \frac{1}{\sqrt{2\times3}}\right)s \approx 0.921s$
和值随项数增大变化缓慢,项数=1~19,和值为 (0.4, 0.6, 0.9, 1.1, 1.2, 1.3, 1.5, 1.6, …, 2.3)
2$-\frac{\sqrt{2}}{\sqrt{3}}s \approx -0.816s$$\left(\frac{1}{\sqrt{4\times5}} + \frac{1}{\sqrt{3\times4}}\right)s \approx 0.512s$
和值随项数增大变化缓慢,项数=1~18,和值为 (0.2, 0.5, 0.6, 0.8, 0.9, 1.1, 1.2, 1.3, …, 1.8)
3$-\frac{\sqrt{3}}{\sqrt{4}}s\approx -0.866s$$\left(\frac{1}{\sqrt{4\times5}}\right)s \approx 0.224 s$
和值随项数增大变化缓慢,项数=1~17,和值为 (0.2, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0, 1.0, …, 1.5)
4$-\frac{\sqrt{4}}{\sqrt{5}}s \approx -0.894 s$0
假想大值n$-\frac{\sqrt{n}}{\sqrt{n+1}}s$
$n\rightarrow \inf$ 时,趋近于 $-1\cdot s$
0

可以看到,对于所有分量:

  • 变化幅度最小发生在升维变量的最后一个维度,最小幅度为随d越大越接近1的某个值
  • 变化幅度最大发生在升维变量德尔第0个维度也就是补全的维度,最大幅度随维度d提升上升缓慢,如d=20时也只有3.1,d=200时只有 5.3,d=2000 时只有 7.6

另外:

  • 由于 Permutohedral 的 lattice 操作全部发生在补全成 (d+1) 维齐次坐标的体系下,和原始的坐标并没有太好的直接对应关系,因此在多个分辨率中无法加入 Dense 层,也不存在邻居的概念。

double_backward 几个函数的区别

左侧,double_backward_from_positions_gpu_1,事实上就是从 dL_ddLdx 计算 dL_dparams;

右侧,double_backward_from_positions_gpu_2,事实上就是从 dL_ddLdx 计算 dL_ddLdy

没有用上的 double_backward_from_positions_gpu 函数就是两个都计算

https://longtimenohack.com/posts/paper_reading/2023cvpr_rosu_permutosdf/image-20230912094658982.png

考虑现有实现能否进一步加快

  • [-X-] fwd 时额外保存 rank, rem0, elevated 等中间变量,直接用于 bwd / bwdbwd 中

    • 经测试是行不通的:因为显存写入的速度是低于计算的速度的,这样的操作会使得 fwd 速度变慢为原来的10倍,从 12ms 左右飙升至 120 ms
  • 将排序从 d*(d+1)/2 步 缩减为 O(dlogd)

    • 这篇笔记 中,我们知道,计算 permutohedral lattice 插值过程 是一个 O(d^2) 的耗时 。

      • 其中一步,是需要计算 $\vec{\Delta}=\vec{x}-\vec{y}_0$ 的排序索引,这一步目前的实现是 O(d^2) 的,可以考虑改为 O(dlogd) 的形式
      • 另外一步,是计算 (d+1) 个 余-k 点的 (d+1) 维坐标,只能是 O(d^2) 的耗时,无法优化;而且这一步时耗时的大头,质心坐标的计算、哈希计算、参数值读取都是在每一个循环中完成的。

代码中 rank 的计算事实上就是 $\vec{\Delta}$ 向量从大到小排序的排名,或者说逆排序索引(排序后的向量使用这个索引torch.gather 可以恢复排序前的向量)

_, rank, rem0, elevated = _backend.permuto_enc_fwd(meta, positions, lattice_values, None, None, None, None, False, True)
delta = elevated - rem0
_rank = torch.argsort(torch.argsort(delta, descending=True))

print(torch.equal(rank, _rank))

delta_sorted = delta.sort(descending=True, dim=-1).values
_delta = torch.gather(delta_sorted, -1, rank.long())
print(torch.equal(delta, _delta))
rank
tensor([[[3, 5, 4, 7, 0, 6, 1, 2]]], device='cuda:0', dtype=torch.int32)
_rank
tensor([[[3, 5, 4, 7, 0, 6, 1, 2]]], device='cuda:0')

_delta
tensor([[[ 0.9514, -2.4141, -0.4438, -3.9063,  3.7429, -3.8749,  3.6033,
           2.3415]]], device='cuda:0')
delta
tensor([[[ 0.9514, -2.4141, -0.4438, -3.9063,  3.7429, -3.8749,  3.6033,
           2.3415]]], device='cuda:0')