CS 공부/AI

Multi Head-attention 구조 파악하기

imsmile2000 2023. 3. 22. 17:47
n_batch = 128 
n_src   = 32
d_feat  = 200
n_head  = 5 # 5개를 multihead attention 함
self.d_head = self.d_feat // self.n_head=200//5=40
scores=(Q*K^T/d_k)
d_k=K_split.size(-1)=40
 
Input src:	[128, 32, 200]  	= [n_batch, n_src, d_feat]

Q_feat:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]
K_feat:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]
V_feat:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]

#multi-head attention위해 split
Q_split:  	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
K_split:  	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
V_split:  	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]

# 40 생략
scores:   	[128, 5, 32, 32]  	= [n_batch, n_head, n_src, n_src]
attention:	[128, 5, 32, 32]  	= [n_batch, n_head, n_src, n_src]

# x_raw = torch.matmul(self.dropout(attention),V_split) dropout은 정확도 향상을 위해
# 합성곱 차원: [128, 5, 32, 32],[128, 5, 32, 40]
x_raw:    	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
#x_rsh1 = x_raw.permute(0,2,1,3).contiguous() -> index1,2 바꾸기
x_rsh1:   	[128, 32, 5, 40]  	= [n_batch, n_src, n_head, d_head]

# concatenate 하기 40*5=200
x_rsh2:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]

Output x: 	[128, 32, 200]  	= [n_batch, n_src, d_feat]