Intel AVX:256-Bit-Version des Punktprodukts für Gleitkommavariablen mit doppelter Genauigkeit

Intel AVX:256-Bit-Version des Punktprodukts für Gleitkommavariablen mit doppelter Genauigkeit


Die Intel Advanced Vector Extensions (AVX) bieten kein Punktprodukt in der 256-Bit-Version (YMM-Register) für Gleitkommavariablen mit doppelter Genauigkeit . Das "Warum?" Frage wurden in einem anderen Forum (hier) und auf Stack Overflow (hier) sehr kurz behandelt. Aber ich stehe vor der Frage, wie ich diese fehlende Anweisung auf effiziente Weise durch andere AVX-Anweisungen ersetzen kann?


Das Punktprodukt in der 256-Bit-Version existiert für Gleitkommavariablen mit einfacher Genauigkeit (Referenz hier):


 __m256 _mm256_dp_ps(__m256 m1, __m256 m2, const int mask);

Die Idee ist, ein effizientes Äquivalent für diese fehlende Anweisung zu finden:


 __m256d _mm256_dp_pd(__m256d m1, __m256d m2, const int mask);

Genauer gesagt, der Code, den ich von __m128 transformieren möchte (vier Floats) bis __m256d (4 Doppelte) verwenden Sie die folgenden Anweisungen:


   __m128 val0 = ...; // Four float values
__m128 val1 = ...; //
__m128 val2 = ...; //
__m128 val3 = ...; //
__m128 val4 = ...; //
__m128 res = _mm_or_ps( _mm_dp_ps(val1, val0, 0xF1),
_mm_or_ps( _mm_dp_ps(val2, val0, 0xF2),
_mm_or_ps( _mm_dp_ps(val3, val0, 0xF4),
_mm_dp_ps(val4, val0, 0xF8) )));

Das Ergebnis dieses Codes ist ein _m128 Vektor aus vier Gleitkommazahlen, die die Ergebnisse der Skalarprodukte zwischen val1 enthalten und val0 , val2 und val0 , val3 und val0 , val4 und val0 .


Vielleicht kann das Hinweise für die Vorschläge geben?


Antworten:


Ich würde eine 4*doppelte Multiplikation verwenden, dann eine hadd (was leider nur 2*2 Floats in der oberen und unteren Hälfte hinzufügt), extrahiere die obere Hälfte (ein Shuffle sollte gleich funktionieren, vielleicht schneller) und füge es der unteren Hälfte hinzu.


Das Ergebnis befindet sich in den niedrigen 64 Bit von dotproduct .


__m256d xy = _mm256_mul_pd( x, y );
__m256d temp = _mm256_hadd_pd( xy, xy );
__m128d hi128 = _mm256_extractf128_pd( temp, 1 );
__m128d dotproduct = _mm_add_pd( (__m128d)temp, hi128 );

Bearbeiten:

Nach einer Idee von Norbert P. habe ich diese Version erweitert, um 4 Punktprodukte auf einmal zu machen.


__m256d xy0 = _mm256_mul_pd( x[0], y[0] );
__m256d xy1 = _mm256_mul_pd( x[1], y[1] );
__m256d xy2 = _mm256_mul_pd( x[2], y[2] );
__m256d xy3 = _mm256_mul_pd( x[3], y[3] );
// low to high: xy00+xy01 xy10+xy11 xy02+xy03 xy12+xy13
__m256d temp01 = _mm256_hadd_pd( xy0, xy1 );
// low to high: xy20+xy21 xy30+xy31 xy22+xy23 xy32+xy33
__m256d temp23 = _mm256_hadd_pd( xy2, xy3 );
// low to high: xy02+xy03 xy12+xy13 xy20+xy21 xy30+xy31
__m256d swapped = _mm256_permute2f128_pd( temp01, temp23, 0x21 );
// low to high: xy00+xy01 xy10+xy11 xy22+xy23 xy32+xy33
__m256d blended = _mm256_blend_pd(temp01, temp23, 0b1100);
__m256d dotproduct = _mm256_add_pd( swapped, blended );

Einige Code-Antworten


 __m256 _mm256_dp_ps(__m256 m1, __m256 m2, const int mask);
 __m256d _mm256_dp_pd(__m256d m1, __m256d m2, const int mask);
   __m128 val0 = ...;
// Four float values __m128 val1 = ...;
// __m128 val2 = ...;
// __m128 val3 = ...;
// __m128 val4 = ...;
//
__m128 res = _mm_or_ps( _mm_dp_ps(val1, val0, 0xF1),
_mm_or_ps( _mm_dp_ps(val2, val0, 0xF2),
_mm_or_ps( _mm_dp_ps(val3, val0, 0xF4), _mm_dp_ps(val4, val0, 0xF8) )));
__m256d xy = _mm256_mul_pd( x, y );
__m256d temp = _mm256_hadd_pd( xy, xy );
__m128d hi128 = _mm256_extractf128_pd( temp, 1 );
__m128d dotproduct = _mm_add_pd( (__m128d)temp, hi128 );
__m256d xy0 = _mm256_mul_pd( x[0], y[0] );
__m256d xy1 = _mm256_mul_pd( x[1], y[1] );
__m256d xy2 = _mm256_mul_pd( x[2], y[2] );
__m256d xy3 = _mm256_mul_pd( x[3], y[3] );
// low to high: xy00+xy01 xy10+xy11 xy02+xy03 xy12+xy13 __m256d temp01 = _mm256_hadd_pd( xy0, xy1 );
// low to high: xy20+xy21 xy30+xy31 xy22+xy23 xy32+xy33 __m256d temp23 = _mm256_hadd_pd( xy2, xy3 );
// low to high: xy02+xy03 xy12+xy13 xy20+xy21 xy30+xy31 __m256d swapped = _mm256_permute2f128_pd( temp01, temp23, 0x21 );
// low to high: xy00+xy01 xy10+xy11 xy22+xy23 xy32+xy33 __m256d blended = _mm256_blend_pd(temp01, temp23, 0b1100);
__m256d dotproduct = _mm256_add_pd( swapped, blended );
__m256d xy = _mm256_mul_pd( x, y );
__m256d zw = _mm256_mul_pd( z, w );
__m256d temp = _mm256_hadd_pd( xy, zw );
__m128d hi128 = _mm256_extractf128_pd( temp, 1 );
__m128d dotproduct = _mm_add_pd( (__m128d)temp, hi128 );
// both elements = dot(x,y) __m128d dot1(__m256d x, __m256d y) {
__m256d xy = _mm256_mul_pd(x, y);
__m128d xylow = _mm256_castps256_pd128(xy);
// (__m128d)cast isn't portable
__m128d xyhigh = _mm256_extractf128_pd(xy, 1);
__m128d sum1 = _mm_add_pd(xylow, xyhigh);
__m128d swapped = _mm_shuffle_pd(sum1, sum1, 0b01);
// or unpackhi
__m128d dotproduct = _mm_add_pd(sum1, swapped);
return dotproduct;
}
/*  Norbert's version, for an Intel CPU:
__m256d temp = _mm256_hadd_pd( xy, zw );
// 2 shuffle + 1 add
__m128d hi128 = _mm256_extractf128_pd( temp, 1 );
// 1 shuffle (lane crossing, higher latency)
__m128d dotproduct = _mm_add_pd( (__m128d)temp, hi128 );
// 1 add
// 3 shuffle + 2 add */