Vision-RWKV:Efficient and Scalable Visual Perception with RWKV-Like Architectures
- 来源:https://arxiv.org/pdf/2403.02308
- 发表:ICLR 2025
- 作者单位:OpenGVLab, Shanghai AI Laboratory, The Chinese University of Hong Kong, Fudan University, Nanjing University, Tsinghua University, SenseTime Research
- 动机:
- 作为 Vision Transformer (ViT) 的低成本替代方案,旨在降低计算复杂度和内存消耗,同时保持 ViT 的优势,例如捕获长距离依赖和处理稀疏输入的能力。
- VRWKV 旨在替代 Vision Transformer (ViT),在保持相似性能和可扩展性的同时,降低计算成本和内存消耗。
- VRWKV 的设计目标是保留 RWKV 架构的核心结构和优势,并进行必要的修改,使其能够灵活应用于视觉任务,支持稀疏输入,并确保在扩展后训练过程的稳定性。
- VRWKV 尤其适用于 ViT 难以承受全局注意力的高计算开销的任务。
- 主要模块和思想:
- 整体架构:采用类似于 ViT 的块堆叠图像编码器设计,由一个 patch embedding 层和 L 个相同的 VRWKV 编码器层堆叠而成。每个编码器层包含一个空间混合 (spatial-mix) 模块和一个通道混合 (channel-mix) 模块。
- 空间混合模块:作为注意力机制,用于执行线性复杂度的全局注意力计算,包含 Q-Shift 操作,双向 WKV (Bi-WKV)模块,以及层归一化。
- 通道混合模块:作为前馈网络 (FFN),用于在通道维度执行特征融合。
- Q-Shift (四向移位):一种为视觉任务量身定制的 token 移位方法,允许所有 tokens 与其相邻的 tokens 进行移位和线性插值。Q-Shift 使得不同通道的注意力机制获得关注相邻标记的先验,而无需引入许多额外的 FLOP。Q-Shift 操作还增加了每个标记的感受野,从而大大增强了标记在后验层中的覆盖范围。
- 双向全局注意力:将原始的因果 RWKV 注意力机制修改为双向全局注意力机制,使模型能够实现全局注意力,并以线性计算复杂度计算全局注意力。通过修改 RWKV 注意力机制中的指数,将绝对位置偏差转换为相对偏差,从而增强模型能力并确保可扩展性和稳定性。双向注意力机制使模型能够实现全局注意力,而原始 RWKV 注意力在内部具有因果掩码。
- 线性复杂度双向注意力:双向注意力可以等价地表示为求和形式和 RNN 形式。
- 有效感受野 (ERF):Q-Shift 方法扩展了感受野的核心范围,增强了全局注意力的归纳偏置。
- 图像到 Tokens 的转换及 Q-Shift 操作:
- 将$H×W×3$ 的图像转换为 $HW/p²$ 个 patches,其中 p 代表 patch 的大小。
- 这些 patches 经过线性投影,并加上位置嵌入(position embedding),得到形状为 $T×C$ 的图像 tokens,其中 $T = HW/p²$。
- Q-Shift 操作允许所有 tokens 与其相邻的 tokens 进行移位和线性插值,代码实现如下:
import torch
def q_shift(input, shift_pixel=1, gamma=1/4, patch_resolution=None):
"""
对输入的特征图进行通道层面的空间偏移操作。
参数:
input (torch.Tensor): 输入的特征图张量,形状为 (B, N, C)。
shift_pixel (int): 像素偏移量,默认为 1。
gamma (float): 通道分组的比例,默认为 1/4。
patch_resolution (tuple): 输入特征图的块分辨率,默认为 None。
返回:
torch.Tensor: 处理后的特征图张量,形状为 (B, N, C)。
"""
assert gamma <= 1/4, "gamma 必须小于等于 1/4" # 确保 gamma 的值不超过 1/4
B, N, C = input.shape # 获取输入张量的形状
input = input.transpose(1, 2).reshape(B, C, patch_resolution[0], patch_resolution[1]) #将特征向量转换为类似于图像的二维形式
B, C, H, W = input.shape # 获取重塑后输入张量的形状
output = torch.zeros_like(input) # 创建与输入张量形状相同的全零输出张量
# 将输入张量的不同通道部分进行不同的像素偏移,并赋值给输出张量
output[:, 0:int(C*gamma), :, shift_pixel:W] = input[:, 0:int(C*gamma), :, 0:W-shift_pixel] # 前 gamma 比例的通道向右偏移
output[:, int(C*gamma):int(C*gamma*2), :, 0:W-shift_pixel] = input[:, int(C*gamma):int(C*gamma*2), :, shift_pixel:W] # gamma 到 2*gamma 比例的通道向左偏移
output[:, int(C*gamma*2):int(C*gamma*3), shift_pixel:H, :] = input[:, int(C*gamma*2):int(C*gamma*3), 0:H-shift_pixel, :] # 2*gamma 到 3*gamma 比例的通道向下偏移
output[:, int(C*gamma*3):int(C*gamma*4), 0:H-shift_pixel, :] = input[:, int(C*gamma*3):int(C*gamma*4), shift_pixel:H, :] # 3*gamma 到 4*gamma 比例的通道向上偏移
output[:, int(C*gamma*4):, ...] = input[:, int(C*gamma*4):, ...] # 剩余通道直接复制
return output.flatten(2).transpose(1, 2) # 将输出张量展平并转置,恢复原始形状
RWKV-UNet:Improving UNet with Long-Range Cooperation for Effective Medical Image Segmentation
- 来源:https://arxiv.org/pdf/2501.08458
- 发表于:Arxiv 预印本,未见刊
- 作者单位:Zhejiang Univeristy Youtu Lab, Tencent
- 动机:
- 旨在解决医学图像分割中 CNN 和 Transformer 模型面临的局限性。
- CNN 在捕获长距离依赖关系方面存在不足,而 Transformer 则具有较高的计算复杂度。
- 为了解决这些问题,RWKV-UNet 提出了一种将 RWKV(Receptance Weighted Key Value)结构集成到 U-Net 架构中的新模型。
- RWKV-UNet 的提出受到了现有 U-Net 变体和其他方法的启发,旨在结合 RWKV 的优势,克服 CNN 和 Transformer 在医学图像分割中的局限性。
- 关键特点和组成部分:
- 整体架构:RWKV-UNet 采用 U-Net 的对称 U 型结构,包含一个编码器、一个解码器以及带有 Cross-Channel Mix (CCM) 模块的跳跃连接。
- 编码器:通过堆叠 IR(Inverted Residual)块和 IR-RWKV 块 构建。
- IR-RWKV 块:结合了 Vision RWKV 中的空间混合模块 (Spatial-Mix) 和深度可分离卷积 (DW-Conv),以结合全局和局部依赖关系,并权衡模型成本和准确性。
- IR 块:包含一个点卷积层、一个 DW-Conv 层和一个带有局部和全局残差跳跃连接的点卷积层,去除了 IR-RWKV 中的空间混合以及展开和折叠过程。
- 解码器:基于 CNN,包含一个点卷积层和一个 9×9 DW-Conv 层,后跟一个点操作层和一个上采样操作。
- Cross-Channel Mix (CCM) 模块:可以有效地提取跨多尺度编码器特征的通道信息。通过捕获沿通道的丰富全局上下文,CCM 进一步增强了长程上下文信息的提取。
- RWKV 的优势:RWKV 结合了 RNN 的线性和 Transformer 的并行处理优势。与 Transformer 相比,RWKV 计算复杂度较低,并且能够处理高分辨率图像,而无需窗口操作。
- 跳跃连接:利用跳跃连接来融合编码器和解码器的特征,从而保留更多细节,显著提高模型分割效果。
- Cross-Channel Mix (CCM) 模块:是 RWKV-UNet 架构中用于多尺度特征融合的关键组成部分。它的灵感来源于 Vision RWKV (VRWKV) 中的 Channel Mix 模块。CCM 模块旨在有效提取跨多尺度编码器特征的通道信息,通过捕获丰富的全局上下文信息,进一步增强长距离上下文信息的提取。
BSBP-RWKV:Background Suppression with Boundary Preservation for Efficient Medical Image Segmentation
- 文献来源:https://openreview.net/forum?id=ULD5RCk0oo
- 发表于:ACM-MM
- 作者单位:University of Science and Technology of China
- 动机:解决医学图像分割中背景噪声干扰和分割效率低下的问题。
- 创新点:
- DWT-PMD RWKV 块:结合了 PMD 在噪声抑制和边界细节保留方面的优势,并利用 RWKV 的高效结构,旨在抑制背景噪声干扰,同时保留病灶区域的边界细节。
- 多步 Runge-Kutta 卷积块:利用其高精度特征提取能力,摄取目标主体特征以及 DWT-PMD RWKV 块的边界输出,并将两者整合以提高形状感知分割质量。
- 形状细化损失函数:在空间域和频率域中对齐预测的病灶区域形状与真实掩码,有助于模型跳出局部最优,并在预测掩码和 ground truth 在空间域中表现出高相似性时进一步细化掩码。该损失函数通过结合空间域损失($L_{spac}$) 和频率域损失 ($L_{Lre}$),在空间域和频率域中对齐预测的病灶区域形状与真实掩码 (GT masks),从而帮助模型跳出局部最优,并在预测掩码和 ground truth 在空间域中表现出高相似性时进一步细化掩码。
- DWT-PMD RWKV 块的工作原理:通过 PMD 抑制噪声,通过 DWT 提取边界细节,并通过 RWKV 实现高效计算,从而有效地抑制背景噪声干扰,同时保留病灶区域的边界细节。
相关