专业的编程技术博客社区

网站首页 > 博客文章 正文

机器学习之:跑通第一个入门算法(如何跑通github代码)

baijin 2024-09-26 06:50:45 博客文章 3 ℃ 0 评论

时光闹钟app开发者,请关注我,后续分享更精彩!

坚持原创,共同进步!

前言

上篇介绍了机器学习算法的入门(详细请点击 "机器学习入门之:概念介绍" 查阅),本文将带大家一起编写初学者的第一个机器学习算法。好了,废话少序,直接进入正题。

K近邻算法(KNN)

KNN算法是数据挖掘分类技术中最简单的方法之一,思想是每个样本都可以用它最接近的k个邻居来代表。如果一个样本在特征空间中的k个最邻近的样本中的大多数属于某一个类别,则该样本也划分为这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
例如下图中,绿色圆圈表示待测样本类别,当k取3时,有2个红色三角形1个蓝色正方形,因此类型与红色三角形相同;当k取5时,有2个红色三角形3个蓝色正方形,因此类型与蓝色正方形相同。k的取值会直接影响最终的结果。

对于K近邻算法而言,最关键的就是计算待测点和其它样本之间的距离,然后找到最邻近的K个邻居,统计这些邻居中各种类型的个数,获取个数最多的类型。

鸢尾花分类问题

问题介绍:一名植物学爱好者收集了鸢尾花的一些测量数据:花瓣的长度和宽度以及花萼的长度和宽度。他还有一些鸢尾花分类的测量数据,这些花已经被植物学专家鉴定为属于versicolor、 setosa 或 virginica 三个品种之一。对于测量数据,他可以确定每朵鸢尾花所属品种。我们的目标是构建一个机器学习模型,从这些已知品种的鸢尾花测量数据中学习,从而能够预测未知鸢尾花的品种。

鸢尾花数据集主要属性

总的数据量:150条测量数据,包含3个类别,每个类别有50个样本;

每条记录包含5项基本信息:花萼的长度、花萼的宽度、花瓣的长度、花瓣的宽度以及鸢尾花的类别。

鸢尾花数据集文件 iris.txt,文件内容如下:

