You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
236 lines
8.2 KiB
236 lines
8.2 KiB
// MIT License |
|
// |
|
// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. |
|
// |
|
// Permission is hereby granted, free of charge, to any person obtaining a copy |
|
// of this software and associated documentation files (the "Software"), to deal |
|
// in the Software without restriction, including without limitation the rights |
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
// copies of the Software, and to permit persons to whom the Software is |
|
// furnished to do so, subject to the following conditions: |
|
// |
|
// The above copyright notice and this permission notice shall be included in all |
|
// copies or substantial portions of the Software. |
|
// |
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
// SOFTWARE. |
|
|
|
#include "cmdparser.hpp" |
|
#include "example_utils.hpp" |
|
#include "rocblas_utils.hpp" |
|
|
|
#include <rocblas/rocblas.h> |
|
|
|
#include <hip/hip_runtime.h> |
|
|
|
#include <algorithm> |
|
#include <cstdlib> |
|
#include <iostream> |
|
#include <limits> |
|
#include <numeric> |
|
#include <vector> |
|
|
|
int main(const int argc, const char** argv) |
|
{ |
|
// Parse user inputs. |
|
cli::Parser parser(argc, argv); |
|
parser.set_optional<float>("a", "alpha", 1.f, "Alpha scalar"); |
|
parser.set_optional<float>("b", "beta", 1.f, "Beta scalar"); |
|
parser.set_optional<int>("c", "count", 3, "Batch count"); |
|
parser.set_optional<int>("m", "m", 5, "Number of rows of matrices A_i and C_i"); |
|
parser.set_optional<int>("n", "n", 5, "Number of columns of matrices B_i and C_i"); |
|
parser.set_optional<int>("k", "k", 5, "Number of columns of matrix A_i and rows of B_i"); |
|
parser.run_and_exit_if_error(); |
|
|
|
// Set sizes of matrices. |
|
const rocblas_int m = parser.get<int>("m"); |
|
const rocblas_int n = parser.get<int>("n"); |
|
const rocblas_int k = parser.get<int>("k"); |
|
|
|
// Set batch counter. |
|
const rocblas_int batch_count = parser.get<int>("c"); |
|
|
|
// Check input values validity. |
|
if(m <= 0) |
|
{ |
|
std::cout << "Value of 'm' should be greater than 0" << std::endl; |
|
return error_exit_code; |
|
} |
|
|
|
if(n <= 0) |
|
{ |
|
std::cout << "Value of 'n' should be greater than 0" << std::endl; |
|
return error_exit_code; |
|
} |
|
|
|
if(k <= 0) |
|
{ |
|
std::cout << "Value of 'k' should be greater than 0" << std::endl; |
|
return error_exit_code; |
|
} |
|
|
|
if(batch_count <= 0) |
|
{ |
|
std::cout << "Value of 'c' should be greater than 0" << std::endl; |
|
return error_exit_code; |
|
} |
|
|
|
// Set scalar values used for multiplication. |
|
const rocblas_float h_alpha = parser.get<float>("a"); |
|
const rocblas_float h_beta = parser.get<float>("b"); |
|
|
|
// Set GEMM operation as identity operation: $X' = X$. |
|
const rocblas_operation trans_a = rocblas_operation_none; |
|
const rocblas_operation trans_b = rocblas_operation_none; |
|
|
|
rocblas_int lda, ldb, ldc; |
|
int stride1_a, stride2_a, stride1_b, stride2_b; |
|
rocblas_stride stride_a, stride_b, stride_c; |
|
|
|
// Set up matrix dimension variables. |
|
if(trans_a == rocblas_operation_none) |
|
{ |
|
lda = m; |
|
stride_a = rocblas_stride(k) * lda; |
|
stride1_a = 1; |
|
stride2_a = lda; |
|
} |
|
else |
|
{ |
|
lda = k; |
|
stride_a = rocblas_stride(m) * lda; |
|
stride1_a = lda; |
|
stride2_a = 1; |
|
} |
|
if(trans_b == rocblas_operation_none) |
|
{ |
|
ldb = k; |
|
stride_b = rocblas_stride(n) * ldb; |
|
stride1_b = 1; |
|
stride2_b = ldb; |
|
} |
|
else |
|
{ |
|
ldb = n; |
|
stride_b = rocblas_stride(k) * ldb; |
|
stride1_b = ldb; |
|
stride2_b = 1; |
|
} |
|
ldc = m; |
|
stride_c = rocblas_stride(n) * ldc; |
|
|
|
// Get maximum of batch count. |
|
rocblas_int count_max = std::max(batch_count, 1); |
|
|
|
// Get vector sizes. |
|
size_t size_a = size_t(stride_a) * count_max; |
|
size_t size_b = size_t(stride_b) * count_max; |
|
size_t size_c = size_t(stride_c) * count_max; |
|
|
|
// Allocate host data. |
|
std::vector<float> h_a(size_a, 1); |
|
std::vector<float> h_b(size_b); |
|
std::vector<float> h_c(size_c, 1); |
|
std::vector<float> h_gold(size_c); |
|
|
|
// Set B_i matrix. |
|
for(rocblas_int i = 0; i < batch_count; ++i) |
|
{ |
|
generate_identity_matrix(h_b.data() + i * stride_b, k, n, ldb); |
|
} |
|
|
|
// Initialize gold standard matrix. |
|
h_gold = h_c; |
|
|
|
// Calculate gold standard on CPU. |
|
for(rocblas_int i = 0; i < batch_count; ++i) |
|
{ |
|
multiply_matrices<float>(h_alpha, |
|
h_beta, |
|
m, |
|
n, |
|
k, |
|
h_a.data() + i * stride_a, |
|
stride1_a, |
|
stride2_a, |
|
h_b.data() + i * stride_b, |
|
stride1_b, |
|
stride2_b, |
|
h_gold.data() + i * stride_c, |
|
ldc); |
|
} |
|
|
|
// Allocate device memory. |
|
float* d_a{}; |
|
float* d_b{}; |
|
float* d_c{}; |
|
HIP_CHECK(hipMalloc(&d_a, size_a * sizeof(float))); |
|
HIP_CHECK(hipMalloc(&d_b, size_b * sizeof(float))); |
|
HIP_CHECK(hipMalloc(&d_c, size_c * sizeof(float))); |
|
|
|
// Copy data from CPU to device. |
|
HIP_CHECK(hipMemcpy(d_a, |
|
static_cast<void*>(h_a.data()), |
|
sizeof(float) * size_a, |
|
hipMemcpyHostToDevice)); |
|
HIP_CHECK(hipMemcpy(d_b, |
|
static_cast<void*>(h_b.data()), |
|
sizeof(float) * size_b, |
|
hipMemcpyHostToDevice)); |
|
HIP_CHECK(hipMemcpy(d_c, |
|
static_cast<void*>(h_c.data()), |
|
sizeof(float) * size_c, |
|
hipMemcpyHostToDevice)); |
|
|
|
// Use rocBLAS API. |
|
rocblas_handle handle; |
|
ROCBLAS_CHECK(rocblas_create_handle(&handle)); |
|
|
|
// Enable passing alpha and beta parameters from pointer to host memory. |
|
ROCBLAS_CHECK(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); |
|
|
|
// Asynchronous matrix multiplication calculation on device. |
|
ROCBLAS_CHECK(rocblas_sgemm_strided_batched(handle, |
|
trans_a, |
|
trans_b, |
|
m, |
|
n, |
|
k, |
|
&h_alpha, |
|
d_a, |
|
lda, |
|
stride_a, |
|
d_b, |
|
ldb, |
|
stride_b, |
|
&h_beta, |
|
d_c, |
|
ldc, |
|
stride_c, |
|
batch_count)); |
|
|
|
// Fetch device memory results, automatically blocks until results are ready. |
|
HIP_CHECK(hipMemcpy(h_c.data(), d_c, sizeof(float) * size_c, hipMemcpyDeviceToHost)); |
|
|
|
// Destroy the rocBLAS handle. |
|
ROCBLAS_CHECK(rocblas_destroy_handle(handle)); |
|
|
|
// Free device memory as it is no longer required. |
|
HIP_CHECK(hipFree(d_a)); |
|
HIP_CHECK(hipFree(d_b)); |
|
HIP_CHECK(hipFree(d_c)); |
|
|
|
// Check the relative error between output generated by the rocBLAS API and the CPU. |
|
const float eps = 10.f * std::numeric_limits<float>::epsilon(); |
|
unsigned int errors = 0; |
|
for(rocblas_int i = 0; i < ldc; ++i) |
|
{ |
|
errors += std::fabs(h_c[i] - h_gold[i]) > eps; |
|
} |
|
return report_validation_result(errors); |
|
}
|
|
|