经典的 Point Transformer工作
最终成了一个经典的 backbone
作者引入了一种vector attention 的计算方式,用于 点云 这种无结构数据;
同时,这里是一种局部注意力,Query 来自查询点;Key,Value 来自周围的邻居。
import torch
import torch.nn as nn
import torch.nn.functional as F
N = 4 # 查询点数量
K = 5 # 每个点的邻居数
C_in = 64 # 输入特征维度
C_out = 128 # 输出特征维度
# 查询点特征 (N, C_in)
x_i = torch.randn(N, C_in)
# 邻居点特征 (N, K, C_in)
x_j = torch.randn(N, K, C_in)
# 相对位置编码 (N, K, C_out)
delta = torch.randn(N, K, C_out)
class ScalarAttention(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.q = nn.Linear(in_dim, out_dim)
self.k = nn.Linear(in_dim, out_dim)
self.v = nn.Linear(in_dim, out_dim)
self.alpha = nn.Linear(in_dim, out_dim)
def forward(self, x_i, x_j, delta):
# 1. 投影
q_i = self.q(x_i) # (N, C)
k_j = self.k(x_j) # (N, K, C)
v_j = self.alpha(x_j) # (N, K, C)
# 2. 注意力权重计算(标量)
attn = (k_j @ q_i.unsqueeze(-1)).squeeze(-1) # → (N, K)
attn = attn + delta.mean(dim=-1) # 加上位置偏移 (简化)
attn = F.softmax(attn, dim=1) # → (N, K)
# 3. 加权求和
out = torch.sum(attn.unsqueeze(-1) * v_j, dim=1) # (N, C)
return out
class VectorAttention(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.phi = nn.Linear(in_dim, out_dim)
self.psi = nn.Linear(in_dim, out_dim)
self.alpha = nn.Linear(in_dim, out_dim)
self.gamma = nn.Sequential( # 用于生成注意力向量
nn.Linear(out_dim, out_dim),
nn.ReLU(inplace=True),
nn.Linear(out_dim, out_dim)
)
def forward(self, x_i, x_j, delta):
phi_i = self.phi(x_i).unsqueeze(1) # (N, 1, C)
psi_j = self.psi(x_j) # (N, K, C)
v_j = self.alpha(x_j) # (N, K, C)
# 计算特征差 + 位置偏移
diff = psi_j - phi_i + delta # (N, K, C)
# 向量注意力
attn = self.gamma(diff) # (N, K, C)
attn = F.softmax(attn, dim=1) # (N, K, C)
# 加权求和(逐通道调制)
out = torch.sum(attn * v_j, dim=1) # (N, C)
return out
scalar_layer = ScalarAttention(C_in, C_out)
vector_layer = VectorAttention(C_in, C_out)
out_scalar = scalar_layer(x_i, x_j, delta)
out_vector = vector_layer(x_i, x_j, delta)
print("Scalar Attention Output:", out_scalar.shape) # → (N, C_out)
print("Vector Attention Output:", out_vector.shape) # → (N, C_out)
位置编码采用了一个 (MLP+ReLU+MLP)的结构对 3D坐标进行编码