"Sepal.Length" "Sepal.Width" "Petal.Length" "Petal.Width" "Species"
"1" 5.1 3.5 1.4 0.2 "setosa"
"2" 4.9 3 1.4 0.2 "setosa"
"3" 4.7 3.2 1.3 0.2 "setosa"
"4" 4.6 3.1 1.5 0.2 "setosa"
"5" 5 3.6 1.4 0.2 "setosa"
"6" 5.4 3.9 1.7 0.4 "setosa"
"7" 4.6 3.4 1.4 0.3 "setosa"
"8" 5 3.4 1.5 0.2 "setosa"
"9" 4.4 2.9 1.4 0.2 "setosa"
"10" 4.9 3.1 1.5 0.1 "setosa"
"11" 5.4 3.7 1.5 0.2 "setosa"
"12" 4.8 3.4 1.6 0.2 "setosa"
"13" 4.8 3 1.4 0.1 "setosa"
"14" 4.3 3 1.1 0.1 "setosa"
"15" 5.8 4 1.2 0.2 "setosa"
"16" 5.7 4.4 1.5 0.4 "setosa"
"17" 5.4 3.9 1.3 0.4 "setosa"
"18" 5.1 3.5 1.4 0.3 "setosa"
"19" 5.7 3.8 1.7 0.3 "setosa"
"20" 5.1 3.8 1.5 0.3 "setosa"
"21" 5.4 3.4 1.7 0.2 "setosa"
"22" 5.1 3.7 1.5 0.4 "setosa"
"23" 4.6 3.6 1 0.2 "setosa"
"24" 5.1 3.3 1.7 0.5 "setosa"
"25" 4.8 3.4 1.9 0.2 "setosa"
"26" 5 3 1.6 0.2 "setosa"
"27" 5 3.4 1.6 0.4 "setosa"
"28" 5.2 3.5 1.5 0.2 "setosa"
"29" 5.2 3.4 1.4 0.2 "setosa"
"30" 4.7 3.2 1.6 0.2 "setosa"
"31" 4.8 3.1 1.6 0.2 "setosa"
"32" 5.4 3.4 1.5 0.4 "setosa"
"33" 5.2 4.1 1.5 0.1 "setosa"
"34" 5.5 4.2 1.4 0.2 "setosa"
"35" 4.9 3.1 1.5 0.2 "setosa"
"36" 5 3.2 1.2 0.2 "setosa"
"37" 5.5 3.5 1.3 0.2 "setosa"
"38" 4.9 3.6 1.4 0.1 "setosa"
"39" 4.4 3 1.3 0.2 "setosa"
"40" 5.1 3.4 1.5 0.2 "setosa"
"41" 5 3.5 1.3 0.3 "setosa"
"42" 4.5 2.3 1.3 0.3 "setosa"
"43" 4.4 3.2 1.3 0.2 "setosa"
"44" 5 3.5 1.6 0.6 "setosa"
"45" 5.1 3.8 1.9 0.4 "setosa"
"46" 4.8 3 1.4 0.3 "setosa"
"47" 5.1 3.8 1.6 0.2 "setosa"
"48" 4.6 3.2 1.4 0.2 "setosa"
"49" 5.3 3.7 1.5 0.2 "setosa"
"50" 5 3.3 1.4 0.2 "setosa"
"51" 7 3.2 4.7 1.4 "versicolor"
"52" 6.4 3.2 4.5 1.5 "versicolor"
"53" 6.9 3.1 4.9 1.5 "versicolor"
"54" 5.5 2.3 4 1.3 "versicolor"
"55" 6.5 2.8 4.6 1.5 "versicolor"
"56" 5.7 2.8 4.5 1.3 "versicolor"
"57" 6.3 3.3 4.7 1.6 "versicolor"
"58" 4.9 2.4 3.3 1 "versicolor"
"59" 6.6 2.9 4.6 1.3 "versicolor"
"60" 5.2 2.7 3.9 1.4 "versicolor"
"61" 5 2 3.5 1 "versicolor"
"62" 5.9 3 4.2 1.5 "versicolor"
"63" 6 2.2 4 1 "versicolor"
"64" 6.1 2.9 4.7 1.4 "versicolor"
"65" 5.6 2.9 3.6 1.3 "versicolor"
"66" 6.7 3.1 4.4 1.4 "versicolor"
"67" 5.6 3 4.5 1.5 "versicolor"
"68" 5.8 2.7 4.1 1 "versicolor"
"69" 6.2 2.2 4.5 1.5 "versicolor"
"70" 5.6 2.5 3.9 1.1 "versicolor"
"71" 5.9 3.2 4.8 1.8 "versicolor"
"72" 6.1 2.8 4 1.3 "versicolor"
"73" 6.3 2.5 4.9 1.5 "versicolor"
"74" 6.1 2.8 4.7 1.2 "versicolor"
"75" 6.4 2.9 4.3 1.3 "versicolor"
"76" 6.6 3 4.4 1.4 "versicolor"
"77" 6.8 2.8 4.8 1.4 "versicolor"
"78" 6.7 3 5 1.7 "versicolor"
"79" 6 2.9 4.5 1.5 "versicolor"
"80" 5.7 2.6 3.5 1 "versicolor"
"81" 5.5 2.4 3.8 1.1 "versicolor"
"82" 5.5 2.4 3.7 1 "versicolor"
"83" 5.8 2.7 3.9 1.2 "versicolor"
"84" 6 2.7 5.1 1.6 "versicolor"
"85" 5.4 3 4.5 1.5 "versicolor"
"86" 6 3.4 4.5 1.6 "versicolor"
"87" 6.7 3.1 4.7 1.5 "versicolor"
"88" 6.3 2.3 4.4 1.3 "versicolor"
"89" 5.6 3 4.1 1.3 "versicolor"
"90" 5.5 2.5 4 1.3 "versicolor"
"91" 5.5 2.6 4.4 1.2 "versicolor"
"92" 6.1 3 4.6 1.4 "versicolor"
"93" 5.8 2.6 4 1.2 "versicolor"
"94" 5 2.3 3.3 1 "versicolor"
"95" 5.6 2.7 4.2 1.3 "versicolor"
"96" 5.7 3 4.2 1.2 "versicolor"
"97" 5.7 2.9 4.2 1.3 "versicolor"
"98" 6.2 2.9 4.3 1.3 "versicolor"
"99" 5.1 2.5 3 1.1 "versicolor"
"100" 5.7 2.8 4.1 1.3 "versicolor"
"101" 6.3 3.3 6 2.5 "virginica"
"102" 5.8 2.7 5.1 1.9 "virginica"
"103" 7.1 3 5.9 2.1 "virginica"
"104" 6.3 2.9 5.6 1.8 "virginica"
"105" 6.5 3 5.8 2.2 "virginica"
"106" 7.6 3 6.6 2.1 "virginica"
"107" 4.9 2.5 4.5 1.7 "virginica"
"108" 7.3 2.9 6.3 1.8 "virginica"
"109" 6.7 2.5 5.8 1.8 "virginica"
"110" 7.2 3.6 6.1 2.5 "virginica"
"111" 6.5 3.2 5.1 2 "virginica"
"112" 6.4 2.7 5.3 1.9 "virginica"
"113" 6.8 3 5.5 2.1 "virginica"
"114" 5.7 2.5 5 2 "virginica"
"115" 5.8 2.8 5.1 2.4 "virginica"
"116" 6.4 3.2 5.3 2.3 "virginica"
"117" 6.5 3 5.5 1.8 "virginica"
"118" 7.7 3.8 6.7 2.2 "virginica"
"119" 7.7 2.6 6.9 2.3 "virginica"
"120" 6 2.2 5 1.5 "virginica"
"121" 6.9 3.2 5.7 2.3 "virginica"
"122" 5.6 2.8 4.9 2 "virginica"
"123" 7.7 2.8 6.7 2 "virginica"
"124" 6.3 2.7 4.9 1.8 "virginica"
"125" 6.7 3.3 5.7 2.1 "virginica"
"126" 7.2 3.2 6 1.8 "virginica"
"127" 6.2 2.8 4.8 1.8 "virginica"
"128" 6.1 3 4.9 1.8 "virginica"
"129" 6.4 2.8 5.6 2.1 "virginica"
"130" 7.2 3 5.8 1.6 "virginica"
"131" 7.4 2.8 6.1 1.9 "virginica"
"132" 7.9 3.8 6.4 2 "virginica"
"133" 6.4 2.8 5.6 2.2 "virginica"
"134" 6.3 2.8 5.1 1.5 "virginica"
"135" 6.1 2.6 5.6 1.4 "virginica"
"136" 7.7 3 6.1 2.3 "virginica"
"137" 6.3 3.4 5.6 2.4 "virginica"
"138" 6.4 3.1 5.5 1.8 "virginica"
"139" 6 3 4.8 1.8 "virginica"
"140" 6.9 3.1 5.4 2.1 "virginica"
"141" 6.7 3.1 5.6 2.4 "virginica"
"142" 6.9 3.1 5.1 2.3 "virginica"
"143" 5.8 2.7 5.1 1.9 "virginica"
"144" 6.8 3.2 5.9 2.3 "virginica"
"145" 6.7 3.3 5.7 2.5 "virginica"
"146" 6.7 3 5.2 2.3 "virginica"
"147" 6.3 2.5 5 1.9 "virginica"
"148" 6.5 3 5.2 2 "virginica"
"149" 6.2 3.4 5.4 2.3 "virginica"
"150" 5.9 3 5.1 1.8 "virginica"

