# The current problem list corresponds to sequence length paremeter value
# equals to 40 (seq_len = 40) and beam patameter equals to 4
# batch = 1

# =============================================================================
# Encoder part
# num_encoder_stages = 6
#   2d problems - M = batch * seq_len
#   4d problems - M = seq_len, B = batch x 16
# =============================================================================

--reset
--dt=bf16
--bia-dt=bf16 --bia_mask=2

--stag=ab --wtag=any --dtag=ab
# 6 per each encoder stage, 36 in total
40x1024:1024x1024n"Transformer_lt:Encoder_MM_1*36"

--bia-dt=undef
--stag=abcd --wtag=abdc --dtag=abcd
1x16x40x64:1x16x64x40n"Transformer_lt:Encoder_MM_4*6"
--stag=abcd --wtag=abcd --dtag=abcd
1x16x40x40:1x16x40x64n"Transformer_lt:Encoder_MM_5*6"

--bia-dt=bf16 --bia_mask=2
--stag=ab --wtag=any --dtag=ab
40x1024:1024x4096n"Transformer_lt:Encoder_MM_7*6"
40x4096:4096x1024n"Transformer_lt:Encoder_MM_8*6"

# =============================================================================
# Decoder part
#   2d problems - M = batch * beam
#   4d problems - M = beam, B = batch x 16
# number of calls depends on sequence length value
# =============================================================================

--bia-dt=bf16 --bia_mask=2
--stag=ab --wtag=any --dtag=ab
# 6 per each encoder and sequence length, 1440 in total for seq_len = 40
4x1024:1024x1024n"Transformer_lt:Decoder_MM_1*1440"

--bia-dt=undef
--stag=abcd --wtag=abdc --dtag=abcd
# 1 per each encoder and sequence length, 240 in total for seq_len = 40
1x16x4x64:1x16x64x40n"Transformer_lt:Decoder_MM_4*240"
--stag=abcd --wtag=abcd --dtag=abcd
# 1 per each encoder and sequence length, 240 in total for seq_len = 40
1x16x4x40:1x16x40x64n"Transformer_lt:Decoder_MM_5*240"

--bia-dt=bf16 --bia_mask=2
--stag=ab --wtag=any --dtag=ab
# 1 per each encoder and sequence length, 240 in total for seq_len = 40
4x1024:1024x4096n"Transformer_lt:Decoder_MM_7*240"
# 1 per each encoder and sequence length, 240 in total for seq_len = 40
4x4096:4096x1024n"Transformer_lt:Decoder_MM_8*240"
# 1 per each sequence length, 40 in total for seq_len = 40
4x1024:1024x32768n"Transformer_lt:Decoder_vocabulary*40"

# for each encoder and sequence length, there is set of problems like
# batch x 16 x 1 x 64 : batch x 16 x 64 x sl
# batch x 16 x 1 x sl : batch x 16 x sl x 64
# for  1 <= sl <= seq_len
# instead of running whole set of such problems we fix sl = 20 and consider
# this problem is running 6*seq_len times

--bia-dt=undef
--stag=abcd --wtag=abdc --dtag=abcd
1x16x1x64:1x16x64x20n"Transformer_lt:Decoder_MM_xx20*240"
--stag=abcd --wtag=abcd --dtag=abcd
1x16x1x20:1x16x20x64n"Transformer_lt:Decoder_MM_yy20*240"
