You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
206 lines
4.9 KiB
206 lines
4.9 KiB
// MIT License |
|
// |
|
// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. |
|
// |
|
// Permission is hereby granted, free of charge, to any person obtaining a copy |
|
// of this software and associated documentation files (the "Software"), to deal |
|
// in the Software without restriction, including without limitation the rights |
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
// copies of the Software, and to permit persons to whom the Software is |
|
// furnished to do so, subject to the following conditions: |
|
// |
|
// The above copyright notice and this permission notice shall be included in all |
|
// copies or substantial portions of the Software. |
|
// |
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
// SOFTWARE. |
|
|
|
#pragma once |
|
|
|
#ifdef __device__ |
|
#define FUNC_QUALIFIER __host__ __device__ |
|
#else |
|
#define FUNC_QUALIFIER |
|
#endif |
|
|
|
#include <ranges> // std::ranges::size |
|
#include <utility> // std::forward |
|
|
|
namespace tmp |
|
{ |
|
template<int Init, typename Pred, typename Step, typename F> |
|
FUNC_QUALIFIER void static_for(F&& f) |
|
{ |
|
if constexpr(Pred{}.template operator()<Init>()) |
|
{ |
|
f.template operator()<Init>(); |
|
static_for<Step{}.template operator()<Init>(), Pred, Step, F>(std::forward<F>(f)); |
|
} |
|
} |
|
|
|
template<int J> |
|
struct constant |
|
{ |
|
template<int> |
|
FUNC_QUALIFIER constexpr int operator()() |
|
{ |
|
return J; |
|
} |
|
}; |
|
template<int J> |
|
struct less_than |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr bool operator()() |
|
{ |
|
return I < J; |
|
} |
|
}; |
|
template<int J> |
|
struct greater_than |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr bool operator()() |
|
{ |
|
return I > J; |
|
} |
|
}; |
|
template<int J> |
|
struct less_than_eq |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr bool operator()() |
|
{ |
|
return I <= J; |
|
} |
|
}; |
|
template<int J> |
|
struct greater_than_eq |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr bool operator()() |
|
{ |
|
return I >= J; |
|
} |
|
}; |
|
template<int J> |
|
struct equal |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr bool operator()() |
|
{ |
|
return I == J; |
|
} |
|
}; |
|
template<int J> |
|
struct not_equal |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr bool operator()() |
|
{ |
|
return I != J; |
|
} |
|
}; |
|
template<int J = 1> |
|
struct increment |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr int operator()() |
|
{ |
|
return I + J; |
|
} |
|
}; |
|
template<int J = 1> |
|
struct decrement |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr int operator()() |
|
{ |
|
return I - J; |
|
} |
|
}; |
|
template<int J> |
|
struct divide |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr int operator()() |
|
{ |
|
return I / J; |
|
} |
|
}; |
|
template<int J> |
|
struct multiply |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr int operator()() |
|
{ |
|
return I * J; |
|
} |
|
}; |
|
|
|
template<int J> |
|
struct divide_ceil |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr int operator()() |
|
{ |
|
static_assert(std::is_signed_v<decltype(I)> && std::is_signed_v<decltype(J)>, |
|
"Invalid type."); |
|
return (I + J - 1) / J; |
|
} |
|
}; |
|
|
|
template<typename Pred, typename TruePath, typename FalsePath> |
|
struct select |
|
{ |
|
template<int I> |
|
FUNC_QUALIFIER constexpr int operator()() |
|
{ |
|
if constexpr(Pred{}.template operator()<I>()) |
|
return TruePath{}.template operator()<I>(); |
|
else |
|
return FalsePath{}.template operator()<I>(); |
|
} |
|
}; |
|
|
|
template<auto InputRange, int Index, int N, typename F> |
|
FUNC_QUALIFIER void static_range_for_impl(F&& f) |
|
{ |
|
if constexpr(Index != N) |
|
{ |
|
f.template operator()<InputRange[Index]>(); |
|
static_range_for_impl<InputRange, Index + 1, N, F>(std::forward<F>(f)); |
|
} |
|
} |
|
|
|
template<auto InputRange, int Index, int N, typename S, typename F> |
|
FUNC_QUALIFIER void static_switch_impl(S&& s, F&& f) |
|
{ |
|
if constexpr(Index != N) |
|
{ |
|
if(s == InputRange[Index]) |
|
f.template operator()<InputRange[Index]>(); |
|
else |
|
static_switch_impl<InputRange, Index + 1, N, S, F>(std::forward<S>(s), |
|
std::forward<F>(f)); |
|
} |
|
} |
|
|
|
template<auto SizedRange, typename F> |
|
FUNC_QUALIFIER void static_range_for(F&& f) |
|
{ |
|
static_range_for_impl<SizedRange, 0, std::ranges::size(SizedRange), F>(std::forward<F>(f)); |
|
} |
|
|
|
template<auto SizedRange, typename S, typename F> |
|
FUNC_QUALIFIER void static_switch(S&& s, F&& f) |
|
{ |
|
static_switch_impl<SizedRange, 0, std::ranges::size(SizedRange), S, F>(std::forward<S>(s), |
|
std::forward<F>(f)); |
|
} |
|
} // namespace tmp
|
|
|