代码实现

代码是基于Python实现,详细如下:
(注:别忘记把上述iris.txt文件复制下来,放在和以下Python代码同级目录下)

import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
from collections import Counter

def init_data():  # 初始化数据
    with open("iris.txt", mode="r", encoding="utf-8") as fp:
        lines = fp.readlines()  # 按行读取数据
    iris_datas = []
    for i in range(1, len(lines)):  # 从第2行开始读取
        iris_datas.append(lines[i].replace("\n", "").replace("\"", "").split())

    
    iris_datas = np.array(iris_datas)  # 将数据转化为多维数组

    labels = iris_datas[:, -1]  # 获取标签数据,最后一列
    # print(labels)
    f_datas = iris_datas[:, 1:-1]  # 获取特征数据,第2列到倒数第2列
    f_datas = f_datas.astype(np.float)  # 改变数据类型
    # print(f_datas)
    return f_datas, labels

def draw(f_datas, labels):  # 绘制数据图
    d_dict = defaultdict(list)  # 创建字典,默认值为列表类型
    for i in range(len(labels)):  # 循环遍历每一个样本
        d_dict[labels[i]].append(f_datas[i])  # 将同一品种的鸢尾花放在一起
    styles = ["ro", "b+", "g*"]  # 设置一些样式
    plt.rcParams["font.family"] = "STSong"  # 设置支持中文
    plt.subplot(1, 2, 1)  # 添加子块
    for i, (key, values) in enumerate(d_dict.items()):
        values = np.array(values)  # 获取每一个品种对应的样本
        plt.plot(values[:, 0], values[:, 1], styles[i], label=key)
    plt.legend()  # 显示图例
    plt.title("花萼的分布图")  # 显示标题
    plt.xlabel("花萼的长度")  # 显示X轴标签
    plt.ylabel("花萼的宽度")  # 显示Y轴标签
    plt.subplot(1, 2, 2)  # 添加子块
    for i, (key, values) in enumerate(d_dict.items()):
        values = np.array(values)
        plt.plot(values[:, 2], values[:, 3], styles[i], label=key)
    plt.legend()  # 显示图例
    plt.title("花瓣的分布图")  # 显示标题
    plt.xlabel("花瓣的长度")  # 显示X轴标签
    plt.ylabel("花瓣的宽度")  # 显示Y轴标签
    plt.subplots_adjust(wspace=0.4)  # 调整位置
    plt.show()  # 显示图片

