SIMDを使用して列ごとの最大値を最適化する

Aug 15 2020

私はコードにかなりの時間を費やしたこの関数を持っており、可能であれば、vectorization-SIMD-compiler組み込み関数によってそれを最適化したいと思います。

基本的に、列の行列の最大値と最大値の位置を見つけて、それらを格納します。

  • val_ptr:入力行列:column-major(Fortranスタイル)n_rows-by-n_cols(通常はn_rows >> n_cols)
  • opt_pos_ptr:最大値の位置を格納する長さn_rowsのintベクトル。ゼロで埋められたエントリ。
  • max_ptr:最大値を格納する長さn_rowsのfloatベクトル。val_ptrの最初の列のコピーで埋められたエントリ
  • 関数は並列ループで呼び出されます
  • メモリ領域は重複しないことが保証されています
  • max_ptrを入力する必要はありません。現在は、簿記とメモリ割り当ての回避に使用されています。
  • 私はWindows10でMSVC、C ++ 17を使用しています。最新のIntelCPUを実行するためのものです。

テンプレートタイプがfloatまたはdoubleであることが意図されているコード:

template <typename eT>
find_max(const int n_cols, 
         const int n_rows, 
         const eT* val_ptr,
         int* opt_pos_ptr,
         eT* max_ptr){
    for (int col = 1; col < n_cols; ++col)
    {
        //Getting the pointer to the beginning of the column
        const auto* value_col = val_ptr + col * n_rows;
        //Looping over the rows
        for (int row = 0; row < n_rows; ++row)
        {
            //If the value is larger than the current maximum, we replace and we store its positions
            if (value_col[row] > max_ptr[row])
            {
                max_ptr[row] = value_col[row];
                opt_pos_ptr[row] = col;
            }
        }
    }
}

私がこれまでに試したこと:

  • 内側のループでOpenMP並列forを使用しようとしましたが、現在の使用法よりも少し大きい非常に大きな行でのみ何かが発生します。
  • 内側のループにifがあると、#pragma omp simdが機能しなくなり、それなしでは書き直すことができませんでした。

回答

3 AndreySemashev Aug 15 2020 at 21:55

投稿したコードサンプルに基づくと、垂直方向の最大値を計算したいようです。つまり、この場合、「列」は水平方向です。C / C ++では、要素の水平シーケンス(つまり、2つの隣接する要素がメモリ内の1つの要素の距離を持つ)は通常、行および垂直(2つの隣接する要素がメモリ内の行サイズの距離を持つ)-列と呼ばれます。以下の私の答えでは、行が水平で列が垂直である従来の用語を使用します。

また、簡潔にするために、マトリックス要素の1つの可能なタイプに焦点を当てます- float。基本的な考え方は、についても同じですdoubleが、主な違いは、ベクトルあたりの要素数と_ps/_pd組み込み関数の選択です。double最後にのバージョンを提供します。


_mm_max_ps/を使用して、複数の列の垂直方向の最大値を並列に計算できるという考え方です_mm_max_pd。見つかった最大値の位置も記録するために、前の最大値を現在の要素と比較できます。比較の結果はマスクであり、要素はすべて1であり、最大値が更新されます。そのマスクを使用して、更新する必要のある位置を選択することもできます。

以下のアルゴリズムは、列に等しい最大要素が複数ある場合、どの最大要素の位置が記録されるかは重要ではないと想定していることに注意する必要があります。また、マトリックスには比較に影響を与えるNaN値が含まれていないと思います。これについては後で詳しく説明します。

