PermutoSDF: Fast Multi-View Reconstruction with Implicit Surfaces using Permutohedral Lattices
编者语:主要是把 NGP 的思想推广到了一种对高维输入更友好的格点 (lattice) 上
背景:permutohedral lattice (排列多面体点阵)
背景数学知识详见我的 这篇笔记
Motivation
续上述笔记之言,Instant-NGP 其实就是采用 $\mathbb{Z}^d$ 的 regular grid 的 lattice,而这种 lattice 的 Delaunay cell 的顶点个数是 $2^d$,随着维度 d 提升需要计算的顶点个数(i.e. 权重的个数)指数上升,因此其计算效率对高维较为不友好。
而本篇的思想就是使用 Permutohedral lattice $A^\ast_d$ 来代替 $\mathbb{Z}^d$,因为 $A^\ast_d$ 的 Delaunay cell 的顶点个数是 $(d+1)$ ,是所有 lattice 中 Delaunay cell 顶点个数随维度 d 上升最慢的,因此针对高维输入的计算效率将大大提高。
permutohedral_encoding 底层代码解读
维度补全 & scale 定义
EncodingFixedParams()
结构体创建时传入的参数 sigmas_list
即为 [self@PermutoEncoding].scale_per_level
参数,该参数是 PermutoEncoding
传入的,该参数的一个典型例子如 scale_list
:
coarsest_scale=1.0
finest_scale=0.0001
scale_list=np.geomspace(coarsest_scale, finest_scale, num=nr_levels)
这个正常的逐层的空间缩放参数在 EncodingFixedParams::compute_scale_factor_tensor
被计算转存到 m_scale_factor
中,计算方式如下:
torch::Tensor compute_scale_factor_tensor(const std::vector<float> sigmas_list, const int pos_dim){
int nr_resolutions=sigmas_list.size();
torch::Tensor scale_factor_tensor=torch::zeros({ nr_resolutions, pos_dim }, torch::dtype(torch::kFloat32).device(torch::kCUDA, 0) );
double invStdDev = 1.0;
for(int res_idx=0; res_idx<nr_resolutions; res_idx++){
for (int i = 0; i < pos_dim; i++) {
// 1 除以 (i+1)(i+2) 的和开根号, i 是当前第几个维度 (从0开始索引)
scale_factor_tensor[res_idx][i] = 1.0 / (std::sqrt((double) (i + 1) * (i + 2))) * invStdDev;
// 配置时的从 1.0 到 0.0001 几何平均的 scale_list -> 传入 python module 的 scale_per_level -> 在这里除到了分母上
scale_factor_tensor[res_idx][i]=scale_factor_tensor[res_idx][i]/ sigmas_list[res_idx];
}
}
return scale_factor_tensor;
}
而 m_scale_factor
的利用方式如下(在 forward_gpu
, backward_gpu
中均如此计算):
// 下面这段把 d 维向量升到 (d+1) 维的齐次坐标 (坐标和为0)
// random_shift 主要是不同层、不同维度分量施加不同的随机位置偏移,最大可能避免 hash collision;
// 受 python module 的 appply_random_shift_per_level: bool 控制,如果为 False 这里就是零tensor
// scale 的定义 和 elevated 计算做了特殊处理; 最终能够使得:
// 1. elevated 之和仍然可以保证为0;
// 2. elevated 每个分量的变化幅度接近,即使最大维度非常大,变化幅度的量级也仍在个位数;
float sm = 0;
#pragma unroll
for (int i = pos_dim; i > 0; i--) {
// float cf = (pos[i-1] +random_shift_constant[level*pos_dim + i-1] ) * scale_factor_constant[level*pos_dim + i-1];
float cf = (pos[i-1] +random_shift_monolithic[level][i-1] ) * scale_factor[level][i-1];
elevated[i] = sm - i * cf; // 第i个 elevated 放置的是 第i-1 个 pos 元素;
sm += cf;
}
elevated[0] = sm; // elevated 的第0个放的就是补全成齐次坐标的分量维度, 让整体坐标和为0
其中的 (std::sqrt((double) (i + 1) * (i + 2)))
(i.e. $\sqrt{(i+1)(i+2)}, i\in{0,1,\dots,\text{pos_dim}-1}$) 部分比较奇怪,尝试推导:
假定对某层,其 sigmas_list
中的值的倒数为 $s$,假定当前 pos_dim=4
(4维输入):
- i=4:
- $\text{cf}\leftarrow\frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack$
- $\text{elevated}\lbrack 4\rbrack \leftarrow -\frac{4s}{\sqrt{4\times 5}} \text{pos} \lbrack 3 \rbrack$
- $\text{sm}=\frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack$
- i=3:
- $\text{cf}\leftarrow\frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack$
- $\text{elevated}\lbrack 3\rbrack \leftarrow = \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack -\frac{3s}{\sqrt{3\times 4}} \text{pos} \lbrack 2 \rbrack$
- $\text{sm}= \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack$
- i=2:
- $\text{cf}\leftarrow\frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack$
- $\text{elevated}\lbrack 2\rbrack \leftarrow \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack-\frac{2s}{\sqrt{2\times 3}} \text{pos} \lbrack 1 \rbrack$
- $\text{sm}=\frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack + \frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack$
- i=1:
- $\text{cf}\leftarrow\frac{s}{\sqrt{1\times2}} \text{pos}\lbrack 0\rbrack$
- $\text{elevated}\lbrack 1\rbrack \leftarrow \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack + \frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack -\frac{s}{\sqrt{1\times2}} \text{pos} \lbrack 0 \rbrack$
- $\text{sm}=\frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack + \frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack + \frac{s}{\sqrt{1\times2}} \text{pos}\lbrack 0\rbrack$
- 最后:
- $\text{elevated} \lbrack 0 \rbrack \leftarrow \frac{s}{\sqrt{4\times5}} \text{pos}\lbrack 3\rbrack + \frac{s}{\sqrt{3\times4}} \text{pos}\lbrack 2\rbrack + \frac{s}{\sqrt{2\times3}} \text{pos}\lbrack 1\rbrack + \frac{s}{\sqrt{1\times2}} \text{pos}\lbrack 0\rbrack$
验证 $\text{elevated}$ 之和为0,没有问题:
$\left( 4\times \frac{s}{\sqrt{4\times5}} - \frac{4s}{\sqrt{4\times5}} \right) \text{pos}\lbrack 3\rbrack + \left( 3 \times \frac{s}{\sqrt{3\times 4}} - \frac{3s}{\sqrt{3\times 4}} \right) \text{pos} \lbrack 2 \rbrack + \left( 2\times \frac{s}{\sqrt{2\times3}} - \frac{2s}{\sqrt{2\times 3}} \right) \text{pos} \lbrack 1 \rbrack + \left( 1\times \frac{s}{\sqrt{1\times2}} - \frac{s}{\sqrt{1\times2}} \right) \text{pos} \lbrack 0 \rbrack = 0$
- 首先:补全齐次坐标其实没什么限制,只要能从原始的 $d$ 维坐标出发,定义任何一种 可逆的 连续的 到 $(d+1)$ 维齐次坐标的映射都可以。那么既然如此,我们可以挑选出满足我们要求的、性质比较优良的补全方式。
- 然后:最简单暴力的补全方式 $\text{elevated} \lbrack 0 \rbrack = 1-\sum_{k=1}^{d} \text{elevated} \lbrack k \rbrack$ 的主要问题是补全这个维度的量级可以变得很大,比如当我们控制其他维度坐标在 $\lbrack 0,1 \rbrack$之间的话,他们是有可能全部取到 1 的,也就是说 $\text{elevated} \lbrack 0 \rbrack$ 的量级将是 $\lbrack -d+1,1 \rbrack$,这其实是不希望看到的;我们还是比较希望不同维度的量级都在一个比较接近的程度的
- 按照上述方式定义后,假定原始 输入 pos 各个分量的取值范围均在 [0,1] 之间,那么 $\text{elevated}$ 各个分量的量级:
i | 最小值 | 最大值 |
---|---|---|
0 | 0 | $\left( \frac{1}{\sqrt{4\times 5}} + \frac{1}{\sqrt{3\times 4}} + \frac{1}{\sqrt{2\times 3}} + \frac{1}{\sqrt{1\times 2}} \right) s \approx 1.628s$ 通项形如 $\sqrt{\frac{n+1}{n}}-\sqrt{\frac{n}{n+1}}$,无法相互抵消;$n\rightarrow \inf$ 时和极限为无穷; 和值随项数增大变化缓慢,项数=1~20,和值为 (0.7, 1.1, 1.4, 1.6, 1.8, 1.9, 2.0, 2.2, …, 3.1) |
1 | $-\frac{\sqrt{1}}{\sqrt{2}}s\approx -0.707s$ | $\left(\frac{1}{\sqrt{4\times5}} + \frac{1}{\sqrt{3\times4}} + \frac{1}{\sqrt{2\times3}}\right)s \approx 0.921s$ 和值随项数增大变化缓慢,项数=1~19,和值为 (0.4, 0.6, 0.9, 1.1, 1.2, 1.3, 1.5, 1.6, …, 2.3) |
2 | $-\frac{\sqrt{2}}{\sqrt{3}}s \approx -0.816s$ | $\left(\frac{1}{\sqrt{4\times5}} + \frac{1}{\sqrt{3\times4}}\right)s \approx 0.512s$ 和值随项数增大变化缓慢,项数=1~18,和值为 (0.2, 0.5, 0.6, 0.8, 0.9, 1.1, 1.2, 1.3, …, 1.8) |
3 | $-\frac{\sqrt{3}}{\sqrt{4}}s\approx -0.866s$ | $\left(\frac{1}{\sqrt{4\times5}}\right)s \approx 0.224 s$ 和值随项数增大变化缓慢,项数=1~17,和值为 (0.2, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0, 1.0, …, 1.5) |
4 | $-\frac{\sqrt{4}}{\sqrt{5}}s \approx -0.894 s$ | 0 |
假想大值n | $-\frac{\sqrt{n}}{\sqrt{n+1}}s$ $n\rightarrow \inf$ 时,趋近于 $-1\cdot s$ | 0 |
可以看到,对于所有分量:
- 变化幅度最小发生在升维变量的最后一个维度,最小幅度为随d越大越接近1的某个值
- 变化幅度最大发生在升维变量德尔第0个维度也就是补全的维度,最大幅度随维度d提升上升缓慢,如d=20时也只有3.1,d=200时只有 5.3,d=2000 时只有 7.6
另外:
- 由于 Permutohedral 的 lattice 操作全部发生在补全成 (d+1) 维齐次坐标的体系下,和原始的坐标并没有太好的直接对应关系,因此在多个分辨率中无法加入 Dense 层,也不存在邻居的概念。
double_backward 几个函数的区别
左侧,double_backward_from_positions_gpu_1
,事实上就是从 dL_ddLdx 计算 dL_dparams;
右侧,double_backward_from_positions_gpu_2
,事实上就是从 dL_ddLdx 计算 dL_ddLdy
没有用上的 double_backward_from_positions_gpu
函数就是两个都计算
考虑现有实现能否进一步加快
[-X-] fwd 时额外保存 rank, rem0, elevated 等中间变量,直接用于 bwd / bwdbwd 中
- 经测试是行不通的:因为显存写入的速度是低于计算的速度的,这样的操作会使得 fwd 速度变慢为原来的10倍,从 12ms 左右飙升至 120 ms
将排序从 d*(d+1)/2 步 缩减为 O(dlogd)
从 这篇笔记 中,我们知道,计算 permutohedral lattice 插值过程 是一个 O(d^2) 的耗时 。
- 其中一步,是需要计算 $\vec{\Delta}=\vec{x}-\vec{y}_0$ 的排序索引,这一步目前的实现是 O(d^2) 的,可以考虑改为 O(dlogd) 的形式
- 另外一步,是计算 (d+1) 个 余-k 点的 (d+1) 维坐标,只能是 O(d^2) 的耗时,无法优化;而且这一步时耗时的大头,质心坐标的计算、哈希计算、参数值读取都是在每一个循环中完成的。
代码中 rank 的计算事实上就是 $\vec{\Delta}$ 向量从大到小排序的排名,或者说逆排序索引(排序后的向量使用这个索引torch.gather 可以恢复排序前的向量)
_, rank, rem0, elevated = _backend.permuto_enc_fwd(meta, positions, lattice_values, None, None, None, None, False, True)
delta = elevated - rem0
_rank = torch.argsort(torch.argsort(delta, descending=True))
print(torch.equal(rank, _rank))
delta_sorted = delta.sort(descending=True, dim=-1).values
_delta = torch.gather(delta_sorted, -1, rank.long())
print(torch.equal(delta, _delta))
rank
tensor([[[3, 5, 4, 7, 0, 6, 1, 2]]], device='cuda:0', dtype=torch.int32)
_rank
tensor([[[3, 5, 4, 7, 0, 6, 1, 2]]], device='cuda:0')
_delta
tensor([[[ 0.9514, -2.4141, -0.4438, -3.9063, 3.7429, -3.8749, 3.6033,
2.3415]]], device='cuda:0')
delta
tensor([[[ 0.9514, -2.4141, -0.4438, -3.9063, 3.7429, -3.8749, 3.6033,
2.3415]]], device='cuda:0')