def knn(test_data, train_datas, train_labels, k):
    nums = train_datas.shape[0]  # 获取已知样本的数量
    test_datas = np.tile(test_data, (nums, 1))
    d_1 = test_datas - train_datas  # 相应位置相减
    d_2 = np.square(d_1)  # 求平方
    d_3 = np.sum(d_2, axis=1)  # 按行求和
    d_4 = np.sqrt(d_3)  # 开平方,得到距离
    index = np.argsort(d_4)  # 排序获取排序后元素的索引
    count = Counter(train_labels[index[:k]])  # 统计最邻近的k个邻居的标签
    print(count)
    return count.most_common()[0][0]  # 返回出现次数最多的标签

def test(f_datas, labels):
    index = np.arange(len(labels))
    np.random.shuffle(index)
    test_data = f_datas[index[-1]]
    test_label = labels[index[-1]]
    train_datas = f_datas[index[:-1]]
    train_labels = labels[index[:-1]]

    # print("--- index:", index)
    # print("--- test_data:", test_data, "train_datas:", train_datas)

    predict_label = knn(test_data, train_datas, train_labels, k=10)
    if predict_label == test_label:
        print("预测准确")
    else:
        print("预测错误,预测类别为:", predict_label, "实际的类别为:", test_label)

f_datas, l_datas = init_data()
# 查看数据可视化分析,取消下面注释
# draw(f_datas, l_datas)
for i in range(100):
    test(f_datas, l_datas)

运行上述代码,执行100次预测。每次预测,随机打乱分类标签集合,取最后一个标签集合数据为测试集,其他数据为训练集。通过KNN算法对测试集数据进行预测,输出评估结果。

数据可视化分析

以图表的形式观察鸢尾花花瓣长度和宽度与鸢尾花品种的关系、鸢尾花花萼长度和宽度与鸢尾花品种的关系。关键代码如下(上述代码块中的draw函数)。

def draw(f_datas, labels):  # 绘制数据图
    d_dict = defaultdict(list)  # 创建字典,默认值为列表类型
    for i in range(len(labels)):  # 循环遍历每一个样本
        d_dict[labels[i]].append(f_datas[i])  # 将同一品种的鸢尾花放在一起
    styles = ["ro", "b+", "g*"]  # 设置一些样式
    plt.rcParams["font.family"] = "STSong"  # 设置支持中文
    plt.subplot(1, 2, 1)  # 添加子块
    for i, (key, values) in enumerate(d_dict.items()):
        values = np.array(values)  # 获取每一个品种对应的样本
        plt.plot(values[:, 0], values[:, 1], styles[i], label=key)
    plt.legend()  # 显示图例
    plt.title("花萼的分布图")  # 显示标题
    plt.xlabel("花萼的长度")  # 显示X轴标签
    plt.ylabel("花萼的宽度")  # 显示Y轴标签
    plt.subplot(1, 2, 2)  # 添加子块
    for i, (key, values) in enumerate(d_dict.items()):
        values = np.array(values)
        plt.plot(values[:, 2], values[:, 3], styles[i], label=key)
    plt.legend()  # 显示图例
    plt.title("花瓣的分布图")  # 显示标题
    plt.xlabel("花瓣的长度")  # 显示X轴标签
    plt.ylabel("花瓣的宽度")  # 显示Y轴标签
    plt.subplots_adjust(wspace=0.4)  # 调整位置
    plt.show()  # 显示图片

生成的数据图效果如下,通过图可以观察到花瓣的长度和宽度与鸢尾花品种的相关性更强。
取消Python代码最后的注释,重新运行代码

# 查看数据可视化分析,取消下面注释
draw(f_datas, l_datas)

代码运行结果如下图:

总结

本文通过使用KNN机器算法,实现鸢尾花的分类问题。介绍了KNN算法的原理,鸢尾花Python代码实现,希望对刚入门的初学者们有所帮助。

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表