混合密度網(wǎng)絡(MDN)進行多元回歸詳解和代碼示例(1)
來源:Deephub Imba
“回歸預測建模是逼近從輸入變量 (X) 到連續(xù)輸出變量 (y) 的映射函數(shù) (f) [...] 回歸問題需要預測具體的數(shù)值。具有多個輸入變量的問題通常被稱為多元回歸問題 例如,預測房屋價值,可能在 100,000 美元到 200,000 美元之間
這是另一個區(qū)分分類問題和回歸問題的視覺解釋如下:
另外一個例子
DENSITY “密度” 是什么意思?這是一個快速的通俗示例:
假設正在為必勝客運送比薩?,F(xiàn)在記錄剛剛進行的每次交付的時間(以分鐘為單位)。交付 1000 次后,將數(shù)據(jù)可視化以查看工作表現(xiàn)如何。這是結果:
這是披薩交付時間數(shù)據(jù)分布的“密度”。平均而言,每次交付需要 30 分鐘(圖中的峰值)。它還表示,在 95% 的情況下(2 個標準差2sd ),交付需要 20 到 40 分鐘才能完成。密度種類代表時間結果的“頻率”。“頻率”和“密度”的區(qū)別在于:
頻率:如果你在這條曲線下繪制一個直方圖并對所有的 bin 進行計數(shù),它將求和為任何整數(shù)(取決于數(shù)據(jù)集中捕獲的觀察總數(shù))。
密度:如果你在這條曲線下繪制一個直方圖并計算所有的 bin,它總和為 1。我們也可以將此曲線稱為概率密度函數(shù) (pdf)。
用統(tǒng)計術語來說,這是一個漂亮的正態(tài)/高斯分布。這個正態(tài)分布有兩個參數(shù):
均值
標準差:“標準差是一個數(shù)字,用于說明一組測量值如何從平均值(平均值)或預期值中展開。低標準偏差意味著大多數(shù)數(shù)字接近平均值。高標準差意味著數(shù)字更加分散?!?/span>
均值和標準差的變化會影響分布的形狀。例如:
有許多具有不同類型參數(shù)的各種不同分布類型。例如:
現(xiàn)在讓我們看看這 3 個分布:
如果我們采用這種雙峰分布(也稱為一般分布):
混合密度網(wǎng)絡使用這樣的假設,即任何像這種雙峰分布的一般分布都可以分解為正態(tài)分布的混合(該混合也可以與其他類型的分布一起定制 例如拉普拉斯):
混合密度網(wǎng)絡也是一種人工神經(jīng)網(wǎng)絡。這是神經(jīng)網(wǎng)絡的經(jīng)典示例:
輸入層(黃色)、隱藏層(綠色)和輸出層(紅色)。
如果我們將神經(jīng)網(wǎng)絡的目標定義為學習在給定一些輸入特征的情況下輸出連續(xù)值。在上面的例子中,給定年齡、性別、教育程度和其他特征,那么神經(jīng)網(wǎng)絡就可以進行回歸的運算。
密度網(wǎng)絡
密度網(wǎng)絡也是神經(jīng)網(wǎng)絡,其目標不是簡單地學習輸出單個連續(xù)值,而是學習在給定一些輸入特征的情況下輸出分布參數(shù)(此處為均值和標準差)。在上面的例子中,給定年齡、性別、教育程度等特征,神經(jīng)網(wǎng)絡學習預測期望工資分布的均值和標準差。預測分布比預測單個值具有很多的優(yōu)勢,例如能夠給出預測的不確定性邊界。這是解決回歸問題的“貝葉斯”方法。下面是預測每個預期連續(xù)值的分布的一個很好的例子:
下面的圖片向我們展示了每個預測實例的預期值分布:
混合密度網(wǎng)絡
最后回到正題,混合密度網(wǎng)絡的目標是在給定特定輸入特征的情況下,學習輸出混合在一般分布中的所有分布的參數(shù)(此處為均值、標準差和 Pi)。新參數(shù)“Pi”是混合參數(shù),它給出最終混合中給定分布的權重/概率。
最終結果如下:
上面的定義和理論基礎已經(jīng)介紹完畢,下面我們開始代碼的演示:
import numpy as np
import pandas as pd
from mdn_model import MDN
from sklearn.datasets import make_moons
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.kernel_ridge import KernelRidge
plt.style.use('ggplot')
生成著名的“半月”型的數(shù)據(jù)集:
X, y = make_moons(n_samples=2500, noise=0.03)
y = X[:, 1].reshape(-1,1)
X = X[:, 0].reshape(-1,1)
x_scaler = StandardScaler()
y_scaler = StandardScaler()
X = x_scaler.fit_transform(X)
y = y_scaler.fit_transform(y)
plt.scatter(X, y, alpha = 0.3)
繪制目標值 (y) 的密度分布:
sns.kdeplot(y.ravel(), shade=True)
通過查看數(shù)據(jù),我們可以看到有兩個重疊的簇:
這時一個很好的多模態(tài)分布(一般分布)。如果我們在這個數(shù)據(jù)集上嘗試一個標準的線性回歸來用 X 預測 y:
model = LinearRegression()
model.fit(X.reshape(-1,1), y.reshape(-1,1))
y_pred = model.predict(X.reshape(-1,1))
plt.scatter(X, y, alpha = 0.3)
plt.scatter(X,y_pred)
plt.title('Linear Regression')
sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'Linear Pred dist')
sns.kdeplot(y.ravel(), shade=True, label = 'True dist')
效果必須不好!現(xiàn)在讓嘗試一個非線性模型(徑向基函數(shù)核嶺回歸):
model = KernelRidge(kernel = 'rbf')
model.fit(X, y)
y_pred = model.predict(X)
plt.scatter(X, y, alpha = 0.3)
plt.scatter(X,y_pred)
plt.title('Non Linear Regression')
sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'NonLinear Pred dist')
sns.kdeplot(y.ravel(), shade=True, label = 'True dist')
雖然結果也不盡如人意,但是比上面的線性回歸要好很多了。
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權請聯(lián)系工作人員刪除。