program NTTMontgomery;
{$mode objfpc}{$H+}
{$ASSERTIONS ON}
uses
SysUtils;
type
u32 = LongWord;
u64 = QWord;
pu32 = ^u32;
{ ====== Montgomery Arithmetic ====== }
type
Montgomery = record
modulus: u32; { mod }
mod2: u32; { 2 * mod }
n_inv: u32; { n_inv * mod == -1 (mod 2^32) }
r: u32; { 2^32 % mod }
r2: u32; { (2^32)^2 % mod }
end;
function min_u32(a, b: u32): u32; inline;
begin
if a < b then
Result := a
else
Result := b;
end;
function montgomery_shrink(const mt: Montgomery; val: u32): u32; inline;
begin
Result := min_u32(val, val - mt.modulus);
end;
function montgomery_shrink2(const mt: Montgomery; val: u32): u32; inline;
begin
Result := min_u32(val, val - mt.mod2);
end;
function montgomery_reduce_strict(const mt: Montgomery; val: u64): u32; inline;
var
res: u32;
begin
res := (val + u32(val * mt.n_inv) * u64(mt.modulus)) shr 32;
res := montgomery_shrink(mt, res);
Result := res;
end;
function montgomery_reduce_nonstrict(const mt: Montgomery; val: u64): u32; inline;
var
res: u32;
begin
res := (val + u32(val * mt.n_inv) * u64(mt.modulus)) shr 32;
Result := res;
end;
function montgomery_mul_strict(const mt: Montgomery; a, b: u32): u32; inline;
begin
Result := montgomery_reduce_strict(mt, u64(a) * b);
end;
function montgomery_mul_nonstrict(const mt: Montgomery; a, b: u32): u32; inline;
begin
Result := montgomery_reduce_nonstrict(mt, u64(a) * b);
end;
{ Montgomery power with template-like behavior controlled by function parameters }
function montgomery_power(const mt: Montgomery; b, e: u32; input_in_space, output_in_space: Boolean): u32;
var
r: u32;
b_temp: u32;
e_temp: u32;
begin
b_temp := b;
e_temp := e;
if not input_in_space then
b_temp := montgomery_mul_nonstrict(mt, b_temp, mt.r2);
if output_in_space then
r := mt.r
else
r := 1;
while e_temp > 0 do
begin
if (e_temp and 1) <> 0 then
r := montgomery_mul_nonstrict(mt, r, b_temp);
b_temp := montgomery_mul_nonstrict(mt, b_temp, b_temp);
e_temp := e_temp shr 1;
end;
Result := montgomery_shrink(mt, r);
end;
procedure montgomery_init(var mt: Montgomery; modulus: u32);
var
i: Integer;
begin
Assert(modulus mod 2 = 1);
Assert(modulus < (u32(1) shl 30));
mt.modulus := modulus;
mt.mod2 := 2 * modulus;
mt.n_inv := 1;
for i := 0 to 4 do
begin
mt.n_inv := mt.n_inv * (2 + mt.n_inv * modulus);
end;
mt.r := (u64(1) shl 32) mod modulus;
mt.r2 := u64(mt.r) * mt.r mod modulus;
end;
{ ====== NTT ====== }
type
NTT = record
modulus: u32;
pr_root: u32;
mt: Montgomery;
wd: ^u32;
wrd: ^u32;
lg_max: Integer; { Maximum lg value (size of wd and wrd arrays) }
end;
{ Helper function to count trailing zeros (equivalent to __builtin_ctz) }
function count_trailing_zeros(x: u32): Integer; inline;
var
count: Integer;
begin
if x = 0 then
begin
Result := 32;
Exit;
end;
count := 0;
while (x and 1) = 0 do
begin
x := x shr 1;
Inc(count);
end;
Result := count;
end;
{ Helper function to check if all factors satisfy the condition }
function check_all_factors(const mt: Montgomery; candidate, modulus: u32; factors: pu32; num_factors: Integer): Boolean;
var
i: Integer;
begin
for i := 0 to num_factors - 1 do
begin
if montgomery_power(mt, candidate, (modulus - 1) div factors[i], False, False) = 1 then
begin
Result := False;
Exit;
end;
end;
Result := True;
end;
function ntt_find_pr_root(modulus: u32; const mt: Montgomery): u32;
var
factors: array[0..31] of u32;
num_factors: Integer;
n: u32;
i: u32;
begin
num_factors := 0;
n := modulus - 1;
i := 2;
while u64(i) * i <= n do
begin
if n mod i = 0 then
begin
factors[num_factors] := i;
Inc(num_factors);
repeat
n := n div i;
until n mod i <> 0;
end;
Inc(i);
end;
if n > 1 then
begin
factors[num_factors] := n;
Inc(num_factors);
end;
for i := 2 to modulus - 1 do
begin
if check_all_factors(mt, i, modulus, @factors[0], num_factors) then
begin
Result := i;
Exit;
end;
end;
Assert(False, 'primitive root not found');
Result := 0;
end;
procedure ntt_init(var ntt: NTT; modulus: u32);
var
mt: ^Montgomery;
lg: Integer;
i: Integer;
a, b, f: u32;
begin
ntt.modulus := modulus;
montgomery_init(ntt.mt, modulus);
mt := @ntt.mt;
ntt.pr_root := ntt_find_pr_root(modulus, mt^);
lg := count_trailing_zeros(modulus - 1);
ntt.lg_max := lg;
GetMem(ntt.wd, lg * SizeOf(u32));
GetMem(ntt.wrd, lg * SizeOf(u32));
FillChar(ntt.wd^, lg * SizeOf(u32), 0);
FillChar(ntt.wrd^, lg * SizeOf(u32), 0);
for i := 0 to lg - 2 do
begin
a := montgomery_power(mt^, ntt.pr_root, (modulus - 1) shr (i + 2), False, True);
b := montgomery_power(mt^, ntt.pr_root, ((modulus - 1) shr (i + 1)) * ((1 shl i) - 1), False, True);
f := montgomery_mul_strict(mt^, a, montgomery_power(mt^, b, modulus - 2, True, True));
ntt.wd[i] := f;
ntt.wrd[i] := montgomery_power(mt^, f, modulus - 2, True, True);
end;
end;
procedure ntt_free(var ntt: NTT);
begin
FreeMem(ntt.wd);
FreeMem(ntt.wrd);
ntt.wd := nil;
ntt.wrd := nil;
end;
{ Butterfly operation - non-transposed, trivial }
procedure butterfly_x2_forward_trivial(var a, b: u32; w: u32; const mt: Montgomery); inline;
var
a1, b1: u32;
begin
a1 := montgomery_shrink2(mt, a);
b1 := montgomery_shrink2(mt, b);
a := a1 + b1;
b := a1 + mt.mod2 - b1;
end;
{ Butterfly operation - non-transposed, non-trivial }
procedure butterfly_x2_forward(var a, b: u32; w: u32; const mt: Montgomery); inline;
var
a1, b1: u32;
begin
a1 := montgomery_shrink2(mt, a);
b1 := montgomery_mul_nonstrict(mt, b, w);
a := a1 + b1;
b := a1 + mt.mod2 - b1;
end;
{ Butterfly operation - transposed, trivial }
procedure butterfly_x2_inverse_trivial(var a, b: u32; w: u32; const mt: Montgomery); inline;
var
a2, b2: u32;
begin
a2 := montgomery_shrink2(mt, a + b);
b2 := montgomery_shrink2(mt, a + mt.mod2 - b);
a := a2;
b := b2;
end;
{ Butterfly operation - transposed, non-trivial }
procedure butterfly_x2_inverse(var a, b: u32; w: u32; const mt: Montgomery); inline;
var
a2, b2: u32;
begin
a2 := montgomery_shrink2(mt, a + b);
b2 := montgomery_mul_nonstrict(mt, a + mt.mod2 - b, w);
a := a2;
b := b2;
end;
{ Transform auxiliary function - forward, trivial }
procedure transform_aux_forward_trivial(const ntt: NTT; k, i: Integer; data: pu32; var wi: u32);
var
mt: ^Montgomery;
j: Integer;
begin
mt := @ntt.mt;
for j := 0 to (1 shl k) - 1 do
begin
butterfly_x2_forward_trivial(data[i + j], data[i + (1 shl k) + j], wi, mt^);
end;
wi := montgomery_mul_strict(mt^, wi, ntt.wd[count_trailing_zeros(not i shr (k + 1))]);
end;
{ Transform auxiliary function - forward, non-trivial }
procedure transform_aux_forward(const ntt: NTT; k, i: Integer; data: pu32; var wi: u32);
var
mt: ^Montgomery;
j: Integer;
begin
mt := @ntt.mt;
for j := 0 to (1 shl k) - 1 do
begin
butterfly_x2_forward(data[i + j], data[i + (1 shl k) + j], wi, mt^);
end;
wi := montgomery_mul_strict(mt^, wi, ntt.wd[count_trailing_zeros(not i shr (k + 1))]);
end;
{ Transform auxiliary function - inverse, trivial }
procedure transform_aux_inverse_trivial(const ntt: NTT; k, i: Integer; data: pu32; var wi: u32);
var
mt: ^Montgomery;
j: Integer;
begin
mt := @ntt.mt;
for j := 0 to (1 shl k) - 1 do
begin
butterfly_x2_inverse_trivial(data[i + j], data[i + (1 shl k) + j], wi, mt^);
end;
wi := montgomery_mul_strict(mt^, wi, ntt.wrd[count_trailing_zeros(not i shr (k + 1))]);
end;
{ Transform auxiliary function - inverse, non-trivial }
procedure transform_aux_inverse(const ntt: NTT; k, i: Integer; data: pu32; var wi: u32);
var
mt: ^Montgomery;
j: Integer;
begin
mt := @ntt.mt;
for j := 0 to (1 shl k) - 1 do
begin
butterfly_x2_inverse(data[i + j], data[i + (1 shl k) + j], wi, mt^);
end;
wi := montgomery_mul_strict(mt^, wi, ntt.wrd[count_trailing_zeros(not i shr (k + 1))]);
end;
{ input in [0, 4 * mod) }
{ output in [0, 4 * mod) }
procedure ntt_transform_forward(const ntt: NTT; lg: Integer; data: pu32);
var
mt: ^Montgomery;
k, i: Integer;
wi: u32;
begin
mt := @ntt.mt;
for k := lg - 1 downto 0 do
begin
wi := mt^.r;
transform_aux_forward_trivial(ntt, k, 0, data, wi);
i := 1 shl (k + 1);
while i < (1 shl lg) do
begin
transform_aux_forward(ntt, k, i, data, wi);
Inc(i, 1 shl (k + 1));
end;
end;
end;
{ input in [0, 2 * mod) }
{ output in [0, mod) }
procedure ntt_transform_inverse(const ntt: NTT; lg: Integer; data: pu32);
var
mt: ^Montgomery;
k, i: Integer;
wi, f: u32;
begin
mt := @ntt.mt;
for k := 0 to lg - 1 do
begin
wi := mt^.r;
transform_aux_inverse_trivial(ntt, k, 0, data, wi);
i := 1 shl (k + 1);
while i < (1 shl lg) do
begin
transform_aux_inverse(ntt, k, i, data, wi);
Inc(i, 1 shl (k + 1));
end;
end;
f := montgomery_power(mt^, (ntt.modulus + 1) shr 1, lg, False, True);
for i := 0 to (1 shl lg) - 1 do
begin
data[i] := montgomery_mul_strict(mt^, data[i], f);
end;
end;
{ input in [0, 2 * mod) }
{ output in [0, mod) }
{ with multiplication by scale factor }
procedure ntt_transform_inverse_scaled(const ntt: NTT; lg: Integer; data: pu32; sc: u32);
var
mt: ^Montgomery;
k, i: Integer;
wi, f: u32;
begin
mt := @ntt.mt;
for k := 0 to lg - 1 do
begin
wi := mt^.r;
transform_aux_inverse_trivial(ntt, k, 0, data, wi);
i := 1 shl (k + 1);
while i < (1 shl lg) do
begin
transform_aux_inverse(ntt, k, i, data, wi);
Inc(i, 1 shl (k + 1));
end;
end;
f := montgomery_power(mt^, (ntt.modulus + 1) shr 1, lg, False, True);
f := montgomery_mul_strict(mt^, f, montgomery_mul_nonstrict(mt^, mt^.r2, sc));
for i := 0 to (1 shl lg) - 1 do
begin
data[i] := montgomery_mul_strict(mt^, data[i], f);
end;
end;
procedure ntt_convolve_cyclic(const ntt: NTT; lg: Integer; a, b: pu32);
var
mt: ^Montgomery;
i: Integer;
begin
ntt_transform_forward(ntt, lg, a);
ntt_transform_forward(ntt, lg, b);
mt := @ntt.mt;
for i := 0 to (1 shl lg) - 1 do
begin
a[i] := montgomery_mul_nonstrict(mt^, montgomery_shrink2(mt^, a[i]), montgomery_shrink2(mt^, b[i]));
end;
ntt_transform_inverse_scaled(ntt, lg, a, mt^.r);
end;
{ ====== Helper functions for main ====== }
{ Simple random number generator state (replacing std::mt19937) }
var
rnd_state: LongWord = 1;
procedure rnd_seed(seed: LongWord);
begin
rnd_state := seed;
end;
function rnd_next: u32;
begin
{ Simple LCG }
rnd_state := rnd_state * 1664525 + 1013904223;
Result := rnd_state;
end;
{ Helper to find maximum of two integers }
function max_int(a, b: Integer): Integer; inline;
begin
if a > b then
Result := a
else
Result := b;
end;
{ Helper to find minimum of two integers }
function min_int(a, b: Integer): Integer; inline;
begin
if a < b then
Result := a
else
Result := b;
end;
{ Shuffle function (Fisher-Yates shuffle) }
procedure shuffle(arr: PInteger; start, finish: Integer);
var
i, j: Integer;
temp: Integer;
begin
for i := finish - 1 downto start + 1 do
begin
j := start + (rnd_next mod (i - start + 1));
temp := arr[i];
arr[i] := arr[j];
arr[j] := temp;
end;
end;
{ ====== main ====== }
var
prime_list: array[0..8] of u32 = (
3,
5,
17,
257,
65537,
1000000007,
1000000009,
1000001329,
998244353
);
prime_list_size: Integer;
p_idx: Integer;
modulus: u32;
ntt2: NTT;
lg: Integer;
beg: QWord;
num_iterations: Integer;
k: Integer;
a, b, a2, b2: pu32;
i, j: Integer;
all: PInteger;
all_size: Integer;
swap_idx: Integer;
temp: Integer;
test_size: Integer;
idx: Integer;
val: u32;
elapsed: Double;
begin
prime_list_size := Length(prime_list);
for p_idx := 0 to prime_list_size - 1 do
begin
modulus := prime_list[p_idx];
WriteLn(StdErr, 'testing mod: ', modulus);
rnd_seed(5489); { Default seed for mt19937 }
ntt_init(ntt2, modulus);
lg := 0;
while (modulus - 1) mod (1 shl lg) = 0 do
begin
beg := GetTickCount64;
num_iterations := max_int(1, 1000000 div (1 shl lg));
for k := 0 to num_iterations - 1 do
begin
{ Generate random vectors a and b }
GetMem(a, (1 shl lg) * SizeOf(u32));
GetMem(b, (1 shl lg) * SizeOf(u32));
for i := 0 to (1 shl lg) - 1 do
begin
a[i] := rnd_next mod modulus;
b[i] := rnd_next mod modulus;
end;
{ Create copies a2 and b2 }
GetMem(a2, (1 shl lg) * SizeOf(u32));
GetMem(b2, (1 shl lg) * SizeOf(u32));
Move(a^, a2^, (1 shl lg) * SizeOf(u32));
Move(b^, b2^, (1 shl lg) * SizeOf(u32));
ntt_convolve_cyclic(ntt2, lg, a2, b2);
{ Create and initialize the 'all' array }
all_size := 1 shl lg;
GetMem(all, all_size * SizeOf(Integer));
for i := 0 to all_size - 1 do
begin
all[i] := i;
end;
{ Swap last element with element at index min(size-1, 1) }
swap_idx := min_int(all_size - 1, 1);
temp := all[all_size - 1];
all[all_size - 1] := all[swap_idx];
all[swap_idx] := temp;
{ Shuffle from index min(size, 2) to end }
shuffle(all, min_int(all_size, 2), all_size);
{ Resize to min(30, size) }
test_size := min_int(30, all_size);
{ Verify results }
for idx := 0 to test_size - 1 do
begin
i := all[idx];
val := 0;
for j := 0 to (1 shl lg) - 1 do
begin
val := val + u64(a[j]) * b[((1 shl lg) + i - j) mod (1 shl lg)] mod modulus;
val := val mod modulus;
end;
Assert(a2[i] = val);
end;
FreeMem(a);
FreeMem(b);
FreeMem(a2);
FreeMem(b2);
FreeMem(all);
end;
elapsed := (GetTickCount64 - beg) / 1000.0;
if lg < 10 then
Write(StdErr, lg, ' ')
else
Write(StdErr, lg, ' ');
WriteLn(StdErr, Format('%.5fs OK', [elapsed]));
Inc(lg);
end;
WriteLn(StdErr);
ntt_free(ntt2);
end;
end.