MVANet——小范围内捕捉高分辨率细节而在大范围内不损失精度的强大的背景消除模型

一、概述

前景提取(背景去除)是现代计算机视觉的关键挑战之一,在各种应用中的重要性与日俱增。在图像编辑和视频制作中有效地去除背景不仅能提高美学价值,还能提高工作流程的效率。在要求精确度的领域,如医学图像分析和自动驾驶技术中的物体识别,背景去除也发挥着重要作用。主要的挑战是在高分辨率图像中捕捉小区域的精细细节,同时保持大区域的精确度。迄今为止,还没有一种方法能将细节再现与整体精度相结合。然而,一种名为 MVANet 的新方法为这一挑战提供了创新的解决方案。

MVANet 采用的独特方法受到人类视觉的启发。正如人类从多个角度观察物体一样,MVANet 也从多个角度分析物体。这种方法可以在不丢失细节的情况下提高整体精度。此外,多视角的整合还可实现远距离视觉交互,这是传统方法难以实现的。

市场营销、娱乐、医疗保健和安全等各行各业对背景消除技术的需求与日俱增。在网上购物中,它可使产品的前景更加突出,从而提高购买意愿。它对于使用虚拟背景的视频会议应用以及视频制作中绿屏的替代技术也很重要。随着所有这些应用成为焦点,前景提取性能的提高将对整个行业产生重大影响。

这种新方法已经证明了它的有效性。特别是在 DIS-5K 数据集上,它在精度和速度上都优于目前的 SOTA;MVANet 有潜力成为前景提取任务的新标准,并有望在未来获得更广泛的应用。

二、算法架构

图 1:MVANet 概述。

MVANet 的整体结构与 UNet 类似,如图 1 所示。编码器使用一个远景(G)和一个近景(Lm)作为输入,远景和近景由 M(本文中为 M=4)不重叠的局部斑块组成。

G 和 Lm 构成一个多视角补丁序列,分批输入特征提取器,生成多级特征图 Ei(i=1,2,3,4,5)。每个 Ei 包含远景和近景的表示。最高级别的特征图 E5 沿批次维度被分成两组不同的全局和局部特征,并被输入多视图完成定位模块(MCLM,图 2-a)。2-a),并将其输入 MCLM(MCLM,图 2-a)。

该解码器类似于 FPN(Lin et.al, 2017)架构,但在每个解码阶段都插入了一个即时多视图完成细化模块(MCRM,图 2-b)。每个阶段的输出用于重建 SDO 地图(只有前景的地图)和计算损失。图 1 的右下方显示了多视角整合。局部特征合并后输入到 Conv Head,以便与全局特征进行细化和串联。

图 2:MCLM 和 MCRM 架构。

学习的损失函数

如图 1 所示,解码器每一层的输出和最终预测都加入了监督。

具体来说,前者由三个部分组成:ll、lg 和 la,分别代表细化模块中的组合局部表征、全局表征和标记注意图。每个侧输出都需要一个单独的卷积层来获得单通道预测。后者用 lf 表示。这些组件结合使用了二元交叉熵(BCE)损失和加权 IoU 损失,这在大多数分割任务中都很常用。

最终的学习损失函数如下式所示。本文设置 λg=0.3,λh=0.3。

三、试验

数据集和评估指标

数据集

本文使用 DIS5K 基准数据集进行实验。该数据集包含 225 个类别的 5,470 张高分辨率图像(2K、4K 或更大尺寸)。数据集分为三个部分

  • DIS-TR:3 000 幅训练图像。
  • DIS-VD:470 幅验证图像。
  • DIS-TE:2,000 张测试图像,分为四个子集(DIS-TE1、2、3 和 4),每个子集有 500 张图像,几何复杂度依次增加

DIS5K 数据集因其高分辨率图像、详细的结构和出色的注释质量,比其他分割数据集更具挑战性,需要先进的模型来捕捉复杂的细节。

评估指标

采用以下指标评估绩效

  • 最大 F 值:测量准确性和重复性的最大得分,β² 设置为 0.3。
  • 加权 F 值:与 F 值类似,但已加权。
  • 结构相似性测量(Sm):评估预测值与真实值之间的结构相似性,同时考虑领域和对象识别。
  • 电子测量:用于评估像素与图像之间的匹配程度。
  • 平均绝对误差 (MAE):计算预测地图与真实值之间的平均误差。

