题解 | 简化Attention输出的元素总和
简化Attention输出的元素总和
https://www.nowcoder.com/practice/3ba85cb991d4471b81ad6d775447fc44
import sys
from numpy import ones
from numpy import triu, transpose, sqrt, exp, max, sum, zeros
def attention(n,m,h):
X = ones((n,m))
W1 = triu(ones((m,h)))
W2 = triu(ones((m,h)))
W3 = triu(ones((m,h)))
Q = X@W1 # 矩阵相乘,请用@,千万别用*!!!
K = X@W2
V = X@W3
S = (Q @ transpose(K)) / sqrt(h)
def softmax(X):
after_softmax = zeros((n,n))
for l in range(len(X)): # 0,...,n-1
fenzi = exp(X[l]) # 一个行向量
fenmu = sum(exp(X[l])) # 一个数
after_softmax[l] = fenzi / fenmu # 一个行向量
return after_softmax
Y = softmax(S) @ V
ans = round(sum(Y))
return ans
for line in sys.stdin:
a = line.split()
print(attention(int(a[0]),int(a[1]),int(a[2])))

