x = torch.randn(size=(batch_size, seq_len, d))
w_qs = [torch.randn(size=(d, sub_d)) for _ in range(h)]
w_ks = [torch.randn(size=(d, sub_d)) for _ in range(h)]
w_vs = [torch.randn(size=(d, sub_d)) for _ in range(h)]
print("========From Single Head START=========")
q = torch.matmul(x, w_qs[head])
k = torch.matmul(x, w_ks[head])
v = torch.matmul(x, w_vs[head])
att = torch.einsum('bnd,bmd->bnm', q, k) / (sub_d ** 0.5)
att = torch.softmax(att, dim=-1)
att_sep_soft.append(att)
o = torch.einsum('bnm,bmd->bnd', att, v)
output = torch.cat(os, dim=-1)
print("output shape: ", output.shape)
print("========From Single Head END=========")
print("========VEC START=========")
w_qs_vec = torch.stack(w_qs, dim=-1) # (d, sub_d, h)
w_ks_vec = torch.stack(w_ks, dim=-1)
w_vs_vec = torch.stack(w_vs, dim=-1)
q_vec = torch.einsum('bnd,dsh->bnsh', x, w_qs_vec) # (b, n, d, h)
k_vec = torch.einsum('bnd,dsh->bnsh', x, w_ks_vec)
v_vec = torch.einsum('bnd,dsh->bnsh', x, w_vs_vec)
att = torch.einsum('bnsh,bmsh->bnmh', q_vec, k_vec) / (sub_d ** 0.5)
att = torch.softmax(att, dim=-2)
o = torch.einsum('bnmh,bmsh->bnhs', att, v_vec)
output2 = torch.reshape(o, shape=(batch_size, seq_len, d))
print("output2 shape: ", output2.shape)
print("========VEC END=========")
err = torch.sum(torch.abs(output - output2))