program avxmatrixmult;
{$mode objfpc}
{$ASMMODE INTEL}
{$OPTIMIZATION LEVEL3}
uses
SysUtils,
Math,
Windows,
Classes;
const
MATRIX_SIZE = 1000;
MATRIX_ELEMENTS = MATRIX_SIZE * MATRIX_SIZE;
ALIGN_SIZE = 32;
EPSILON = 1e-6;
THREAD_COUNT = 6;
type
TAlignDoubleArray = array[0..0] of Double;
PAlignDoubleArray = ^TAlignDoubleArray;
TThreadData = record
a: PAlignDoubleArray;
b_transposed: PAlignDoubleArray;
c: PAlignDoubleArray;
start_row, end_row: Integer;
end;
PThreadData = ^TThreadData;
var
a, b, b_transposed, c_avx, c_naive: PAlignDoubleArray;
{------------------- Memory Management --------------------}
function AlignedArrayAlloc(size: Integer): PAlignDoubleArray;
var
originalPtr: Pointer;
alignedAddr: Pointer;
begin
GetMem(originalPtr, size * SizeOf(Double) + ALIGN_SIZE - 1 + SizeOf(Pointer));
alignedAddr := originalPtr;
Inc(PByte(alignedAddr), SizeOf(Pointer));
alignedAddr := Pointer((PtrUInt(alignedAddr) + ALIGN_SIZE - 1) and not (ALIGN_SIZE - 1));
Result := PAlignDoubleArray(alignedAddr);
PPointer(PByte(Result) - SizeOf(Pointer))^ := originalPtr;
end;
procedure AlignedArrayFree(arr: PAlignDoubleArray);
begin
if Assigned(arr) then
FreeMem(PPointer(PByte(arr) - SizeOf(Pointer))^);
end;
{------------------- AVX Support Check --------------------}
function IsAVXSupported: Boolean;
var
_ecx: LongInt;
begin
asm
mov eax, $1
cpuid
mov _ecx, ecx
end;
Result := (_ecx and (1 shl 28)) <> 0;
end;
{------------------- Matrix Transposition -----------------}
procedure TransposeMatrix(const src, dst: PAlignDoubleArray);
var
i, j: Integer;
begin
for i := 0 to MATRIX_SIZE - 1 do
for j := 0 to MATRIX_SIZE - 1 do
dst^[j * MATRIX_SIZE + i] := src^[i * MATRIX_SIZE + j];
end;
{------------------- Thread Worker ------------------------}
function AVXThreadWorker(Data: Pointer): PtrInt;
var
ThreadData: PThreadData;
begin
asm
push rbx
push r12
push r13
push r14
push r15
push rdi
push rsi
mov r15, Data
mov r14, [r15 + TThreadData.a]
mov r13, [r15 + TThreadData.b_transposed]
mov r12, [r15 + TThreadData.c]
mov edi, [r15 + TThreadData.start_row]
mov esi, [r15 + TThreadData.end_row]
@outer_loop:
cmp edi, esi
jge @exit
xor ebx, ebx
@middle_loop:
// Calculate base addresses for A[i] and B_transposed[j]
mov eax, edi
imul eax, MATRIX_SIZE
shl eax, 3 // eax = i * MATRIX_SIZE * 8
mov r10, rax
mov eax, ebx
imul eax, MATRIX_SIZE
shl eax, 3 // eax = j * MATRIX_SIZE * 8
mov r11, rax
vxorpd ymm0, ymm0, ymm0
vxorpd ymm1, ymm1, ymm1
vxorpd ymm2, ymm2, ymm2
vxorpd ymm3, ymm3, ymm3
xor ecx, ecx // k = 0
@inner_loop:
// Calculate offset for A[i][k] and B_transposed[j][k]
mov rax, rcx
shl rax, 3 // rax = k * 8
add rax, r10 // rax = i*MATRIX_SIZE*8 + k*8
mov rdx, rcx
shl rdx, 3 // rdx = k * 8
add rdx, r11 // rdx = j*MATRIX_SIZE*8 + k*8
vmovapd ymm4, [r14 + rax] // Load A[i][k]
vmovapd ymm5, [r13 + rdx] // Load B_transposed[j][k]
vfmadd231pd ymm0, ymm4, ymm5 // Accumulate
vmovapd ymm6, [r14 + rax + 32] // Load A[i][k+4]
vmovapd ymm7, [r13 + rdx + 32] // Load B_transposed[j][k+4]
vfmadd231pd ymm1, ymm6, ymm7
add ecx, 8 // Process 8 elements per iteration
cmp ecx, MATRIX_SIZE
jl @inner_loop
// Horizontal sum
vaddpd ymm0, ymm0, ymm1
vaddpd ymm0, ymm0, ymm2
vaddpd ymm0, ymm0, ymm3
vextractf128 xmm1, ymm0, 1
vaddpd xmm0, xmm0, xmm1
vhaddpd xmm0, xmm0, xmm0
// Store result to C[i][j]
mov eax, edi
imul eax, MATRIX_SIZE
add eax, ebx
vmovsd [r12 + rax * 8], xmm0
inc ebx
cmp ebx, MATRIX_SIZE
jl @middle_loop
inc edi
jmp @outer_loop
@exit:
vzeroupper
pop rsi
pop rdi
pop r15
pop r14
pop r13
pop r12
pop rbx
xor eax, eax
end;
end;
{------------------- Parallel AVX Multiplication ----------}
procedure AVXMatrixMultParallel(a, b_transposed, c: PAlignDoubleArray);
var
Threads: array of TThreadID = nil;
Data: array of TThreadData = nil;
i, rows_per_thread: Integer;
begin
SetLength(Threads, THREAD_COUNT);
SetLength(Data, THREAD_COUNT);
rows_per_thread := MATRIX_SIZE div THREAD_COUNT;
for i := 0 to THREAD_COUNT - 1 do
begin
Data[i].a := a;
Data[i].b_transposed := b_transposed;
Data[i].c := c;
Data[i].start_row := i * rows_per_thread;
Data[i].end_row := (i + 1) * rows_per_thread;
if i = THREAD_COUNT - 1 then
Data[i].end_row := MATRIX_SIZE;
Threads[i] := BeginThread(@AVXThreadWorker, @Data[i]);
end;
for i := 0 to THREAD_COUNT - 1 do
WaitForThreadTerminate(Threads[i], DWORD(INFINITE));
end;
{------------------- Naive Matrix Multiplication ---------}
procedure NaiveMatrixMult(a, b, c: PAlignDoubleArray);
var
i, j, k: Integer;
sum: Double;
begin
for i := 0 to MATRIX_SIZE - 1 do
for j := 0 to MATRIX_SIZE - 1 do
begin
sum := 0;
for k := 0 to MATRIX_SIZE - 1 do
sum := sum + a^[i * MATRIX_SIZE + k] * b^[k * MATRIX_SIZE + j];
c^[i * MATRIX_SIZE + j] := sum;
end;
end;
{------------------- Results Validation -------------------}
procedure CheckResults(avx, naive: PAlignDoubleArray);
var
i, j: Integer;
begin
for i := 0 to MATRIX_SIZE - 1 do
for j := 0 to MATRIX_SIZE - 1 do
if Abs(avx^[i * MATRIX_SIZE + j] - naive^[i * MATRIX_SIZE + j]) > EPSILON then
begin
Writeln(Format('Error [%d,%d]: AVX=%.5f Naive=%.5f',
[i, j, avx^[i * MATRIX_SIZE + j], naive^[i * MATRIX_SIZE + j]]));
Exit;
end;
Writeln('Validation passed!');
end;
{------------------- Main Program -------------------------}
var
i, j: Integer;
StartTime, EndTime: Int64;
begin
if not IsAVXSupported then
begin
Writeln('AVX not supported!');
Halt(1);
end;
a := AlignedArrayAlloc(MATRIX_ELEMENTS);
b := AlignedArrayAlloc(MATRIX_ELEMENTS);
b_transposed := AlignedArrayAlloc(MATRIX_ELEMENTS);
c_avx := AlignedArrayAlloc(MATRIX_ELEMENTS);
c_naive := AlignedArrayAlloc(MATRIX_ELEMENTS);
try
Randomize;
for i := 0 to MATRIX_SIZE - 1 do
for j := 0 to MATRIX_SIZE - 1 do
begin
a^[i * MATRIX_SIZE + j] := Random * 100;
b^[i * MATRIX_SIZE + j] := Random * 100;
end;
TransposeMatrix(b, b_transposed);
StartTime := GetTickCount64;
NaiveMatrixMult(a, b, c_naive);
EndTime := GetTickCount64;
Writeln('FPC Naive Execution Time, ms: ', (EndTime - StartTime));
StartTime := GetTickCount64;
AVXMatrixMultParallel(a, b_transposed, c_avx);
EndTime := GetTickCount64;
Writeln('FPC AVX Execution Time, ms: ', (EndTime - StartTime));
CheckResults(c_avx, c_naive);
finally
AlignedArrayFree(a);
AlignedArrayFree(b);
AlignedArrayFree(b_transposed);
AlignedArrayFree(c_avx);
AlignedArrayFree(c_naive);
end;
Readln;
end.