感觉很简单的一件事,花费了三天多的时间才搞定,Transformers有相应的类,但是传递参数后的返回值,怎么处理一直没搞明白,找别人的代码调通之后,发现得到的结果不好,那么一定是代码有问题,最后找到他的代码:ChineseNMT/README.md at master · hemingkx/ChineseNMT (github.com),他写的长,就拿来试试。
但是呢,他使用的是哈弗公布的Transformer的代码,我用的不是那个版本,贴出来我的一段代码,如果你的和我的一样,你就明白了
下面是我在他beam search的基础上修改后的代码
标注日期的1015或者带2014的部分是我做了修改的地方,可以参考原始文件。并且为了代码能正确运行,我对注意力机制的代码也做了一定的修改,这不妨碍平时使用,其中的if语句我把如果attn_mask和scores在shape[0]不一致的时候,我把attn_mask第0维进行了扩充
使用得到输出结果
我的Transformer内部是这样的关系,Decoder的输出是未经过projection的
希望对你有帮助。