Commit 0fd8ceec authored by Julius Nehring-Wirxel's avatar Julius Nehring-Wirxel
Browse files

Added special cases and cleanup.

parent 129c095b
......@@ -6,6 +6,30 @@
namespace
{
[[maybe_unused]] std::string imul()
{
return "inline u64 imul(u64 lhs, u64 rhs, u64* high)\n"
"{\n"
"#ifdef _MSC_VER\n"
" return _mul128(lhs, rhs, &(i64(high)));\n"
"#else\n"
" __int128 res_i = __int128(i64(lhs)) * i64(rhs);\n"
" u64 v[2] = {};\n"
" memcpy(&v, &res_i, sizeof(__int128));\n"
" *high = v[1];\n"
" return v[0];\n"
"#endif\n"
"}\n";
}
[[maybe_unused]] std::string mul()
{
return "inline u64 mul(u64 lhs, u64 rhs, u64* high)\n"
"{\n"
" return _mulx_u64(lhs, rhs, high);\n"
"}\n";
}
std::string generate_mul(int w_r, int w_a, int w_b)
{
auto const out_type = "u" + std::to_string(w_r * 64);
......@@ -59,7 +83,7 @@ std::string generate_mul(int w_r, int w_a, int w_b)
}
std::string result;
result += "template<>\n";
result += "template <>\n";
result += "inline " + out_type + " mul(" + a_type + " lhs, " + b_type + " rhs)\n{\n";
result += " " + out_type + " res;\n";
......@@ -106,8 +130,73 @@ std::string generate_mul(int w_r, int w_a, int w_b)
return result;
}
/// special case
std::string imul_128_64_128()
{
return "template <>\n"
"inline i128 imul(i64 lhs, i128 rhs)\n"
"{\n"
"#ifdef _MSC_VER\n"
" return imul(i128(lhs), rhs);\n"
"#else\n"
" __int128 l = lhs;\n"
" __int128 r;\n"
" memcpy(&r, &rhs, sizeof(__int128));\n"
" __int128 inres = l * r;\n"
" i128 res;\n"
" memcpy(&res, &inres, sizeof(__int128));\n"
" return res;\n"
"#endif\n"
"}\n";
}
std::string imul_128_128_64()
{
return "template <>\n"
"inline i128 imul(i128 lhs, i64 rhs)\n"
"{\n"
"#ifdef _MSC_VER\n"
" return imul(lhs, i128(rhs));\n"
"#else\n"
" __int128 l;\n"
" __int128 r = rhs;\n"
" memcpy(&l, &lhs, sizeof(__int128));\n"
" __int128 inres = l * r;\n"
" i128 res;\n"
" memcpy(&res, &inres, sizeof(__int128));\n"
" return res;\n"
"#endif\n"
"}\n";
}
std::string imul_128_64_64()
{
return "template <>\n"
"inline i128 imul(i64 lhs, i64 rhs)\n"
"{\n"
"#ifdef _MSC_VER\n"
" return imul(i128(lhs), i128(rhs));\n"
"#else\n"
" __int128 l = lhs;\n"
" __int128 r = rhs;\n"
" __int128 inres = l * r;\n"
" i128 res;\n"
" memcpy(&res, &inres, sizeof(__int128));\n"
" return res;\n"
"#endif\n"
"}\n";
}
std::string generate_imul(int w_r, int w_a, int w_b)
{
// special cases
if (w_r == 2 && w_a == 1 && w_b == 1)
return imul_128_64_64();
if (w_r == 2 && w_a == 2 && w_b == 1)
return imul_128_128_64();
if (w_r == 2 && w_a == 1 && w_b == 2)
return imul_128_64_128();
auto const out_type = "i" + std::to_string(w_r * 64);
auto const a_type = "i" + std::to_string(w_a * 64);
auto const b_type = "i" + std::to_string(w_b * 64);
......@@ -172,19 +261,19 @@ std::string generate_imul(int w_r, int w_a, int w_b)
result += " }\n";
};
result += "template<>\n";
result += "template <>\n";
result += "inline " + out_type + " imul(" + a_type + " lhs, " + b_type + " rhs)\n{\n";
// result += "fixed_int<" + std::to_string(w_r) + "> imul(" + a_type + (w_a > 1 ? " const&" : "") + " lhs, " + b_type + (w_b > 1 ? " const&" : "") + " rhs)\n{\n";
result += " fixed_int<" + std::to_string(w_r) + "> res;\n";
result += " " + out_type + " res;\n";
// conditional inversion
if (w_a != w_b || w_a != w_r)
{
result += " u64 s_l = u64(i64(" + lhs_of(w_a - 1) + ") >> 63); // 0 iff > 0, -1 otherwise\n";
result += " u64 s_r = u64(i64(" + rhs_of(w_b - 1) + ") >> 63); // 0 iff > 0, -1 otherwise\n";
result += " u64 s_res = s_l ^ s_r;\n";
conditional_invert(lhs_of, "s_l", w_a);
conditional_invert(rhs_of, "s_r", w_b);
result += " u64 s_l = u64(i64(" + lhs_of(w_a - 1) + ") >> 63); // 0 iff > 0, -1 otherwise\n";
result += " u64 s_r = u64(i64(" + rhs_of(w_b - 1) + ") >> 63); // 0 iff > 0, -1 otherwise\n";
result += " u64 s_res = s_l ^ s_r;\n";
conditional_invert(lhs_of, "s_l", w_a);
conditional_invert(rhs_of, "s_r", w_b);
}
for (auto i = 0u; i < vars_l.size(); ++i)
......@@ -225,7 +314,7 @@ std::string generate_imul(int w_r, int w_a, int w_b)
result += ";\n";
if (w_a != w_b || w_a != w_r)
conditional_invert([](int i) { return "res.d[" + std::to_string(i) + "]"; }, "s_res", w_r);
conditional_invert([](int i) { return "res.d[" + std::to_string(i) + "]"; }, "s_res", w_r);
result += " return res;\n";
result += "}\n";
......@@ -250,16 +339,16 @@ void generate_mul_file()
file << "#include <typed-geometry/feature/fixed_int.hh>\n\n";
file << "namespace tg::detail\n{\n";
file << "template <int w_res, class T0, class T1>\n";
file << "fixed_uint<w_res> mul(T0 const& lhs, T1 const& rhs);\n\n";
for (auto r = 2; r <= 4; ++r)
for (auto j = 1; j <= r; ++j)
for (auto i = 1; i <= r; ++i)
if (i + j >= r)
{
file << generate_mul(r, i, j);
file << "\n";
}
{
file << generate_mul(r, i, j);
file << "\n";
}
file << "} // namespace tg::detail";
}
......@@ -276,25 +365,24 @@ void generate_imul_file()
file << "#include <intrin.h>\n";
file << "#else\n";
file << "#include <x86intrin.h>\n";
file << "#include <cstring>\n";
file << "#endif\n\n";
file << "#include <typed-geometry/feature/fixed_int.hh>\n\n";
file << "namespace tg::detail\n{\n";
file << "template <int w_res, class T0, class T1>\n";
file << "fixed_int<w_res> imul(T0 const& lhs, T1 const& rhs);\n\n";
for (auto r = 2; r <= 4; ++r)
for (auto j = 1; j <= r; ++j)
for (auto i = 1; i <= r; ++i)
if (i + j >= r)
{
file << generate_imul(r, i, j);
file << "\n";
}
{
file << generate_imul(r, i, j);
file << "\n";
}
file << "} // namespace tg::detail";
}
} // namespace
TEST_CASE("tg generate multiplications")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment