【知识】PyTorch中不同优化器的特点和使用

news/2025/2/25 15:54:38

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~

目录

1. SGD(随机梯度下降)

2. Adam(自适应矩估计)

3. AdamW

4. Adagrad

5. Adadelta

6. Adafactor

7. SparseAdam

8. Adamax

9. LBFGS

10. RMSprop

11. Rprop(弹性反向传播)

12. ASGD(平均随机梯度下降)

13. NAdam(Nesterov 加速自适应矩估计)

14. RAdam(修正 Adam)

15. Adafactor(自适应因子化梯度)

16. AMSGrad 

性能考虑

总结


torch.optim — PyTorch 2.6 documentation

1. SGD(随机梯度下降)

  • 用途:适用于小型到中型模型的基本优化。

  • 特点

    • 通过负梯度方向更新参数。

    • 可以包含动量(momentum)以加速学习并减少震荡。

    • 简单且广泛使用,但需要仔细调整学习率。

python">import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的模型
model = nn.Linear(10, 1)  # 一个线性模型
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

# 训练循环
for input, target in dataloader:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

2. Adam(自适应矩估计)

  • 用途:深度学习模型,尤其是需要 L2 正则化时。

  • 特点

    • 根据一阶和二阶矩估计为每个参数计算自适应学习率。

    • 支持学习率衰减的无偏估计。

    • 通常在适当设置下比 SGD 收敛更快。

python">optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False)

3. AdamW

  • 用途:迁移学习、视觉任务,以及权重衰减关键的场景。

  • 特点

    • 将权重衰减与梯度解耦,使其更有效。

    • 在某些场景下性能超过 Adam 和 SGD。

python">optimizer = optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)

4. Adagrad

  • 用途:处理稀疏数据,例如自然语言处理或图像识别。

  • 特点

    • 累积之前的平方梯度以调整学习率。

    • 随着训练的进行,学习率单调递减,有助于收敛。

python">optimizer = optim.Adagrad(model.parameters(), lr=0.01, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10)

5. Adadelta

  • 用途:文本数据处理和图像分类。

  • 特点

    • 通过使用窗口和解决 Adagrad 的学习率递减问题。

    • 维护平方梯度和平方参数更新的运行平均值。

python">optimizer = optim.Adadelta(model.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0.0)

6. Adafactor

  • 用途:大规模模型、大批量或长序列(例如深度学习在网页文本语料库上的应用)。

  • 特点

    • 通过使用近似值替换二阶矩来减少计算开销。

    • 专为非常大的模型设计,不会增加训练时间。

python">optimizer = optim.Adafactor(model.parameters(), lr=0.05, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True, warmup_init=False)

7. SparseAdam

  • 用途:具有稀疏梯度数据的模型,例如 NLP 中的嵌入层。

  • 特点

    • 优化稀疏张量更新;结合 SparseAdam 用于密集张量和 Adagrad 用于稀疏更新。

    • 专为具有许多零值的参数设计。

python">optimizer = optim.SparseAdam(model.sparse_parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)

8. Adamax

  • 用途:类似于 Adam,但基于无穷范数,某些问题上更稳定。

  • 特点

    • 使用过去梯度的最大值而不是平均值。

python">optimizer = optim.Adamax(model.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)

9. LBFGS

  • 用途:无约束优化问题、回归以及需要二阶信息的问题。

  • 特点

    • 使用梯度评估近似海森矩阵的拟牛顿方法。

    • 比 SGD 或 Adam 需要更多内存和计算资源。

python">optimizer = optim.LBFGS(model.parameters(), lr=1.0, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn=None)

# 使用 LBFGS 需要提供一个闭包(closure)来重新评价模型
def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss

optimizer.step(closure)

10. RMSprop

  • 用途:卷积神经网络和递归神经网络。

  • 特点

    • 维护平方梯度的运行平均值,并对参数更新进行归一化。

    • 解决 Adagrad 学习率单调递减的问题。

python">optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0.0, momentum=0.0, centered=False)

11. Rprop(弹性反向传播)

  • 用途:神经网络中梯度大小不重要的场景。

  • 特点

    • 仅使用梯度的符号来更新参数,根据梯度符号变化调整学习率。

python">optimizer = optim.Rprop(model.parameters(), lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50.0))

12. ASGD(平均随机梯度下降)

  • 用途:促进某些模型的泛化。

  • 特点

    • 维护优化过程中遇到的参数的运行平均值。

