网站首页 > 博客文章 正文
这是我的第297篇原创文章。
一、引言
在深度学习中,超参数是指在训练模型时需要手动设置的参数,它们通常不能通过训练数据自动学习得到。超参数的选择对于模型的性能至关重要,因此在进行深度学习实验时,超参数调优通常是一个重要的步骤。常见的超参数包括:
- model.add()
- neurons(隐含层神经元数量)
- init_mode(初始权重方法)
- activation(激活函数)
- dropout(丢弃率)
- model.compile()
- loss(损失函数)
- optimizer(优化器)
- learning rate(学习率)
- momentum(动量)
- weight decay(权重衰减系数)
- model.fit()
- batch size(批量大小)
- epochs(迭代周期数)
一般来说,可以通过手动调优、网格搜索(Grid Search)、随机搜索(Random Search)、自动调参算法方式进行超参数调优在深度学习中,Epoch(周期)和 Batch Size(批大小)是训练神经网络时经常使用的两个重要的超参数。
- Epoch(周期):一个Epoch就是将所有训练样本训练一次的过程。然而,当一个Epoch的样本(也就是所有的训练样本)数量可能太过庞大(对于计算机而言),就需要把它分成多个小块,也就是就是分成多个Batch 来进行训练。
- Batch(批 / 一批样本):将整个训练样本分成若干个Batch。
- Batch_Size(批大小):每批样本的大小。即1次迭代所使用的样本量。
- Iteration(一次迭代):训练一个Batch就是一次Iteration(这个概念跟程序语言中的迭代器相似)每次迭代更新1次网络结构的参数
- step(一步):训练一个样本就是一个step。
比如我有1000个训练样本,bachsize设置为10,则数据分成了100个batch,所有训练样本训练一次即一个epoch需要100个iteration,训练一个batch就是一次iteration。本文采用网格搜索选择Epoch和Batch_size。
二、实现过程
2.1 准备数据
dataset:
dataset = pd.read_csv("data.csv", header=None)
dataset = pd.DataFrame(dataset)
print(dataset)
2.2 数据划分
# 切分数据为输入 X 和输出 Y
X = dataset.iloc[:,0:8]
Y = dataset.iloc[:,8]
# 为了复现,设置随机种子
seed = 7
np.random.seed(seed)
random.set_seed(seed)
2.3 创建模型
需要定义个网格的架构函数create_model,create_model里面的参数要在KerasClassifier这个对象里面存在而且参数名要一致。
def create_model():
# 创建模型
model = Sequential()
model.add(Dense(50, input_shape=(8, ), kernel_initializer='uniform', activation='relu'))
model.add(Dropout(0.05))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# 编译模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
model = KerasClassifier(model=create_model)
这里使用了scikeras库的KerasClassifier类来定义一个分类器,这里由于KerasClassifier有批量大小、迭代次数的参数,不需要自定义表示。
2.4 定义网格搜索参数
param_grid = {'batch_size': [20, 40], 'epochs': [10, 50]}
param_grid是一个字典,key是超参数名称,这里的名称必须要在KerasClassifier这个对象里面存在而且参数名要一致。value是key可取的值,也就是要尝试的方案。
2.5 进行参数搜索
from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(estimator=model, param_grid=param_grid)
grid_result = grid.fit(X, Y)
使用sklearn里面的GridSearchCV类进行参数搜索,传入模型和网格参数。
2.6 总结搜索结果
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
print("%f (%f) with: %r" % (mean, stdev, param))
结果:
经过网格搜索,批量大小的最优选择是20,迭代次数最优选择是50。
作者简介: 读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。关注gzh:数据杂坛,获取数据和源码学习更多内容。
原文链接:
猜你喜欢
- 2025-05-10 机器学习实操步骤:收集数据、数据准备、选择一个模型、训练、评估、参数调整、预测
- 2025-05-10 全新大模型参数高效微调方法SSF:仅需训练0.3M的参数,效果卓越
- 2025-05-10 深度学习模型写作指南(深度学习模型训练论文)
- 2025-05-10 观察|深度学习需持续收集海量数据,自动驾驶开发有哪些挑战
- 2025-05-10 超参数自动调参库介绍(超参数设置)
- 2025-05-10 机器人EI?这些刊反馈超快!(2021年机器人领域顶级期刊)
- 2025-05-10 一文彻底搞懂深度学习 - 训练和推理(Training vs Inference)
- 2025-05-10 深度学习基础知识题库大全(深度学习入门理论教材)
你 发表评论:
欢迎- 07-07Xiaomi Enters SUV Market with YU7 Launch, Targeting Tesla with Bold Pricing and High-Tech Features
- 07-07Black Sesame Maps Expansion Into Robotics With New Edge AI Strategy
- 07-07Wuhan's 'Black Tech' Powers China's Cross-Border Push with Niche Electronics and Scientific Firepower
- 07-07Maven 干货 全篇共:28232 字。预计阅读时间:110 分钟。建议收藏!
- 07-07IT运维必会的30个工具(it运维工具软件)
- 07-07开源项目有你需要的吗?(开源项目什么意思)
- 07-07自动化测试早就跑起来了,为什么测试管理还像在走路?
- 07-07Cursor 最强竞争对手来了,专治复杂大项目,免费一个月
- 最近发表
-
- Xiaomi Enters SUV Market with YU7 Launch, Targeting Tesla with Bold Pricing and High-Tech Features
- Black Sesame Maps Expansion Into Robotics With New Edge AI Strategy
- Wuhan's 'Black Tech' Powers China's Cross-Border Push with Niche Electronics and Scientific Firepower
- Maven 干货 全篇共:28232 字。预计阅读时间:110 分钟。建议收藏!
- IT运维必会的30个工具(it运维工具软件)
- 开源项目有你需要的吗?(开源项目什么意思)
- 自动化测试早就跑起来了,为什么测试管理还像在走路?
- Cursor 最强竞争对手来了,专治复杂大项目,免费一个月
- Cursor 太贵?这套「Cline+OpenRouter+Deepseek+Trae」组合拳更香
- 为什么没人真的用好RAG,坑都在哪里? 谈谈RAG技术架构的演进方向
- 标签列表
-
- ifneq (61)
- 字符串长度在线 (61)
- messagesource (56)
- aspose.pdf破解版 (56)
- promise.race (63)
- 2019cad序列号和密钥激活码 (62)
- window.performance (66)
- qt删除文件夹 (72)
- mysqlcaching_sha2_password (64)
- ubuntu升级gcc (58)
- nacos启动失败 (64)
- ssh-add (70)
- jwt漏洞 (58)
- macos14下载 (58)
- yarnnode (62)
- abstractqueuedsynchronizer (64)
- source~/.bashrc没有那个文件或目录 (65)
- springboot整合activiti工作流 (70)
- jmeter插件下载 (61)
- 抓包分析 (60)
- idea创建mavenweb项目 (65)
- vue回到顶部 (57)
- qcombobox样式表 (68)
- tomcatundertow (58)
- pastemac (61)
本文暂时没有评论,来添加一个吧(●'◡'●)