{"id":53762,"date":"2025-02-16T12:47:29","date_gmt":"2025-02-16T04:47:29","guid":{"rendered":"https:\/\/fwq.ai\/blog\/53762\/"},"modified":"2025-02-16T12:47:29","modified_gmt":"2025-02-16T04:47:29","slug":"%e7%94%a8%e5%b0%8f%e6%a8%a1%e5%9e%8b%e5%90%88%e6%88%90%e8%a1%a8%e6%a0%bc%e6%95%b0%e6%8d%ae","status":"publish","type":"post","link":"https:\/\/fwq.ai\/blog\/53762\/","title":{"rendered":"\u7528\u5c0f\u6a21\u578b\u5408\u6210\u8868\u683c\u6570\u636e"},"content":{"rendered":"<p>\u5408\u6210\u6570\u636e\u751f\u6210\u89e3\u51b3\u4e86\u591a\u4e2a\u57fa\u672c\u6311\u6218\uff1a\u6570\u636e\u96c6\u4e2d\u7684\u7c7b\u522b\u4e0d\u5e73\u8861\u3001\u6570\u636e\u9690\u79c1\u8981\u6c42\u3001\u6570\u636e\u83b7\u53d6\u6210\u672c\u4f18\u5316\u548c\u5b9e\u9a8c\u5468\u671f\u52a0\u901f\u3002\u4f20\u7edf\u65b9\u6cd5\uff08\u5982 SMOTE [1]\uff09\u901a\u8fc7\u5728\u73b0\u6709\u6570\u636e\u70b9\u4e4b\u95f4\u8fdb\u884c\u63d2\u503c\u6765\u4e3a\u5c11\u6570\u7c7b\u751f\u6210\u5408\u6210\u6837\u672c\u3002\u4e4b\u524d\u7684\u535a\u5ba2\u6587\u7ae0 [2] \u5bf9\u8868\u683c\u5408\u6210\uff08\u6570\u503c\uff09\u6570\u636e\u751f\u6210\u7684\u751f\u6210\u65b9\u6cd5\u8fdb\u884c\u4e86\u5168\u9762\u8bc4\u4f30\uff0c\u5305\u62ec\u751f\u6210\u5bf9\u6297\u7f51\u7edc (GAN)\u3001\u53d8\u5206\u81ea\u52a8\u7f16\u7801\u5668 (VAE)\u3001\u9ad8\u65af Copula\u3001\u8d1d\u53f6\u65af\u7f51\u7edc\u548c\u6761\u4ef6\u8868\u683c GAN (CTGAN)\u3002<\/p>\n<p>\u8fd9\u7bc7\u6587\u7ae0\u7814\u7a76\u4e86\u5229\u7528\u5c0f\u8bed\u8a00\u6a21\u578b (SLM) \u751f\u6210\u5408\u6210\u8868\u683c\u6570\u503c\u6570\u636e\u7684\u65b0\u65b9\u6cd5\u3002\u4e0e\u4e4b\u524d\u7684\u7814\u7a76\u4fdd\u6301\u8fde\u7eed\u6027\uff0c\u6211\u4eec\u4e13\u6ce8\u4e8e\u5355\u4e00\u8868\u683c\u6570\u636e\uff0c\u7279\u522b\u662f\u5206\u6790\u6765\u81ea NASA \u827e\u59c6\u65af\u9884\u6d4b\u5353\u8d8a\u4e2d\u5fc3\u7684\u6da1\u6247\u53d1\u52a8\u673a\u9000\u5316\u6a21\u62df\u6570\u636e\u96c6 [3][4]\u3002\u6709\u5173\u6570\u636e\u96c6\u7279\u5f81\u548c\u7814\u7a76\u52a8\u673a\uff0c\u8bfb\u8005\u53ef\u4ee5\u53c2\u8003\u4e4b\u524d\u7684\u51fa\u7248\u7269\u3002<\/p>\n<p>\u8be5\u7814\u7a76\u8003\u5bdf\u4e86\u56db\u79cd\u5173\u952e\u65b9\u6cd5\uff1a<\/p>\n<ul>\n<li>\u5177\u6709\u9886\u57df\u7279\u5b9a\u7ea6\u675f\u7684 SLM \u5fae\u8c03<\/li>\n<li>\u4f7f\u7528\u6570\u503c\u6807\u8bb0\u5668\u548c\u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\u8fdb\u884c\u9ad8\u7ea7\u5fae\u8c03<\/li>\n<li>Transformer GAN \u548c\u6761\u4ef6 Transformer GAN \u67b6\u6784<\/li>\n<li>\u8bed\u8a00\u6a21\u578b GAN (LM-GAN) \u5b9e\u73b0<\/li>\n<\/ul>\n<p>\u5c06\u8bed\u8a00\u6a21\u578b\u5f52\u7c7b\u4e3a\u201c\u5c0f\u578b\u201d\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u9886\u57df\u8868\u73b0\u51fa\u65f6\u95f4\u53d8\u5316\u3002\u4e00\u4e2a\u503c\u5f97\u6ce8\u610f\u7684\u4f8b\u5b50\u662f GPT-2\uff0c\u5b83\u5728 2019 \u5e74\u53d1\u5e03\u65f6\u5177\u6709 15 \u4ebf\u4e2a\u53c2\u6570\uff0c\u88ab\u5f52\u7c7b\u4e3a\u5927\u578b\u6a21\u578b\uff0c\u4f46\u73b0\u5728\u6309\u7167\u5f53\u4ee3\u6807\u51c6\u88ab\u8ba4\u4e3a\u662f\u5c0f\u578b\u7684\u3002\u5f53\u524d\u5206\u7c7b (2024) \u5c06 SLM \u5b9a\u4e49\u4e3a\u5305\u542b 3-100 \u4ebf\u4e2a\u53c2\u6570\u7684\u6a21\u578b\uff0c\u800c\u5927\u578b\u8bed\u8a00\u6a21\u578b (LLM) \u901a\u5e38\u5305\u542b\u6570\u5343\u4ebf\u4e2a\u53c2\u6570\u3002SLM \u9488\u5bf9\u8d44\u6e90\u6548\u7387\u548c\u8fb9\u7f18\u90e8\u7f72\u573a\u666f\u8fdb\u884c\u4e86\u4f18\u5316\uff0c\u4ee3\u8868\u6027\u6a21\u578b\u5305\u62ec Phi 3 [8]\u3001Galactica \u548c Gemma\u3002<\/p>\n<p>SLM \u7684\u67b6\u6784\u591a\u6837\u6027\u4e0e\u5176\u8f83\u5927\u7684\u540c\u7c7b\u4ea7\u54c1\u76f8\u4f3c\uff0c\u5305\u542b\u5404\u79cd\u6ce8\u610f\u529b\u673a\u5236\uff1a<\/p>\n<ul>\n<li>\u591a\u5934\u6ce8\u610f\u529b (MHA)<\/li>\n<li>\u591a\u67e5\u8be2\u6ce8\u610f\u529b (MQA)<\/li>\n<li>\u7ec4\u67e5\u8be2\u6ce8\u610f\u529b (GQA)<\/li>\n<li>\u591a\u5934\u6f5c\u5728\u6ce8\u610f\u529b (MLA)<\/li>\n<\/ul>\n<p>\u8fd9\u4e9b\u6a21\u578b\u5728\u5176\u7ed3\u6784\u7ec4\u4ef6\u4e2d\u8868\u73b0\u51fa\u663e\u8457\u7684\u53d8\u5316 [26]\uff0c\u5305\u62ec\uff1a<\/p>\n<ul>\n<li>\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u5b9e\u73b0\uff08\u6807\u51c6 FFN\u3001\u95e8\u63a7 FFN\uff09<\/li>\n<li>\u6fc0\u6d3b\u51fd\u6570\u9009\u62e9\uff08ReLU\u3001GELU\u3001GELUtanh\u3001SiLU\uff09<\/li>\n<li>\u8bcd\u6c47\u91cf\u8303\u56f4\uff08&lt;50K \u5230 250K+\uff09<\/li>\n<li>\u8bad\u7ec3\u6570\u636e\u91cf\uff08\u4ece\u6570\u767e\u4e07\u5230 6T \u4ee4\u724c\u4e0d\u7b49\uff09<\/li>\n<\/ul>\n<p>\u672c\u7814\u7a76\u91cd\u70b9\u662f\u5229\u7528\u8fd9\u4e9b SLM \u67b6\u6784\u8fdb\u884c\u5408\u6210\u6570\u636e\u751f\u6210\u5e94\u7528\u3002<\/p>\n<p>\u672c\u7814\u7a76\u8bc4\u4f30\u4e86\u4f7f\u7528\u5c0f\u578b\u8bed\u8a00\u6a21\u578b (SLM) \u8fdb\u884c\u5408\u6210\u8868\u683c\u6570\u636e\u751f\u6210\u7684\u56db\u79cd\u9ad8\u7ea7\u65b9\u6cd5\uff0c\u5e76\u7b80\u8981\u6982\u8ff0\u4e86\u5177\u6709\u7ed3\u6784\u5316\u7ea6\u675f\u7684\u5feb\u901f\u5de5\u7a0b\u3002\u9488\u5bf9\u7279\u5b9a\u5236\u9020\u76ee\u6807\u5bf9 SLM \u8fdb\u884c\u5fae\u8c03\u3001\u7ed3\u5408\u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\u7684\u9ad8\u7ea7\u5fae\u8c03\u3001\u5c06 Transformers \u4e0e GAN \u76f8\u7ed3\u5408\u7684\u6df7\u5408\u67b6\u6784\u4ee5\u53ca\u65b0\u9896\u7684\u8bed\u8a00\u6a21\u578b GAN (LM-GAN) \u6846\u67b6\u3002\u8be5\u5b9e\u73b0\u5229\u7528 Microsoft \u7684 Phi-3.5 mini [13] \u8fdb\u884c\u5fae\u8c03\u5b9e\u9a8c\uff0c\u4f7f\u7528 DistilGPT-2 [23] \u8fdb\u884c\u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\u96c6\u6210\uff0c\u5e76\u4e3a Transformer GAN \u548c LM-GAN \u6a21\u578b\u5f00\u53d1\u4e13\u7528\u67b6\u6784\u3002\u6bcf\u79cd\u65b9\u6cd5\u90fd\u9010\u6b65\u5efa\u7acb\u5728 SLM \u7684\u529f\u80fd\u4e4b\u4e0a\uff0c\u6700\u7ec8\u5f62\u6210 LM-GAN \u67b6\u6784\uff0c\u8be5\u67b6\u6784\u5c55\u793a\u4e86\u5408\u6210\u5236\u9020\u6570\u636e\u4e2d\u7edf\u8ba1\u5c5e\u6027\u7684\u5353\u8d8a\u4fdd\u5b58\u3002<\/p>\n<h2>1\u3001\u63d0\u793a\u5de5\u7a0b<\/h2>\n<p>\u63d0\u793a\u5de5\u7a0b\u4ee3\u8868\u4e86\u4e00\u79cd\u57fa\u672c\u800c\u6709\u6548\u7684\u5408\u6210\u6570\u636e\u751f\u6210\u65b9\u6cd5\u3002\u8be5\u65b9\u6cd5\u9700\u8981\u4e3a\u8bed\u8a00\u6a21\u578b\u6784\u5efa\u7ed3\u6784\u5316\u8f93\u5165\uff0c\u5305\u62ec\u5b57\u6bb5\u63cf\u8ff0\u548c\u89c4\u8303\u3001\u9886\u57df\u7279\u5b9a\u7ea6\u675f\u548c\u5177\u6709\u5c11\u91cf\u5b66\u4e60\u7684\u6837\u672c\u6570\u636e\u6837\u672c\u3002\u8be5\u65b9\u6cd5\u5141\u8bb8\u6307\u5b9a\u751f\u6210\u53c2\u6570\uff0c\u5305\u62ec\u8f93\u51fa\u91cf\u548c\u6761\u4ef6\u7ea6\u675f\u3002\u7814\u7a76\u8868\u660e\uff0c\u6709\u4e24\u79cd\u4e3b\u8981\u7684\u63d0\u793a\u8303\u5f0f\uff1a\u81ea\u7136\u8bed\u8a00\u63cf\u8ff0 [10] \u548c\u7ed3\u6784\u5316 CSV \u683c\u5f0f [9]\uff0c\u6bcf\u79cd\u8303\u5f0f\u90fd\u4f18\u5316\u4e86 token \u5229\u7528\u6548\u7387\u7684\u4e0d\u540c\u65b9\u9762\u3002\u9ad8\u7ea7\u63d0\u793a\u6280\u672f\u7ed3\u5408\u4e86\u968f\u673a\u503c\u66ff\u6362\u4ee5\u589e\u5f3a\u6570\u636e\u591a\u6837\u6027\u548c\u5206\u5c42\u5206\u7ec4\u673a\u5236\uff0c\u4ee5\u5b9e\u73b0\u6761\u4ef6\u751f\u6210 [11][12]\u3002\u867d\u7136\u63d0\u793a\u5de5\u7a0b\u4e3a\u5408\u6210\u6570\u636e\u751f\u6210\u63d0\u4f9b\u4e86\u5de8\u5927\u7684\u6f5c\u529b\uff0c\u4f46\u672c\u6587\u91cd\u70b9\u4ecb\u7ecd\u66f4\u5148\u8fdb\u7684\u65b9\u6cd5\u6846\u67b6\uff0c\u5c06\u63d0\u793a\u6280\u672f\u7684\u5168\u9762\u63a2\u7d22\u7559\u7ed9\u73b0\u6709\u6587\u732e\u548c\u8bfb\u8005\u8fdb\u4e00\u6b65\u63a2\u7d22\u63d0\u793a\u5de5\u7a0b<\/p>\n<h2>2\u3001SLM \u5fae\u8c03<\/h2>\n<p>\u6211\u4eec\u63a2\u7d22\u7684\u7b2c\u4e00\u79cd\u6280\u672f\u662f\u5fae\u8c03\u8bed\u8a00\u6a21\u578b\u3002\u672c\u6587\u4e3a\u6b64\u8003\u8651\u4e86 Microsoft Phi 3.5 mini instruct \u6a21\u578b [13]\u3002\u5b83\u662f\u4e00\u4e2a\u4ec5\u89e3\u7801\u5668\u7684 transformer \u67b6\u6784\u6a21\u578b\uff0c\u5177\u6709 3.82B \u53c2\u6570\uff0c\u5728 3.4 \u4e07\u4ebf\u4e2a token \u4e0a\u8fdb\u884c\u8bad\u7ec3\u3002\u5b83\u7684\u4e0a\u4e0b\u6587\u957f\u5ea6\u4e3a 128k \u5230kens\uff0c\u5e76\u9488\u5bf9\u57fa\u672c\u63a8\u7406\u4efb\u52a1\u3001\u4ee3\u7801\u751f\u6210\u548c\u6570\u5b66\u95ee\u9898\u89e3\u51b3\u8fdb\u884c\u4e86\u4e13\u95e8\u4f18\u5316\u3002<\/p>\n<p>\u4e3a\u4e86\u8fdb\u4e00\u6b65\u8bf4\u660e\uff0c\u8ba9\u6211\u4eec\u4f7f\u7528\u8868\u683c\u6570\u503c\u6570\u636e\u751f\u6210\u6765\u6b63\u5f0f\u6784\u5efa\u95ee\u9898\u3002<\/p>\n<p>\u5047\u8bbeD\u4e3a\u5177\u6709 n \u4e2a\u6837\u672c\u7684\u8bad\u7ec3\u6570\u636e\u96c6 &nbsp;<code>D = {X1\u3001X2\u3001X3\u3001\u2026\u3001Xn}<\/code>\u3002\u6bcf\u4e2a\u6837\u672c\u90fd\u6709\u4e00\u7ec4 m \u4e2a\u952e\u503c\u5bf9\uff0c\u4ee5\u5b57\u6bb5\u540d\u79f0\u4f5c\u4e3a\u952e\u3002<\/p>\n<p>\u8fd9\u4e9b\u5206\u5e03\u5728 n \u884c\u4e2d\uff0c\u5176\u4e2d\u5bf9\u7531\u4e00\u4e2a\u7279\u6b8a\u6807\u8bb0\u5206\u9694\u3002\u5bf9\u4e8e\u6b64\u7ec3\u4e60\uff0c\u9009\u62e9 <code>:::<\/code> \u4f5c\u4e3a\u6807\u8bb0\u3002\u8fde\u63a5\u8fd0\u7b97\u7b26 C \u6267\u884c\u6b64\u64cd\u4f5c\u5e76\u751f\u6210\u7ec4\u5408\u5b57\u7b26\u4e32\u3002<\/p>\n<p>\u5176\u4e2d \u03b4 \u8868\u793a\u7279\u6b8a\u6807\u8bb0\u3002\u751f\u6210\u7684\u6587\u672c\u88ab\u8fdb\u4e00\u6b65\u6807\u8bb0\u3002\u6807\u8bb0\u5316\u51fd\u6570 T \u5c06\u8fde\u63a5\u7684\u5b57\u7b26\u4e32\u6620\u5c04\u5230\u4e00\u7cfb\u5217\u6807\u8bb0\uff1a<\/p>\n<p>\u6700\u540e\uff0c\u8fd9\u4e9b\u6807\u8bb0\u7528\u4e8e\u5fae\u8c03\u8bed\u8a00\u6a21\u578b\uff0c\u8be5\u6a21\u578b\u6839\u636e\u79cd\u5b50\u6807\u8bb0\u9884\u6d4b\u4e0b\u4e00\u4e2a\u6807\u8bb0\u3002<\/p>\n<p>\u5bf9\u4e8e\u5408\u6210\u6570\u636e\u751f\u6210\uff0c\u7ed9\u5b9a\u79cd\u5b50\u6807\u8bb0 s\uff0c\u6a21\u578b\u4ece\u5b66\u4e60\u7684\u5206\u5e03 X^ \u4e2d\u91c7\u6837\u3002<\/p>\n<p>   \u56fe 1- SLM \u5fae\u8c03\u4ee5\u751f\u6210\u5408\u6210\u6570\u636e <\/p>\n<p>\u4e0a\u9762\u7684\u56fe 1 \u89e3\u91ca\u4e86\u8fd9\u4e2a\u8fc7\u7a0b\u3002\u6807\u6709 <code>F{i}<\/code> \u7684\u6b65\u9aa4\u7528\u4e8e\u5fae\u8c03\u8fc7\u7a0b\uff0c\u6807\u6709 <code>G{i}<\/code> \u7684\u6b65\u9aa4\u7528\u4e8e\u751f\u6210\u3002<\/p>\n<p>\u79cd\u5b50\u6807\u8bb0\u7684\u9009\u62e9\u53d6\u51b3\u4e8e\u6a21\u578b\u7684\u5fae\u8c03\u65b9\u5f0f\u3002\u5982\u679c\u5b57\u6bb5\u5728\u8fde\u63a5\u8fd0\u7b97\u7b26\u4e4b\u524d\u88ab\u968f\u673a\u5316\uff0c\u5219\u79cd\u5b50\u6807\u8bb0\u53ef\u4ee5\u662f\u4efb\u4f55\u5b57\u6bb5\u540d\u79f0\u3002\u6216\u8005\uff0c\u5b83\u53ef\u4ee5\u7528\u4e00\u7ec4\u786e\u5b9a\u5176\u4ed6\u5b57\u6bb5\u5206\u5e03\u7684\u56fa\u5b9a\u5b57\u6bb5\u540d\u79f0\u8fdb\u884c\u8c03\u8282\u3002\u4f60\u53ef\u4ee5\u5728\u6587\u732e\u4e2d\u627e\u5230\u8fd9\u4e24\u79cd\u6280\u672f\u90fd\u7ecf\u8fc7\u4e86\u8bc4\u4f30\uff0c\u7ed3\u679c\u56e0\u6570\u636e\u96c6\u800c\u5f02[10]\u3002\u867d\u7136\u8fd9\u7bc7\u6587\u7ae0\u5229\u7528\u4e86\u901a\u7528\u5b57\u6bb5\u540d\u79f0\uff0c\u4f46\u6709\u7814\u7a76\u8868\u660e\uff0c\u4f7f\u7528\u63cf\u8ff0\u6027\u5b57\u6bb5\u540d\u79f0\u53ef\u4ee5\u8fdb\u4e00\u6b65\u63d0\u9ad8\u5408\u6210\u6837\u672c\u7684\u4fdd\u771f\u5ea6[14]\u3002<\/p>\n<p>\u4e3a\u4e86\u5fae\u8c03\u6a21\u578b\uff0c\u516c\u5f00\u53ef\u7528\u7684\uff08\u901a\u8fc7 Huggingface\uff09Phi 3.5 mini\u6a21\u578b\u7684\u6743\u91cd[15]\u88ab\u7528\u4f5c\u9884\u8bad\u7ec3\u6743\u91cd\u3002\u4f7f\u7528 AdamW \u4f18\u5316\u5668[16]\u5bf9\u6a21\u578b\u8fdb\u884c\u5fae\u8c03\uff0c\u8be5\u4f18\u5316\u5668\u5177\u6709\u6052\u5b9a\u5b66\u4e60\u7387\u8c03\u5ea6\u7a0b\u5e8f\u548c 2e \u2212 4 \u7684\u5b66\u4e60\u7387\u3002\u4e0e\u6807\u51c6 Adam \u76f8\u6bd4\uff0cAdamW \u4f18\u5316\u5668\u63d0\u4f9b\u4e86\u66f4\u597d\u7684\u6cdb\u5316\u6027\u80fd\uff0c\u5e76\u4e14\u4e00\u76f4\u5728\u5fae\u8c03\u4efb\u52a1[17][18]\u3002\u8fd9\u662f\u56e0\u4e3a\u5b83\u5c06\u6743\u91cd\u8870\u51cf\u4e0e\u8ddf\u8e2a\u4e00\u9636\u548c\u4e8c\u9636\u77e9\u53ca\u5176\u5404\u81ea\u7684\u6743\u91cd\u8870\u51cf\u5206\u79bb\u3002\u5bf9\u4e8e\u5fae\u8c03\uff0cHuggingface transformers[19] \u5e93\u7684\u8bad\u7ec3\u5668\u7c7b\u4e0e\u4f4e\u79e9\u81ea\u9002\u5e94[20] \u548c BitsAndBytes[21] \u91cf\u5316\u4e00\u8d77\u4f7f\u7528\u3002<\/p>\n<p>\u4ee5\u4e0b\u914d\u7f6e\u63d0\u4f9b\u4e86\u5fae\u8c03\u8bbe\u7f6e\u7684\u5feb\u901f\u5feb\u7167<\/p>\n<pre><code>model_id: &amp;model_id \"microsoft\/Phi-3.5-mini-instruct\"\n\ntokenizer_config:\n    max_length: 350\n    truncation: True\n    padding: \"max_length\"\n\ntraining_env:\n    model_dir: \"opt\/ml\/phi-model\" #directory for fine tuned model weights\n    cache_dir: &amp;cache_dir \"\/tmp\/.cache\" #directory for storing pretrained model weights(downloaded from huggingface)\n    merge_dir: \"\/tmp\/phi-model\" #directory for storing merged model weights \n\nmodel_config:\n    trust_remote_code: True\n    cache_dir: *cache_dir\n    device_map: \"auto\"\n    torch_dtype: \"float16\"\n    attn_implementation: \"flash_attention_2\"\n\nbnb_config:\n    load_in_4bit: True\n    bnb_4bit_use_double_quant: True\n    bnb_4bit_quant_type: \"nf4\"\n    bnb_4bit_compute_dtype: \"bfloat16\"\n\nlora_config:\n    r: 8\n    lora_alpha: 16\n    lora_dropout: 0.1\n    bias: \"none\"\n    task_type: \"CAUSAL_LM\"\n\ntraining_config:\n    per_device_train_batch_size: 4\n    per_device_eval_batch_size: 1\n    gradient_accumulation_steps: 2\n    gradient_checkpointing: True    \n    learning_rate: 0.0002\n    lr_scheduler_type: \"constant\"      \n    num_train_epochs: 1\n    logging_strategy: \"steps\"\n    logging_steps: 10\n    log_on_each_node: False\n    bf16: True\n    ddp_find_unused_parameters: False\n    fsdp: \"\" #fsdp turned off \n    fsdp_config: null\n    save_strategy: \"no\"\n    output_dir: \"outputs\"\n    report_to: none\n    optim: adamw_torch       \n    save_strategy: epoch\n    max_grad_norm: 0.3 \n    warmup_ratio: 0.03<\/code><\/pre>\n<p>\u4e0b\u56fe 2 \u63d0\u4f9b\u4e86\u539f\u59cb\uff08\u6d4b\u8bd5\uff09\u5206\u5e03\u548c\u5408\u6210\u5206\u5e03\u4e4b\u95f4\u7684 Kullback-Leibler (KL) \u6563\u5ea6\u3002KL \u6563\u5ea6\u503c\u8de8\u5ea6\u7ea6\u4e3a\u4e24\u4e2a\u6570\u91cf\u7ea7\uff0c\u8303\u56f4\u4ece ~0.5 \u5230 ~40\uff0c\u8868\u660e\u5408\u6210\u6570\u636e\u8d28\u91cf\u5728\u4e0d\u540c\u8bbe\u7f6e\u4e0b\u5b58\u5728\u663e\u8457\u5dee\u5f02\u3002\u5bf9\u4e8e\u5176\u4ed6\u5b57\u6bb5\uff0c\u89c2\u5bdf\u5230\u7684 KL \u6563\u5ea6\u503c\u8f83\u4f4e\uff0c\u8868\u793a\u4fdd\u771f\u5ea6\u8f83\u9ad8\u3002<\/p>\n<p>  \u56fe 2 \u5fae\u8c03 Phi-3.5 \u2014 KL \u6563\u5ea6 <\/p>\n<h2>3\u3001\u9ad8\u7ea7\u5fae\u8c03<\/h2>\n<p>\u63a5\u4e0b\u6765\uff0c\u8fd9\u7bc7\u6587\u7ae0\u7814\u7a76\u4e86\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u8003\u8651 KL \u6563\u5ea6\u7684\u5f71\u54cd\uff0c\u5e76\u5c06\u5176\u89c6\u4e3a\u635f\u5931\u51fd\u6570\u7684\u4e00\u90e8\u5206\u3002\u5728\u672c\u8282\u4e2d\uff0c\u6211\u4eec\u5c06\u7814\u7a76\u4f7f\u7528 DistilGPT-2 \u8bed\u8a00\u6a21\u578b\u751f\u6210\u5408\u6210\u6570\u636e\u7684\u9ad8\u7ea7\u5fae\u8c03\u7b56\u7565\u7684\u5b9e\u73b0\u3002\u5b83\u662f GPT2[22] \u7684\u538b\u7f29\u7248\u672c\uff0c\u901a\u8fc7\u77e5\u8bc6\u84b8\u998f\u5f00\u53d1\u800c\u6210\u3002\u5b83\u6709 8200 \u4e07\u4e2a\u53c2\u6570\uff0c\u662f\u4f7f\u7528\u77e5\u8bc6\u84b8\u998f\u5f00\u53d1\u7684\uff0c\u5728 1.24 \u4ebf\u4e2a\u53c2\u6570\u7248\u672c\u7684 GPT-2 \u7684\u76d1\u7763\u4e0b\u8fdb\u884c\u4e86\u9884\u8bad\u7ec3\uff0c\u5927\u7ea6\u662f\u7236\u6a21\u578b\u7684\u4e00\u534a\u5927\u5c0f\u3002\u5176\u67b6\u6784\u5305\u542b 6 \u4e2a\u8f6c\u6362\u5668\u5757\uff08GPT2\uff1a12\uff09\uff0c\u540c\u65f6\u4fdd\u7559\u6bcf\u4e2a\u5757\u7684 12 \u4e2a\u6ce8\u610f\u529b\u5934\u548c 768 \u7684\u5d4c\u5165\u7ef4\u5ea6\u3002\u8be5\u6a21\u578b\u4fdd\u7559\u4e86 1024 \u4e2a\u6807\u8bb0\u7684\u4e0a\u4e0b\u6587\u7a97\u53e3\u548c 50,257 \u4e2a\u6807\u8bb0\u7684\u8bcd\u6c47\u91cf[23]\u3002<\/p>\n<p>\u672c\u6587\u63a2\u8ba8\u4e86\u4f7f\u7528\u4e24\u79cd\u7b56\u7565\u5bf9 GPT2 \u8fdb\u884c\u5fae\u8c03<\/p>\n<h3>3.1 \u6570\u503c\u5206\u8bcd\u5668<\/h3>\n<p>\u6570\u503c\u5206\u8bcd\u5668\u7684\u52a8\u673a\u6709\u4e24\u4e2a\u65b9\u9762\u3002\u9996\u5148\uff0c\u5b83\u4e3a\u6570\u503c\u4e2d\u7684\u6240\u6709\u6570\u5b57\u63d0\u4f9b\u76f8\u540c\u7684\u6743\u91cd\uff0c\u800c\u9ed8\u8ba4\u6807\u8bb0\u5668\u53ef\u4ee5\u5728\u4e0d\u540c\u7684\u6570\u5b57\u5904\u62c6\u5206\u4ee5\u751f\u6210\u6807\u8bb0\u3002\u7b2c\u4e8c\u4e2a\u52a8\u673a\u662f\u4e3a\u5217\u503c\u3001\u5206\u9694\u7b26\u6807\u8bb0\uff08 <code>:::<\/code>\uff09\u548c\u5b57\u6bb5\u5206\u9694\u7b26\uff08 <code>,<\/code>\uff09\u521b\u5efa\u4e13\u7528\u7684\u5355\u6570\u6807\u8bb0\u3002\u5b83\u53ef\u4ee5\u6b63\u5f0f\u5b9a\u4e49\u4e3a\u57fa\u672c GPT-2 \u6807\u8bb0\u5668 Tb \u7684\u6269\u5c55\uff0c\u5982\u4e0b\u6240\u793a\uff1a<\/p>\n<p>Tn \u662f\u62c6\u5206\u6570\u503c\u7684\u6570\u503c\u6807\u8bb0\u51fd\u6570\u3002<\/p>\n<p>\u6807\u8bb0\u5668\u7684\u8bcd\u6c47\u8868\u5b9a\u4e49\u4e3a\uff1a<\/p>\n<p>\u5176\u4e2d V \u662f\u7ed3\u679c\u8bcd\u6c47\u8868\uff0cVb \u662f\u57fa\u672c GPT-2 \u8bcd\u6c47\u8868\uff0cF \u662f\u5b57\u6bb5\u540d\u79f0\u96c6\uff0cD \u662f\u6570\u5b57\u548c\u5c0f\u6570\u70b9\u96c6\u3002<\/p>\n<p>\u53ef\u4ee5\u5728\u4e0b\u9762\u627e\u5230\u6b64\u4ee3\u7801\uff1a<\/p>\n<pre><code>from transformers import AutoTokenizer\nimport re\nfrom typing import List, Union, Dict\nclass NumTokenizer:\n    def __init__(self,model_id,reqd_cols,sep_token):\n        self.base_tokenizer = AutoTokenizer.from_pretrained(model_id)\n        self.base_tokenizer.pad_token = self.base_tokenizer.eos_token # Set Tokenizer pad Token\n        self.base_tokenizer.special_tokens = [sep_token]\n        self.base_tokenizer.special_tokens.extend(reqd_cols)\n        self.base_tokenizer.special_tokens.extend(',')\n        self.num_pattern = re.compile(rf'{sep_token.strip()}\\d+(?:\\.\\d+)?')\n\n    def tokenize_num(self,num_text):\n        return [t for t in num_text]\n    \n    def __call__(self,text: Union[str,List[str]],padding:bool=True,truncation:bool=True,max_length:int=None,return_tensors:str=None, **kwargs) -&gt; Dict:\n        if isinstance(text,str):\n            tl = [text]\n        else:\n            tl = text\n        encoded_inputs = []\n        for t in tl:\n            encoded = self.encode(t)\n            encoded_inputs.append(encoded)\n        if max_length is None:\n            max_length = max(len(e) for e in encoded_inputs)\n        if padding:\n            encoded_inputs = [enc + [self.base_tokenizer.pad_token_id] * (max_length - len(enc)) for enc in encoded_inputs]\n\n        if truncation:\n            encoded_inputs = [enc[:max_length] for enc in encoded_inputs]\n        \n        if isinstance(text,str):\n            output = {\"input_ids\": encoded_inputs[0], \"attention_mask\" : [1] * len(encoded_inputs[0])}\n        else:\n            output = {\"input_ids\": encoded_inputs, \"attention_mask\" : [[1] * len(enc) for enc in encoded_inputs]}\n        \n        if return_tensors:\n            output = {k: torch.tensor(v) for k,v in output.items()}\n        return output\n\n        \n    def tokenize(self,text):\n        col_names = self.num_pattern.split(text)\n        col_values = [n.replace(sep_token.strip(),'').strip() for n in self.num_pattern.findall(text)]\n        tokens = []\n        for col_name, col_value in zip (col_names,col_values + ['']):\n            tokens.extend(self.base_tokenizer.tokenize(col_name))\n            tokens.extend(self.tokenize_num(col_value))\n        return tokens\n    def encode(self,text, **kwargs):\n        tokens = self.tokenize(text)\n        return self.base_tokenizer.convert_tokens_to_ids(tokens)\n                \n    def decode(self,token_ids, **kwargs):\n        return self.base_tokenizer.decode(token_ids)\n        \n    def __getattr__(self,name):\n        return getattr(self.base_tokenizer,name)<\/code><\/pre>\n<h2>3.2 \u5e26\u6709 KL \u6563\u5ea6\u635f\u5931\u7684\u81ea\u5b9a\u4e49\u8bad\u7ec3\u5668<\/h2>\n<p>\u7b2c\u4e8c\u4e2a\u52a8\u673a\u5f15\u5165\u4e86\u4e00\u79cd\u6df7\u5408\u635f\u5931\u51fd\u6570\uff0c\u65e8\u5728\u5e73\u8861\u7ef4\u62a4\u8bed\u8a00\u7ed3\u6784\uff08\u5b57\u6bb5\u540d\u79f0\uff09\u548c\u6570\u503c\u7ed3\u6784\uff08\u5b57\u6bb5\u503c\uff09\u3002\u635f\u5931\u51fd\u6570\u5c06\u6587\u672c\u8fde\u8d2f\u6027\u7684\u4ea4\u53c9\u71b5\u635f\u5931\u4e0e\u6570\u503c\u51c6\u786e\u6027\u7684 KL \u6563\u5ea6\u76f8\u7ed3\u5408\uff0c\u7531\u53c2\u6570 \u03b1 \u548c \u03b2 \u52a0\u6743\uff0c\u503c\u5206\u522b\u8bbe\u7f6e\u4e3a 0.6 \u548c 0.4\u3002\u8fd9\u79cd\u516c\u5f0f\u4f7f\u6a21\u578b\u80fd\u591f\u5b66\u4e60\u8868\u683c\u6570\u636e\u8868\u793a\u7684\u7ed3\u6784\u6a21\u5f0f\u548c\u6570\u503c\u7684\u5e95\u5c42\u5206\u5e03\uff0c\u8fd9\u79cd\u65b9\u6cd5\u7684\u76ee\u6807\u662f\u89e3\u51b3\u5408\u6210\u6570\u636e\u751f\u6210\u4e2d\u7684\u57fa\u672c\u6311\u6218\uff1a\u4fdd\u7559\u5b57\u6bb5\u4e4b\u95f4\u7684\u8bed\u4e49\u5173\u7cfb\u548c\u6570\u503c\u7684\u7edf\u8ba1\u5c5e\u6027\u3002<\/p>\n<p>\u6587\u672c\u635f\u5931 Ltext \u662f\u7528\u4ea4\u53c9\u71b5\u51fd\u6570\u8ba1\u7b97\u7684\uff0c\u8be5\u51fd\u6570\u6839\u636e\u6bcf\u4e2a\u5143\u7d20\u7684\u9884\u6d4b\u6982\u7387\u5206\u5e03\u8ba1\u7b97\u771f\u5b9e\u7c7b\u6807\u7b7e\u7684\u8d1f\u5bf9\u6570\u4f3c\u7136\uff0c\u7136\u540e\u8ba1\u7b97\u6240\u6709\u5143\u7d20\u4e2d\u8fd9\u4e9b\u635f\u5931\u7684\u5e73\u5747\u503c\u3002<\/p>\n<p>\u5bf9\u4e8e\u6570\u503c\uff0c\u9884\u6d4b\u5206\u5e03\u548c\u771f\u5b9e\u5206\u5e03\u4e4b\u95f4\u7684 KL \u6563\u5ea6 Lnum\uff1a<\/p>\n<p>\u5b9e\u73b0\u5982\u4e0b\uff1a<\/p>\n<pre><code>import torch.nn.functional as F\n\n\ndef kl_div(p, q):\n    return (p * torch.log(p \/ q)).sum(-1)\n    \nclass CustomTrainer(transformers.Trainer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n    \n    def compute_loss(self,model,inputs,num_items_in_batch,return_outputs=False):\n        labels = inputs.pop(\"labels\")\n        labels = torch.where(labels == -100, tokenizer.pad_token_id,labels)\n        outputs = model(**inputs)\n        logits = outputs.logits\n        logits_m =  torch.argmax(logits[:, :-1, :], dim=-1)\n        labels_m = labels[:, 1:].contiguous()\n        \n        pred_texts = tokenizer.batch_decode(logits_m,skip_special_tokens=False)\n        label_texts = tokenizer.batch_decode(labels_m,skip_special_tokens=False)\n        pred_list = text2dict(pred_texts,df_meta_data,sep_token)\n        label_list = text2dict(label_texts,df_meta_data,sep_token)\n\n        shift_logits = logits[:, :-1, :]\n        shift_labels = labels[:, 1:]\n        \n        #column names losses  \n        text_loss = F.cross_entropy(\n            shift_logits.reshape(-1, shift_logits.size(-1)),\n            shift_labels.reshape(-1),\n            ignore_index=tokenizer.pad_token_id,)\n        \n        #column val loss with KL diveregence\n        num_losses = []\n        for pred_dict, label_dict in zip(pred_list, label_list):\n            for key in pred_dict.keys():\n                if key != '' and key in label_dict:\n                    pred_value = pred_dict[key]\n                    true_value = label_dict[key]\n                    if pred_value is not None and true_value is not None:\n                        num_losses.append(kl_div(torch.tensor([pred_value], device=logits.device), torch.tensor([true_value], device=logits.device)))\n\n        num_loss = torch.mean(torch.stack(num_losses)) if num_losses else 0\n\n        alpha = 0.6  # Weightage for text loss\n        beta = 0.4  # Weightage for numerical loss\n        combined_loss = alpha * text_loss + beta * num_loss\n\n        if return_outputs:\n            return combined_loss, outputs\n        else:\n            return combined_loss<\/code><\/pre>\n<p>\u4e0d\u540c\u5217\u7684 KL \u6563\u5ea6\u6d4b\u91cf\u7ed3\u679c\u63ed\u793a\u4e86\u6a21\u578b\u6355\u83b7\u5e95\u5c42\u5206\u5e03\u7684\u80fd\u529b\u5b58\u5728\u663e\u8457\u5dee\u5f02\u3002\u603b\u4f53\u800c\u8a00\uff0c\u4e0e\u539f\u59cb\u5fae\u8c03\u76f8\u6bd4\uff0c\u8fd9\u4e9b\u5206\u5e03\u7684\u8d28\u91cf\u6709\u6240\u63d0\u9ad8\uff0c\u4f46\u5b57\u6bb5 s16 \u548c s18 \u7684\u5f02\u5e38\u503c\u8868\u660e\u8fd9\u4e9b\u7279\u5f81\u7684\u539f\u59cb\u5206\u5e03\u548c\u5408\u6210\u5206\u5e03\u4e4b\u95f4\u5b58\u5728\u76f8\u5f53\u5927\u7684\u5dee\u5f02\u3002\u9700\u8981\u6ce8\u610f\u7684\u662f\uff0c\u8be5\u6a21\u578b\u6ca1\u6709\u4f7f\u7528\u4e0d\u540c\u7684\u8d85\u53c2\u6570\u8fdb\u884c\u8bc4\u4f30\uff0c\u800c\u662f\u5728 3 \u4e2a\u65f6\u671f\u5185\u8fdb\u884c\u4e86\u5fae\u8c03\u3002<\/p>\n<p>  \u56fe 3- KL \u6563\u5ea6- Distil-GPT2 \u7684\u5fae\u8c03 <\/p>\n<h2>4\u3001Transformer GAN<\/h2>\n<p>\u5df2\u7ecf\u6709\u7814\u7a76\u5229\u7528 Transformer \u6a21\u578b\u751f\u6210\u5408\u6210\u8868\u683c\u6570\u636e[30][31]\u3002\u8fd9\u7bc7\u6587\u7ae0\u4ecb\u7ecd\u4e86\u5c06 Transformer \u67b6\u6784\u4e0e\u751f\u6210\u5bf9\u6297\u7f51\u7edc (GAN) \u76f8\u7ed3\u5408\u4ee5\u751f\u6210\u5408\u6210\u8868\u683c\u6570\u636e\u7684\u65b9\u6cd5\u3002\u8fd9\u4e2a\u63d0\u51fa\u7684 Transformer GAN \u6846\u67b6\u5229\u7528 Transformer \u4f5c\u4e3a\u751f\u6210\u5668\u7684\u5f3a\u5927\u81ea\u6ce8\u610f\u529b\u673a\u5236\uff0c\u5e76\u4e0e\u81ea\u5b9a\u4e49\u9274\u522b\u5668\u7f51\u7edc\u914d\u5bf9\u3002<\/p>\n<p>\u6211\u4eec\u63a2\u7d22\u4e86\u4e24\u79cd\u4e0d\u540c\u7684\u53d8\u4f53\uff1a\u539f\u59cb TransformerGAN \u548c\u6761\u4ef6 Transformer GAN\uff0c\u6bcf\u79cd\u53d8\u4f53\u90fd\u63d0\u4f9b\u4e86\u72ec\u7279\u7684\u5408\u6210\u6570\u636e\u751f\u6210\u529f\u80fd\u3002\u539f\u59cb\u67b6\u6784\u91c7\u7528\u57fa\u4e8e Transformer \u7684\u751f\u6210\u5668\uff0c\u901a\u8fc7\u591a\u4e2a\u81ea\u6ce8\u610f\u5c42\u5904\u7406\u8f93\u5165\u566a\u58f0\u5411\u91cf\uff0c\u800c\u6761\u4ef6\u53d8\u4f53\u5219\u7ed3\u5408\u4e86\u989d\u5916\u7684\u4e0a\u4e0b\u6587\u4fe1\u606f\u6765\u6307\u5bfc\u751f\u6210\u8fc7\u7a0b\u3002\u8be5\u67b6\u6784\u8fd8\u5305\u62ec\u4f4d\u7f6e\u7f16\u7801\u7684\u96c6\u6210\uff0c\u4ee5\u5e2e\u52a9\u6a21\u578b\u66f4\u597d\u5730\u7406\u89e3\u8868\u683c\u6570\u636e\u4e2d\u7684\u987a\u5e8f\u6a21\u5f0f\u3002<\/p>\n<p>\u4e0b\u56fe\u6982\u8ff0\u4e86 Transformer GAN \u53ca\u5176\u5b9e\u73b0\u7684\u7ec4\u4ef6\uff1a<\/p>\n<p>  \u56fe -4 \u7528\u4e8e\u5408\u6210\u6570\u636e\u751f\u6210\u7684 Transformer GAN <\/p>\n<h3>4.1 \u539f\u59cb Transformer GAN<\/h3>\n<p>\u4e3a\u4e86\u6b63\u5f0f\u89e3\u91ca\u8fd9\u4e00\u70b9\uff0c\u8bad\u7ec3\u6d89\u53ca\u4e24\u4e2a\u795e\u7ecf\u7f51\u7edc\uff0c\u4e00\u4e2a\u751f\u6210\u5668\u548c\u4e00\u4e2a\u9274\u522b\u5668\uff0c\u4ee5\u5bf9\u6297\u7684\u65b9\u5f0f\u8fdb\u884c\u3002Transformer \u5145\u5f53\u751f\u6210\u5668\u7f51\u7edc\uff0c\u8868\u793a\u4e3a G\uff0c\u4ee5\u6f5c\u5728\u5411\u91cf z \u4f5c\u4e3a\u8f93\u5165\u5e76\u751f\u6210\u5408\u6210\u6570\u636e\u6837\u672c x^ = G(z)\uff1a<\/p>\n<p>\u8fd9\u91cc Femb \u8868\u793a\u8f93\u5165\u5d4c\u5165\u5c42\uff0cPE \u8868\u793a\u4f4d\u7f6e\u7f16\u7801\uff0cFout \u8868\u793a\u8f93\u51fa\u7ebf\u6027\u5c42\u3002<\/p>\n<p>\u4e3a\u4e86\u751f\u6210\u8868\u683c\u6570\u636e\u7b49\u7ed3\u6784\u5316\u8f93\u51fa\uff0c\u6211\u4eec\u7684\u6a21\u578b\u9700\u8981\u4f7f\u7528\u6b63\u786e\u7684\u5b57\u6bb5\u987a\u5e8f\u53ca\u5176\u5e8f\u5217\u8fdb\u884c\u8bad\u7ec3[28]\u3002\u8fd9\u662f\u901a\u8fc7\u63d0\u4f9b\u76f8\u5bf9\u548c\u7edd\u5bf9\u4f4d\u7f6e\u4fe1\u606f\u7684\u4f4d\u7f6e\u7f16\u7801\u5b8c\u6210\u7684\u3002\u4ece\u4e0a\u9762\u53ef\u4ee5\u770b\u51fa\uff0c\u4f4d\u7f6e\u7f16\u7801\u5177\u6709\u4e0e\u8f93\u5165\u5d4c\u5165\u5c42\u76f8\u540c\u7684\u7ef4\u5ea6\uff0c\u5e76\u4e14\u53ef\u4ee5\u6dfb\u52a0\u4e24\u4e2a\u5d4c\u5165\u3002\u5bf9\u4e8e\u7ed9\u5b9a\u7684\u6a21\u578b\u7ef4\u5ea6 dmodel\uff0c\u4f4d\u7f6e pos \u548c\u7ef4\u5ea6 i \u7684\u4f4d\u7f6e\u7f16\u7801[25]\u5b9a\u4e49\u4e3a\uff1a<\/p>\n<p>\u5728\u6b64\u5b9e\u73b0\u4e2d\uff0ci \u7684\u8303\u56f4\u4ece 0 \u5230 dmodel\/2\u3002\u8fd9\u91c7\u7528\u6b63\u5f26\u4f4d\u7f6e\u7f16\u7801\uff0c\u5176\u4e2d\u6bcf\u4e2a\u4f4d\u7f6e\u548c\u7ef4\u5ea6\u5bf9\u7684\u7f16\u7801\u90fd\u662f\u4f7f\u7528\u4ea4\u66ff\u7684\u6b63\u5f26\u548c\u4f59\u5f26\u51fd\u6570\u8ba1\u7b97\u7684\u3002\u5bf9\u4e8e\u5076\u6570\u7ef4\u5ea6 (2i)\uff0c\u7f16\u7801\u4f7f\u7528\u6b63\u5f26\u51fd\u6570\uff0c\u800c\u5947\u6570\u7ef4\u5ea6 (2i+1) \u4f7f\u7528\u4f59\u5f26\u51fd\u6570\u3002\u5bf9\u6bcf\u4e2a\u4f4d\u7f6e\u91cd\u590d\u6b64\u64cd\u4f5c\u3002\u8fd9\u79cd\u65b9\u6cd5\u5728\u6211\u4eec\u7684\u8868\u683c\u6570\u636e\u73af\u5883\u4e2d\u7279\u522b\u6709\u4ef7\u503c\uff0c\u56e0\u4e3a\u5b83\u4f7f\u6a21\u578b\u80fd\u591f\u6355\u83b7\u5c40\u90e8\u548c\u5168\u5c40\u4f4d\u7f6e\u5173\u7cfb\u3002<\/p>\n<pre><code>class PositionalEncoding(nn.Module):\n    def __init__(self,d_model,max_positions=1024,n=10000):\n        super(PositionalEncoding,self).__init__()\n        pe = torch.zeros(max_positions*d_model).reshape(max_positions, d_model) \n        k = torch.arange(0,max_positions).unsqueeze(1)\n        i = torch.arange(d_model\/\/2)\n        div_term = (n ** ((2*i)\/d_model))   \n        theta = 1\/div_term\n        pe[:, 0::2] = torch.sin(k * theta) \n        pe[:, 1::2] = torch.cos(k * theta)\n        self.pe = pe.to(device)\n    def forward(self,x):\n        x = x + self.pe[:x.size()[0],:]\n        return x<\/code><\/pre>\n<p>\u9274\u522b\u5668\u7f51\u7edc\uff08\u8868\u793a\u4e3a D\uff09\u5c06\u6570\u636e\u6837\u672c x \u4f5c\u4e3a\u8f93\u5165\uff0c\u5e76\u9884\u6d4b\u5176\u4e3a\u5b9e\u6570\u6216\u4ece\u96c6\u5408 (0,1) \u4e2d\u751f\u6210\u7684\u6982\u7387 D(x)\u3002\u672c\u8d28\u4e0a\uff0c\u9274\u522b\u5668\u53ef\u4ee5\u5b9a\u4e49\u4e3a\u4e00\u7cfb\u5217\u7ebf\u6027\u53d8\u6362\uff1a<\/p>\n<p>\u5176\u4e2d \u03c3 \u8868\u793a LeakyReLU \u6fc0\u6d3b\uff0c\u4f46\u6700\u540e\u4e00\u5c42\u9664\u5916\uff0c\u8be5\u5c42\u4f7f\u7528 sigmoid\u3002<\/p>\n<pre><code>class TransformerGenerator(nn.Module):\n    def __init__(self, input_dim, model_dim,num_heads,num_layers,feedforward_dim):\n        super(TransformerGenerator, self).__init__()\n        self.embedding = nn.Linear(input_dim,model_dim)\n        self.pos_encoding = PositionalEncoding(d_model=model_dim)\n        encoder_layer = nn.TransformerEncoderLayer(model_dim,num_heads,feedforward_dim,dropout=0.2)\n        self.transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers)\n        self.fc_out = nn.Linear(model_dim,input_dim)\n \n    def forward(self, x):\n        emb = self.embedding(x)\n        pe = self.pos_encoding(emb)\n        x = emb + pe\n        x = self.transformer_encoder(x)\n        return self.fc_out(x)\n\nclass Discriminator(nn.Module):\n    def __init__(self, input_dim):\n        super(Discriminator,self).__init__()\n        self.model = nn.Sequential(\n            nn.Linear(input_dim, 1024),\n            nn.LeakyReLU(0.2),\n            nn.Linear(1024, 512),\n            nn.LeakyReLU(0.2),\n            nn.Linear(512, 256),\n            nn.LeakyReLU(0.2),\n            nn.Linear(256, 128),\n            nn.LeakyReLU(0.2),\n            nn.Linear(128, 1),\n            nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        return self.model(x)<\/code><\/pre>\n<p>\u4ee5\u4e0b\u8d85\u53c2\u6570\u7528\u4e8e\u8bad\u7ec3\u7f51\u7edc\u3002<\/p>\n<pre><code># Set hyperparams\nbatch_size = 32\nnum_epochs = 40\nlr = 0.0001\ncondition_dim = 3\ninput_dim = len(df.columns) - condition_dim\nmodel_dim = 512\nnum_heads = 16\nnum_layers = 18\nfeedforward_dim = 512<\/code><\/pre>\n<p>\u8bad\u7ec3\u5faa\u73af\u5c06\u4f20\u7edf\u7684 GAN \u8bad\u7ec3\u4e0e\u4e13\u4e3a\u8868\u683c\u6570\u636e\u5408\u6210\u800c\u8bbe\u8ba1\u7684\u989d\u5916\u4f18\u5316\u6280\u672f\u76f8\u7ed3\u5408\u3002\u521d\u59cb\u8bad\u7ec3\u540e\u5b9e\u65bd\u7684\u989d\u5916\u7b56\u7565\u5bfc\u81f4\u6a21\u5f0f\u5d29\u6e83[29]\u3002\u8fd9\u5bfc\u81f4\u751f\u6210\u5668\u635f\u5931\u53d8\u5f97\u975e\u5e38\u5927\u4e14\u4e3a\u8d1f\uff0c\u800c\u9274\u522b\u5668\u635f\u5931\u53d8\u5c0f\u5e76\u6536\u655b\u5230\u96f6\u3002\u4e3a\u4e86\u89e3\u51b3\u8fd9\u4e2a\u95ee\u9898\uff0c\u5bf9\u9274\u522b\u5668\u635f\u5931\u5e94\u7528\u4e86\u6807\u7b7e\u5e73\u6ed1\uff081 &#8211; \u5e73\u6ed1\uff09\u4ee5\u9632\u6b62\u8fc7\u5ea6\u81ea\u4fe1\u3002\u9274\u522b\u5668\u7684\u635f\u5931\uff08d_loss\uff09\u8ba1\u7b97\u4e3a\u771f\u5b9e\u548c\u5047\u6837\u672c\u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\u4e4b\u548c\uff0c\u4f18\u5316\u4e86\u9274\u522b\u5668\u533a\u5206\u771f\u5b9e\u548c\u5408\u6210\u5236\u9020\u6570\u636e\u7684\u80fd\u529b\u3002<\/p>\n<p>\u751f\u6210\u5668\u8bad\u7ec3\u9636\u6bb5\u5f15\u5165\u4e86\u51e0\u79cd\u590d\u6742\u7684\u6280\u672f\u6765\u63d0\u9ad8\u5408\u6210\u6570\u636e\u7684\u8d28\u91cf\u3002\u8f93\u5165\u566a\u58f0\u5411\u91cf\u4f7f\u7528 0 \u5230 2 \u4e4b\u95f4\u7684\u968f\u673a\u56e0\u5b50\u52a8\u6001\u7f29\u653e\uff0c\u4ece\u800c\u5728\u751f\u6210\u7684\u6837\u672c\u4e2d\u5f15\u5165\u4e86\u53ef\u53d8\u6027\u3002\u751f\u6210\u5668\u7684\u635f\u5931\u51fd\u6570\u7ed3\u5408\u4e86\u4e24\u4e2a\u90e8\u5206\uff1a\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931 (g_loss_bce)\uff0c\u9f13\u52b1\u751f\u6210\u5668\u751f\u6210\u53ef\u4ee5\u6b3a\u9a97\u9274\u522b\u5668\u7684\u6837\u672c\uff1bKullback-Leibler \u6563\u5ea6\u9879 (g_loss_kl)\uff0c\u7528\u4e8e\u6d4b\u91cf\u751f\u6210\u7684\u6570\u636e\u5206\u5e03\u4e0e\u5b9e\u9645\u6570\u636e\u5206\u5e03\u4e4b\u95f4\u7684\u7edf\u8ba1\u8ddd\u79bb\u3002KL \u6563\u5ea6\u9879\u7531 kl_div_weight \u52a0\u6743\uff0c\u4ee5\u5e73\u8861\u5176\u5bf9\u6574\u4f53\u635f\u5931\u7684\u8d21\u732e\u3002\u751f\u6210\u5668\u548c\u9274\u522b\u5668\u4f18\u5316\u5668\u90fd\u5229\u7528\u5b66\u4e60\u7387\u8c03\u5ea6\u7a0b\u5e8f\u5728\u8bad\u7ec3\u671f\u95f4\u81ea\u9002\u5e94\u5730\u8c03\u6574\u5b66\u4e60\u7387\uff0c\u4ece\u800c\u6709\u53ef\u80fd\u63d0\u9ad8\u6536\u655b\u7a33\u5b9a\u6027\u3002<\/p>\n<pre><code>for epoch in range(num_epochs):\n    for batch in dataloader:\n        real_data = batch[0].to(device)\n        batch_size = real_data.size(0)\n\n        optimizer_D.zero_grad()\n\n        real_labels = torch.ones(batch_size, 1,device=device)* (1 - smoothing)\n        real_outputs = model.discriminator(real_data)\n        real_loss = criterion_bce(real_outputs, real_labels)\n\n        z = torch.randn(batch_size, input_dim,device=device)\n        fake_data = model.generator(z).to(device)\n        fake_labels = torch.zeros(batch_size, 1,device=device)* smoothing\n        fake_outputs = model.discriminator(fake_data)\n        fake_loss = criterion_bce(fake_outputs, fake_labels)\n\n        d_loss = real_loss + fake_loss\n        d_loss.backward()\n        optimizer_D.step()\n\n        optimizer_G.zero_grad()\n\n        z = torch.randn(batch_size, input_dim,device=device)\n        scale = torch.rand(batch_size, 1, device=device) * 2  # Random scale between 0 and 2\n        z = z * scale\n        \n        fake_data = model.generator(z)\n        fake_outputs = model.discriminator(fake_data)\n        g_loss_bce = criterion_bce(fake_outputs, real_labels)\n\n\n        # KL divergence \n        g_loss_kl =  torch.abs(criterion_kl(fake_data, real_data))\n        g_loss = g_loss_bce + kl_div_weight * g_loss_kl \n        g_loss.backward()        \n\n        optimizer_G.step()\n        scheduler_G.step()\n        scheduler_D.step()\n        \n    print(f\"Epoch [{epoch+1}\/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}\")   <\/code><\/pre>\n<h3>4.2 \u6761\u4ef6 Transformer GAN<\/h3>\n<p>\u7ed9\u5b9a\u4e00\u4e2a\u771f\u5b9e\u6570\u636e\u5206\u5e03 X \u548c\u6761\u4ef6\u7a7a\u95f4 C\uff0c\u4ee5\u53ca\u6f5c\u5728\u7a7a\u95f4 Z\uff0c\u6761\u4ef6\u751f\u6210\u5668 G \u548c\u9274\u522b\u5668 D \u5b9a\u4e49\u4e3a\uff1a<\/p>\n<p>\u6761\u4ef6\u751f\u6210\u5668\u67b6\u6784\u7ed3\u5408\u4e86\u6f5c\u5728\u8f93\u5165\u548c\u6761\u4ef6\u4fe1\u606f\u7684\u53cc\u91cd\u5d4c\u5165\u7b56\u7565\u3002<\/p>\n<p>\u5176\u4e2d Femb \u662f\u8f93\u5165\u5d4c\u5165\u5c42\uff0cFcond \u662f\u6761\u4ef6\u5d4c\u5165\u5c42\uff0cPE \u662f\u4f4d\u7f6e\u7f16\u7801\u3002\u5728\u7279\u5b9a\u6570\u636e\u96c6\u4e2d\uff0csettings1\u3001settings2 \u548c settings3 \u6784\u6210\u6761\u4ef6\u7a7a\u95f4\u3002\u6b64\u5916\uff0c\u5b83\u8fd8\u5305\u62ec\u4e00\u4e2a\u53ef\u9009\u7684\u57fa\u4e8e CNN \u7684\u5d4c\u5165\u8def\u5f84\uff0c\u53ef\u4ee5\u901a\u8fc7\u5377\u79ef\u5c42\u5904\u7406\u8f93\u5165\u6570\u636e\uff0c\u63d0\u4f9b\u66ff\u4ee3\u7684\u7279\u5f81\u63d0\u53d6\u673a\u5236\u3002<br \/>\u6761\u4ef6\u9274\u522b\u5668\u5904\u7406\u771f\u5b9e\/\u751f\u6210\u7684\u6837\u672c\u548c\u6761\u4ef6\uff1a<\/p>\n<p>\u5176\u4e2d [x; c] \u8868\u793a\u8f93\u5165\u548c\u6761\u4ef6\u5411\u91cf\u7684\u8fde\u63a5\u3002\u6761\u4ef6\u5411\u91cf c \u662f\u4ece\u8f93\u5165\u7684\u524d\u4e09\u4e2a\u5206\u91cf\u4e2d\u63d0\u53d6\u7684\u3002c = x1:3\uff0cx\u2032 = x4:n<\/p>\n<pre><code>class ConditionalTransformerGenerator(nn.Module):\n    def __init__(self, input_dim,condition_dim, model_dim,num_heads,num_layers,feedforward_dim):\n        super(ConditionalTransformerGenerator, self).__init__()\n        self.embedding = nn.Linear(input_dim,model_dim)\n        self.condition_embedding = nn.Linear(condition_dim, model_dim)\n        self.pos_encoding = PositionalEncoding(d_model=model_dim)\n        encoder_layer = nn.TransformerEncoderLayer(model_dim,num_heads,feedforward_dim,dropout=0.2)\n        self.transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers)\n        self.fc_out = nn.Linear(model_dim,input_dim)\n        self.cnn_embedding = nn.Sequential(\n            nn.Conv1d(1, 32, kernel_size=3, stride=1, padding=1),\n            nn.LeakyReLU(0.2),\n            nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1),\n            nn.LeakyReLU(0.2),\n            nn.MaxPool1d(2)\n        )\n        self.linear_proj = nn.Linear(640, model_dim)\n \n    def forward(self, x,condition):\n        if cnn_embeddings:\n            x = x.unsqueeze(1)  #for CNN\n            emb = self.cnn_embedding(x)\n            emb = emb.view(emb.size(0), -1)  # flatten\n            emb = self.linear_proj(emb)\n        else:\n            emb = self.embedding(x)\n        condition_emb = self.condition_embedding(condition)\n        x = emb + condition_emb\n        pe = self.pos_encoding(x)\n        x = x + pe\n        x = self.transformer_encoder(x)\n        return self.fc_out(x)\n\nclass ConditionalDiscriminator(nn.Module):\n    def __init__(self, input_dim,condition_dim):\n        super(ConditionalDiscriminator,self).__init__()\n        self.condition_dim = condition_dim\n        self.input_dim = input_dim\n        self.model = nn.Sequential(\n            nn.Linear(input_dim + condition_dim, 1024),\n            nn.LeakyReLU(0.2),\n            nn.Linear(1024, 512),\n            nn.LeakyReLU(0.2),\n            nn.Linear(512, 256),\n            nn.LeakyReLU(0.2),\n            nn.Linear(256, 128),\n            nn.LeakyReLU(0.2),\n            nn.Linear(128, 1),\n            nn.Sigmoid()\n        )\n\n\n    def forward(self, x,condition):\n        if condition.ndim == 1:\n            condition = condition.unsqueeze(0).repeat(x.size(0), 1)\n        x = torch.cat([x, condition], dim=1)\n        return self.model(x)<\/code><\/pre>\n<p>\u8bad\u7ec3\u5faa\u73af\u7c7b\u4f3c\u4e8e Transformer GAN\u3002Conditional Transformer GAN \u548c Vanilla Transformer GAN \u8bad\u7ec3\u5faa\u200b\u200b\u73af\u4e4b\u95f4\u7684\u5173\u952e\u533a\u522b\u5728\u4e8e\u6570\u636e\u7684\u5904\u7406\u65b9\u5f0f\u3002\u5728 Conditional Transformer GAN \u8bad\u7ec3\u5faa\u200b\u200b\u73af\u4e2d\uff0c\u8f93\u5165\u6570\u636e\u88ab\u7b56\u7565\u6027\u5730\u5212\u5206\u4e3a\u6761\u4ef6\u5411\u91cf\uff08\u64cd\u4f5c\u8bbe\u7f6e\uff09\u548c\u76ee\u6807\u7279\u5f81\uff08\u4f20\u611f\u5668\u503c\uff09\uff0c\u5176\u4e2d\u524d\u4e09\u4e2a\u7279\u5f81\u7528\u4f5c\u6307\u5bfc\u751f\u6210\u8fc7\u7a0b\u7684\u6761\u4ef6\u3002\u5728\u6bcf\u6b21\u8bad\u7ec3\u8fed\u4ee3\u671f\u95f4\uff0c\u8fd9\u4e9b\u6761\u4ef6\u5411\u91cf\u90fd\u660e\u786e\u4e0e\u771f\u5b9e\u6570\u636e\u548c\u751f\u6210\u6570\u636e\u914d\u5bf9\uff0c\u4ece\u800c\u5f71\u54cd\u751f\u6210\u5668\u7684\u5408\u6210\u8fc7\u7a0b\u548c\u9274\u522b\u5668\u7684\u8bc4\u4f30\u3002\u53cd\u8fc7\u6765\uff0c\u9274\u522b\u5668\u901a\u8fc7\u8003\u8651\u751f\u6210\u7684\/\u771f\u5b9e\u7684\u6837\u672c\u53ca\u5176\u76f8\u5e94\u7684\u6761\u4ef6\u8fde\u63a5\u5728\u4e00\u8d77\u6765\u8bc4\u4f30\u6570\u636e\u7684\u771f\u5b9e\u6027\uff0c\u5e76\u5728\u64cd\u4f5c\u53c2\u6570\u7684\u80cc\u666f\u4e0b\u505a\u51fa\u51b3\u7b56\u3002\u8fd9\u4e0e\u6807\u51c6\u7684 Transformer GAN \u8bad\u7ec3\u5faa\u200b\u200b\u73af\u4e0d\u540c\uff0c\u5728\u6807\u51c6 Transformer GAN \u8bad\u7ec3\u5faa\u200b\u200b\u73af\u4e2d\uff0c\u751f\u6210\u5668\u4ec5\u5bf9\u968f\u673a\u566a\u58f0\u8f93\u5165\u8fdb\u884c\u64cd\u4f5c\uff0c\u800c\u9274\u522b\u5668\u5219\u5728\u6ca1\u6709\u4efb\u4f55\u6761\u4ef6\u4e0a\u4e0b\u6587\u7684\u60c5\u51b5\u4e0b\u8bc4\u4f30\u6570\u636e\u3002<\/p>\n<h2>5\u3001LM- GAN<\/h2>\n<p>\u672c\u6587\u63a5\u4e0b\u6765\u5c06 Transformer GAN \u7684\u601d\u60f3\u6269\u5c55\u5230\u8bed\u8a00\u6a21\u578b\uff0c\u7279\u522b\u662f SLM\uff0c\u5e76\u5229\u7528\u9884\u5148\u8bad\u7ec3\u7684 distilGPT2 \u8bed\u8a00\u6a21\u578b\u4f5c\u4e3a\u751f\u6210\u5668\uff0c\u4ecb\u7ecd\u4e86\u8bed\u8a00\u6a21\u578b GAN (LM-GAN) \u67b6\u6784\u3002\u968f\u540e\u7684\u5b9e\u8bc1\u5206\u6790\u5305\u62ec KL \u6563\u5ea6\u6d4b\u91cf\u3001\u4e3b\u6210\u5206\u5206\u6790 (PCA) \u548c\u76f4\u65b9\u56fe\uff0c\u8bc1\u660e\u4e86 LM-GAN \u7684\u4f18\u8d8a\u6027\u3002\u672c\u6587\u8ba4\u4e3a\uff0cLM GAN \u67b6\u6784\u751f\u6210\u7684\u5408\u6210\u6837\u672c\u4fdd\u6301\u4e86\u539f\u59cb\u4f20\u611f\u5668\u6d4b\u91cf\u7684\u7edf\u8ba1\u7279\u6027\u548c\u591a\u6a21\u6001\u7279\u5f81\uff0c\u540c\u65f6\u907f\u514d\u4e86\u6a21\u5f0f\u5d29\u6e83\uff0c\u8fd9\u662f GAN \u8bad\u7ec3\u4e2d\u7684\u5e38\u89c1\u6311\u6218\u3002<\/p>\n<p>  \u56fe 5 \u7528\u4e8e\u5408\u6210\u6570\u636e\u751f\u6210\u7684 LN GAN <\/p>\n<p>LM GAN \u7684\u516c\u5f0f\u5728\u5f88\u5927\u7a0b\u5ea6\u4e0a\u7c7b\u4f3c\u4e8e Transformer GAN\u3002\u6b64\u8bed\u8a00\u6a21\u578b GAN \u67b6\u6784\u4e2d\u7684\u751f\u6210\u5668\u7f51\u7edc\u5229\u7528\u4e86\u9884\u5148\u8bad\u7ec3\u7684 DistilGPT2 \u6a21\u578b\uff0c\u5173\u952e\u533a\u522b\u5728\u4e8e\uff0c\u4f5c\u4e3a\u751f\u6210\u5668\u4e00\u90e8\u5206\u7684 distil GPT2 \u7531 \u03b8 \u53c2\u6570\u5316<\/p>\n<p>\u5176\u4e2d Z \u8868\u793a\u8f93\u5165\u6807\u8bb0\u7a7a\u95f4\uff0cX \u8868\u793a\u8f93\u51fa logits \u7a7a\u95f4\u3002\u751f\u6210\u5668\u63a5\u53d7\u8f93\u5165\u7684 token \u5e8f\u5217\u53ca\u5176\u5bf9\u5e94\u7684\u6ce8\u610f\u529b\u63a9\u7801\uff0c\u5e76\u901a\u8fc7\u7531\u81ea\u6ce8\u610f\u529b\u673a\u5236\u548c\u524d\u9988\u795e\u7ecf\u7f51\u7edc\u7ec4\u6210\u7684\u591a\u4e2a\u8f6c\u6362\u5668\u5757\u5bf9\u5176\u8fdb\u884c\u5904\u7406\u3002\u8be5\u6a21\u578b\u7684\u67b6\u6784\u4fdd\u7559\u4e86 DistilGPT2 \u7684\u539f\u59cb\u914d\u7f6e\uff0c\u5e76\u4fee\u6539\u4e86 token \u5d4c\u5165\u4ee5\u9002\u5e94\u7279\u5b9a\u4e8e\u5236\u9020\u4e1a\u7684\u8bcd\u6c47\uff0c\u5305\u62ec\u6570\u503c\u548c\u5217\u6807\u8bc6\u7b26\u3002<\/p>\n<pre><code>class Generator(nn.Module):\n    def __init__(self, model):\n        super(Generator, self).__init__()\n        self.model = model # pretrained distil GPT2 model\n        \n    def forward(self, input_ids, attention_mask):\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            labels=input_ids,\n            #max_length = max_length\n        )\n        return outputs.logits.squeeze(1), outputs.loss<\/code><\/pre>\n<p>\u751f\u6210\u5668\u8f93\u51fa\u4e3a\uff1a<\/p>\n<p>\u5176\u4e2d Xsyn \u8868\u793a\u5408\u6210\u6570\u636e\u5bf9\u6570\uff0cLgen \u8868\u793a\u8bed\u8a00\u5efa\u6a21\u635f\u5931\u3002\u5728\u524d\u5411\u4f20\u9012\u8fc7\u7a0b\u4e2d\uff0c\u751f\u6210\u5668\u4ea7\u751f\u4e24\u4e2a\u8f93\u51fa\uff1a\u8868\u793a\u5e8f\u5217\u4e2d\u6bcf\u4e2a\u4f4d\u7f6e\u7684\u8bcd\u6c47\u8868\u6982\u7387\u5206\u5e03\u7684\u5bf9\u6570\uff0c\u4ee5\u53ca\u901a\u8fc7\u6559\u5e08\u5f3a\u5236\u8ba1\u7b97\u7684\u8bed\u8a00\u5efa\u6a21\u635f\u5931\u3002\u5bf9\u6570\u5f62\u72b6\u4e3a (batch_size\u3001sequence_length\u3001vocabulary_size)\uff0c\u6355\u83b7\u6a21\u578b\u5bf9\u6bcf\u4e2a\u6807\u8bb0\u4f4d\u7f6e\u7684\u9884\u6d4b\uff0c\u800c\u8bed\u8a00\u5efa\u6a21\u635f\u5931\u6709\u52a9\u4e8e\u4fdd\u6301\u5728\u9884\u8bad\u7ec3\u671f\u95f4\u5b66\u4e60\u7684\u8bed\u8a00\u7ed3\u6784\u3002\u751f\u6210\u5668\u7684\u53c2\u6570\u5728 GAN \u8bad\u7ec3\u671f\u95f4\u8fdb\u884c\u5fae\u8c03\uff0c\u4f7f\u5176\u80fd\u591f\u5c06\u5176\u9884\u5b66\u4e60\u7684\u8868\u793a\u8c03\u6574\u4e3a\u5236\u9020\u4f20\u611f\u5668\u6570\u636e\u4e2d\u5b58\u5728\u7684\u7279\u5b9a\u6a21\u5f0f\u548c\u5206\u5e03\uff0c\u540c\u65f6\u4fdd\u7559\u5176\u751f\u6210\u8fde\u8d2f\u5e8f\u5217\u7684\u80fd\u529b\u3002<\/p>\n<p>\u5177\u6709\u53c2\u6570 \u03d5 \u7684\u9274\u522b\u5668 D\u03d5 \u5b9a\u4e49\u4e3a\uff1a<\/p>\n<p>\u5176\u4e2d\uff1a \u2014 e : Rn\u00d7v \u2192 Rn\u00d7d \u662f\u5d4c\u5165\u5c42 \u2014 hlstm : Rn\u00d7d \u2192 R2d\u662f\u53cc\u5411 LSTM[32] \u3002 f\u03d5 \u662f\u5206\u7c7b\u5c42\uff0c\u03c3 \u662fS \u578b\u6fc0\u6d3b\u51fd\u6570\u3002\u9274\u522b\u5668\u662f\u4e00\u79cd\u6df7\u5408\u67b6\u6784\uff0c\u7ed3\u5408\u4e86\u5d4c\u5165\u5c42\u3001\u53cc\u5411 LSTM\uff08256 \u4e2a\u9690\u85cf\u5355\u5143\uff09\u548c\u5bc6\u96c6\u795e\u7ecf\u7f51\u7edc\uff0c\u7528\u4e8e\u533a\u5206\u771f\u5b9e\u6570\u636e\u548c\u5408\u6210\u6570\u636e\u3002\u6700\u540e\u7684\u5206\u7c7b\u9636\u6bb5\u7531\u4e00\u7cfb\u5217\u5177\u6709 LeakyReLU \u6fc0\u6d3b\u51fd\u6570\u7684\u5bc6\u96c6\u5c42\u7ec4\u6210\u3002\u5728\u5904\u7406\u751f\u6210\u7684\u6837\u672c\u65f6\uff0c\u9274\u522b\u5668\u9996\u5148\u901a\u8fc7 argmax \u64cd\u4f5c\u5c06\u751f\u6210\u5668\u7684 logit \u8f6c\u6362\u4e3a token ID\u3002<\/p>\n<pre><code>class Discriminator(nn.Module):\n    def __init__(self, vocab_size):\n        super(Discriminator, self).__init__()\n        self.embedding = nn.Embedding(vocab_size, 128)\n        self.lstm = nn.LSTM(128, 256, batch_first=True, bidirectional=True)\n        self.classifier = nn.Sequential(\n            nn.Linear(512, 256),\n            nn.LeakyReLU(0.2),\n            nn.Linear(256, 1),\n            nn.Sigmoid()\n        )\n        \n    def forward(self, input_ids):\n        if input_ids.dim() == 3:  # If input is logits (batch_size, sequence_length, vocab_size)\n            input_ids = torch.argmax(input_ids, dim=-1)  # Convert to token ids\n        \n        embedded = self.embedding(input_ids.int())\n        lstm_out, _ = self.lstm(embedded)\n        lstm_out = lstm_out[:, -1, :]  # take last hidden state\n        validity = self.classifier(lstm_out)\n        return validity<\/code><\/pre>\n<p>\u8bad\u7ec3\u8fc7\u7a0b\u5229\u7528 Adam \u4f18\u5316\u5668[33]\uff0c\u751f\u6210\u5668\u548c\u9274\u522b\u5668\u7f51\u7edc\u7684\u521d\u59cb\u5b66\u4e60\u7387\u5747\u4e3a 2e-5\uff0c\u5e76\u7ed3\u5408 ReduceLROnPlateau \u8c03\u5ea6\u5668[34]\uff0c\u6839\u636e\u635f\u5931\u8f68\u8ff9\u52a8\u6001\u8c03\u6574\u5b66\u4e60\u7387\u3002<\/p>\n<pre><code>g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(g_optimizer, mode='min', factor=0.5, patience=2)\nd_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(d_optimizer, mode='min', factor=0.5, patience=2)<\/code><\/pre>\n<p>\u8be5\u8fc7\u7a0b\u4ece\u4e0a\u6587\u9ad8\u7ea7\u5fae\u8c03\u90e8\u5206\u4e2d\u4ecb\u7ecd\u7684\u6570\u5b57\u6807\u8bb0\u5668\u5f00\u59cb\u3002\u5b83\u4f7f\u7528\u81ea\u5b9a\u4e49\u5206\u9694\u6807\u8bb0\u5904\u7406\u6570\u503c\u548c\u5217\u6807\u8bc6\u7b26\uff0c\u786e\u4fdd\u7cbe\u786e\u8868\u793a\u5236\u9020\u6d4b\u91cf\u503c\u3002\u5728\u6bcf\u6b21\u8bad\u7ec3\u8fed\u4ee3\u671f\u95f4\uff0c\u771f\u5b9e\u6570\u636e\u5e8f\u5217\u9996\u5148\u88ab\u6807\u8bb0\u5316\u5e76\u8f93\u5165\u9274\u522b\u5668\uff0c\u9274\u522b\u5668\u5b66\u4e60\u4e3a\u771f\u5b9e\u7684\u5236\u9020\u6570\u636e\u6a21\u5f0f\u5206\u914d\u9ad8\u6982\u7387\u3002\u751f\u6210\u5668\u5229\u7528 DistilGPT2 \u67b6\u6784\u751f\u6210\u5408\u6210\u5e8f\u5217\uff0c\u9274\u522b\u5668\u8bc4\u4f30\u8fd9\u4e9b\u5e8f\u5217\uff0c\u5e76\u4f7f\u7528\u4e8c\u8fdb\u5236\u4ea4\u53c9\u71b5\u8ba1\u7b97\u5bf9\u6297\u6027\u635f\u5931\u3002<\/p>\n<pre><code> for batch_idx, batch in enumerate(dataloader):\n    real_texts = batch['input_ids'].to(device)\n    attention_mask = batch['attention_mask'].to(device)\n    batch_size = real_texts.size(0)\n    \n    #Generator\n    g_optimizer.zero_grad()\n    \n    gen_logits, causal_lm_loss = generator(real_texts, attention_mask)\n    fake_checks = discriminator(gen_logits)\n\n    r = real_texts.squeeze().float()\n    g =  torch.argmax(gen_logits, dim=-1).float()\n    \n    w_loss = wasserstein_loss(fake_checks)\n    k_loss = kth_order_loss(g, r, k=2)\n    \n    g_loss = (\n        1.0 * causal_lm_loss +  # lm loss\n        0.8 * w_loss + # Wasserstein loss \n        0.2 * k_loss            # k-th order loss for numerical accuracy\n    )\n    \n    g_loss.backward()\n    torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)            \n    g_optimizer.step()\n\n    # Discriminator \n    d_optimizer.zero_grad()\n    \n    real_checks = discriminator(real_texts)\n    fake_checks = discriminator(g)\n    \n    d_loss = wasserstein_loss(fake_checks) - wasserstein_loss(real_checks)\n    \n    gradient_penalty = compute_gradient_penalty(discriminator, r, g)\n    d_loss += 10 * gradient_penalty\n\n    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)\n    d_optimizer.step()\n    d_loss.backward()    <\/code><\/pre>\n<p>\u8bad\u7ec3\u8fc7\u7a0b\u7ed3\u5408\u4e86\u751f\u6210\u5668\u7684\u6559\u5e08\u5f3a\u5236\uff0c\u5176\u4e2d\u8bed\u8a00\u5efa\u6a21\u635f\u5931\u4e0e\u5bf9\u6297\u6027\u635f\u5931\u76f8\u7ed3\u5408\uff0c\u4ee5\u5728\u9002\u5e94\u5236\u9020\u6570\u636e\u5206\u5e03\u7684\u540c\u65f6\u4fdd\u6301\u8fde\u8d2f\u7684\u5e8f\u5217\u751f\u6210\u3002 ReduceLROnPlateau \u8c03\u5ea6\u7a0b\u5e8f\u76d1\u63a7\u635f\u5931\u6307\u6807\uff08\u8010\u5fc3 2 \u4e2a\u65f6\u671f\uff09\uff0c\u5f53\u6539\u8fdb\u505c\u6ede\u65f6\u5c06\u5b66\u4e60\u7387\u964d\u4f4e 0.5 \u500d\uff0c\u6709\u52a9\u4e8e\u7a33\u5b9a\u8bad\u7ec3\u5e76\u9632\u6b62\u9707\u8361\u3002\u68af\u5ea6\u60e9\u7f5a\u8ba1\u7b97\u548c\u68af\u5ea6\u526a\u88c1\u662f LM-GAN \u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u81f3\u5173\u91cd\u8981\u7684\u7a33\u5b9a\u673a\u5236\u3002\u68af\u5ea6\u60e9\u7f5a\u662f\u901a\u8fc7\u5728\u771f\u5b9e\u6837\u672c\u548c\u751f\u6210\u6837\u672c\u4e4b\u95f4\u8fdb\u884c\u63d2\u503c\u6765\u8ba1\u7b97\u7684\u3002\u6b64\u5916\uff0ctorch.nn.utils.clip_grad_norm_ \u5e94\u7528\u4e8e\u751f\u6210\u5668\u548c\u9274\u522b\u5668\u53c2\u6570\uff0c\u6700\u5927\u8303\u6570\u9608\u503c\u4e3a 1.0\uff0c\u9632\u6b62\u68af\u5ea6\u7206\u70b8\u3002\u8fd9\u79cd\u53cc\u91cd\u4f18\u5316\u8fc7\u7a0b\u4e0e\u4e13\u95e8\u7684\u6807\u8bb0\u5316\u7b56\u7565\u76f8\u7ed3\u5408\uff0c\u4f7f\u6a21\u578b\u80fd\u591f\u5b66\u4e60\u6570\u636e\u7684\u7edf\u8ba1\u7279\u6027\u548c\u4e0d\u540c\u4f20\u611f\u5668\u6d4b\u91cf\u4e4b\u95f4\u7684\u6f5c\u5728\u5173\u7cfb\uff0c\u800c\u8c03\u5ea6\u673a\u5236\u786e\u4fdd\u6574\u4e2a\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u4e24\u4e2a\u7f51\u7edc\u7684\u7a33\u5065\u6536\u655b\u3002<\/p>\n<p>\u5408\u6210\u6570\u636e\u548c\u539f\u59cb\u5236\u9020\u6570\u636e\u4e4b\u95f4\u7684\u6bd4\u8f83\u5206\u6790\u8868\u660e\uff0cLM-GAN \u6a21\u578b\u5728\u6355\u6349\u4f20\u611f\u5668\u6d4b\u91cf\u7684\u7edf\u8ba1\u7279\u6027\u548c\u6f5c\u5728\u6a21\u5f0f\u65b9\u9762\u5177\u6709\u51fa\u8272\u7684\u80fd\u529b\u3002\u5206\u5e03\u56fe\u63ed\u793a\u4e86\u6240\u6709 21 \u4e2a\u4f20\u611f\u5668\u901a\u9053\u4e2d\u771f\u5b9e\u6570\u636e\u548c\u5408\u6210\u6570\u636e\u4e4b\u95f4\u7684\u4e00\u81f4\u6027\uff0c\u5176\u4e2d\u5408\u6210\u6570\u636e\uff08\u4ee5\u6a59\u8272\u865a\u7ebf\u663e\u793a\uff09\u7d27\u8ddf\u539f\u59cb\u6d4b\u91cf\uff08\u84dd\u8272\u5b9e\u7ebf\uff09\u7684\u65f6\u95f4\u52a8\u6001\u3002<\/p>\n<p>  \u56fe 6- LM-GAN \u7684\u5206\u5e03\u6bd4\u8f83 <\/p>\n<p>\u76f4\u65b9\u56fe\u6bd4\u8f83\u663e\u793a\u51fa\u4e0d\u540c\u4f20\u611f\u5668\u8303\u56f4\u5185\u7684\u4e00\u81f4\u5206\u5e03\u5339\u914d\uff0c\u5728\u64cd\u4f5c\u8bbe\u7f6e\uff08s1-s3\uff09\u548c\u5173\u952e\u4f20\u611f\u5668\u6d4b\u91cf\u4e2d\u5c24\u4e3a\u660e\u663e\u3002<\/p>\n<p>  \u56fe 7-LM GAN \u7684\u76f4\u65b9\u56fe\u5206\u6790 <\/p>\n<p>\u503c\u5f97\u6ce8\u610f\u7684\u662f\uff0cPCA \u53ef\u89c6\u5316\u5c55\u793a\u4e86\u6570\u636e\u6d41\u5f62\u7684\u51fa\u8272\u4fdd\u5b58\uff0c\u5408\u6210\u6837\u672c\uff08\u7ea2\u70b9\uff09\u4e0e\u4e3b\u6210\u5206\u4e2d\u7684\u771f\u5b9e\u6570\u636e\u70b9\uff08\u84dd\u70b9\uff09\u7d27\u5bc6\u6df7\u5408\uff0c\u8868\u660e\u8be5\u6a21\u578b\u5df2\u6210\u529f\u6355\u83b7\u4e0d\u540c\u4f20\u611f\u5668\u6d4b\u91cf\u4e4b\u95f4\u7684\u590d\u6742\u76f8\u5173\u6027\u3002<\/p>\n<p>  \u56fe 8-LM-GAN \u7684 PCA \u5206\u6790 <\/p>\n<p>\u8fd9\u9879\u7efc\u5408\u8bc4\u4f30\u9a8c\u8bc1\u4e86 LM-GAN \u67b6\u6784\u4e0d\u4ec5\u4fdd\u7559\u4e86\u5404\u4e2a\u4f20\u611f\u5668\u7684\u8fb9\u9645\u5206\u5e03\uff0c\u800c\u4e14\u8fd8\u4fdd\u6301\u4e86\u5236\u9020\u8fc7\u7a0b\u6570\u636e\u4e2d\u5b58\u5728\u7684\u590d\u6742\u5173\u7cfb\u548c\u64cd\u4f5c\u6a21\u5f0f\u3002<\/p>\n<h2>6\u3001\u6d88\u878d\u7814\u7a76<\/h2>\n<p>\u4e3a\u4e86\u5b9a\u91cf\u8bc4\u4f30\u5408\u6210\u6570\u636e\u751f\u6210\u65b9\u6cd5\u7684\u4fdd\u771f\u5ea6\uff0c\u6211\u4eec\u8fdb\u884c\u4e86\u5168\u9762\u7684\u6d88\u878d\u7814\u7a76\uff0c\u6d4b\u91cf\u539f\u59cb\u6570\u636e\u96c6\u548c\u5408\u6210\u6570\u636e\u96c6\u4e4b\u95f4\u7684\u7edf\u8ba1\u76f8\u4f3c\u6027\u3002\u5206\u6790\u91c7\u7528\u4e09\u4e2a\u4e3b\u8981\u8bc4\u4f30\u6846\u67b6\uff1a<\/p>\n<p>\u6bd4\u8f83\u539f\u59cb\u6570\u636e\u96c6\u548c\u5408\u6210\u6570\u636e\u96c6\u4e4b\u95f4\u7684\u7edf\u8ba1\u6d4b\u91cf\uff08\u6700\u5c0f\u503c\u3001\u6700\u5927\u503c\u3001\u5e73\u5747\u503c\u548c\u6807\u51c6\u5dee\uff09\uff1a<\/p>\n<p>  \u56fe 9 \u7edf\u8ba1\u6bd4\u8f83 <\/p>\n<p>\u805a\u5408\u4fdd\u771f\u5ea6\u5f97\u5206 Kolmogorov-Smirnov \u68c0\u9a8c\u548c Kolomogorov Smirnov \u5f97\u5206\uff1a<\/p>\n<p>  \u56fe 10 \u805a\u5408\u4fdd\u771f\u5ea6\u5f97\u5206 <\/p>\n<p>\u901a\u8fc7\u76f8\u5173\u77e9\u9635\u6bd4\u8f83\u8bc4\u4f30\u7279\u5f81\u5173\u7cfb\u7684\u4fdd\u5b58\u3002\u4ee5\u4e0b\u662f LM-GAN \u67b6\u6784\u751f\u6210\u7684\u539f\u59cb\u5206\u5e03\u548c\u5408\u6210\u5206\u5e03\u4e4b\u95f4\u7684\u76f8\u5173\u77e9\u9635\uff1a<\/p>\n<p>  \u56fe 11 \u76f8\u5173\u77e9\u9635\u6bd4\u8f83\u6d4b\u91cf <\/p>\n<h2>7\u3001\u7ed3\u675f\u8bed<\/h2>\n<p>\u8fd9\u4e2a\u7531\u4e24\u90e8\u5206\u7ec4\u6210\u7684\u7814\u7a76\u7cfb\u5217\u5bf9\u5de5\u4e1a\u5236\u9020\u5e94\u7528\u7684\u5408\u6210\u6570\u636e\u751f\u6210\u6280\u672f\u8fdb\u884c\u4e86\u5168\u9762\u5206\u6790\u3002\u7b2c\u4e00\u90e8\u5206\u901a\u8fc7\u8bc4\u4f30\u5305\u62ec GAN\u3001VAE\u3001\u9ad8\u65af Copula\u3001\u8d1d\u53f6\u65af\u7f51\u7edc\u548c CTGAN \u5728\u5185\u7684\u4f20\u7edf\u751f\u6210\u6a21\u578b\u5efa\u7acb\u4e86\u57fa\u7ebf\u6027\u80fd\u6307\u6807\uff0c\u5c55\u793a\u4e86\u8d1d\u53f6\u65af\u7f51\u7edc\u548c\u9ad8\u65af Copula \u7b49\u6982\u7387\u65b9\u6cd5\u5728\u4fdd\u7559\u7edf\u8ba1\u5206\u5e03\u65b9\u9762\u7684\u5353\u8d8a\u6027\u80fd\u3002<\/p>\n<p>\u7b2c\u4e8c\u90e8\u5206\u901a\u8fc7\u4ecb\u7ecd\u7528\u4e8e\u5408\u6210\u6570\u636e\u751f\u6210\u7684\u5c0f\u578b\u8bed\u8a00\u6a21\u578b (SLM) \u7684\u65b0\u5e94\u7528\u6765\u63a8\u52a8\u8be5\u9886\u57df\u7684\u53d1\u5c55\u3002\u7814\u7a76\u4ece\u57fa\u7840\u7684\u5373\u65f6\u5de5\u7a0b\u53d1\u5c55\u5230\u50cf LM-GAN \u8fd9\u6837\u7684\u590d\u6742\u67b6\u6784\uff0c\u5c55\u793a\u4e86\u5408\u6210\u6570\u636e\u4fdd\u771f\u5ea6\u4e0d\u65ad\u53d1\u5c55\u7684\u80fd\u529b\u3002\u6bd4\u8f83\u5206\u6790\u8868\u660e\uff0c\u4f7f\u7528\u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\u7684\u9ad8\u7ea7\u5fae\u8c03\u663e\u8457\u6539\u5584\u4e86\u57fa\u672c\u7684 SLM \u65b9\u6cd5\uff0c\u800c\u63d0\u51fa\u7684 LM-GAN \u67b6\u6784\u5728\u4fdd\u7559\u8fb9\u9645\u5206\u5e03\u548c\u590d\u6742\u7684\u7279\u5f81\u95f4\u5173\u7cfb\u65b9\u9762\u5b9e\u73b0\u4e86\u6700\u5148\u8fdb\u7684\u6027\u80fd\u3002<\/p>\n<p>\u4e24\u9879\u7814\u7a76\u7684\u4e3b\u8981\u53d1\u73b0\u8868\u660e\uff1a<\/p>\n<ul>\n<li>\u4f20\u7edf\u6982\u7387\u6a21\u578b\u4e3a\u5236\u9020\u6570\u636e\u5408\u6210\u63d0\u4f9b\u4e86\u5f3a\u5927\u7684\u57fa\u7ebf\u6027\u80fd<\/li>\n<li>\u57fa\u4e8e SLM \u7684\u65b9\u6cd5\u5728\u6355\u83b7\u7279\u5b9a\u9886\u57df\u7ea6\u675f\u65b9\u9762\u63d0\u4f9b\u4e86\u589e\u5f3a\u7684\u7075\u6d3b\u6027<\/li>\n<li>\u5c06\u8bed\u8a00\u6a21\u578b\u4e0e\u5bf9\u6297\u6027\u8bad\u7ec3\u76f8\u7ed3\u5408\u7684\u6df7\u5408\u67b6\u6784\u5728\u4fdd\u6301\u7edf\u8ba1\u7279\u6027\u65b9\u9762\u8868\u73b0\u51fa\u5353\u8d8a\u7684\u6027\u80fd<\/li>\n<li>\u63d0\u51fa\u7684 LM-GAN \u6846\u67b6\u6210\u529f\u89e3\u51b3\u4e86\u4f20\u7edf GAN \u4e2d\u5e38\u89c1\u7684\u6a21\u5f0f\u5d29\u6e83\u95ee\u9898\uff0c\u540c\u65f6\u4fdd\u6301\u4e86\u6570\u636e\u4fdd\u771f\u5ea6<\/li>\n<\/ul>\n<hr>\n<p>\n","protected":false},"excerpt":{"rendered":"<p>\u5408\u6210\u6570\u636e\u751f\u6210\u89e3\u51b3\u4e86\u591a\u4e2a\u57fa\u672c\u6311\u6218\uff1a\u6570\u636e\u96c6\u4e2d\u7684\u7c7b\u522b\u4e0d\u5e73\u8861\u3001\u6570\u636e\u9690\u79c1\u8981\u6c42\u3001\u6570\u636e\u83b7\u53d6\u6210\u672c\u4f18\u5316\u548c\u5b9e\u9a8c\u5468\u671f\u52a0\u901f\u3002\u4f20\u7edf\u65b9\u6cd5\uff08\u5982 SMOTE [1]\uff09\u901a\u8fc7\u5728\u73b0\u6709\u6570\u636e\u70b9\u4e4b\u95f4\u8fdb\u884c\u63d2\u503c\u6765\u4e3a\u5c11\u6570\u7c7b\u751f\u6210\u5408\u6210\u6837\u672c\u3002\u4e4b\u524d\u7684\u535a\u5ba2\u6587\u7ae0 [2] \u5bf9\u8868\u683c\u5408\u6210\uff08\u6570\u503c\uff09\u6570\u636e\u751f\u6210\u7684\u751f\u6210\u65b9\u6cd5\u8fdb\u884c\u4e86\u5168\u9762\u8bc4\u4f30\uff0c\u5305\u62ec\u751f\u6210\u5bf9\u6297\u7f51\u7edc (GAN)\u3001\u53d8\u5206\u81ea\u52a8\u7f16\u7801\u5668 (VAE)\u3001\u9ad8\u65af Copula\u3001\u8d1d\u53f6\u65af\u7f51\u7edc\u548c\u6761\u4ef6\u8868\u683c GAN (CTGAN)\u3002 \u8fd9\u7bc7\u6587\u7ae0\u7814\u7a76\u4e86\u5229\u7528\u5c0f\u8bed\u8a00\u6a21\u578b (SLM) \u751f\u6210\u5408\u6210\u8868\u683c\u6570\u503c\u6570\u636e\u7684\u65b0\u65b9\u6cd5\u3002\u4e0e\u4e4b\u524d\u7684\u7814\u7a76\u4fdd\u6301\u8fde\u7eed\u6027\uff0c\u6211\u4eec\u4e13\u6ce8\u4e8e\u5355\u4e00\u8868\u683c\u6570\u636e\uff0c\u7279\u522b\u662f\u5206\u6790\u6765\u81ea NASA \u827e\u59c6\u65af\u9884\u6d4b\u5353\u8d8a\u4e2d\u5fc3\u7684\u6da1\u6247\u53d1\u52a8\u673a\u9000\u5316\u6a21\u62df\u6570\u636e\u96c6 [3][4]\u3002\u6709\u5173\u6570\u636e\u96c6\u7279\u5f81\u548c\u7814\u7a76\u52a8\u673a\uff0c\u8bfb\u8005\u53ef\u4ee5\u53c2\u8003\u4e4b\u524d\u7684\u51fa\u7248\u7269\u3002 \u8be5\u7814\u7a76\u8003\u5bdf\u4e86\u56db\u79cd\u5173\u952e\u65b9\u6cd5\uff1a \u5177\u6709\u9886\u57df\u7279\u5b9a\u7ea6\u675f\u7684 SLM \u5fae\u8c03 \u4f7f\u7528\u6570\u503c\u6807\u8bb0\u5668\u548c\u81ea\u5b9a\u4e49\u635f\u5931\u51fd\u6570\u8fdb\u884c\u9ad8\u7ea7\u5fae\u8c03 Transformer GAN \u548c\u6761\u4ef6 Transformer GAN \u67b6\u6784 \u8bed\u8a00\u6a21\u578b GAN (LM-GAN) \u5b9e\u73b0 \u5c06\u8bed\u8a00\u6a21\u578b\u5f52\u7c7b\u4e3a\u201c\u5c0f\u578b\u201d\u5728\u81ea\u7136\u8bed\u8a00\u5904\u7406\u9886\u57df\u8868\u73b0\u51fa\u65f6\u95f4\u53d8\u5316\u3002\u4e00\u4e2a\u503c\u5f97\u6ce8\u610f\u7684\u4f8b\u5b50\u662f GPT-2\uff0c\u5b83\u5728 2019 \u5e74\u53d1\u5e03\u65f6\u5177\u6709 15 \u4ebf\u4e2a\u53c2\u6570\uff0c\u88ab\u5f52\u7c7b\u4e3a\u5927\u578b\u6a21\u578b\uff0c\u4f46\u73b0\u5728\u6309\u7167\u5f53\u4ee3\u6807\u51c6\u88ab\u8ba4\u4e3a\u662f\u5c0f\u578b\u7684\u3002\u5f53\u524d\u5206\u7c7b (2024) \u5c06 SLM \u5b9a\u4e49\u4e3a\u5305\u542b 3-100 \u4ebf\u4e2a\u53c2\u6570\u7684\u6a21\u578b\uff0c\u800c\u5927\u578b\u8bed\u8a00\u6a21\u578b (LLM) \u901a\u5e38\u5305\u542b\u6570\u5343\u4ebf\u4e2a\u53c2\u6570\u3002SLM \u9488\u5bf9\u8d44\u6e90\u6548\u7387\u548c\u8fb9\u7f18\u90e8\u7f72\u573a\u666f\u8fdb\u884c\u4e86\u4f18\u5316\uff0c\u4ee3\u8868\u6027\u6a21\u578b\u5305\u62ec Phi 3 [8]\u3001Galactica \u548c Gemma\u3002 SLM \u7684\u67b6\u6784\u591a\u6837\u6027\u4e0e\u5176\u8f83\u5927\u7684\u540c\u7c7b\u4ea7\u54c1\u76f8\u4f3c\uff0c\u5305\u542b\u5404\u79cd\u6ce8\u610f\u529b\u673a\u5236\uff1a \u591a\u5934\u6ce8\u610f\u529b (MHA) [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"closed","ping_status":"","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[13],"tags":[],"class_list":["post-53762","post","type-post","status-publish","format-standard","hentry","category-ai"],"_links":{"self":[{"href":"https:\/\/fwq.ai\/blog\/wp-json\/wp\/v2\/posts\/53762","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/fwq.ai\/blog\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/fwq.ai\/blog\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/fwq.ai\/blog\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/fwq.ai\/blog\/wp-json\/wp\/v2\/comments?post=53762"}],"version-history":[{"count":0,"href":"https:\/\/fwq.ai\/blog\/wp-json\/wp\/v2\/posts\/53762\/revisions"}],"wp:attachment":[{"href":"https:\/\/fwq.ai\/blog\/wp-json\/wp\/v2\/media?parent=53762"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/fwq.ai\/blog\/wp-json\/wp\/v2\/categories?post=53762"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/fwq.ai\/blog\/wp-json\/wp\/v2\/tags?post=53762"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}