Recent

Author Topic: [SOLVED] Need help converting a C++ math program  (Read 4528 times)

jamie

  • Hero Member
  • *****
  • Posts: 7544
Re: Need help converting a C++ math program
« Reply #60 on: December 10, 2025, 11:37:23 pm »
You may want to check the values in the given test run.

for example
In C
int(1e9 + 7) gives you 10000007, i may have a zero missing.

In Pascal

Int(1e9 + 7) gives you 10000000 etc. its missing the 7 at the end.

I get failures due to this etc.

who is correct ?

P.S.
  I just did that in google and it agrees with the C compiler, so fpc is generating the wrong results.


Jamie
« Last Edit: December 10, 2025, 11:42:48 pm by jamie »
The only true wisdom is knowing you know nothing

bytebites

  • Hero Member
  • *****
  • Posts: 778
Re: Need help converting a C++ math program
« Reply #61 on: December 11, 2025, 12:10:59 am »
This code gives same result than pascal when n_inv and mod are uint64.

Code: C  [Select][+][-]
  1. #include <iostream>
  2. #include <cstdint>
  3. using namespace std;
  4.  
  5. int main() {
  6.  
  7.   uint64_t val=0x6F1BF78C;
  8.   uint32_t n_inv=0x3B7FFFFF;
  9.   uint32_t mod=0x3b800001;
  10.  
  11.   uint32_t res = val + uint32_t(val) * n_inv * uint64_t(mod) >> 32;    
  12.  
  13.   cout << std::hex << res << endl;    
  14.  
  15.   return 0;
  16. }

speter

  • Sr. Member
  • ****
  • Posts: 487
Re: Need help converting a C++ math program
« Reply #62 on: December 11, 2025, 12:24:25 am »
In Pascal

Int(1e9 + 7) gives you 10000000 etc. its missing the 7 at the end.

Code: Pascal  [Select][+][-]
  1. var
  2.   a : integer;
  3. begin
  4.   a := int(1e9 +7);
  5.   memo1.append(format('a = %d',[a]));

I get an error (line 4 above):
Quote
Error: Incompatible types got "Double" expected "LongInt".

changing the code to:
Code: Pascal  [Select][+][-]
  1. var
  2.   a : integer;
  3. begin
  4.   a := trunc(1e9) +7;
  5.   memo1.append(format('a = %d',[a]));
yields the correct result.

cheers
S.
I climbed mighty mountains, and saw that they were actually tiny foothills. :)

MathMan

  • Sr. Member
  • ****
  • Posts: 472
Re: Need help converting a C++ math program
« Reply #63 on: December 11, 2025, 12:27:49 am »
This code gives same result than pascal when n_inv and mod are uint64.

Code: C  [Select][+][-]
  1. #include <iostream>
  2. #include <cstdint>
  3. using namespace std;
  4.  
  5. int main() {
  6.  
  7.   uint64_t val=0x6F1BF78C;
  8.   uint32_t n_inv=0x3B7FFFFF;
  9.   uint32_t mod=0x3b800001;
  10.  
  11.   uint32_t res = val + uint32_t(val) * n_inv * uint64_t(mod) >> 32;    
  12.  
  13.   cout << std::hex << res << endl;    
  14.  
  15.   return 0;
  16. }

Interesting. Now, if you can explain why the version I have is giving a different result and what I have to change in the Pascal version to also get $23fffff8 as a result I'm a happy man. And of course eternally gratefull ...

jamie

  • Hero Member
  • *****
  • Posts: 7544
Re: Need help converting a C++ math program
« Reply #64 on: December 11, 2025, 12:40:38 am »
I also fixed it doing this..

Trunc(Double(1e9)+7);

I don't know why I need to do this, it works correctly in the C compiler.

Jamie
The only true wisdom is knowing you know nothing

MathMan

  • Sr. Member
  • ****
  • Posts: 472
