本發(fā)明涉及圖數(shù)據(jù)挖掘領(lǐng)域,具體來說涉及圖神經(jīng)網(wǎng)絡(luò)知識蒸餾領(lǐng)域,更具體地說,涉及一種基于圖神經(jīng)網(wǎng)絡(luò)架構(gòu)挑選的知識蒸餾方法。
背景技術(shù):
1、圖神經(jīng)網(wǎng)絡(luò)(graph?neural?networks,簡稱gnns)作為一種有效的圖數(shù)據(jù)處理工具,在近年來受到了廣泛關(guān)注。通過有效地捕捉圖結(jié)構(gòu)中的相互依賴關(guān)系,gnns能夠?qū)D中的節(jié)點和邊進(jìn)行高效地推斷和預(yù)測,已經(jīng)在社交網(wǎng)絡(luò)分析、推薦系統(tǒng)、化學(xué)分子分析等領(lǐng)域展現(xiàn)出強(qiáng)大的潛力。然而,在實際應(yīng)用中,如何構(gòu)建一個性能優(yōu)異的圖神經(jīng)網(wǎng)絡(luò)模型依然是一個具有挑戰(zhàn)性的問題。
2、近年來出現(xiàn)了圖神經(jīng)網(wǎng)絡(luò)知識蒸餾(圖蒸餾)方法,通過將預(yù)訓(xùn)練教師圖模型的知識轉(zhuǎn)移到學(xué)生圖模型中,來增強(qiáng)學(xué)生表現(xiàn)。圖蒸餾的核心思想是通過提取教師模型的表示學(xué)習(xí)能力和知識,將其傳遞給學(xué)生圖模型,從而進(jìn)一步提升學(xué)生圖模型的預(yù)測準(zhǔn)確性。這種方法在一定程度上為gnns在圖數(shù)據(jù)分析和應(yīng)用上,提供了一種有效的模型性能增強(qiáng)方案。
3、然而,目前的圖蒸餾方法主要集中在知識的傳遞的優(yōu)化上,所選擇的學(xué)生圖模型的結(jié)構(gòu)可能并不適用教師圖模型所應(yīng)用的分類任務(wù),即教師和學(xué)生的網(wǎng)絡(luò)結(jié)構(gòu)可能不匹配,以致影響學(xué)生圖模型的性能和泛化能力。
4、需要說明的是:本背景技術(shù)僅用于介紹本發(fā)明的相關(guān)信息,以便于幫助理解本發(fā)明的技術(shù)方案,但并不意味著相關(guān)信息必然是現(xiàn)有技術(shù)。相關(guān)信息與本發(fā)明方案一同提交和公開,在沒有證據(jù)表明相關(guān)信息已在本發(fā)明的申請日以前公開的情況下,相關(guān)信息不應(yīng)被視為現(xiàn)有技術(shù)。
技術(shù)實現(xiàn)思路
1、因此,本發(fā)明的目的在于克服上述現(xiàn)有技術(shù)的缺陷,提供一種基于圖神經(jīng)網(wǎng)絡(luò)架構(gòu)挑選的知識蒸餾方法。
2、本發(fā)明的目的是通過以下技術(shù)方案實現(xiàn)的:
3、根據(jù)本發(fā)明的第一方面,提供一種基于圖神經(jīng)網(wǎng)絡(luò)架構(gòu)挑選的知識蒸餾方法,包括:獲取預(yù)先定義的圖神經(jīng)網(wǎng)絡(luò)架構(gòu)的搜索空間,其定義架構(gòu)搜索的范圍;利用神經(jīng)架構(gòu)搜索技術(shù)和預(yù)設(shè)分類任務(wù)對應(yīng)的驗證集,采用強(qiáng)化學(xué)習(xí)機(jī)制從所述搜索空間中搜索執(zhí)行該分類任務(wù)的最優(yōu)圖神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),所述預(yù)設(shè)分類任務(wù)與訓(xùn)練教師圖模型時所對應(yīng)的分類任務(wù)相同;基于預(yù)設(shè)的知識蒸餾方式,利用所述教師圖模型指導(dǎo)采用所述最優(yōu)圖神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)的學(xué)生圖模型進(jìn)行節(jié)點分類,得到經(jīng)訓(xùn)練的學(xué)生圖模型。
4、可選的,強(qiáng)化學(xué)習(xí)機(jī)制包括:設(shè)置參數(shù)化的策略網(wǎng)絡(luò),利用策略網(wǎng)絡(luò)基于所述搜索空間進(jìn)行多次模擬,每次模擬得到一個仿真網(wǎng)絡(luò)結(jié)構(gòu);利用仿真網(wǎng)絡(luò)結(jié)構(gòu)構(gòu)建的模型在驗證集上的性能指標(biāo)作為獎勵,指導(dǎo)策略網(wǎng)絡(luò)優(yōu)化網(wǎng)絡(luò)參數(shù);基于優(yōu)化參數(shù)后的策略網(wǎng)絡(luò),從搜索空間中確定最優(yōu)圖神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)。
5、可選的,所述預(yù)設(shè)的蒸餾方式包括:獲取訓(xùn)練集,其包括多個樣本和每個樣本對應(yīng)的標(biāo)簽,所述樣本為包括節(jié)點和邊的圖數(shù)據(jù),所述標(biāo)簽為圖數(shù)據(jù)中至少部分節(jié)點所屬的類別真值;將訓(xùn)練集中的樣本輸入教師圖模型和學(xué)生圖模型,得到教師圖模型的分類預(yù)測值和學(xué)生圖模型的分類預(yù)測值;利用預(yù)設(shè)的損失函數(shù)、教師圖模型的分類預(yù)測值和學(xué)生圖模型的分類預(yù)測值,指導(dǎo)更新學(xué)生圖模型的參數(shù),以降低預(yù)測損失。
6、可選的,預(yù)設(shè)的損失函數(shù)用于計算一次訓(xùn)練所采用的所有節(jié)點的節(jié)點損失的均值,單個節(jié)點損失采用以下kl散度損失函數(shù)計算:
7、lkl=kl(yt||ys)
8、其中,lkl表示kl散度損失函數(shù),yt是教師圖模型對輸入樣本中單個節(jié)點輸出的分類預(yù)測值,ys是學(xué)生圖模型對yt所對應(yīng)的同一節(jié)點輸出的分類預(yù)測值,kl(·)表示kl散度函數(shù)。
9、可選的,預(yù)設(shè)的損失函數(shù)用于計算一次訓(xùn)練所采用的所有節(jié)點的節(jié)點損失的均值,單個的節(jié)點損失采用以下加權(quán)損失函數(shù)計算:
10、ls=λlkl+(1-λ)lce
11、其中,ls表示加權(quán)損失函數(shù),lkl表示用于計算教師圖模型和學(xué)生圖模型對同一節(jié)點輸出的分類預(yù)測值間損失的kl散度損失函數(shù),lce表示用于計算學(xué)生圖模型對設(shè)有標(biāo)簽的節(jié)點輸出的分類預(yù)測值和標(biāo)簽之間損失的交叉熵?fù)p失函數(shù),λ表示lkl的權(quán)重。
12、可選的,kl散度損失函數(shù)為:
13、lkl=kl(yt||ys)
14、其中,lkl表示kl散度損失函數(shù),yt是教師圖模型對輸入樣本中單個節(jié)點輸出的分類預(yù)測值,ys是學(xué)生圖模型對yt所對應(yīng)的同一節(jié)點輸出的分類預(yù)測值,kl(·)表示kl散度函數(shù)。
15、可選的,圖數(shù)據(jù)為論文引用關(guān)系圖數(shù)據(jù),節(jié)點代表論文,所述邊代表論文之間的引用關(guān)系,所述標(biāo)簽為節(jié)點所屬的主題類別;
16、所述學(xué)生圖模型進(jìn)行節(jié)點分類包括:從輸入的論文引用關(guān)系圖數(shù)據(jù)中提取各節(jié)點的節(jié)點特征,并根據(jù)節(jié)點特征預(yù)測節(jié)點的主題分類預(yù)測值。
17、根據(jù)本發(fā)明的第二方面,提供一種對圖數(shù)據(jù)中節(jié)點進(jìn)行分類的方法,包括:獲取待分類的圖數(shù)據(jù)以及按照第一方面所述的方法得到的經(jīng)訓(xùn)練的學(xué)生圖模型;利用所述經(jīng)訓(xùn)練的學(xué)生圖模型對待分類的圖數(shù)據(jù)中的節(jié)點進(jìn)行分類。
18、根據(jù)本發(fā)明的第三方面,提供一種電子設(shè)備,包括:一個或多個處理器;以及存儲器,其中存儲器用于存儲可執(zhí)行指令;所述一個或多個處理器被配置為經(jīng)由執(zhí)行所述可執(zhí)行指令以實現(xiàn)第一方面和/或第二方面所述方法的步驟。
19、與現(xiàn)有技術(shù)相比,本發(fā)明的優(yōu)點在于:
20、本發(fā)明實施例提供了一種基于圖神經(jīng)網(wǎng)絡(luò)架構(gòu)挑選的知識蒸餾方法,一方面,利用神經(jīng)架構(gòu)搜索技術(shù)和預(yù)設(shè)分類任務(wù)對應(yīng)的驗證集,采用強(qiáng)化學(xué)習(xí)機(jī)制從所述搜索空間中搜索執(zhí)行該分類任務(wù)的最優(yōu)圖神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),所述預(yù)設(shè)分類任務(wù)與訓(xùn)練教師圖模型時所對應(yīng)的分類任務(wù)相同,不僅能夠搜索出適應(yīng)于固定任務(wù)的最優(yōu)圖神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),以解決師生網(wǎng)絡(luò)結(jié)構(gòu)不匹配的問題,又能夠利用教師圖模型的知識來指導(dǎo)學(xué)生圖模型的訓(xùn)練,從而提高學(xué)生圖模型的性能。
1.一種基于圖神經(jīng)網(wǎng)絡(luò)架構(gòu)挑選的知識蒸餾方法,其特征在于,包括:
2.根據(jù)權(quán)利要求1所述的方法,其特征在于,所述強(qiáng)化學(xué)習(xí)機(jī)制包括:
3.根據(jù)權(quán)利要求1或2所述的方法,其特征在于,所述預(yù)設(shè)的蒸餾方式包括:
4.根據(jù)權(quán)利要求3所述的方法,其特征在于,所述預(yù)設(shè)的損失函數(shù)用于計算一次訓(xùn)練所采用的所有節(jié)點的節(jié)點損失的均值,單個節(jié)點損失采用以下kl散度損失函數(shù)計算:
5.根據(jù)權(quán)利要求3所述的方法,其特征在于,所述預(yù)設(shè)的損失函數(shù)用于計算一次訓(xùn)練所采用的所有節(jié)點的節(jié)點損失的均值,單個的節(jié)點損失采用以下加權(quán)損失函數(shù)計算:
6.根據(jù)權(quán)利要求5所述的方法,其特征在于,kl散度損失函數(shù)為:
7.根據(jù)權(quán)利要求3所述的方法,其特征在于,所述圖數(shù)據(jù)為論文引用關(guān)系圖數(shù)據(jù),節(jié)點代表論文,所述邊代表論文之間的引用關(guān)系,所述標(biāo)簽為節(jié)點所屬的主題類別;
8.一種對圖數(shù)據(jù)中節(jié)點進(jìn)行分類的方法,包括:
9.一種計算機(jī)可讀存儲介質(zhì),其特征在于,其上存儲有計算機(jī)程序,所述計算機(jī)程序可被處理器執(zhí)行以實現(xiàn)權(quán)利要求1至8之一所述方法的步驟。
10.一種電子設(shè)備,其特征在于,包括: