xbyak/test/test_by_xed.py

388 lines
11 KiB
Python
Raw Normal View History

2024-10-09 19:14:21 -07:00
import re
import math
import sys
class Reg:
def __init__(self, s):
self.name = s
def __str__(self):
return self.name
2024-10-10 17:55:14 -07:00
def __eq__(self, rhs):
return self.name == rhs.name
def __lt__(self, rhs):
return self.name < rhs.name
g_xmmTbl = '''
xmm0 xmm1 xmm2 xmm3 xmm4 xmm5 xmm6 xmm7
xmm8 xmm9 xmm10 xmm11 xmm12 xmm13 xmm14 xmm15
xmm16 xmm17 xmm18 xmm19 xmm20 xmm21 xmm22 xmm23
xmm24 xmm25 xmm26 xmm27 xmm28 xmm29 xmm30 xmm31
ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6 ymm7
ymm8 ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15
ymm16 ymm17 ymm18 ymm19 ymm20 ymm21 ymm22 ymm23
ymm24 ymm25 ymm26 ymm27 ymm28 ymm29 ymm30 ymm31
zmm0 zmm1 zmm2 zmm3 zmm4 zmm5 zmm6 zmm7
zmm8 zmm9 zmm10 zmm11 zmm12 zmm13 zmm14 zmm15
zmm16 zmm17 zmm18 zmm19 zmm20 zmm21 zmm22 zmm23
zmm24 zmm25 zmm26 zmm27 zmm28 zmm29 zmm30 zmm31
'''.split()
2024-10-09 19:14:21 -07:00
g_regTbl = '''
eax ecx edx ebx esp ebp esi edi
ax cx dx bx sp bp si di
al cl dl bl ah ch dh bh
k1 k2 k3 k4 k5 k6 k7
rax rcx rdx rbx rsp rbp rsi rdi r8 r9 r10 r11 r12 r13 r14 r15
r16 r17 r18 r19 r20 r21 r22 r23 r24 r25 r26 r27 r28 r29 r30 r31
r8d r9d r10d r11d r12d r13d r14d r15d
r16d r17d r18d r19d r20d r21d r22d r23d r24d r25d r26d r27d r28d r29d r30d r31d
r8w r9w r10w r11w r12w r13w r14w r15w
r16w r17w r18w r19w r20w r21w r22w r23w r24w r25w r26w r27w r28w r29w r30w r31w
r8b r9b r10b r11b r12b r13b r14b r15b
r16b r17b r18b r19b r20b r21b r22b r23b r24b r25b r26b r27b r28b r29b r30b r31b
spl bpl sil dil
2024-10-10 17:55:14 -07:00
tmm0 tmm1 tmm2 tmm3 tmm4 tmm5 tmm6 tmm7
'''.split()+g_xmmTbl
2024-10-09 19:14:21 -07:00
# define global constants
for e in g_regTbl:
globals()[e] = Reg(e)
2024-10-10 17:55:14 -07:00
g_maskTbl = [k1, k2, k3, k4, k5, k6, k7]
2024-10-09 19:14:21 -07:00
g_replaceCharTbl = '{}();|,'
g_replaceChar = str.maketrans(g_replaceCharTbl, ' '*len(g_replaceCharTbl))
g_sizeTbl = ['byte', 'word', 'dword', 'qword', 'xword', 'yword', 'zword']
2024-10-10 17:55:14 -07:00
g_xedSizeTbl = ['xmmword', 'ymmword', 'zmmword']
g_attrTbl = ['T_sae', 'T_rn_sae', 'T_rd_sae', 'T_ru_sae', 'T_rz_sae', 'T_z']
g_attrXedTbl = ['sae', 'rne-sae', 'rd-sae', 'ru-sae', 'rz-sae', 'z']
2024-10-09 19:14:21 -07:00
class Attr:
def __init__(self, s):
self.name = s
def __str__(self):
return self.name
2024-10-10 17:55:14 -07:00
def __eq__(self, rhs):
return self.name == rhs.name
def __lt__(self, rhs):
return self.name < rhs.name
2024-10-09 19:14:21 -07:00
for e in g_attrTbl:
globals()[e] = Attr(e)
2024-10-10 17:55:14 -07:00
def newReg(s):
if type(s) == str:
return Reg(s)
return s
2024-10-09 19:14:21 -07:00
class Memory:
2024-10-10 17:55:14 -07:00
def __init__(self, size=0, base=None, index=None, scale=0, disp=0, broadcast=False):
2024-10-09 19:14:21 -07:00
self.size = size
2024-10-10 17:55:14 -07:00
self.base = newReg(base)
self.index = newReg(index)
2024-10-09 19:14:21 -07:00
self.scale = scale
self.disp = disp
2024-10-10 17:55:14 -07:00
self.broadcast = broadcast
2024-10-09 19:14:21 -07:00
def __str__(self):
s = 'ptr' if self.size == 0 else g_sizeTbl[int(math.log2(self.size))]
2024-10-10 17:55:14 -07:00
if self.broadcast:
s += '_b'
2024-10-09 19:14:21 -07:00
s += ' ['
needPlus = False
if self.base:
s += str(self.base)
needPlus = True
if self.index:
if needPlus:
s += '+'
s += str(self.index)
if self.scale > 1:
s += f'*{self.scale}'
needPlus = True
if self.disp:
if needPlus:
s += '+'
s += hex(self.disp)
s += ']'
return s
def __eq__(self, rhs):
2024-10-10 17:55:14 -07:00
# xbyak uses ptr if it is automatically detected, so xword == ptr is true
if self.broadcast != rhs.broadcast: return False
# if not self.broadcast and 0 < self.size <= 8 and 0 < rhs.size <= 8 and self.size != rhs.size: return False
if not self.broadcast and self.size > 0 and rhs.size > 0 and self.size != rhs.size: return False
r = self.base == rhs.base and self.index == rhs.index and self.scale == rhs.scale and self.disp == rhs.disp
return r
2024-10-09 19:14:21 -07:00
2024-10-10 17:55:14 -07:00
def parseBroadcast(s):
if '_b' in s:
return (s.replace('_b', ''), True)
r = re.search(r'({1to\d+})', s)
if not r:
return (s, False)
return (s.replace(r.group(1), ''), True)
def parseMemory(s, broadcast=False):
org_s = s
2024-10-09 19:14:21 -07:00
s = s.replace(' ', '').lower()
size = 0
2024-10-10 17:55:14 -07:00
base = index = None
scale = 0
disp = 0
if not broadcast:
(s, broadcast) = parseBroadcast(s)
# Parse size
2024-10-09 19:14:21 -07:00
for i in range(len(g_sizeTbl)):
w = g_sizeTbl[i]
if s.startswith(w):
size = 1<<i
s = s[len(w):]
2024-10-10 17:55:14 -07:00
break
if size == 0:
for i in range(len(g_xedSizeTbl)):
w = g_xedSizeTbl[i]
if s.startswith(w):
size = 1<<(i+4)
s = s[len(w):]
break
2024-10-09 19:14:21 -07:00
# Remove 'ptr' if present
if s.startswith('ptr'):
s = s[3:]
2024-10-10 17:55:14 -07:00
if s.startswith('_b'):
broadcast = True
s = s[2:]
2024-10-09 19:14:21 -07:00
# Extract the content inside brackets
r = re.match(r'\[(.*)\]', s)
if not r:
2024-10-10 17:55:14 -07:00
raise ValueError(f'bad format {org_s=}')
2024-10-09 19:14:21 -07:00
# Parse components
elems = re.findall(r'([a-z0-9]+)(?:\*([0-9]+))?|([+-])', r.group(1))
for i, e in enumerate(elems):
if e[2]: # This is a '+' or '-' sign
continue
2024-10-10 17:55:14 -07:00
if e[0] in g_regTbl:
2024-10-09 19:14:21 -07:00
if base is None and (not e[1] or int(e[1]) == 1):
base = e[0]
elif index is None:
index = e[0]
scale = int(e[1]) if e[1] else 1
else:
raise ValueError(f'bad format2 {s=}')
else:
sign = -1 if i > 0 and elems[i-1][2] == '-' else 1
b = 16 if e[0].startswith('0x') else 10
disp += sign * int(e[0], b)
2024-10-10 17:55:14 -07:00
return Memory(size, base, index, scale, disp, broadcast)
2024-10-09 19:14:21 -07:00
class Nmemonic:
def __init__(self, name, args=[], attrs=[]):
self.name = name
self.args = args
2024-10-10 17:55:14 -07:00
self.attrs = attrs.sort()
2024-10-09 19:14:21 -07:00
def __str__(self):
s = f'{self.name}('
for i in range(len(self.args)):
if i > 0:
s += ', '
s += str(self.args[i])
2024-10-10 17:55:14 -07:00
if i == 0 and self.attrs:
for e in self.attrs:
s += f'|{e}'
2024-10-09 19:14:21 -07:00
s += ');'
return s
2024-10-10 17:55:14 -07:00
def __eq__(self, rhs):
return self.name == rhs.name and self.args == rhs.args and self.attrs == rhs.attrs
2024-10-09 19:14:21 -07:00
def parseNmemonic(s):
2024-10-10 17:55:14 -07:00
args = []
attrs = []
(s, broadcast) = parseBroadcast(s)
# replace xm0 with xmm0
while True:
r = re.search(r'([xyz])m(\d\d?)', s)
if not r:
break
s = s.replace(r.group(0), r.group(1) + 'mm' + r.group(2))
# check 'zmm0{k7}'
r = re.search(r'({k[1-7]})', s)
if r:
idx = int(r.group(1)[2])
attrs.append(g_maskTbl[idx-1])
s = s.replace(r.group(1), '')
# check 'zmm0|k7'
r = re.search(r'(\|\s*k[1-7])', s)
if r:
idx = int(r.group(1)[-1])
attrs.append(g_maskTbl[idx-1])
s = s.replace(r.group(1), '')
2024-10-09 19:14:21 -07:00
s = s.translate(g_replaceChar)
# reconstruct memory string
v = []
inMemory = False
for e in s.split():
if inMemory:
v[-1] += e
if ']' in e:
inMemory = False
else:
v.append(e)
2024-10-10 17:55:14 -07:00
if e in g_sizeTbl or e in g_xedSizeTbl or e.startswith('ptr'):
2024-10-09 19:14:21 -07:00
v[-1] += ' ' # to avoid 'byteptr'
2024-10-10 17:55:14 -07:00
if ']' not in v[-1]:
inMemory = True
2024-10-09 19:14:21 -07:00
name = v[0]
for e in v[1:]:
if e.startswith('0x'):
args.append(int(e, 16))
elif e[0] in '0123456789':
args.append(int(e))
elif e in g_attrTbl:
attrs.append(Attr(e))
elif e in g_attrXedTbl:
attrs.append(Attr(g_attrTbl[g_attrXedTbl.index(e)]))
elif e in g_regTbl:
2024-10-10 17:55:14 -07:00
args.append(Reg(e))
# xed special format : xmm8+3
elif e[:-2] in g_xmmTbl and e.endswith('+3'):
args.append(Reg(e[:-2]))
2024-10-09 19:14:21 -07:00
else:
2024-10-10 17:55:14 -07:00
args.append(parseMemory(e, broadcast))
2024-10-09 19:14:21 -07:00
return Nmemonic(name, args, attrs)
def loadFile(name):
with open(name) as f:
r = []
for line in f.read().split('\n'):
if line:
2024-10-10 19:22:35 -07:00
if line[0] == '#' or line.startswith('//'):
2024-10-09 19:14:21 -07:00
continue
r.append(line)
return r
# remove top 5 information
# e.g. XDIS 0: AVX512 AVX512EVEX 62F1E91858CB vaddpd ymm1{rne-sae}, ymm2, ymm3
def removeExtraInfo(s):
v = s.split()
return ' '.join(v[5:])
def run(cppText, xedText):
cpp = loadFile(cppText)
xed = loadFile(xedText)
n = len(cpp)
if n != len(xed):
raise Exception(f'different line {n} {len(xed)}')
2024-10-10 19:22:35 -07:00
for i in range(n):
2024-10-09 19:14:21 -07:00
line1 = cpp[i]
line2 = removeExtraInfo(xed[i])
m1 = parseNmemonic(line1)
m2 = parseNmemonic(line2)
2024-10-10 17:55:14 -07:00
assertEqual(m1, m2, f'{i+1}')
print('run ok', n)
2024-10-09 19:14:21 -07:00
def assertEqualStr(a, b, msg=None):
if str(a) != str(b):
raise Exception(f'assert fail {msg}:', str(a), str(b))
2024-10-10 17:55:14 -07:00
def assertEqual(a, b, msg=None):
if a != b:
raise Exception(f'assert fail {msg}:', str(a), str(b))
2024-10-09 19:14:21 -07:00
def MemoryTest():
tbl = [
(Memory(0, rax), 'ptr [rax]'),
(Memory(4, rax), 'dword [rax]'),
(Memory(8, rax, rcx), 'qword [rax+rcx]'),
(Memory(8, rax, rcx, 4), 'qword [rax+rcx*4]'),
(Memory(8, None, rcx, 4), 'qword [rcx*4]'),
(Memory(8, rax, None, 0, 5), 'qword [rax+0x5]'),
(Memory(8, None, None, 0, 255), 'qword [0xff]'),
2024-10-10 17:55:14 -07:00
(Memory(0, r8, r9, 1, 32), 'ptr [r8+r9+0x20]'),
2024-10-09 19:14:21 -07:00
]
for (m, expected) in tbl:
assertEqualStr(m, expected)
2024-10-10 17:55:14 -07:00
assertEqual(Memory(16, rax), Memory(0, rax))
2024-10-09 19:14:21 -07:00
def parseMemoryTest():
print('parseMemoryTest')
tbl = [
('[]', Memory()),
('[rax]', Memory(0, rax)),
('ptr[rax]', Memory(0, rax)),
2024-10-10 17:55:14 -07:00
('ptr_b[rax]', Memory(0, rax, broadcast=True)),
2024-10-09 19:14:21 -07:00
('dword[rbx]', Memory(4, rbx)),
('xword ptr[rcx]', Memory(16, rcx)),
2024-10-10 17:55:14 -07:00
('xmmword ptr[rcx]', Memory(16, rcx)),
2024-10-09 19:14:21 -07:00
('xword ptr[rdx*8]', Memory(16, None, rdx, 8)),
('[12345]', Memory(0, None, None, 0, 12345)),
('[0x12345]', Memory(0, None, None, 0, 0x12345)),
('yword [rax+rdx*4]', Memory(32, rax, rdx, 4)),
('zword [rax+rdx*4+123]', Memory(64, rax, rdx, 4, 123)),
]
for (s, expected) in tbl:
my = parseMemory(s)
assertEqualStr(my, expected)
def parseNmemonicTest():
print('parseNmemonicTest')
tbl = [
('vaddpd(ymm1, ymm2, ymm3 |T_rn_sae);', Nmemonic('vaddpd', [ymm1, ymm2, ymm3], [T_rn_sae])),
('vaddpd ymm1{rne-sae}, ymm2, ymm3', Nmemonic('vaddpd', [ymm1, ymm2, ymm3], [T_rn_sae])),
('mov(rax, dword ptr [rcx + rdx * 8 ] );', Nmemonic('mov', [rax, Memory(4, rcx, rdx, 8)])),
('mov(rax, ptr [rcx + rdx * 8 ] );', Nmemonic('mov', [rax, Memory(0, rcx, rdx, 8)])),
('vcmppd(k1, ymm2, ymm3 |T_sae, 3);', Nmemonic('vcmppd', [k1, ymm2, ymm3, 3], [T_sae])),
('vcmppd k1{sae}, ymm2, ymm3, 0x3', Nmemonic('vcmppd', [k1, ymm2, ymm3, 3], [T_sae])),
2024-10-10 17:55:14 -07:00
('v4fmaddps zmm1, zmm8+3, xmmword ptr [rdx+0x40]', Nmemonic('v4fmaddps', [zmm1, zmm8, Memory(16, rdx, None, 0, 0x40)])),
('vp4dpwssd zmm23{k7}{z}, zmm1+3, xmmword ptr [rax+0x40]', Nmemonic('vp4dpwssd', [zmm23, zmm1, Memory(16, rax, None, 0, 0x40)], [k7, T_z])),
('v4fnmaddps(zmm5 | k5, zmm2, ptr [rcx + 0x80]);', Nmemonic('v4fnmaddps', [zmm5, zmm2, Memory(0, rcx, None, 0, 0x80)], [k5])),
('vpcompressw(zmm30 | k2 |T_z, zmm1);', Nmemonic('vpcompressw', [zmm30, zmm1], [k2, T_z])),
('vpcompressw zmm30{k2}{z}, zmm1', Nmemonic('vpcompressw', [zmm30, zmm1], [k2, T_z])),
('vpshldw(xmm9|k3|T_z, xmm2, ptr [rax + 0x40], 5);', Nmemonic('vpshldw', [xmm9, xmm2, Memory(0, rax, None, 0, 0x40), 5], [k3, T_z])),
('vpshrdd(xmm5|k3|T_z, xmm2, ptr_b [rax + 0x40], 5);', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, True), 5], [k3, T_z])),
('vpshrdd xmm5{k3}{z}, xmm2, dword ptr [rax+0x40]{1to4}, 0x5', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, True), 5], [k3, T_z])),
('vcmpph(k1, xm15, ptr[rax+64], 1);', Nmemonic('vcmpph', [k1, xm15, Memory(0, rax, None, 0, 64), 1])),
2024-10-09 19:14:21 -07:00
]
for (s, expected) in tbl:
e = parseNmemonic(s)
2024-10-10 17:55:14 -07:00
assertEqual(e, expected)
2024-10-09 19:14:21 -07:00
def test():
print('test start')
MemoryTest()
parseMemoryTest()
parseNmemonicTest()
print('test end')
def main():
if len(sys.argv) == 2 and sys.argv[1] == 'test':
test()
elif len(sys.argv) == 3:
run(sys.argv[1], sys.argv[2])
else:
print(f'{__name__} <cpp-text> <xed-text> # compare cpp-text and xed-text generated by xed')
print(f'{__name__} test # for test')
if __name__ == '__main__':
main()