Re: Need help converting a C++ math program
« Reply #65 on: December 11, 2025, 02:47:26 am »
I'm getting closer. Installed a C compiler and made a debugging tour de force.

The issue is the reduce function. It looks as follows in C and Pascal

Code: C  [Select][+][-]
  1.     template <bool strict = true>
  2.     u32 reduce(u64 val) const {
  3.         u32 res = val + u32(val) * n_inv * u64(mod) >> 32;
  4.         if constexpr (strict)
  5.             res = shrink(res);
  6.         return res;
  7.     }

Code: Pascal  [Select][+][-]
  1. function reduce(
  2.   const val: u64;
  3.   const strict: boolean=TRUE
  4. ):u32;inline;
  5.  
  6. begin
  7.   Result := ( val + u32( val )*mt.n_inv*u64( mt.mod1 ) ) shr 32;
  8.   if( strict ) then
  9.     Result := shrink( Result );
  10. end;

If the functions are entered with strict=false val=$000000006F1BF78C n_inv=$3B7FFFFF and mod=$3b800001 then the result is

Code: Pascal  [Select][+][-]
  1. Pascal: $19D30000
  2. C: $23fffff8

I already tried several modifications to the Pascal version, but I can't see how the C result could be generated  :o

Quoting myself here, as this is so funny. I solved it, and the correct translation to Pascal is

Code: Pascal  [Select][+][-]
  1. Result := ( val + u32( val*mt.n_inv )*u64( mt.mod1 ) ) shr 32;

What's happening here is that mul binds stronger than plus. But the muls are evaluated from left to right in C and the first u32( val )*n_inv in fact casts the complete product to u32! This then gets multiplied by mod, but the cast u64( mod ) brings the final product to 64 bit again to which val is added and the sum is right shifted.

Yippee,
MathMan

srvaldez

  • Full Member
  • ***
  • Posts: 191
Re: [SOLVED] Need help converting a C++ math program
« Reply #66 on: December 11, 2025, 12:08:44 pm »
good day Mr. MathMan  :)
I applied your solution to my FreeBasic version and now it works in 64-bit also  :D
I know that this the Lazarus/FPC forum but I think that posting the snippet that involves your solution can be instructive
this is what I had before the fix in function montgomery_reduce_strict and in function montgomery_reduce_nonstrict
dim res as u32 = (value + ((cast(u32, value) * mt->n_inv) * cast(u64, mt->mod1))) shr 32

the fix

dim res as u32 = (value + ((cast(u32, value * mt->n_inv)) * cast(u64, mt->mod1))) shr 32


srvaldez

  • Full Member
  • ***
  • Posts: 191
Re: [SOLVED] Need help converting a C++ math program
« Reply #67 on: December 11, 2025, 12:49:30 pm »
C++
the statement
 u32 res = val + u32(val) * n_inv * u64(mod) >> 32;
can be made unambigious
 u32 res = (val + u32(val) * n_inv * u64(mod)) >> 32;
but it also seems that cast u32(val) extends to the expression following like u32(val * mt->n_inv)

srvaldez

  • Full Member
  • ***
  • Posts: 191
