Python 實現(xiàn)注意力機制
引言
隨著信息技術的發(fā)展,海量繁雜的信息向人們不斷襲來,信息無時無刻充斥在四周。然而人類所能接收的信息則是有限的,科研人員發(fā)現(xiàn)人類視覺系統(tǒng)在有限的視野之下卻有著龐大的視覺信息處理能力。在處理視覺數(shù)據(jù)的初期,人類視覺系統(tǒng)會迅速將注意力集中在場景中的重要區(qū)域上,這一選擇性感知機制極大地減少了人類視覺系統(tǒng)處理數(shù)據(jù)的數(shù)量,從而使人類在處理復雜的視覺信息時能夠抑制不重要的刺激,并將有限的神經(jīng)計算資源分配給場景中的關鍵部分,為更高層次的感知推理和更復雜的視覺處理任務(如物體識別、場景分類、視頻理解等)提供更易于處理且更相關的信息。借鑒人類視覺系統(tǒng)的這一特點,科研人員提出了注意力機制的思想。對于事物來說特征的重要性是不同的,反映在卷積網(wǎng)絡中即每張?zhí)卣鲌D的重要性是具有差異性的。注意力機制的核心思想是通過一定手段獲取到每張?zhí)卣鲌D重要性的差異,將神經(jīng)網(wǎng)絡的計算資源更多地投入更重要的任務當中,并利用任務結果反向指導特征圖的權重更新,從而高效快速地完成相應任務。
近兩年,注意力模型被廣泛使用在自然語言處理、圖像識別、語音識別等各種不同類型的深度學習任務當中。
如下圖所示,顏色越深的地方表示關注度越大,即注意力的權重越大。
故本項目將通過搭建 BiLSTM 的注意力機制模型來實現(xiàn)對時間數(shù)據(jù)的格式轉換,實現(xiàn)的最終結果如下:
注意力機制介紹
注意力機制最初在2014年作為RNN中編碼器-****框架的一部分來編碼長的輸入語句,后續(xù)被廣泛運用在RNN中。例如在機器翻譯中通常是用一個 RNN編碼器讀入上下文,得到一個上下文向量,一個RNN****以這個隱狀態(tài)為起始狀態(tài),依次生成目標的每一個單詞。但這種做法的缺點是:無論之前的上下文有多長,包含多少信息量,最終都要被壓縮成一個幾百維的向量。這意味著上下文越大,最終的狀態(tài)向量會丟失越多的信息。輸入語句長度增加后,最終****翻譯的結果會顯著變差。事實上,因為上下文在輸入時已知,一個模型完全可以在解碼的過程中利用上下文的全部信息,而不僅僅是最后一個狀態(tài)的信息,這就是注意力機制的基礎思想。
1.1
基本方法介紹
當前注意力機制的主流方法是將特征圖中的潛在注意力信息進行深度挖掘,最常見的是通過各種手段獲取各個特征圖通道間的通道注意力信息與特征圖內(nèi)部像素點之間的空間注意力信息,獲取的方法也包括但不僅限于卷積操作,矩陣操作構建相關性矩陣等,其共同的目的是更深層次,更全面的獲取特征圖中完善的注意力信息,于是如何更深的挖掘,從哪里去挖掘特征圖的注意力信息,將極有可能會成為未來注意力方法發(fā)展的方向之一。
目前,獲取注意力的方法基本基于通道間的注意力信息、空間像素點之間的注意力信息和卷積核選擇的注意力信息,是否能夠從新的方向去獲取特征圖更豐富的注意力信息,或者以新的方式或手段去獲取更精準的注意力信息也是未來需要關注的一個重點。
模型實驗
2.1
數(shù)據(jù)處理
讀取數(shù)據(jù)集json文件,并將每一個索引轉換為對應的one-hot編碼形式,并設置輸入數(shù)據(jù)最大長度為41。代碼如下:
with open('data/Time Dataset.json','r') as f: dataset = json.loads(f.read()) with open('data/Time Vocabs.json','r') as f: human_vocab, machine_vocab = json.loads(f.read()) human_vocab_size = len(human_vocab) machine_vocab_size = len(machine_vocab) m = len(dataset) def preprocess_data(dataset, human_vocab, machine_vocab, Tx, Ty): m = len(dataset) X = np.zeros([m, Tx], dtype='int32') Y = np.zeros([m, Ty], dtype='int32') for i in range(m): data = dataset[i] X[i] = np.array(tokenize(data[0], human_vocab, Tx)) Y[i] = np.array(tokenize(data[1], machine_vocab, Ty)) Xoh = oh_2d(X, len(human_vocab)) Yoh = oh_2d(Y, len(machine_vocab)) return (X, Y, Xoh, Yoh)
2.2 網(wǎng)絡模型設置
其中Tx=41為序列的最大長度,Ty=5為序列長度,layer1 size設置為32為網(wǎng)絡層,1ayer2 size=64為注意力層,human vocab size=41表述human時間會用到41個不同的字符,machine vocab size=11表述machine時間會用到11個不同的字符。這里雙向LSTM作為Encoder編碼器,全連接層作為Decoder****。
代碼如下:
layer3 = Dense(machine_vocab_size, activation=softmax) def get_model(Tx, Ty, layer1_size, layer2_size, x_vocab_size, y_vocab_size): X = Input(shape=(Tx, x_vocab_size)) a1 = Bidirectional(LSTM(layer1_size, return_sequences=True), merge_mode='concat')(X) a2 = attention_layer(a1, layer2_size, Ty) a3 = [layer3(timestep) for timestep in a2] model = Model(inputs=[X], outputs=a3) return model
2.3
注意力網(wǎng)絡
為了達到反饋更新的作用,注意力網(wǎng)絡在每個輸出時間步上關注輸入的某些部分。_attention_表示哪些輸入與當前輸出步驟最相關。如果一個輸入步驟是相關的,那么它的注意力權重為1,否則為0。_context_是“輸入的摘要”。全局定義部分注意力層,以便每個注意力都有相同的層次。代碼如下:
def one_step_of_attention(h_prev, a): h_repeat = at_repeat(h_prev) i = at_concatenate([a, h_repeat]) i = at_dense1(i) i = at_dense2(i) attention = at_softmax(i) context = at_dot([attention, a]) return context def attention_layer(X, n_h, Ty): h = Lambda(lambda X: K.zeros(shape=(K.shape(X)[0], n_h)))(X) c = Lambda(lambda X: K.zeros(shape=(K.shape(X)[0], n_h)))(X) at_LSTM = LSTM(n_h, return_state=True) output = [] for _ in range(Ty): context = one_step_of_attention(h, X) h, _, c = at_LSTM(context, initial_state=[h, c]) output.append(h) return output
2.4 模型訓練評估
通過調(diào)用get_model函數(shù)獲取整個模型架構,并使用adam優(yōu)化器迭代更新,創(chuàng)建交叉熵損失函數(shù)最后訓練和評估。
代碼如下:
model = get_model(Tx, Ty, layer1_size, layer2_size, human_vocab_size, machine_vocab_size) opt = Adam(lr=0.05, decay=0.04, clipnorm=1.0) model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy']) outputs_train = list(Yoh_train.swapaxes(0,1)) model.fit([Xoh_train], outputs_train, epochs=30, batch_size=100) outputs_test = list(Yoh_test.swapaxes(0,1)) score = model.evaluate(Xoh_test, outputs_test) print('Test loss: ', score[0])
圖片
完整代碼:
鏈接:
https://pan.baidu.com/s/1d9delZAQ7gepH9T9um4dMQ
提取碼:a2ed
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權請聯(lián)系工作人員刪除。