PT v1

介绍

经典的 Point Transformer工作

最终成了一个经典的 backbone

Attention 算子

作者引入了一种vector attention 的计算方式,用于 点云 这种无结构数据;

同时,这里是一种局部注意力,Query 来自查询点;Key,Value 来自周围的邻居。


import torch
import torch.nn as nn
import torch.nn.functional as F

N = 4         # 查询点数量
K = 5         # 每个点的邻居数
C_in = 64     # 输入特征维度
C_out = 128   # 输出特征维度

# 查询点特征 (N, C_in)
x_i = torch.randn(N, C_in)

# 邻居点特征 (N, K, C_in)
x_j = torch.randn(N, K, C_in)

# 相对位置编码 (N, K, C_out)
delta = torch.randn(N, K, C_out)

class ScalarAttention(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.q = nn.Linear(in_dim, out_dim)
        self.k = nn.Linear(in_dim, out_dim)
        self.v = nn.Linear(in_dim, out_dim)
        self.alpha = nn.Linear(in_dim, out_dim)

    def forward(self, x_i, x_j, delta):
        # 1. 投影
        q_i = self.q(x_i)                      # (N, C)
        k_j = self.k(x_j)                      # (N, K, C)
        v_j = self.alpha(x_j)                  # (N, K, C)

        # 2. 注意力权重计算(标量)
        attn = (k_j @ q_i.unsqueeze(-1)).squeeze(-1)  # → (N, K)
        attn = attn + delta.mean(dim=-1)              # 加上位置偏移 (简化)
        attn = F.softmax(attn, dim=1)                 # → (N, K)

        # 3. 加权求和
        out = torch.sum(attn.unsqueeze(-1) * v_j, dim=1)  # (N, C)
        return out

class VectorAttention(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.phi = nn.Linear(in_dim, out_dim)
        self.psi = nn.Linear(in_dim, out_dim)
        self.alpha = nn.Linear(in_dim, out_dim)

        self.gamma = nn.Sequential(  # 用于生成注意力向量
            nn.Linear(out_dim, out_dim),
            nn.ReLU(inplace=True),
            nn.Linear(out_dim, out_dim)
        )

    def forward(self, x_i, x_j, delta):
        phi_i = self.phi(x_i).unsqueeze(1)             # (N, 1, C)
        psi_j = self.psi(x_j)                          # (N, K, C)
        v_j = self.alpha(x_j)                          # (N, K, C)

        # 计算特征差 + 位置偏移
        diff = psi_j - phi_i + delta                   # (N, K, C)

        # 向量注意力
        attn = self.gamma(diff)                        # (N, K, C)
        attn = F.softmax(attn, dim=1)                  # (N, K, C)

        # 加权求和(逐通道调制)
        out = torch.sum(attn * v_j, dim=1)             # (N, C)
        return out

scalar_layer = ScalarAttention(C_in, C_out)
vector_layer = VectorAttention(C_in, C_out)

out_scalar = scalar_layer(x_i, x_j, delta)
out_vector = vector_layer(x_i, x_j, delta)

print("Scalar Attention Output:", out_scalar.shape)  # → (N, C_out)
print("Vector Attention Output:", out_vector.shape)  # → (N, C_out)

位置编码

位置编码采用了一个 (MLP+ReLU+MLP)的结构对 3D坐标进行编码