ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Multi Head-attention 구조 파악하기
    CS 공부/AI 2023. 3. 22. 17:47
    n_batch = 128 
    n_src   = 32
    d_feat  = 200
    n_head  = 5 # 5개를 multihead attention 함
    self.d_head = self.d_feat // self.n_head=200//5=40
    scores=(Q*K^T/d_k)
    d_k=K_split.size(-1)=40
     
    Input src:	[128, 32, 200]  	= [n_batch, n_src, d_feat]
    
    Q_feat:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]
    K_feat:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]
    V_feat:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]
    
    #multi-head attention위해 split
    Q_split:  	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
    K_split:  	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
    V_split:  	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
    
    # 40 생략
    scores:   	[128, 5, 32, 32]  	= [n_batch, n_head, n_src, n_src]
    attention:	[128, 5, 32, 32]  	= [n_batch, n_head, n_src, n_src]
    
    # x_raw = torch.matmul(self.dropout(attention),V_split) dropout은 정확도 향상을 위해
    # 합성곱 차원: [128, 5, 32, 32],[128, 5, 32, 40]
    x_raw:    	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
    #x_rsh1 = x_raw.permute(0,2,1,3).contiguous() -> index1,2 바꾸기
    x_rsh1:   	[128, 32, 5, 40]  	= [n_batch, n_src, n_head, d_head]
    
    # concatenate 하기 40*5=200
    x_rsh2:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]
    
    Output x: 	[128, 32, 200]  	= [n_batch, n_src, d_feat]

    'CS 공부 > AI' 카테고리의 다른 글

    Latent variable model (GAN, Diffusion model)  (0) 2023.03.24
    Auto-regressive model  (0) 2023.03.23
    Transformer  (0) 2023.03.22
    torchvision에서 제공하는 transform 함수  (0) 2023.03.16
    DataLoader  (0) 2023.03.16
Designed by Tistory.