# floating-point graphs
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/JAX-MHA-inf-fp32.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/JAX-MQA-inf-fp32.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-LLaMa-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA-starcoder-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA_backward-Bert_large-train-fp32-bs4.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --case=complex_fusion/mha/MHA_forward-Bert_large-train-fp32-bs4.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json

# int8 graphs
--reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json
--reset --expected-n-partitions=0 --case=complex_fusion/mha/MHA-LLaMa-inf-int8-bs1.json
--reset --case=complex_fusion/mha/MHA-bert_large-inf-int8-bs1.json
--reset --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json
--reset --expected-n-partitions=0 --case=complex_fusion/mha/MHA-starcoder-inf-int8-bs1.json
--reset --expected-n-partitions=0 --case=complex_fusion/mha/dynamic_quantized_mha-Bert_large-inf-int8-bs1-fake.json
--reset --case=complex_fusion/mha/sdpa-plain-wo-scale-int8-bs1.json
--reset --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
--reset --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --case=complex_fusion/mha/sdpa-compressed-v-int8-gs32.json
--reset --case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json

# Re-written graphs
--reset --dt=f32,bf16,f16 --in-shapes=4:4x16x32x256+5:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=3:4x32x32x128+4:4x32x128x33+0:4x32x33x128+1:4x1x32x33 --case=complex_fusion/mha/MHA-LLaMa-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=3:20x16x384x64+4:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=3:10x16x384x64+4:10x1x64x384+0:10x1x384x64+1:10x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=4:56x12x128x64+5:56x12x64x128+0:56x12x128x64+1:56x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=2:1x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=0:56x8x1024x80+1:56x8x77x80+2:56x8x77x80 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=5:20x117x48x128+6:20x1x128x117+19:20x1x117x128 --case=complex_fusion/mha/MHA-starcoder-inf-fp32-bs1.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=2514:32x16x512x64+2518:32x16x512x64+2543:32x1x512x512+2547:32x16x512x512+2525:32x16x512x64 --op-attrs=4837:shape:16384x1024 --case=complex_fusion/mha/MHA_forward-Bert_large-train-fp32-bs4.json
--reset --expected-n-partitions=0 --dt=f32,bf16,f16 --in-shapes=2514:32x16x512x64+2518:32x16x512x64+2591:32x16x512x512+2545:32x16x512x512+2547:32x16x512x512+2525:32x16x512x64+2548:32x16x512x512+5178:16384x1024 --op-attrs=7392:shape:32x512x16x64 --case=complex_fusion/mha/MHA_backward-Bert_large-train-fp32-bs4.json
--reset --dt=f32,bf16,f16 --in-shapes=0:20x16x384x64+1:20x16x384x64+8:20x16x384x64+5:20x1x1x384 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
--reset --dt=f32,bf16,f16 --in-shapes=5:1x1x384x384,5:1x16x384x384 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:2x16x384x64+1:2x16x384x64+5:2x1x1x384+8:2x16x384x64  --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:32x16x128x64+1:32x16x128x64+5:32x16x128x128+8:32x16x128x64  --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:acbd+1:acbd+8:acbd  --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=3:384,3:384x384,3:1x16x384x384 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=5:384,5:1x384 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
--reset --op-attrs=34107656704:group_shape:1x1x1x32+34107654464:transpose_b:1 --in-shapes=0:1x32x32x128+1:1x32x32x4+2:1x32x32x4 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --op-attrs=34107656704:qtype:per_channel*axis:3 --in-shapes=1:32+2:1 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --dt=f32,bf16,f16 --op-attrs=2:axis:-2+3:axis:-1 --in-shapes=0:1x32x128x64+1:1x32x128x64+11:1x32x128x64 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json

