Browse Source

Merge pull request #3943 from troelsy:4.x

Add Otsu's method to cv::cuda::threshold #3943

I implemented Otsu's method in CUDA for a separate project and want to add it to cv::cuda::threshold

I have made an effort to use existing OpenCV functions in my code, but I had some trouble with `ThresholdTypes` and `cv::cuda::calcHist`. I couldn't figure out how to include `precomp.hpp` to get the definition of `ThresholdTypes`. For `cv::cuda::calcHist` I tried adding `opencv_cudaimgproc`, but it creates a circular dependency on `cudaarithm`. I have include a simple implementation of `calcHist` so the code runs, but I would like input on how to use `cv::cuda::calcHist` instead. 

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [ ] The PR is proposed to the proper branch
- [ ] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [ ] The feature is well documented and sample code can be built with the project CMake
pull/3906/merge
Troels Ynddal 3 weeks ago committed by GitHub
parent
commit
6329974d4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 10
      modules/cudaarithm/include/opencv2/cudaarithm.hpp
  2. 222
      modules/cudaarithm/src/cuda/threshold.cu
  3. 50
      modules/cudaarithm/test/test_element_operations.cpp
  4. 2
      modules/cudev/include/opencv2/cudev/util/saturate_cast.hpp
  5. 10
      modules/cudev/include/opencv2/cudev/warp/shuffle.hpp

10
modules/cudaarithm/include/opencv2/cudaarithm.hpp

