进击数据挖掘十大算法(五):AdaBoost

引言

提升(boosting)方法是一种常用的统计学习方法,应用广泛且有效。在分类问题中,它通过改变训练样本的权重,学习多个分类器,并将这些分类器进行线性组合,提高分类的性能。

提升方法基于这样一种思想:对于一个复杂任务来说,将多个专家的判断进行适当的综合所得出的判断,要比其中任何一个专家的单独判断好。实际上,就是“三个臭皮匠顶个诸葛亮”的道理。提升方法就是从弱学习算法出发,反复学习,得到一系列弱分类器,然后组合这些弱分类器,构成一个强分类器。大多数提升方法都是改变训练数据的概率分布(训练数据的权值分布),针对不同的训练数据分布调用弱学习算法学习一系列的弱分类器。

这样,对提升方法来说,有两个问题:一是每一轮如何改变训练数据的权值分布;二是如何将弱分类器组合成一个强分类器。关于第一个问题,AdaBoost的做法是提高那些被前一轮弱分类器错误分类样本的权值,而降低那些被正确分类样本的权值。这样一来,那些没有得到正确分类的数据,由于其权值的加大而收到后一轮的弱分类器的更大关注。至于第二个问题,AdaBoost采取加权多数表决的方法,即加大分类误差率小的弱分类器的权值,使其在表决中起较大的作用;减小分类误差率大的弱分类器的权值,使其在表决中起较小的作用。

AdaBoost的巧妙之处就在于它将这些想法自然且有效的实现在一种算法里。

一、AdaBoost算法

假定给定一个二分类的训练数据集:

其中,每个样本点由实例与标记组成。实例 $x_i\in X \subseteq R^n$ ,标记 $y_i \in Y = \{-1, +1\}$,X是实例空间,Y是标记集合。 AdaBoost利用以下算法,从训练数据中学习一系列弱分类器或基本分类器,并将这些弱分类器线性组合成一个强分类器。

1.1 算法 — AdaBoost

  • 输入:训练数据集 $T$;弱学习算法;

  • 输出:最终分类器 $G(x)$

(1) 初始化训练数据的权值分布

(2) 对 $m=1,2,\cdots,M$

  • (a) 使用具有权值分布 $D_m$ 的训练数据集学习,得到基本分类器
  • (b) 计算 $G_m(x)$ 在训练数据集上的分类误差率
  • (c) 计算 $G_m(x)$ 的系数
  • (d) 更新训练数据集的权值分布

(3) 构建基本分类器的线性组合

得到最终分类器

1.2 AdaBoost算法说明

步骤(1):假设训练数据集具有均匀的权值分布,即每个训练样本在基本分类器的学习中作用相同,这一假设保证第1步能够在原始数据上学习基本分类器 $G_1(x)$。

步骤(2): AdaBoost反复学习基本的分类器,在每一轮 $m=1,2,\cdots, M$ 顺次执行下列操作:

  • (a) 使用当前分布 $D_m$ 加权的训练数据集,学习基本分类器 $G_m(x)$。

  • (b) 计算基本分类器 $G_m(x)$ 在加权训练数据集上的分类误差率

  • (c) 计算基本分类器 $G_m(x)$ 的系数 $a_m$。$a_m$ 表示 $G_m(x)$ 在最终分类器中的重要性。由公式可知,分类误差率越小的基本分类器在最终分类器中的作用越大。

  • (d) 更新训练数据的权值分布为下一轮做准备,被基本分类器误分类样本的权值得以扩大,而被正确分类样本的权值得以缩小。因此,误分类样本在下一轮学习中起更大的作用。不改变所给的训练数据,而不断改变训练数据的权值分布,使得训练数据在基本分类器的学习中起不同的作用,这是AdaBoost的一个特点

步骤(3): 线性组合 $f(x)$ 实现 $M$ 个基本分类器的加权表决。系数 $a_m$ 表示了基本分类器 $G_m(x)$ 的重要性。这里,所有 $a_m$ 之和并不为1$f(x)$的符号决定实例 $x$ 的类,$f(x)$ 的绝对值表示分类的确信度。利用基本分类器的线性组合构建最终分类器是AdaBoost的另一特点

二、AdaBoost实例

序号 1 2 3 4 5 6 7 8 9 10
$x$ 0 1 2 3 4 5 6 7 8 9
$y$ 1 1 1 -1 -1 -1 1 1 1 -1

(1) 初始化数据权值分布

(2) 对 $m=1$ :

  • (a) 在权值分布为 $D_1$ 的训练数据上,阈值 $v$ 取2.5时分类误差率最低,故基本分类器为
  • (b) $G_1(x)$ 在训练数据集上的误差率 $e_1 = P(G_1(x_i)\neq y_i)=0.3$

  • (c) 计算 $G_1(x) 的系数:a_1=\frac{1}{2}\ln\frac{1-e_1}{e_1}=0.4236$

  • (d) 更新训练数据的权值分布

分类器 $sign[f_1(x)]$ 在训练数据集上有3个误分类点。

(3) 对 $m=2$:

  • (a) 在权值分布为 $D_2$ 的训练数据上,阈值 $v$ 取8.5时分类误差率最低,故基本分类器为
  • (b) $G_2(x)$ 在训练数据集上的误差率 $e_2 = P(G_2(x_i)\neq y_i)=0.2143$

  • (c) 计算 $G_2(x) 的系数:a_2=\frac{1}{2}\ln\frac{1-e_2}{e_2}=0.6496$

  • (d) 更新训练数据的权值分布

分类器 $sign[f_2(x)]$ 在训练数据集上有3个误分类点。

(4) 对 $m=3$:

  • (a) 在权值分布为 $D_3$ 的训练数据上,阈值 $v$ 取5.5时分类误差率最低,故基本分类器为
  • (b) $G_3(x)$ 在训练数据集上的误差率 $e_3 = P(G_3(x_i)\neq y_i)=0.1820$

  • (c) 计算 $G_3(x) 的系数:a_3=\frac{1}{2}\ln\frac{1-e_3}{e_3}=0.7514$

  • (d) 更新训练数据的权值分布

分类器 $sign[f_2(x)]$ 在训练数据集上误分类点个数为0。

于是最终分类器为

三、Reference

  • 李航《统计学习方法》

(注:公式证明过程略,可以去B站找对应视频,这篇只用作日后回忆)