void find_max(const int n_cols, 
         const int n_rows, 
         const float* val_ptr,
         int* opt_pos_ptr,
         float* max_ptr){
    const __m128i mm_one = _mm_set1_epi32(1);

    // Pre-compute the number of rows that can be processed in full vector width.
    // In a 128-bit vector there are 4 floats or 2 doubles
    int tail_size = n_rows & 3;
    int n_rows_aligned = n_rows - tail_size;
    int row = 0;
    for (; row < n_rows_aligned; row += 4)
    {
        const auto* col_ptr = val_ptr + row;
        __m128 mm_max = _mm_loadu_ps(col_ptr);
        __m128i mm_max_pos = _mm_setzero_si128();
        __m128i mm_pos = mm_one;
        col_ptr += n_rows;
        for (int col = 1; col < n_cols; ++col)
        {
            __m128 mm_value = _mm_loadu_ps(col_ptr);

            // See if this value is greater than the old maximum
            __m128 mm_mask = _mm_cmplt_ps(mm_max, mm_value);
            // If it is, save its position
            mm_max_pos = _mm_blendv_epi8(mm_max_pos, mm_pos, _mm_castps_si128(mm_mask));

            // Compute the maximum
            mm_max = _mm_max_ps(mm_value, mm_max);

            mm_pos = _mm_add_epi32(mm_pos, mm_one);
            col_ptr += n_rows;
        }

        // Store the results
        _mm_storeu_ps(max_ptr + row, mm_max);
        _mm_storeu_si128(reinterpret_cast< __m128i* >(opt_pos_ptr + row), mm_max_pos);
    }

    // Process tail serially
    for (; row < n_rows; ++row)
    {
        const auto* col_ptr = val_ptr + row;
        auto max = *col_ptr;
        int max_pos = 0;
        col_ptr += n_rows;
        for (int col = 1; col < n_cols; ++col)
        {
            auto value = *col_ptr;
            if (value > max)
            {
                max = value;
                max_pos = col;
            }

            col_ptr += n_rows;
        }

        max_ptr[row] = max;
        opt_pos_ptr[row] = max_pos;
    }
}

上記のコードでは、組み込み関数をブレンドするため、SSE4.1が必要です。これらを_mm_and_si128/ _ps_mm_andnot_si128/ _ps_mm_or_si128/の組み合わせに置き換えることができます。_psその場合、要件はSSE2に引き下げられます。参照してください。インテルの組み込み関数ガイドを、彼らが必要とする命令セット拡張を含む、特定の組み込み関数の詳細については、のために。


NaN値に関する注意。行列にNaNを含めることができる場合、_mm_cmplt_psテストは常にfalseを返します。については_mm_max_ps、一般的に何が返されるかはわかりません。maxpsオペランドのどちらかがNaNである場合は、その命令のオペランドを配置することにより、あなたはどちらかの動作を実現することができます戻って本来の翻訳物は、その二(ソース)というオペランド命令。ただし、_mm_max_ps組み込み関数のどの引数が命令のどのオペランドを表すかは文書化されておらず、コンパイラがさまざまな場合にさまざまな関連付けを使用する可能性さえあります。詳細については、この回答を参照してください。

正しい動作を保証するために。インラインアセンブラを使用して、maxpsオペランドの正しい順序を強制することができます。残念ながら、これは、使用していると言ったx86-64ターゲットのMSVCのオプションではないため、代わり_mm_cmplt_psに、次のような2番目のブレンドに結果を再利用できます。

// Compute the maximum
mm_max = _mm_blendv_ps(mm_max, mm_value, mm_mask);

これにより、結果の最大値のNaNが抑制されます。代わりにNaNを保持したい場合は、2番目の比較を使用してNaNを検出できます。

// Detect NaNs
__m128 mm_nan_mask = _mm_cmpunord_ps(mm_value, mm_value);

// Compute the maximum
mm_max = _mm_blendv_ps(mm_max, mm_value, _mm_or_ps(mm_mask, mm_nan_mask));

より広いベクトル(__m256または__m512)を使用し、外側のループを小さな係数で展開して、内側のループのすべての反復で少なくともキャッシュラインに相当する行データがロードされるようにすると、上記のアルゴリズムのパフォーマンスをさらに向上させることができます。


の実装例を次に示しdoubleます。ここで注意すべき重要な点は、doubleベクトルごとに2つの要素しかないため、ベクトルごとに4つの位置があるため、外側のループを展開doubleして一度に2つのベクトルを処理し、2つのマスクを圧縮して32ビット位置をブレンドするための以前の最大値。

