n_batch = 128
n_src = 32
d_feat = 200
n_head = 5 # 5개를 multihead attention 함
self.d_head = self.d_feat // self.n_head=200//5=40
scores=(Q*K^T/d_k)
d_k=K_split.size(-1)=40
Input src: [128, 32, 200] = [n_batch, n_src, d_feat]
Q_feat: [128, 32, 200] = [n_batch, n_src, d_feat]
K_feat: [128, 32, 200] = [n_batch, n_src, d_feat]
V_feat: [128, 32, 200] = [n_batch, n_src, d_feat]
#multi-head attention위해 split
Q_split: [128, 5, 32, 40] = [n_batch, n_head, n_src, d_head]
K_split: [128, 5, 32, 40] = [n_batch, n_head, n_src, d_head]
V_split: [128, 5, 32, 40] = [n_batch, n_head, n_src, d_head]
# 40 생략
scores: [128, 5, 32, 32] = [n_batch, n_head, n_src, n_src]
attention: [128, 5, 32, 32] = [n_batch, n_head, n_src, n_src]
# x_raw = torch.matmul(self.dropout(attention),V_split) dropout은 정확도 향상을 위해
# 합성곱 차원: [128, 5, 32, 32],[128, 5, 32, 40]
x_raw: [128, 5, 32, 40] = [n_batch, n_head, n_src, d_head]
#x_rsh1 = x_raw.permute(0,2,1,3).contiguous() -> index1,2 바꾸기
x_rsh1: [128, 32, 5, 40] = [n_batch, n_src, n_head, d_head]
# concatenate 하기 40*5=200
x_rsh2: [128, 32, 200] = [n_batch, n_src, d_feat]
Output x: [128, 32, 200] = [n_batch, n_src, d_feat]