python">optimizer = optim.ASGD(model.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0.0)

13. NAdam(Nesterov 加速自适应矩估计)

  • 用途:结合 Nesterov 动量和 Adam。

  • 特点

    • 结合 Nesterov 加速梯度(NAG)以提供更稳定的收敛。

python">optimizer = optim.NAdam(model.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, momentum_decay=0.004)

14. RAdam(修正 Adam)

  • 用途:需要自适应学习率但希望减少方差的场景。

  • 特点

    • 根据梯度方差动态调整学习率。

python">optimizer = optim.RAdam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)

15. Adafactor(自适应因子化梯度)

  • 用途:大规模模型的内存高效优化。

  • 特点

    • 通过将大梯度分解为小成分来减少内存使用。

python">optimizer = optim.Adafactor(model.parameters(), lr=0.3, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True, warmup_init=False)

16. AMSGrad 

  • 用途: Adam 优化器的一种改进版本,旨在解决 Adam 在某些情况下可能不收敛的问题。它通过保留梯度的历史信息来防止学习率过早下降,从而提高训练的稳定性和收敛性。
  • 特点
    • 自适应学习率:AMSGrad 自适应地调整学习率,以便更好地训练神经网络。

    • 防止震荡:它可以防止 Adam 算法中的震荡现象,从而提高训练效果。

    • 改进收敛性:通过优化二阶动量,避免了 Adam 算法可能遭遇的收敛问题,特别适合长时间训练或解决深层网络难题。

python"># 初始化 AMSGrad 优化器,通过amsgrad参数设置
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=True)

性能考虑

  • 这些优化器的性能可能因硬件和问题的性质而异。PyTorch 将优化器分为以下几类:

    • For-loop:基本实现,但由于内核调用较慢。

    • Foreach:使用多张量操作以加快处理速度。

    • Fused:将步骤合并为单个内核以实现最大速度。

总结

  • 选择优化器取决于问题的复杂性、数据的稀疏性和硬件的可用性。像 Adam 或 AdamW 这样的自适应算法因其通用有效性而被广泛使用,而像 SGD 这样的简单方法在适当调整超参数时是最优的。


http://www.niftyadmin.cn/n/5865671.html

相关文章

在PyTorch使用UNet进行图像分割【附源码】

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

DeepSeek 提示词:高效的提示词设计

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…

【基于SprintBoot+Mybatis+Mysql】电脑商城项目之加入购物车和显示购物车列表

🧸安清h:个人主页 🎥个人专栏:【Spring篇】【计算机网络】【Mybatis篇】 🚦作者简介:一个有趣爱睡觉的intp,期待和更多人分享自己所学知识的真诚大学生。 目录 🚀1.加入购物车-数…

git 克隆及拉取github项目到本地微信开发者工具,微信开发者工具通过git commit、git push上传代码到github仓库

git 克隆及拉取github项目到本地微信开发者工具,微信开发者工具通过git commit、git push上传代码到github仓库 git 克隆及拉取github项目到本地 先在自己的用户文件夹新建一个项目文件夹,取名为项目名 例如这样 C:\Users\HP\yzj-再打开一个终端页面&…

数据库的MVCC如何理解?

数据库的MVCC如何理解? MVCC(多版本并发控制,Multi-Version Concurrency Control)是数据库系统中的一种并发控制机制,用于允许多个事务在不互相干扰的情况下并行执行,同时保持数据的一致性和隔离性。 MVC…

基于python+django的宠物商店-宠物管理系统源码+运行步骤

该系统是基于pythondjango开发的宠物商店-宠物管理系统。是给师妹开发的课程作业。现将源码开放给大家。大家学习过程中,如遇问题可以在github咨询作者。加油 演示地址 前台地址: http://pet.gitapp.cn 后台地址: http://pet.gitapp.cn/adm…

mac升级系统后聚焦Spotlight Search功能无法使用,进入安全模式可解

mac升级系统后,聚焦功能无法使用,表现为: 1)快捷键无法唤起聚焦框 2)点击右上角 聚焦图标(放大镜),没有任何反应 解决方案: 1)聚焦重建索引,无…

ArcGIS Pro中生成带计曲线等高线的全面指南

一、引言 在地理信息系统(GIS)领域,等高线作为表达地形起伏的重要视觉元素,被广泛应用于地图制作、空间分析以及地形可视化等方面。ArcGIS Pro,作为Esri公司推出的新一代GIS平台,提供了强大的空间分析和地…