void find_max(const int n_cols, 
         const int n_rows, 
         const double* val_ptr,
         int* opt_pos_ptr,
         double* max_ptr){
    const __m128i mm_one = _mm_set1_epi32(1);

    // Pre-compute the number of rows that can be processed in full vector width.
    // In a 128-bit vector there are 2 doubles, but we want to process
    // two vectors at a time.
    int tail_size = n_rows & 3;
    int n_rows_aligned = n_rows - tail_size;
    int row = 0;
    for (; row < n_rows_aligned; row += 4)
    {
        const auto* col_ptr = val_ptr + row;
        __m128d mm_max1 = _mm_loadu_pd(col_ptr);
        __m128d mm_max2 = _mm_loadu_pd(col_ptr + 2);
        __m128i mm_max_pos = _mm_setzero_si128();
        __m128i mm_pos = mm_one;
        col_ptr += n_rows;
        for (int col = 1; col < n_cols; ++col)
        {
            __m128d mm_value1 = _mm_loadu_pd(col_ptr);
            __m128d mm_value2 = _mm_loadu_pd(col_ptr + 2);

            // See if this value is greater than the old maximum
            __m128d mm_mask1 = _mm_cmplt_pd(mm_max1, mm_value1);
            __m128d mm_mask2 = _mm_cmplt_pd(mm_max2, mm_value2);
            // Compress the 2 masks into one
            __m128i mm_mask = _mm_packs_epi32(
                _mm_castpd_si128(mm_mask1), _mm_castpd_si128(mm_mask2));
            // If it is, save its position
            mm_max_pos = _mm_blendv_epi8(mm_max_pos, mm_pos, mm_mask);

            // Compute the maximum
            mm_max1 = _mm_max_pd(mm_value1, mm_max1);
            mm_max2 = _mm_max_pd(mm_value2, mm_max2);

            mm_pos = _mm_add_epi32(mm_pos, mm_one);
            col_ptr += n_rows;
        }

        // Store the results
        _mm_storeu_pd(max_ptr + row, mm_max1);
        _mm_storeu_pd(max_ptr + row + 2, mm_max2);
        _mm_storeu_si128(reinterpret_cast< __m128i* >(opt_pos_ptr + row), mm_max_pos);
    }

    // Process 2 doubles at once
    if (tail_size >= 2)
    {
        const auto* col_ptr = val_ptr + row;
        __m128d mm_max1 = _mm_loadu_pd(col_ptr);
        __m128i mm_max_pos = _mm_setzero_si128();
        __m128i mm_pos = mm_one;
        col_ptr += n_rows;
        for (int col = 1; col < n_cols; ++col)
        {
            __m128d mm_value1 = _mm_loadu_pd(col_ptr);

            // See if this value is greater than the old maximum
            __m128d mm_mask1 = _mm_cmplt_pd(mm_max1, mm_value1);
            // Compress the mask. The upper half doesn't matter.
            __m128i mm_mask = _mm_packs_epi32(
                _mm_castpd_si128(mm_mask1), _mm_castpd_si128(mm_mask1));
            // If it is, save its position
            mm_max_pos = _mm_blendv_epi8(mm_max_pos, mm_pos, mm_mask);

            // Compute the maximum
            mm_max1 = _mm_max_pd(mm_value1, mm_max1);

            mm_pos = _mm_add_epi32(mm_pos, mm_one);
            col_ptr += n_rows;
        }

        // Store the results
        _mm_storeu_pd(max_ptr + row, mm_max1);
        // Only store the lower two positions
        _mm_storel_epi64(reinterpret_cast< __m128i* >(opt_pos_ptr + row), mm_max_pos);

        row += 2;
    }

    // Process tail serially
    for (; row < n_rows; ++row)
    {
        const auto* col_ptr = val_ptr + row;
        auto max = *col_ptr;
        int max_pos = 0;
        col_ptr += n_rows;
        for (int col = 1; col < n_cols; ++col)
        {
            auto value = *col_ptr;
            if (value > max)
            {
                max = value;
                max_pos = col;
            }

            col_ptr += n_rows;
        }

        max_ptr[row] = max;
        opt_pos_ptr[row] = max_pos;
    }
}