Re: [SOLVED] Need help converting a C++ math program
« Reply #68 on: December 11, 2025, 01:52:33 pm »
I had a FPC translation of the C++ NTT code but it would crash after printing one line, I applied the fix to the code and now it works
Code: Pascal  [Select][+][-]
  1. program NTTMontgomery;
  2.  
  3. {$mode objfpc}{$H+}
  4. {$ASSERTIONS ON}
  5.  
  6. uses
  7.   SysUtils;
  8.  
  9. type
  10.   u32 = LongWord;
  11.   u64 = QWord;
  12.   pu32 = ^u32;
  13. { ====== Montgomery Arithmetic ====== }
  14.  
  15. type
  16.   Montgomery = record
  17.     modulus: u32;    { mod }
  18.     mod2: u32;       { 2 * mod }
  19.     n_inv: u32;      { n_inv * mod == -1 (mod 2^32) }
  20.     r: u32;          { 2^32 % mod }
  21.     r2: u32;         { (2^32)^2 % mod }
  22.   end;
  23.  
  24. function min_u32(a, b: u32): u32; inline;
  25. begin
  26.   if a < b then
  27.     Result := a
  28.   else
  29.     Result := b;
  30. end;
  31.  
  32. function montgomery_shrink(const mt: Montgomery; val: u32): u32; inline;
  33. begin
  34.   Result := min_u32(val, val - mt.modulus);
  35. end;
  36.  
  37. function montgomery_shrink2(const mt: Montgomery; val: u32): u32; inline;
  38. begin
  39.   Result := min_u32(val, val - mt.mod2);
  40. end;
  41.  
  42. function montgomery_reduce_strict(const mt: Montgomery; val: u64): u32; inline;
  43. var
  44.   res: u32;
  45. begin
  46.   res := (val + u32(val * mt.n_inv) * u64(mt.modulus)) shr 32;
  47.   res := montgomery_shrink(mt, res);
  48.   Result := res;
  49. end;
  50.  
  51. function montgomery_reduce_nonstrict(const mt: Montgomery; val: u64): u32; inline;
  52. var
  53.   res: u32;
  54. begin
  55.   res := (val + u32(val * mt.n_inv) * u64(mt.modulus)) shr 32;
  56.   Result := res;
  57. end;
  58.  
  59. function montgomery_mul_strict(const mt: Montgomery; a, b: u32): u32; inline;
  60. begin
  61.   Result := montgomery_reduce_strict(mt, u64(a) * b);
  62. end;
  63.  
  64. function montgomery_mul_nonstrict(const mt: Montgomery; a, b: u32): u32; inline;
  65. begin
  66.   Result := montgomery_reduce_nonstrict(mt, u64(a) * b);
  67. end;
  68.  
  69. { Montgomery power with template-like behavior controlled by function parameters }
  70. function montgomery_power(const mt: Montgomery; b, e: u32; input_in_space, output_in_space: Boolean): u32;
  71. var
  72.   r: u32;
  73.   b_temp: u32;
  74.   e_temp: u32;
  75. begin
  76.   b_temp := b;
  77.   e_temp := e;
  78.  
  79.   if not input_in_space then
  80.     b_temp := montgomery_mul_nonstrict(mt, b_temp, mt.r2);
  81.  
  82.   if output_in_space then
  83.     r := mt.r
  84.   else
  85.     r := 1;
  86.  
  87.   while e_temp > 0 do
  88.   begin
  89.     if (e_temp and 1) <> 0 then
  90.       r := montgomery_mul_nonstrict(mt, r, b_temp);
  91.     b_temp := montgomery_mul_nonstrict(mt, b_temp, b_temp);
  92.     e_temp := e_temp shr 1;
  93.   end;
  94.  
  95.   Result := montgomery_shrink(mt, r);
  96. end;
  97.  
  98. procedure montgomery_init(var mt: Montgomery; modulus: u32);
  99. var
  100.   i: Integer;
  101. begin
  102.   Assert(modulus mod 2 = 1);
  103.   Assert(modulus < (u32(1) shl 30));
  104.  
  105.   mt.modulus := modulus;
  106.   mt.mod2 := 2 * modulus;
  107.   mt.n_inv := 1;
  108.   for i := 0 to 4 do
  109.   begin
  110.     mt.n_inv := mt.n_inv * (2 + mt.n_inv * modulus);
  111.   end;
  112.   mt.r := (u64(1) shl 32) mod modulus;
  113.   mt.r2 := u64(mt.r) * mt.r mod modulus;
  114. end;
  115.  
  116. { ====== NTT ====== }
  117.  
  118. type
  119.   NTT = record
  120.     modulus: u32;
  121.     pr_root: u32;
  122.     mt: Montgomery;
  123.     wd: ^u32;
  124.     wrd: ^u32;
  125.     lg_max: Integer;  { Maximum lg value (size of wd and wrd arrays) }
  126.   end;
  127.  
  128. { Helper function to count trailing zeros (equivalent to __builtin_ctz) }
  129. function count_trailing_zeros(x: u32): Integer; inline;
  130. var
  131.   count: Integer;
  132. begin
  133.   if x = 0 then
  134.   begin
  135.     Result := 32;
  136.     Exit;
  137.   end;
  138.   count := 0;
  139.   while (x and 1) = 0 do
  140.   begin
  141.     x := x shr 1;
  142.     Inc(count);
  143.   end;
  144.   Result := count;
  145. end;
  146.  
  147. { Helper function to check if all factors satisfy the condition }
  148. function check_all_factors(const mt: Montgomery; candidate, modulus: u32; factors: pu32; num_factors: Integer): Boolean;
  149. var
  150.   i: Integer;
  151. begin
  152.   for i := 0 to num_factors - 1 do
  153.   begin
  154.     if montgomery_power(mt, candidate, (modulus - 1) div factors[i], False, False) = 1 then
  155.     begin
  156.       Result := False;
  157.       Exit;
  158.     end;
  159.   end;
  160.   Result := True;
  161. end;
  162.  
  163. function ntt_find_pr_root(modulus: u32; const mt: Montgomery): u32;
  164. var
  165.   factors: array[0..31] of u32;
  166.   num_factors: Integer;
  167.   n: u32;
  168.   i: u32;
  169. begin
  170.   num_factors := 0;
  171.   n := modulus - 1;
  172.  
  173.   i := 2;
  174.   while u64(i) * i <= n do
  175.   begin
  176.     if n mod i = 0 then
  177.     begin
  178.       factors[num_factors] := i;
  179.       Inc(num_factors);
  180.       repeat
  181.         n := n div i;
  182.       until n mod i <> 0;
  183.     end;
  184.     Inc(i);
  185.   end;
  186.   if n > 1 then
  187.   begin
  188.     factors[num_factors] := n;
  189.     Inc(num_factors);
  190.   end;
  191.  
  192.   for i := 2 to modulus - 1 do
  193.   begin
  194.     if check_all_factors(mt, i, modulus, @factors[0], num_factors) then
  195.     begin
  196.       Result := i;
  197.       Exit;
  198.     end;
  199.   end;
  200.  
  201.   Assert(False, 'primitive root not found');
  202.   Result := 0;
  203. end;
  204.  
  205. procedure ntt_init(var ntt: NTT; modulus: u32);
  206. var
  207.   mt: ^Montgomery;
  208.   lg: Integer;
  209.   i: Integer;
  210.   a, b, f: u32;
  211. begin
  212.   ntt.modulus := modulus;
  213.   montgomery_init(ntt.mt, modulus);
  214.  
  215.   mt := @ntt.mt;
  216.   ntt.pr_root := ntt_find_pr_root(modulus, mt^);
  217.  
  218.   lg := count_trailing_zeros(modulus - 1);
  219.   ntt.lg_max := lg;
  220.   GetMem(ntt.wd, lg * SizeOf(u32));
  221.   GetMem(ntt.wrd, lg * SizeOf(u32));
  222.   FillChar(ntt.wd^, lg * SizeOf(u32), 0);
  223.   FillChar(ntt.wrd^, lg * SizeOf(u32), 0);
  224.  
  225.   for i := 0 to lg - 2 do
  226.   begin
  227.     a := montgomery_power(mt^, ntt.pr_root, (modulus - 1) shr (i + 2), False, True);
  228.     b := montgomery_power(mt^, ntt.pr_root, ((modulus - 1) shr (i + 1)) * ((1 shl i) - 1), False, True);
  229.     f := montgomery_mul_strict(mt^, a, montgomery_power(mt^, b, modulus - 2, True, True));
  230.     ntt.wd[i] := f;
  231.     ntt.wrd[i] := montgomery_power(mt^, f, modulus - 2, True, True);
  232.   end;
  233. end;
  234.  
  235. procedure ntt_free(var ntt: NTT);
  236. begin
  237.   FreeMem(ntt.wd);
  238.   FreeMem(ntt.wrd);
  239.   ntt.wd := nil;
  240.   ntt.wrd := nil;
  241. end;
  242.  
  243. { Butterfly operation - non-transposed, trivial }
  244. procedure butterfly_x2_forward_trivial(var a, b: u32; w: u32; const mt: Montgomery); inline;
  245. var
  246.   a1, b1: u32;
  247. begin
  248.   a1 := montgomery_shrink2(mt, a);
  249.   b1 := montgomery_shrink2(mt, b);
  250.   a := a1 + b1;
  251.   b := a1 + mt.mod2 - b1;
  252. end;
  253.  
  254. { Butterfly operation - non-transposed, non-trivial }
  255. procedure butterfly_x2_forward(var a, b: u32; w: u32; const mt: Montgomery); inline;
  256. var
  257.   a1, b1: u32;
  258. begin
  259.   a1 := montgomery_shrink2(mt, a);
  260.   b1 := montgomery_mul_nonstrict(mt, b, w);
  261.   a := a1 + b1;
  262.   b := a1 + mt.mod2 - b1;
  263. end;
  264.  
  265. { Butterfly operation - transposed, trivial }
  266. procedure butterfly_x2_inverse_trivial(var a, b: u32; w: u32; const mt: Montgomery); inline;
  267. var
  268.   a2, b2: u32;
  269. begin
  270.   a2 := montgomery_shrink2(mt, a + b);
  271.   b2 := montgomery_shrink2(mt, a + mt.mod2 - b);
  272.   a := a2;
  273.   b := b2;
  274. end;
  275.  
  276. { Butterfly operation - transposed, non-trivial }
  277. procedure butterfly_x2_inverse(var a, b: u32; w: u32; const mt: Montgomery); inline;
  278. var
  279.   a2, b2: u32;
  280. begin
  281.   a2 := montgomery_shrink2(mt, a + b);
  282.   b2 := montgomery_mul_nonstrict(mt, a + mt.mod2 - b, w);
  283.   a := a2;
  284.   b := b2;
  285. end;
  286.  
  287. { Transform auxiliary function - forward, trivial }
  288. procedure transform_aux_forward_trivial(const ntt: NTT; k, i: Integer; data: pu32; var wi: u32);
  289. var
  290.   mt: ^Montgomery;
  291.   j: Integer;
  292. begin
  293.   mt := @ntt.mt;
  294.   for j := 0 to (1 shl k) - 1 do
  295.   begin
  296.     butterfly_x2_forward_trivial(data[i + j], data[i + (1 shl k) + j], wi, mt^);
  297.   end;
  298.   wi := montgomery_mul_strict(mt^, wi, ntt.wd[count_trailing_zeros(not i shr (k + 1))]);
  299. end;
  300.  
  301. { Transform auxiliary function - forward, non-trivial }
  302. procedure transform_aux_forward(const ntt: NTT; k, i: Integer; data: pu32; var wi: u32);
  303. var
  304.   mt: ^Montgomery;
  305.   j: Integer;
  306. begin
  307.   mt := @ntt.mt;
  308.   for j := 0 to (1 shl k) - 1 do
  309.   begin
  310.     butterfly_x2_forward(data[i + j], data[i + (1 shl k) + j], wi, mt^);
  311.   end;
  312.   wi := montgomery_mul_strict(mt^, wi, ntt.wd[count_trailing_zeros(not i shr (k + 1))]);
  313. end;
  314.  
  315. { Transform auxiliary function - inverse, trivial }
  316. procedure transform_aux_inverse_trivial(const ntt: NTT; k, i: Integer; data: pu32; var wi: u32);
  317. var
  318.   mt: ^Montgomery;
  319.   j: Integer;
  320. begin
  321.   mt := @ntt.mt;
  322.   for j := 0 to (1 shl k) - 1 do
  323.   begin
  324.     butterfly_x2_inverse_trivial(data[i + j], data[i + (1 shl k) + j], wi, mt^);
  325.   end;
  326.   wi := montgomery_mul_strict(mt^, wi, ntt.wrd[count_trailing_zeros(not i shr (k + 1))]);
  327. end;
  328.  
  329. { Transform auxiliary function - inverse, non-trivial }
  330. procedure transform_aux_inverse(const ntt: NTT; k, i: Integer; data: pu32; var wi: u32);
  331. var
  332.   mt: ^Montgomery;
  333.   j: Integer;
  334. begin
  335.   mt := @ntt.mt;
  336.   for j := 0 to (1 shl k) - 1 do
  337.   begin
  338.     butterfly_x2_inverse(data[i + j], data[i + (1 shl k) + j], wi, mt^);
  339.   end;
  340.   wi := montgomery_mul_strict(mt^, wi, ntt.wrd[count_trailing_zeros(not i shr (k + 1))]);
  341. end;
  342.  
  343. { input in [0, 4 * mod) }
  344. { output in [0, 4 * mod) }
  345. procedure ntt_transform_forward(const ntt: NTT; lg: Integer; data: pu32);
  346. var
  347.   mt: ^Montgomery;
  348.   k, i: Integer;
  349.   wi: u32;
  350. begin
  351.   mt := @ntt.mt;
  352.   for k := lg - 1 downto 0 do
  353.   begin
  354.     wi := mt^.r;
  355.     transform_aux_forward_trivial(ntt, k, 0, data, wi);
  356.     i := 1 shl (k + 1);
  357.     while i < (1 shl lg) do
  358.     begin
  359.       transform_aux_forward(ntt, k, i, data, wi);
  360.       Inc(i, 1 shl (k + 1));
  361.     end;
  362.   end;
  363. end;
  364.  
  365. { input in [0, 2 * mod) }
  366. { output in [0, mod) }
  367. procedure ntt_transform_inverse(const ntt: NTT; lg: Integer; data: pu32);
  368. var
  369.   mt: ^Montgomery;
  370.   k, i: Integer;
  371.   wi, f: u32;
  372. begin
  373.   mt := @ntt.mt;
  374.   for k := 0 to lg - 1 do
  375.   begin
  376.     wi := mt^.r;
  377.     transform_aux_inverse_trivial(ntt, k, 0, data, wi);
  378.     i := 1 shl (k + 1);
  379.     while i < (1 shl lg) do
  380.     begin
  381.       transform_aux_inverse(ntt, k, i, data, wi);
  382.       Inc(i, 1 shl (k + 1));
  383.     end;
  384.   end;
  385.  
  386.   f := montgomery_power(mt^, (ntt.modulus + 1) shr 1, lg, False, True);
  387.   for i := 0 to (1 shl lg) - 1 do
  388.   begin
  389.     data[i] := montgomery_mul_strict(mt^, data[i], f);
  390.   end;
  391. end;
  392.  
  393. { input in [0, 2 * mod) }
  394. { output in [0, mod) }
  395. { with multiplication by scale factor }
  396. procedure ntt_transform_inverse_scaled(const ntt: NTT; lg: Integer; data: pu32; sc: u32);
  397. var
  398.   mt: ^Montgomery;
  399.   k, i: Integer;
  400.   wi, f: u32;
  401. begin
  402.   mt := @ntt.mt;
  403.   for k := 0 to lg - 1 do
  404.   begin
  405.     wi := mt^.r;
  406.     transform_aux_inverse_trivial(ntt, k, 0, data, wi);
  407.     i := 1 shl (k + 1);
  408.     while i < (1 shl lg) do
  409.     begin
  410.       transform_aux_inverse(ntt, k, i, data, wi);
  411.       Inc(i, 1 shl (k + 1));
  412.     end;
  413.   end;
  414.  
  415.   f := montgomery_power(mt^, (ntt.modulus + 1) shr 1, lg, False, True);
  416.   f := montgomery_mul_strict(mt^, f, montgomery_mul_nonstrict(mt^, mt^.r2, sc));
  417.   for i := 0 to (1 shl lg) - 1 do
  418.   begin
  419.     data[i] := montgomery_mul_strict(mt^, data[i], f);
  420.   end;
  421. end;
  422.  
  423. procedure ntt_convolve_cyclic(const ntt: NTT; lg: Integer; a, b: pu32);
  424. var
  425.   mt: ^Montgomery;
  426.   i: Integer;
  427. begin
  428.   ntt_transform_forward(ntt, lg, a);
  429.   ntt_transform_forward(ntt, lg, b);
  430.   mt := @ntt.mt;
  431.   for i := 0 to (1 shl lg) - 1 do
  432.   begin
  433.     a[i] := montgomery_mul_nonstrict(mt^, montgomery_shrink2(mt^, a[i]), montgomery_shrink2(mt^, b[i]));
  434.   end;
  435.   ntt_transform_inverse_scaled(ntt, lg, a, mt^.r);
  436. end;
  437.  
  438. { ====== Helper functions for main ====== }
  439.  
  440. { Simple random number generator state (replacing std::mt19937) }
  441. var
  442.   rnd_state: LongWord = 1;
  443.  
  444. procedure rnd_seed(seed: LongWord);
  445. begin
  446.   rnd_state := seed;
  447. end;
  448.  
  449. function rnd_next: u32;
  450. begin
  451.   { Simple LCG }
  452.   rnd_state := rnd_state * 1664525 + 1013904223;
  453.   Result := rnd_state;
  454. end;
  455.  
  456. { Helper to find maximum of two integers }
  457. function max_int(a, b: Integer): Integer; inline;
  458. begin
  459.   if a > b then
  460.     Result := a
  461.   else
  462.     Result := b;
  463. end;
  464.  
  465. { Helper to find minimum of two integers }
  466. function min_int(a, b: Integer): Integer; inline;
  467. begin
  468.   if a < b then
  469.     Result := a
  470.   else
  471.     Result := b;
  472. end;
  473.  
  474. { Shuffle function (Fisher-Yates shuffle) }
  475. procedure shuffle(arr: PInteger; start, finish: Integer);
  476. var
  477.   i, j: Integer;
  478.   temp: Integer;
  479. begin
  480.   for i := finish - 1 downto start + 1 do
  481.   begin
  482.     j := start + (rnd_next mod (i - start + 1));
  483.     temp := arr[i];
  484.     arr[i] := arr[j];
  485.     arr[j] := temp;
  486.   end;
  487. end;
  488.  
  489. { ====== main ====== }
  490.  
  491. var
  492.   prime_list: array[0..8] of u32 = (
  493.     3,
  494.     5,
  495.     17,
  496.     257,
  497.     65537,
  498.     1000000007,
  499.     1000000009,
  500.     1000001329,
  501.     998244353
  502.   );
  503.   prime_list_size: Integer;
  504.   p_idx: Integer;
  505.   modulus: u32;
  506.   ntt2: NTT;
  507.   lg: Integer;
  508.   beg: QWord;
  509.   num_iterations: Integer;
  510.   k: Integer;
  511.   a, b, a2, b2: pu32;
  512.   i, j: Integer;
  513.   all: PInteger;
  514.   all_size: Integer;
  515.   swap_idx: Integer;
  516.   temp: Integer;
  517.   test_size: Integer;
  518.   idx: Integer;
  519.   val: u32;
  520.   elapsed: Double;
  521.  
  522. begin
  523.   prime_list_size := Length(prime_list);
  524.  
  525.   for p_idx := 0 to prime_list_size - 1 do
  526.   begin
  527.     modulus := prime_list[p_idx];
  528.     WriteLn(StdErr, 'testing mod: ', modulus);
  529.  
  530.     rnd_seed(5489);  { Default seed for mt19937 }
  531.     ntt_init(ntt2, modulus);
  532.  
  533.     lg := 0;
  534.     while (modulus - 1) mod (1 shl lg) = 0 do
  535.     begin
  536.       beg := GetTickCount64;
  537.       num_iterations := max_int(1, 1000000 div (1 shl lg));
  538.      
  539.       for k := 0 to num_iterations - 1 do
  540.       begin
  541.         { Generate random vectors a and b }
  542.         GetMem(a, (1 shl lg) * SizeOf(u32));
  543.         GetMem(b, (1 shl lg) * SizeOf(u32));
  544.         for i := 0 to (1 shl lg) - 1 do
  545.         begin
  546.           a[i] := rnd_next mod modulus;
  547.           b[i] := rnd_next mod modulus;
  548.         end;
  549.  
  550.         { Create copies a2 and b2 }
  551.         GetMem(a2, (1 shl lg) * SizeOf(u32));
  552.         GetMem(b2, (1 shl lg) * SizeOf(u32));
  553.         Move(a^, a2^, (1 shl lg) * SizeOf(u32));
  554.         Move(b^, b2^, (1 shl lg) * SizeOf(u32));
  555.  
  556.         ntt_convolve_cyclic(ntt2, lg, a2, b2);
  557.  
  558.         { Create and initialize the 'all' array }
  559.         all_size := 1 shl lg;
  560.         GetMem(all, all_size * SizeOf(Integer));
  561.         for i := 0 to all_size - 1 do
  562.         begin
  563.           all[i] := i;
  564.         end;
  565.        
  566.         { Swap last element with element at index min(size-1, 1) }
  567.         swap_idx := min_int(all_size - 1, 1);
  568.         temp := all[all_size - 1];
  569.         all[all_size - 1] := all[swap_idx];
  570.         all[swap_idx] := temp;
  571.        
  572.         { Shuffle from index min(size, 2) to end }
  573.         shuffle(all, min_int(all_size, 2), all_size);
  574.        
  575.         { Resize to min(30, size) }
  576.         test_size := min_int(30, all_size);
  577.        
  578.         { Verify results }
  579.         for idx := 0 to test_size - 1 do
  580.         begin
  581.           i := all[idx];
  582.           val := 0;
  583.           for j := 0 to (1 shl lg) - 1 do
  584.           begin
  585.             val := val + u64(a[j]) * b[((1 shl lg) + i - j) mod (1 shl lg)] mod modulus;
  586.             val := val mod modulus;
  587.           end;
  588.           Assert(a2[i] = val);
  589.         end;
  590.  
  591.         FreeMem(a);
  592.         FreeMem(b);
  593.         FreeMem(a2);
  594.         FreeMem(b2);
  595.         FreeMem(all);
  596.       end;
  597.  
  598.       elapsed := (GetTickCount64 - beg) / 1000.0;
  599.       if lg < 10 then
  600.         Write(StdErr, lg, '  ')
  601.       else
  602.         Write(StdErr, lg, ' ');
  603.       WriteLn(StdErr, Format('%.5fs  OK', [elapsed]));
  604.      
  605.       Inc(lg);
  606.     end;
  607.     WriteLn(StdErr);
  608.    
  609.     ntt_free(ntt2);
  610.   end;
  611. end.
  612.  

 

TinyPortal © 2005-2018