这些指标有助于了解该模型在识别和分割 DIS5K 数据集中具有复杂结构的物体方面的性能。

实验结果

定量评估

表 1 将拟议的 MVANet 与其他 11 个著名的相关模型(F3Net、GCPANet、PFNet、BSANet、ISDNet、IFA、IS-Net、FPDIS、UDUN、PGNet 和 InSPyReNet)进行了比较。为进行公平比较,输入大小标准化为 1024 × 1024。结果表明,在所有数据集的不同指数上,MVANet 都明显优于其他模型。特别是在 F、Em、Sm 和 MAE 方面,MVANet 分别比 InSPyReNet 高出 2.5%、2.1%、0.5% 和 0.4%。

此外,还评估了 InSPyReNet 和 MVANet 的推理速度。两者都在英伟达 RTX 3090 GPU 上进行了测试。由于采用了简单的单流设计,MVANet 的推理速度达到了 4.6 FPS,而 InSPyReNet 为 2.2 FPS。

表 1.DIS5K 的定量评估。

定性评估

为了直观地展示所提方法的高预测准确性,我们将测试集中所选图像的输出结果可视化。如图 3 所示,即使在复杂的场景中,建议的方法也能准确定位物体并捕捉边缘细节。特别是,建议的方法能够准确区分椅子的完整分割和每个网格的内部,而其他方法则会受到明显的黄色纱布和阴影的干扰(见下行)。

图 3.DIS5K 中的定性评估。

四、代码测试

下载源码

git clone https://github.com/qianyu-dlut/MVANet.git
cd MVANet

环境配置

conda create -n mvanet python==3.8
conda activate mvanet
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
pip install -U openmim
mim install mmcv-full==1.3.17
pip install -r requirements.txt

测试代码:

import os
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torch.autograd import no_grad
from torchvision import transforms
from model.MVANet import inf_MVANet
import ttach as tta

# 参数设置
model_path = 'saved_model/Model_80.pth'  # 修改为你的模型路径
image_directory = 'data/images'  # 修改为你的图像目录路径
output_directory = 'datamasks'  # 预测结果保存路径

# 图像变换
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 定义 TTA 变换
tta_transforms = tta.Compose([
    tta.HorizontalFlip(),
    tta.Scale(scales=[0.75, 1, 1.25], interpolation='bilinear', align_corners=False),
])


def load_model(model_path):
    net = inf_MVANet().cuda()
    # 加载模型参数
    pretrained_dict = torch.load(model_path, map_location='cuda')
    model_dict = net.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    net.load_state_dict(model_dict)
    net.eval()
    return net


def predict_image(net, img_path):
    # 加载图像并进行预处理
    img = Image.open(img_path).convert('RGB')
    w_, h_ = img.size
    img_resize = img.resize([1024, 1024], Image.BILINEAR)
    img_var = img_transform(img_resize).unsqueeze(0).cuda()

    # 预测结果
    masks = []
    with no_grad():
        for transformer in tta_transforms:
            img_transformed = transformer.augment_image(img_var)
            model_output = net(img_transformed)
            deaug_mask = transformer.deaugment_mask(model_output)
            masks.append(deaug_mask)

        prediction = torch.mean(torch.stack(masks, dim=0), dim=0).sigmoid()

    # 将预测结果转换为图像
    prediction_img = transforms.ToPILImage()(prediction.data.squeeze(0).cpu())
    prediction_img = prediction_img.resize((w_, h_), Image.BILINEAR)

    return img, prediction_img


def process_directory(net, image_dir, output_dir):
    # 创建保存目录
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 遍历目录中的所有图像
    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for idx, image_file in enumerate(image_files):
        img_path = os.path.join(image_dir, image_file)
        print(f"Processing {idx + 1}/{len(image_files)}: {img_path}")

        # 预测并显示结果
        original_img, prediction_img = predict_image(net, img_path)

        # 保存预测结果
        prediction_path = os.path.join(output_dir, f"prediction_{image_file}")
        prediction_img.save(prediction_path)

        # 显示图像
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].imshow(original_img)
        axs[0].set_title("Original Image")
        axs[0].axis('off')

        axs[1].imshow(prediction_img, cmap='gray')
        axs[1].set_title("Predicted Mask")
        axs[1].axis('off')

        plt.show()


