PyTorch深度学习简明实战
上QQ阅读APP看书,第一时间看更新

2.1 机器学习基础

什么是机器学习呢?所谓机器学习,就是让计算机从数据中学习到规律,从而做出预测。很多时候,我们很难直接编写一个算法解决问题,例如一张图片,很难编写算法直接正确预测这张图片中显示的是猫还是狗。为了解决这个问题,人们想到了数据驱动的方法,也就是让计算机从现有的大量的带标签的图片中学习规律,一旦计算机学习到了其中的规律,当我们输入一张新的图片给计算机时,它就可以准确地预测这张图片显示的到底是猫还是狗。

这里有两个关键的因素,一是大量的可学习数据,例如带标签的猫、狗图片;二是学习的主体,我们一般称之为模型。如何理解模型呢?读者可以把模型看成是一个映射函数,它包含一些参数,这些参数可以与输入进行计算得到一个输出,我们一般称之为预测结果。例如,输入一张图片到模型中,图片与模型参数计算得到一个映射结果,这就是预测结果。所谓模型学习的过程,就是模型修正其参数、改进映射关系的过程。可以简单地把模型的学习过程总结如下,以预测图片是猫还是狗为例,步骤如下。

(1)创建模型。

(2)输入一张带标签的图片。

(3)使用模型对此图片做出预测。

(4)将预测结果与实际标签比较,产生的差距一般称为损失。

(5)以减小损失为优化目标,根据损失优化模型参数。

(6)循环重复上述第(2)~(5)步。

下面用一个例子来演示创建模型、优化模型的整个过程。