You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

217 lines
8.5 KiB

// MIT License
//
// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include "cmdparser.hpp"
#include "example_utils.hpp"
#include "hipblas_utils.hpp"
#include "hipsolver_utils.hpp"
#include <hipblas/hipblas.h>
#include <hipsolver/hipsolver.h>
#include <hip/hip_runtime.h>
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <vector>
int main(const int argc, char* argv[])
{
// Parse user inputs.
cli::Parser parser(argc, argv);
parser.set_optional<int>("m", "m", 3, "Number of rows of input matrix A");
parser.set_optional<int>("n", "n", 2, "Number of columns of input matrix A");
parser.run_and_exit_if_error();
// Get input matrix rows (m) and columns (n).
const int m = parser.get<int>("m");
if(m <= 0)
{
std::cout << "Value of 'm' should be greater than 0" << std::endl;
return error_exit_code;
}
const int n = parser.get<int>("n");
if(n <= 0)
{
std::cout << "Value of 'n' should be greater than 0" << std::endl;
return error_exit_code;
}
// Initialize leading dimensions of input matrix A and output singular vector matrices.
const int lda = m;
const int ldu = m;
const int ldv = n;
// Define input and output matrices' sizes.
const unsigned int size_A = lda * n;
const unsigned int size_U = ldu * m;
const unsigned int size_V_H = ldv * n;
const unsigned int size_S = std::min(m, n);
// Initialize input matrix with sequence 1, 2, 3, ... .
std::vector<double> A(size_A);
std::iota(A.begin(), A.end(), 1.0);
// We want to obtain the decomposition A = U * S * V_H. Initialize the right-hand matrices:
// - U is an m x m unitary matrix, whose columns are the "left singular vectors".
// - S is an m x n diagonal matrix, whose diagonal values are the "singular values". We store
// a vector with min(m,n) values instead of the whole matrix.
// - V_H is an n x n unitary matrix, whose rows are the "right singular vectors".
std::vector<double> U(size_U, 0);
std::vector<double> S(size_S, 0);
std::vector<double> V_H(size_V_H, 0);
// Convergence information for the BDSQR algorithm used internally, it specifies how many
// superdiagonals of the intermediate bidiagonal form did not converge to zero.
int* d_bdsqr_info{};
int bdsqr_info{};
// Allocate device memory for the matrices needed and copy input matrix A from host to device.
double* d_A{};
double* d_S{};
double* d_U{};
double* d_V_H{};
double* d_W{}; // W = S * V_H, for solution checking.
HIP_CHECK(hipMalloc(&d_A, sizeof(double) * size_A));
HIP_CHECK(hipMalloc(&d_S, sizeof(double) * size_S));
HIP_CHECK(hipMalloc(&d_U, sizeof(double) * size_U));
HIP_CHECK(hipMalloc(&d_V_H, sizeof(double) * size_V_H));
HIP_CHECK(hipMalloc(&d_W, sizeof(double) * size_A));
HIP_CHECK(hipMalloc(&d_bdsqr_info, sizeof(int)));
HIP_CHECK(hipMemcpy(d_A, A.data(), sizeof(double) * size_A, hipMemcpyHostToDevice));
// Define how left and right singular vectors are calculated and stored.
const signed char left_svect
= 'A'; // All m columns of U (left singular vectors) are calculated.
const signed char right_svect
= 'A'; // All n columns of V_H (right singular vectors) are calculated.
// Use the hipSOLVER API to create a handle.
hipsolverHandle_t hipsolver_handle;
HIPSOLVER_CHECK(hipsolverCreate(&hipsolver_handle));
// Working space variables.
int lwork = 0; /*Size of working space*/
double* d_work = nullptr; /*Working space*/
double* d_rwork = nullptr; /*Unconverged superdiagonal elements of an upper bidiagonal matrix*/
// Query working space.
HIPSOLVER_CHECK(
hipsolverDgesvd_bufferSize(hipsolver_handle, left_svect, right_svect, m, n, &lwork));
HIP_CHECK(hipMalloc(&d_work, lwork));
// Compute the singular values (vector S) and singular vectors (matrices U and V_H) of A.
HIPSOLVER_CHECK(hipsolverDgesvd(hipsolver_handle,
left_svect,
right_svect,
m,
n,
d_A,
lda,
d_S,
d_U,
ldu,
d_V_H,
ldv,
d_work,
lwork,
d_rwork,
d_bdsqr_info));
// Copy device output data to host.
HIP_CHECK(hipMemcpy(U.data(), d_U, sizeof(double) * size_U, hipMemcpyDeviceToHost));
HIP_CHECK(hipMemcpy(S.data(), d_S, sizeof(double) * size_S, hipMemcpyDeviceToHost));
HIP_CHECK(hipMemcpy(V_H.data(), d_V_H, sizeof(double) * size_V_H, hipMemcpyDeviceToHost));
HIP_CHECK(hipMemcpy(&bdsqr_info, d_bdsqr_info, sizeof(int), hipMemcpyDeviceToHost));
// Print trace message for BDSQR.
if(bdsqr_info == 0)
{
std::cout << "Internal BDSQR converges." << std::endl;
}
else if(bdsqr_info > 0)
{
std::cout << "Internal BDSQR does not converge (" << bdsqr_info
<< "elements did not converge to 0)." << std::endl;
}
// Check the solution using the hipBLAS API.
// Create a handle and enable passing scalar parameters from a pointer to host memory.
hipblasHandle_t hipblas_handle;
HIPBLAS_CHECK(hipblasCreate(&hipblas_handle));
HIPBLAS_CHECK(hipblasSetPointerMode(hipblas_handle, HIPBLAS_POINTER_MODE_HOST));
// Validate the result by seeing if U * S * V_H - A is the zero matrix.
const double eps = 1.0e5 * std::numeric_limits<double>::epsilon();
const double h_one = 1;
const double h_minus_one = -1;
unsigned int errors = 0;
// Firstly, compute W = S * V_H.
HIPBLAS_CHECK(
hipblasDdgmm(hipblas_handle, HIPBLAS_SIDE_LEFT, n, n, d_V_H, ldv, d_S, 1, d_W, ldv));
// Secondly, make A = U * W - A.
HIP_CHECK(hipMemcpy(d_A, A.data(), sizeof(double) * size_A, hipMemcpyHostToDevice));
HIPBLAS_CHECK(hipblasDgemm(hipblas_handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
m /*rows A*/,
n /*cols A*/,
n /*cols U*/,
&h_one,
d_U,
ldu,
d_W,
ldv,
&h_minus_one,
d_A,
lda));
// Copy the result back to the host.
HIP_CHECK(hipMemcpy(A.data(), d_A, sizeof(double) * size_A, hipMemcpyDeviceToHost));
// Lastly, check if A is 0.
for(int j = 0; j < n; ++j)
{
for(int i = 0; i < m; ++i)
{
errors += std::fabs(A[i + j * lda]) > eps;
}
}
// Free resources.
HIP_CHECK(hipFree(d_A));
HIP_CHECK(hipFree(d_U));
HIP_CHECK(hipFree(d_V_H));
HIP_CHECK(hipFree(d_S));
HIP_CHECK(hipFree(d_W));
HIP_CHECK(hipFree(d_work));
HIP_CHECK(hipFree(d_rwork));
HIP_CHECK(hipFree(d_bdsqr_info));
HIPBLAS_CHECK(hipblasDestroy(hipblas_handle));
HIPSOLVER_CHECK(hipsolverDestroy(hipsolver_handle));
// Print validation result.
return report_validation_result(errors);
}