@ -546,12 +546,16 @@ static inline void scaleAdd(InputArray src1, double alpha, InputArray src2, Outp @@ -546,12 +546,16 @@ static inline void scaleAdd(InputArray src1, double alpha, InputArray src2, Outp
/** @brief Applies a fixed-level threshold to each array element.
The special value cv::THRESH_OTSU may be combined with one of the other types. In this case, the function determines the
optimal threshold value using the Otsu's and uses it instead of the specified threshold. The function returns the
computed threshold value in addititon to the thresholded matrix.
The Otsu's method is implemented only for 8-bit matrices.
@param src Source array (single-channel).
@param dst Destination array with the same size and type as src .
@param dst Destination array with the same size and type as src.
@param thresh Threshold value.
@param maxval Maximum value to use with THRESH_BINARY and THRESH_BINARY_INV threshold types.
@param type Threshold type. For details, see threshold . The THRESH_OTSU and THRESH_TRIANGLE
threshold types are not supported.
@param type Threshold type. For details, see threshold. The THRESH_TRIANGLE threshold type is not supported.
@param stream Stream for the asynchronous version.
@sa threshold

222
modules/cudaarithm/src/cuda/threshold.cu

@ -95,12 +95,232 @@ namespace @@ -95,12 +95,232 @@ namespace
}
}
double cv::cuda::threshold(InputArray _src, OutputArray _dst, double thresh, double maxVal, int type, Stream& stream)
__global__ void otsu_sums(uint *histogram, uint *threshold_sums, unsigned long long *sums)
{
const uint32_t n_bins = 256;
__shared__ uint shared_memory_ts[n_bins];
__shared__ unsigned long long shared_memory_s[n_bins];
int bin_idx = threadIdx.x;
int threshold = blockIdx.x;
uint threshold_sum_above = 0;
unsigned long long sum_above = 0;
if (bin_idx >= threshold)
{
uint value = histogram[bin_idx];
threshold_sum_above = value;
sum_above = value * bin_idx;
}
blockReduce<n_bins>(shared_memory_ts, threshold_sum_above, bin_idx, plus<uint>());
blockReduce<n_bins>(shared_memory_s, sum_above, bin_idx, plus<unsigned long long>());
if (bin_idx == 0)
{
threshold_sums[threshold] = threshold_sum_above;
sums[threshold] = sum_above;
}
}
__global__ void
otsu_variance(float2 *variance, uint *histogram, uint *threshold_sums, unsigned long long *sums)
{
const uint32_t n_bins = 256;
__shared__ signed long long shared_memory_a[n_bins];
__shared__ signed long long shared_memory_b[n_bins];
int bin_idx = threadIdx.x;
int threshold = blockIdx.x;
uint n_samples = threshold_sums[0];
uint n_samples_above = threshold_sums[threshold];
uint n_samples_below = n_samples - n_samples_above;
unsigned long long total_sum = sums[0];
unsigned long long sum_above = sums[threshold];
unsigned long long sum_below = total_sum - sum_above;
float threshold_variance_above_f32 = 0;
float threshold_variance_below_f32 = 0;
if (bin_idx >= threshold)
{
float mean = (float) sum_above / n_samples_above;
float sigma = bin_idx - mean;
threshold_variance_above_f32 = sigma * sigma;
}
else
{
float mean = (float) sum_below / n_samples_below;
float sigma = bin_idx - mean;
threshold_variance_below_f32 = sigma * sigma;
}
uint bin_count = histogram[bin_idx];
signed long long threshold_variance_above_i64 = (signed long long)(threshold_variance_above_f32 * bin_count);
signed long long threshold_variance_below_i64 = (signed long long)(threshold_variance_below_f32 * bin_count);
blockReduce<n_bins>(shared_memory_a, threshold_variance_above_i64, bin_idx, plus<signed long long>());
blockReduce<n_bins>(shared_memory_b, threshold_variance_below_i64, bin_idx, plus<signed long long>());
if (bin_idx == 0)
{
variance[threshold] = make_float2(threshold_variance_above_i64, threshold_variance_below_i64);
}
}
__global__ void
otsu_score(uint *otsu_threshold, uint *threshold_sums, float2 *variance)
{
const uint32_t n_thresholds = 256;
__shared__ float shared_memory[n_thresholds / WARP_SIZE];
int threshold = threadIdx.x;
uint n_samples = threshold_sums[0];
uint n_samples_above = threshold_sums[threshold];
uint n_samples_below = n_samples - n_samples_above;
float threshold_mean_above = (float)n_samples_above / n_samples;
float threshold_mean_below = (float)n_samples_below / n_samples;
float2 variances = variance[threshold];
float variance_above = variances.x / n_samples_above;
float variance_below = variances.y / n_samples_below;
float above = threshold_mean_above * variance_above;
float below = threshold_mean_below * variance_below;
float score = above + below;
float original_score = score;
blockReduce<n_thresholds>(shared_memory, score, threshold, minimum<float>());
if (threshold == 0)
{
shared_memory[0] = score;
}
__syncthreads();
score = shared_memory[0];
// We found the minimum score, but we need to find the threshold. If we find the thread with the minimum score, we
// know which threshold it is
if (original_score == score)
{
*otsu_threshold = threshold - 1;
}
}
void compute_otsu(uint *histogram, uint *otsu_threshold, Stream &stream)
{
const uint n_bins = 256;
const uint n_thresholds = 256;
cudaStream_t cuda_stream = StreamAccessor::getStream(stream);
dim3 block_all(n_bins);
dim3 grid_all(n_thresholds);
dim3 block_score(n_thresholds);
dim3 grid_score(1);
BufferPool pool(stream);
GpuMat gpu_threshold_sums(1, n_bins, CV_32SC1, pool.getAllocator());
GpuMat gpu_sums(1, n_bins, CV_64FC1, pool.getAllocator());
GpuMat gpu_variances(1, n_bins, CV_32FC2, pool.getAllocator());
otsu_sums<<<grid_all, block_all, 0, cuda_stream>>>(
histogram, gpu_threshold_sums.ptr<uint>(), gpu_sums.ptr<unsigned long long>());
otsu_variance<<<grid_all, block_all, 0, cuda_stream>>>(
gpu_variances.ptr<float2>(), histogram, gpu_threshold_sums.ptr<uint>(), gpu_sums.ptr<unsigned long long>());
otsu_score<<<grid_score, block_score, 0, cuda_stream>>>(
otsu_threshold, gpu_threshold_sums.ptr<uint>(), gpu_variances.ptr<float2>());
}
// TODO: Replace this is cv::cuda::calcHist
template <uint n_bins>
__global__ void histogram_kernel(
uint *histogram, const uint8_t *image, uint width,
uint height, uint pitch)
{
__shared__ uint local_histogram[n_bins];
uint x = blockIdx.x * blockDim.x + threadIdx.x;
uint y = blockIdx.y * blockDim.y + threadIdx.y;
uint tid = threadIdx.y * blockDim.x + threadIdx.x;
if (tid < n_bins)
{
local_histogram[tid] = 0;
}
__syncthreads();
if (x < width && y < height)
{
uint8_t value = image[y * pitch + x];
atomicInc(&local_histogram[value], 0xFFFFFFFF);
}
__syncthreads();
if (tid < n_bins)
{
cv::cudev::atomicAdd(&histogram[tid], local_histogram[tid]);
}
}
// TODO: Replace this with cv::cuda::calcHist
void calcHist(
const GpuMat src, GpuMat histogram, Stream stream)
{
const uint n_bins = 256;
cudaStream_t cuda_stream = StreamAccessor::getStream(stream);
dim3 block(128, 4, 1);
dim3 grid = dim3(divUp(src.cols, block.x), divUp(src.rows, block.y), 1);
CV_CUDEV_SAFE_CALL(cudaMemsetAsync(histogram.ptr<uint>(), 0, n_bins * sizeof(uint), cuda_stream));
histogram_kernel<n_bins>
<<<grid, block, 0, cuda_stream>>>(
histogram.ptr<uint>(), src.ptr<uint8_t>(), (uint) src.cols, (uint) src.rows, (uint) src.step);
}
double cv::cuda::threshold(InputArray _src, OutputArray _dst, double thresh, double maxVal, int type, Stream &stream)
{
GpuMat src = getInputMat(_src, stream);
const int depth = src.depth();
const int THRESH_OTSU = 8;
if ((type & THRESH_OTSU) == THRESH_OTSU)
{
CV_Assert(depth == CV_8U);
CV_Assert(src.channels() == 1);
BufferPool pool(stream);
// Find the threshold using Otsu and then run the normal thresholding algorithm
GpuMat gpu_histogram(256, 1, CV_32SC1, pool.getAllocator());
calcHist(src, gpu_histogram, stream);
GpuMat gpu_otsu_threshold(1, 1, CV_32SC1, pool.getAllocator());
compute_otsu(gpu_histogram.ptr<uint>(), gpu_otsu_threshold.ptr<uint>(), stream);
cv::Mat mat_otsu_threshold;
gpu_otsu_threshold.download(mat_otsu_threshold, stream);
stream.waitForCompletion();
// Overwrite the threshold value with the Otsu value and remove the Otsu flag from the type
type = type & ~THRESH_OTSU;
thresh = (double) mat_otsu_threshold.at<int>(0);
}
CV_Assert( depth <= CV_64F );
CV_Assert( type <= 4 /*THRESH_TOZERO_INV*/ );

50
modules/cudaarithm/test/test_element_operations.cpp

@ -2529,7 +2529,7 @@ INSTANTIATE_TEST_CASE_P(CUDA_Arithm, AddWeighted, testing::Combine( @@ -2529,7 +2529,7 @@ INSTANTIATE_TEST_CASE_P(CUDA_Arithm, AddWeighted, testing::Combine(
///////////////////////////////////////////////////////////////////////////////////////////////////////
// Threshold
CV_ENUM(ThreshOp, cv::THRESH_BINARY, cv::THRESH_BINARY_INV, cv::THRESH_TRUNC, cv::THRESH_TOZERO, cv::THRESH_TOZERO_INV)
CV_ENUM(ThreshOp, cv::THRESH_BINARY, cv::THRESH_BINARY_INV, cv::THRESH_TRUNC, cv::THRESH_TOZERO, cv::THRESH_TOZERO_INV, cv::THRESH_OTSU)
#define ALL_THRESH_OPS testing::Values(ThreshOp(cv::THRESH_BINARY), ThreshOp(cv::THRESH_BINARY_INV), ThreshOp(cv::THRESH_TRUNC), ThreshOp(cv::THRESH_TOZERO), ThreshOp(cv::THRESH_TOZERO_INV))
PARAM_TEST_CASE(Threshold, cv::cuda::DeviceInfo, cv::Size, MatType, Channels, ThreshOp, UseRoi)
@ -2577,6 +2577,54 @@ INSTANTIATE_TEST_CASE_P(CUDA_Arithm, Threshold, testing::Combine( @@ -2577,6 +2577,54 @@ INSTANTIATE_TEST_CASE_P(CUDA_Arithm, Threshold, testing::Combine(
ALL_THRESH_OPS,
WHOLE_SUBMAT));
///////////////////////////////////////////////////////////////////////////////////////////////////////
// ThresholdOtsu
PARAM_TEST_CASE(ThresholdOtsu, cv::cuda::DeviceInfo, cv::Size, MatType, Channels, ThreshOp, UseRoi)
{
cv::cuda::DeviceInfo devInfo;
cv::Size size;
int type;
int channel;
int threshOp;
bool useRoi;
virtual void SetUp()
{
devInfo = GET_PARAM(0);
size = GET_PARAM(1);
type = GET_PARAM(2);
channel = GET_PARAM(3);
threshOp = GET_PARAM(4) | cv::THRESH_OTSU;
useRoi = GET_PARAM(5);
cv::cuda::setDevice(devInfo.deviceID());
}
};
CUDA_TEST_P(ThresholdOtsu, Accuracy)
{
cv::Mat src = randomMat(size, CV_MAKE_TYPE(type, channel));
cv::cuda::GpuMat dst = createMat(src.size(), src.type(), useRoi);
double otsu_gpu = cv::cuda::threshold(loadMat(src, useRoi), dst, 0, 255, threshOp);
cv::Mat dst_gold;
double otsu_cpu = cv::threshold(src, dst_gold, 0, 255, threshOp);
ASSERT_DOUBLE_EQ(otsu_gpu, otsu_cpu);
EXPECT_MAT_NEAR(dst_gold, dst, 0.0);
}
INSTANTIATE_TEST_CASE_P(CUDA_Arithm, ThresholdOtsu, testing::Combine(
ALL_DEVICES,
DIFFERENT_SIZES,
testing::Values(MatDepth(CV_8U)),
testing::Values(Channels(1)),
ALL_THRESH_OPS,
WHOLE_SUBMAT));
////////////////////////////////////////////////////////////////////////////////
// InRange

2
modules/cudev/include/opencv2/cudev/util/saturate_cast.hpp

@ -62,6 +62,8 @@ template <typename T> __device__ __forceinline__ T saturate_cast(ushort v) { ret @@ -62,6 +62,8 @@ template <typename T> __device__ __forceinline__ T saturate_cast(ushort v) { ret
template <typename T> __device__ __forceinline__ T saturate_cast(short v) { return T(v); }
template <typename T> __device__ __forceinline__ T saturate_cast(uint v) { return T(v); }
template <typename T> __device__ __forceinline__ T saturate_cast(int v) { return T(v); }
template <typename T> __device__ __forceinline__ T saturate_cast(signed long long v) { return T(v); }
template <typename T> __device__ __forceinline__ T saturate_cast(unsigned long long v) { return T(v); }
template <typename T> __device__ __forceinline__ T saturate_cast(float v) { return T(v); }
template <typename T> __device__ __forceinline__ T saturate_cast(double v) { return T(v); }

10
modules/cudev/include/opencv2/cudev/warp/shuffle.hpp

@ -332,6 +332,16 @@ __device__ __forceinline__ uint shfl_down(uint val, uint delta, int width = warp @@ -332,6 +332,16 @@ __device__ __forceinline__ uint shfl_down(uint val, uint delta, int width = warp
return (uint) __shfl_down((int) val, delta, width);
}
__device__ __forceinline__ signed long long shfl_down(signed long long val, uint delta, int width = warpSize)
{
return __shfl_down(val, delta, width);
}
__device__ __forceinline__ unsigned long long shfl_down(unsigned long long val, uint delta, int width = warpSize)
{
return (unsigned long long) __shfl_down(val, delta, width);
}
__device__ __forceinline__ float shfl_down(float val, uint delta, int width = warpSize)
{
return __shfl_down(val, delta, width);

Loading…
Cancel
Save