TensorFlow 2 中文文档 - 回归预测燃油效率
TensorFlow2 文档系列文章链接:
- TensorFlow 2 / 2.0 中文文档 (Jul 9, 2019)
- TensorFlow 2 中文文档 - MNIST 图像分类 (Jul 9, 2019)
- TensorFlow 2 中文文档 - IMDB 文本分类 (Jul 9, 2019)
- TensorFlow 2 中文文档 - 特征工程结构化数据分类 (Jul 9, 2019)
- TensorFlow 2 中文文档 - 回归预测燃油效率 (Jul 11, 2019)
- TensorFlow 2 中文文档 - 过拟合与欠拟合 (Jul 12, 2019)
- TensorFlow 2 中文文档 - 保存与加载模型 (Jul 13, 2019)
- TensorFlow 2 中文文档 - 卷积神经网络分类 CIFAR-10 (Jul 19, 2019)
- TensorFlow 2 中文文档 - TFHub 迁移学习 (Jul 19, 2019)
- TensorFlow 2 中文文档 - RNN LSTM 文本分类 (Jul 22, 2019)
源代码/数据集已上传到 Github - tensorflow2-docs-zh
TF2.0 TensorFlow 2 / 2.0 文档:Regression 回归
主要内容:使用回归预测烟油效率。
回归通常用来预测连续值,比如价格和概率。分类问题不一样,类别是固定的,目的是判断属于哪一类。比如给你一堆猫和狗的图片,判断一张图片是猫还是狗就是一个典型的分类问题。
接下来使用的是经典的 Auto MPG 数据集,这个数据集包括气缸(cylinders),排量(displayment),马力(horsepower) 和重量(weight)等属性。我们需要利用这些属性搭建模型,预测汽车的燃油效率(fuel efficiency)。
模型搭建使用tf.keras
API。
1 | import pathlib |
Auto MPG 数据集
获取数据
1 | # 下载数据集到本地 |
MPG | 气缸 | 排量 | 马力 | 重量 | 加速度 | 年份 | 产地 | |
---|---|---|---|---|---|---|---|---|
0 | 18.0 | 8 | 307.0 | 130.0 | 3504.0 | 12.0 | 70 | 1 |
1 | 15.0 | 8 | 350.0 | 165.0 | 3693.0 | 11.5 | 70 | 1 |
2 | 18.0 | 8 | 318.0 | 150.0 | 3436.0 | 11.0 | 70 | 1 |
清洗数据
检查是否有 NA 值。
1 | dataset.isna().sum() |
1 | MPG 0 |
直接去除含有NA值的行(马力)
1 | dataset = dataset.dropna() |
在获取的数据集中,Origin
(产地)不是数值类型,需转为独热编码。
1 | origin = dataset.pop('产地') |
MPG | 气缸 | 排量 | 马力 | 重量 | 加速度 | 年份 | 美国 | 欧洲 | 日本 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 18.0 | 8 | 307.0 | 130.0 | 3504.0 | 12.0 | 70 | 1.0 | 0.0 | 0.0 |
1 | 15.0 | 8 | 350.0 | 165.0 | 3693.0 | 11.5 | 70 | 1.0 | 0.0 | 0.0 |
2 | 18.0 | 8 | 318.0 | 150.0 | 3436.0 | 11.0 | 70 | 1.0 | 0.0 | 0.0 |
划分训练集与测试集
1 | # 训练集 80%, 测试集 20% |
检查数据
快速看一看训练集中属性两两之间的关系吧。
1 | # 解决中文乱码问题 |
matplotlib 中文乱码看这里:matplotlib图例中文乱码?
你还可以使用train_dataset.describle()
快速浏览每一属性的平均值、标准差、最小值、最大值等信息,能够帮助你快速地识别出不合理的数据。
1 | train_stats = train_dataset.describe() |
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
气缸 | 314.0 | 5.477707 | 1.699788 | 3.0 | 4.00 | 4.0 | 8.00 | 8.0 |
排量 | 314.0 | 195.318471 | 104.331589 | 68.0 | 105.50 | 151.0 | 265.75 | 455.0 |
… | … | … | … | … | … | …. | … | … |
分离 label
1 | # 分离 label |
归一化数据
通常训练前需要归一化数据,不同属性使用的计量单位不一样,值的范围不一样,训练就会很困难。比如其中一个属性的范围是[0.1, 0.5],而另一个属性的范围是[1000, 5000],那数值大的属性就容易对训练产生干扰,很可能导致训练不能收敛,或者是数值小的属性在模型中几乎没有发挥作用。归一化将不同范围的数据映射到[0,1]的空间内,可以有效地避免这个问题。
1 | def norm(x): |
模型
搭建模型
我们的模型包含2个全连接的隐藏层构成,输出层返回一个连续值。
1 | def build_model(): |
1 | Model: "sequential_1" |
训练模型
在之前的案例,比如结构化数据分类,我们调用model.fit
会打印出训练的进度。我们可以禁用默认的行为,并自定义训练进度条。
1 | import sys |
1 | [==================================================] 1000/1000 |
训练过程都存储在了history
对象中,我们可以借助 matplotlib 将训练过程可视化。
1 | hist = pd.DataFrame(history.history) |
loss | mae | mse | val_loss | val_mae | val_mse | epoch | |
---|---|---|---|---|---|---|---|
997 | 3.132053 | 1.142280 | 3.132053 | 9.711935 | 2.361466 | 9.711935 | 997 |
998 | 3.021109 | 1.093424 | 3.021109 | 9.488593 | 2.298264 | 9.488593 | 998 |
999 | 3.028849 | 1.132241 | 3.028849 | 9.453931 | 2.275017 | 9.453931 | 999 |
1 | def plot_history(history): |
从图中,我们可以看到,从100 epoch开始,训练集的loss仍旧继续降低,但验证集的loss却在升高,说明过拟合了,训练应该早一点结束。接下来,我们使用 keras.callbacks.EarlyStopping
,每一波(epoch)训练结束时,测试训练情况,如果训练不再有效果(验证集的loss,即val_loss 不再下降),则自动地停止训练。
1 | model = build_model() |
1 | [=== ] 70/1000 |
在第 70 epoch 时,停止了训练。
接下来使用测试集来评估训练效果。
1 | loss, mae, mse = model.evaluate(normed_test_data, test_labels, verbose=0) |
从图中我们也可以看出,1.9比验证集还略低一点。
预测
最后,我们使用测试集中的数据来预测 MPG 值。
1 | test_pred = model.predict(normed_test_data).flatten() |
看起来,模型训练得还不错。
结论
- 均方误差(Mean Squared Error, MSE) 常作为回归问题的损失函数(loss function),与分类问题不太一样。
- 同样,评价指标(evaluation metrics)也不一样,分类问题常用准确率(accuracy),回归问题常用平均绝对误差 (Mean Absolute Error, MAE)
- 每一列数据都有不同的范围,每一列,即每一个feature的数据需要分别缩放到相同的范围。常用归一化的方式,缩放到[0, 1]。
- 如果训练数据过少,最好搭建一个隐藏层少的小的神经网络,避免过拟合。
- 早停法(Early Stoping)也是防止过拟合的一种方式。
返回文档首页
完整代码:Github - auto_mpg_regression.ipynb
参考文档:Regression: Predict fuel efficiency
附 推荐
- 一篇文章入门 Python
专题: TensorFlow2 文档
本文发表于 2019-07-11,最后修改于 2022-09-09。
本站永久域名「 geektutu.com 」,也可搜索「 极客兔兔 」找到我。
上一篇 « TensorFlow 2 中文文档 - 特征工程结构化数据分类 下一篇 » TensorFlow 2 中文文档 - 过拟合与欠拟合