欢迎来到《pytorch深度学习教程》系列的第六篇!在前面的五篇中,我们已经介绍了Python、numpy及pytorch的基本使用,进行了梯度及神经网络的实践。今天,我们将深入理解激活函数并进行简单的实践学习
欢迎订阅专栏进行系统学习:
深度学习保姆教程_tRNA做科研的博客-CSDN博客
目录
激活函数
1.线性和非线性函数
线性函数
非线性函数
2.Sigmoid、Tanh和ReLU
Sigmoid函数
Tanh函数
ReLU(修正线性单元)
选择合适的激活函数
3.其他激活函数
Leaky ReLU(ReLU的一种变体)
Parametric ReLU(PReLU)
Exponential Linear Unit(ELU)
Swish
选择合适的激活函数
4.结语
激活函数
激活函数为神经网络引入了非线性,使它们能够学习复杂的模式。它们决定了神经元基于其输入的输出。
1.线性和非线性函数
理解线性和非线性函数之间的基本差异在机器学习中至关重要。它们构成了构建复杂模型的基础。
线性函数
线性函数展示了输入和输出之间的直线关系。它们的特点是变化率恒定。
一般形式: y = mx + b
- m: 斜率,表示变化率。
- b: 截距,表示线与y轴的交点。
示例:
import numpy as np
import matplotlib.pyplot as plt
def linear_function(x, m, b):
"""
线性函数
:param x: 输入值
:param m: 斜率
:param b: 截距
:return: 输出值
"""
return m * x + b
# 生成从-5到5的100个等间隔点
x = np.linspace(-5, 5, 100)
# 计算对应的y值,斜率为2,截距为1
y = linear_function(x, 2, 1)
# 绘制图形
plt.plot(x, y)
plt.xlabel('x') # x轴标签
plt.ylabel('y') # y轴标签
plt.title('linear function') # 图形标题
plt.show() # 显示图形
非线性函数
非线性函数不遵循直线模式。它们引入了复杂性,并允许模型捕捉数据中的复杂关系。
常见示例:
- 多项式函数:y = ax^2 + bx + c
- 指数函数:y = a^x
- 对数函数:y = log(x)
- 三角函数:sin(x), cos(x), tan(x)
示例:
import numpy as np
import matplotlib.pyplot as plt
# 定义多个非线性函数
def quadratic_function(x):
"""
二次函数
:param x: 输入值
:return: 输出值
"""
return x**2
def cubic_function(x):
"""
三次函数
:param x: 输入值
:return: 输出值
"""
return x**3
def exponential_function(x):
"""
指数函数
:param x: 输入值
:return: 输出值
"""
return np.exp(x)
def logarithmic_function(x):
"""
对数函数
:param x: 输入值
:return: 输出值
"""
return np.log(np.abs(x) + 1) # 使用绝对值避免负数对数
# 生成从-5到5的100个等间隔点
x = np.linspace(-5, 5, 100)
# 计算各个函数对应的y值
y_quadratic = quadratic_function(x)
y_cubic = cubic_function(x)
y_exponential = exponential_function(x)
y_logarithmic = logarithmic_function(x)
# 创建一个新的图形
plt.figure(figsize=(10, 6))
# 绘制各个函数的图像
plt.plot(x, y_quadratic, label='Quadratic ($x^2$)', color='blue')
plt.plot(x, y_cubic, label='Cubic ($x^3$)', color='red')
plt.plot(x, y_exponential, label='Exponential ($e^x$)', color='green')
plt.plot(x, y_logarithmic, label='Logarithmic ($\\log(|x|+1)$)', color='orange')
# 添加图例
plt.legend()
# 设置轴标签和标题
plt.xlabel('x') # x轴标签
plt.ylabel('y') # y轴标签
plt.title('function') # 图形标题
# 显示网格
plt.grid(True)
# 显示图形
plt.show()
为什么非线性在机器学习中至关重要
- 复杂模式: 真实世界的数据通常表现出非线性关系。
- 决策边界: 非线性函数使模型能够学习复杂的决策边界。
- 深度学习: 非线性激活函数对于深度神经网络是必不可少的。
实际应用
- 线性回归: 基于线性关系预测连续数值。
- 逻辑回归: 使用非线性的sigmoid函数将数据分类到不同的类别。
- 神经网络: 使用多层非线性函数来学习复杂的模式。
2.Sigmoid、Tanh和ReLU
激活函数是神经网络的心脏和灵魂。它们引入了非线性,使模型能够学习复杂的模式。让我们探索一些最常用的激活函数:Sigmoid、Tanh和ReLU。
Sigmoid函数
Sigmoid函数将任何实数映射到0和1之间的值。它常用于二分类问题的输出层。
示例:
import numpy as np
import matplotlib.pyplot as plt
def sigmoid(x):
"""
Sigmoid激活函数
:param x: 输入值
:return: 输出值
"""
return 1 / (1 + np.exp(-x))
# 生成从-10到10的100个等间隔点
x = np.linspace(-10, 10, 100)
# 计算对应的y值
y = sigmoid(x)
# 绘制图形
plt.plot(x, y)
plt.xlabel('x') # x轴标签
plt.ylabel('y') # y轴标签
plt.title('Sigmoid') # 图形标题
plt.show() # 显示图形
挑战:
- 梯度消失问题: 梯度可能变得非常小,从而减慢训练速度。
- 不是零中心化的: 输出总是正的,这可能影响收敛。
Tanh函数
Tanh函数将输入值映射到-1到1的范围内。由于它是零中心化的,因此通常比Sigmoid更受青睐。
示例:
import numpy as np
import matplotlib.pyplot as plt
def tanh(x):
"""
Tanh激活函数
:param x: 输入值
:return: 输出值
"""
return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
# 生成从-10到10的100个等间隔点
x = np.linspace(-10, 10, 100)
# 计算对应的y值
y = tanh(x)
# 绘制图形
plt.plot(x, y)
plt.xlabel('x') # x轴标签
plt.ylabel('y') # y轴标签
plt.title('Tanh') # 图形标题
plt.show() # 显示图形
ReLU(修正线性单元)
ReLU函数是目前使用最广泛的激活函数。它输出输入值和0之间的最大值。
示例:
import numpy as np
import matplotlib.pyplot as plt
def relu(x):
"""
ReLU激活函数
:param x: 输入值
:return: 输出值
"""
return np.maximum(0, x)
# 生成从-5到5的100个等间隔点
x = np.linspace(-5, 5, 100)
# 计算对应的y值
y = relu(x)
# 绘制图形
plt.plot(x, y)
plt.xlabel('x') # x轴标签
plt.ylabel('y') # y轴标签
plt.title('ReLU函数') # 图形标题
plt.show() # 显示图形
ReLU的优势:
- 计算效率高。
- 缓解了梯度消失问题。
选择合适的激活函数
激活函数的选择取决于问题和神经网络的架构。
- Sigmoid: 常用于二分类问题的输出层。
- Tanh: 在隐藏层中通常比Sigmoid表现更好。
- ReLU: 由于其简单性和高效性,是隐藏层中最受欢迎的选择。
3.其他激活函数
虽然Sigmoid、Tanh和ReLU是基础的激活函数,但激活函数的世界提供了多种多样的选项,以适应不同的神经网络架构和问题领域。
Leaky ReLU(ReLU的一种变体)
Leaky ReLU是ReLU函数的一个变体,旨在通过为负输入引入一个小的、非零的梯度来解决“死亡ReLU”问题。
公式:
LeakyReLU(x) = max(αx, x)
其中α是一个小的正常数(通常为0.01)。
import numpy as np
import matplotlib.pyplot as plt
# 定义Leaky ReLU函数
def leaky_relu(x, alpha=0.01):
return np.maximum(alpha * x, x)
# 创建一个输入数组
x = np.linspace(-5, 5, 1000)
y = leaky_relu(x)
# 绘制Leaky ReLU函数图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='Leaky ReLU')
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Leaky ReLU Activation Function')
plt.legend()
plt.grid(True)
plt.show()
可以看到0之前的线是不等于0的非常小的数值
Parametric ReLU(PReLU)
PReLU是Leaky ReLU的一个扩展,其中负输入的斜率是一个可学习的参数。
PReLU的数学表达式如下:
f(x) = max(αx, x)
import numpy as np
import matplotlib.pyplot as plt
# 定义PReLU函数
def prelu(x, alpha):
return np.where(x >= 0, x, alpha * x)
# 创建一个输入数组
x = np.linspace(-5, 5, 1000)
alpha = 0.05 # 初始设定α为0.01,实际应用中α是可学习的参数
y = prelu(x, alpha)
# 绘制PReLU函数图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='PReLU with α={}'.format(alpha))
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Parametric ReLU (PReLU) Activation Function')
plt.legend()
plt.grid(True)
plt.show()
Exponential Linear Unit(ELU)
ELU试图结合ReLU和tanh的优点。它对负输入输出负值,有助于梯度流动。
ELU的数学表达式如下:
f(x) = { α(e^x - 1) if x ≤ 0
x if x > 0 }
其中,α是一个超参数,通常设置为一个小的常数,例如0.1。当x为负值时,ELU函数的输出是α乘以(e^x - 1),这确保了即使在负值区域,梯度也不会完全消失,从而允许网络继续学习。
import numpy as np
import matplotlib.pyplot as plt
# 定义ELU函数
def elu(x, alpha=0.1):
return np.where(x >= 0, x, alpha * (np.exp(x) - 1))
# 创建一个输入数组
x = np.linspace(-3, 3, 1000)
y = elu(x)
# 绘制ELU函数图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='ELU with α={}'.format(alpha))
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Exponential Linear Unit (ELU) Activation Function')
plt.legend()
plt.grid(True)
plt.show()
Swish
Swish是一个自门控的激活函数,平滑地插值于线性和ReLU行为之间。
公式为:
Swish(x) = x * sigmoid(βx)
其中β是一个可学习的参数。
Swish激活函数的特点包括:
- 自我门控(Self-gating):Swish函数通过x*Sigmoid(βx)的形式实现,简化了gating机制,允许其直接替代ReLU等单输入激活函数,而无需改变网络结构。
- 避免梯度消失问题:Swish函数的导数始终大于0,这有助于缓解梯度消失问题。
- 平滑性:Swish函数具有平滑性,有利于优化和泛化。
import numpy as np
import matplotlib.pyplot as plt
# 定义Swish激活函数
def swish(x, beta=1):
return x * (1 / (1 + np.exp(-beta * x)))
# 创建一个输入数组
x = np.linspace(-5, 5, 1000)
y = swish(x)
# 绘制Swish函数图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='Swish Activation Function')
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Swish Activation Function Visualization')
plt.legend()
plt.grid(True)
plt.show()
选择合适的激活函数
最佳的激活函数取决于多种因素:
- 问题类型: 分类、回归或生成任务。
- 网络架构: 网络的深度和复杂性。
- 数据特性: 输入数据的分布。
- 计算资源: 一些激活函数的计算成本更高。
实验和微调
确定最佳激活函数的最佳方法是通过实验。尝试不同的选项并评估它们在特定任务上的性能。
4.结语
以上就是激活函数本次的教程,如果有什么问题欢迎评论区一起讨论!