if __name__ == '__main__':
    # 加载模型与处理图像目录
    model = load_model(model_path)
    process_directory(model, image_directory, output_directory)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

五、总结

在这篇评论文章中,我们将高精度前景提取(背景去除)建模为一个多视角物体识别问题,提供了一个高效、简单的多视角聚合网络。这样做的目的是更好地平衡模型设计、准确性和推理速度。
为解决多视图的目标对准问题,提出了多视图完成定位模块,以联合计算目标的共同关注区域。此外,提出的多视图完成细化模块被嵌入到每个解码器块中,以充分整合互补的本地信息,减少单视图补丁中语义的缺失。这样,只需一个卷积层就能实现最终的视图细化。
广泛的实验表明,所提出的方法性能良好。MVANet 有潜力成为前景提取任务的新标准,并有望在未来得到更广泛的应用。

源码下载地址:https://download.csdn.net/download/matt45m/90335556


http://www.niftyadmin.cn/n/5841053.html

相关文章

【力扣】438.找到字符串中所有字母异位词

AC截图 题目 思路 我一开始是打算将窗口内的s子字符串和p字符串都重新排序&#xff0c;然后判断是否相等&#xff0c;再之后进行窗口滑动。不过缺点是会超时。 class Solution { public:vector<int> findAnagrams(string s, string p) {vector<int> vec;if(s.siz…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.12 连续数组:为什么contiguous这么重要?

2.12 连续数组&#xff1a;为什么contiguous这么重要&#xff1f; 目录 #mermaid-svg-wxhozKbHdFIldAkj {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-wxhozKbHdFIldAkj .error-icon{fill:#552222;}#mermaid-svg-…

【C++语言】卡码网语言基础课系列----13. 链表的基础操作I

文章目录 背景知识链表1、虚拟头节点(dummyNode)2、定义链表节点3、链表的插入 练习题目链表的基础操作I具体代码实现 小白寄语诗词共勉 背景知识 链表 与数组不同&#xff0c;链表的元素存储可以是连续的&#xff0c;也可以是不连续的&#xff0c;每个数据除了存储本身的信息…

ESP32 Wroom (无串口芯片的简版C3) 烧录

烧录前按住boot, 然后按下reset&#xff08;EN&#xff09;, 松开手烧录完按下reset (EN), 才进入running状态

【漫话机器学习系列】074.异方差(Heteroscedasticity)

异方差&#xff08;Heteroscedasticity&#xff09; 异方差&#xff08;Heteroscedasticity&#xff09;是指在回归分析中&#xff0c;误差项的方差不恒定的现象。通常&#xff0c;我们假设回归模型中的误差项具有恒定方差&#xff08;即同方差性&#xff0c;homoscedasticity…

Windows编译FreeRDP步骤

1. **安装必要工具** powershell # 安装 Visual Studio 2022 (勾选"C桌面开发"组件) # 安装 CMake: https://cmake.org/download/ # 安装 Git: https://git-scm.com/ 2. **安装依赖项** powershell # 使用vcpkg包管理 git clone https://github.com/Microsoft/vcpk…

基于python的Kimi AI 聊天应用

因为这几天deepseek有点状况&#xff0c;导致apikey一直生成不了&#xff0c;用kimi练练手。这是一个基于 Moonshot AI 的 Kimi 接口开发的聊天应用程序&#xff0c;使用 Python Tkinter 构建图形界面。 项目结构 项目由三个主要Python文件组成&#xff1a; 1. main_kimi.py…

K个不同子数组的数目--滑动窗口--字节--亚马逊

Stay hungry, stay foolish 题目描述 给定一个正整数数组 nums和一个整数 k&#xff0c;返回 nums 中 「好子数组」 的数目。 如果 nums 的某个子数组中不同整数的个数恰好为 k&#xff0c;则称 nums 的这个连续、不一定不同的子数组为 「好子数组 」。 例如&#xff0c;[1,2,…