# Re-written int8 graphs
--reset --in-shapes=5:4x16x32x256+4:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json
--reset --expected-n-partitions=0 --in-shapes=4:4x32x32x128+3:4x32x128x33+0:4x32x33x128+1:4x1x32x33 --case=complex_fusion/mha/MHA-LLaMa-inf-int8-bs1.json
--reset --in-shapes=4:20x16x384x64+3:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/MHA-bert_large-inf-int8-bs1.json
--reset --in-shapes=5:56x12x128x64+4:56x12x64x128+0:56x12x128x64+1:56x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json
--reset --in-shapes=2:1x1x1x128 --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json
--reset --expected-n-partitions=0 --in-shapes=4:20x117x48x128+3:20x1x128x117+0:20x1x117x128 --case=complex_fusion/mha/MHA-starcoder-inf-int8-bs1.json
--reset --expected-n-partitions=0 --in-shapes=4:32x16x384x64+3:32x16x64x384+0:32x16x384x64+1:32x1x1x384 --case=complex_fusion/mha/dynamic_quantized_mha-Bert_large-inf-int8-bs1-fake.json
--reset --in-shapes=4:20x16x384x64+3:20x16x64x384+0:20x16x384x64+1:20x1x1x384 --case=complex_fusion/mha/sdpa-plain-wo-scale-int8-bs1.json
--reset --in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1 --op-attrs=0:group_shape:1x1x96x1+8:group_shape:1x1x1x96 --case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json

# phi3-mini-4k-instruct
--reset
--dt=0:s8+2:s8+6:s8+8:s8
--in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+5:1x1x384x384+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1,\
            0:1x32x96x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x96+5:1x1x385x385+6:1x32x385x96+7:1x32x385x1+8:1x32x385x1,\
            0:1x32x96x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x96+5:1x1x512x512+6:1x32x512x96+7:1x32x512x1+8:1x32x512x1,\
            0:1x32x96x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x96+5:1x1x513x513+6:1x32x513x96+7:1x32x513x1+8:1x32x513x1,\
            0:1x32x96x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x96+5:1x1x1024x1024+6:1x32x1024x96+7:1x32x1024x1+8:1x32x1024x1,\
            0:1x32x96x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x96+5:1x1x1025x1025+6:1x32x1025x96+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x96x1+99:group_shape:1x1x1x96
--case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json

--reset
--dt=0:s8+2:s8+6:s8+8:s8
--in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1,\
            0:1x32x96x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x96+6:1x32x385x96+7:1x32x385x1+8:1x32x385x1,\
            0:1x32x96x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x96+6:1x32x512x96+7:1x32x512x1+8:1x32x512x1,\
            0:1x32x96x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x96+6:1x32x513x96+7:1x32x513x1+8:1x32x513x1,\
            0:1x32x96x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x96+6:1x32x1024x96+7:1x32x1024x1+8:1x32x1024x1,\
            0:1x32x96x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x96+6:1x32x1025x96+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x96x1+8:group_shape:1x1x1x96
--case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json

# llama-2-7b-chat
--in-shapes=0:1x32x128x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x128+5:1x1x384x384+6:1x32x384x128+7:1x32x384x1+8:1x32x384x1,\
            0:1x32x128x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x128+5:1x1x385x385+6:1x32x385x128+7:1x32x385x1+8:1x32x385x1,\
            0:1x32x128x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x128+5:1x1x512x512+6:1x32x512x128+7:1x32x512x1+8:1x32x512x1,\
            0:1x32x128x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x128+5:1x1x513x513+6:1x32x513x128+7:1x32x513x1+8:1x32x513x1,\
            0:1x32x128x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x128+5:1x1x1024x1024+6:1x32x1024x128+7:1x32x1024x1+8:1x32x1024x1,\
            0:1x32x128x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x128+5:1x1x1025x1025+6:1x32x1025x128+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x128x1+99:group_shape:1x1x1x128
--case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json

--in-shapes=0:1x32x128x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x128+6:1x32x384x128+7:1x32x384x1+8:1x32x384x1,\
            0:1x32x128x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x128+6:1x32x385x128+7:1x32x385x1+8:1x32x385x1,\
            0:1x32x128x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x128+6:1x32x512x128+7:1x32x512x1+8:1x32x512x1,\
            0:1x32x128x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x128+6:1x32x513x128+7:1x32x513x1+8:1x32x513x1,\
            0:1x32x128x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x128+6:1x32x1024x128+7:1x32x1024x1+8:1x32x1024x1,\
            0:1x32x128x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x128+6:1x32x1025x128+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x128x1+8:group_shape:1x1x1x128
--case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json

# 0: key, 2: key zps, 6: value, 8: value zps. Change them to use s8 data type.
--reset --dt=0:s8+2:s8+6:s8+8:s8 --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
# Change group size to 128. It also affects the shapes of the scales and zps.
--reset --dt=0:s8+2:s8+6:s8+8:s8 --op-attrs=0:group_shape:1x1x128x1+99:group_shape:1x1x1x128 --in-shapes=1:1x32x1x32+2:1x32x1x32+7:1x32x32x1+8:1x32x32x1 --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
