27 #ifndef MPBLOCKS_CUDANN_KERNELS_SO3_CU_HPP_
28 #define MPBLOCKS_CUDANN_KERNELS_SO3_CU_HPP_
38 namespace linalg = cuda::linalg2;
40 template<
typename Scalar >
47 dq = fmaxf(-1,fminf(dq,1));
51 template<
typename Scalar >
58 Scalar arg = 2*dot*dot - 1;
59 arg = fmaxf(-0.999999999999999999999999999f,
60 fminf(arg, 0.9999999999999999999999999f));
64 template<
typename Scalar,
bool Pseudo >
76 template<
typename Scalar>
89 template<
bool Pseudo,
typename Scalar,
unsigned int NDim>
95 unsigned int pitchOut,
99 using namespace linalg;
106 int idx = blockId * N + threadId;
121 set<0>(q0) = q.
data[0];
122 set<1>(q0) = q.
data[1];
123 set<2>(q0) = q.
data[2];
124 set<3>(q0) = q.
data[3];
126 set<0>(q1) = g_in[0*pitchIn + idx];
128 set<1>(q1) = g_in[1*pitchIn + idx];
130 set<2>(q1) = g_in[2*pitchIn + idx];
132 set<3>(q1) = g_in[3*pitchIn + idx];
138 g_out[0*pitchOut + idx] = d;
140 g_out[1*pitchOut + idx] = idx;
153 template<
bool Pseudo,
typename Scalar,
unsigned int NDim>
159 using namespace linalg;
167 int constraintIdx = idx % 81;
168 int sign = idx > 80 ? -1 : 1;
179 med[idx] = Scalar(0.5)*( query.
min[idx] + query.
max[idx] );
181 if( childIdx & (0x01 << idx ) )
184 max[idx] = query.
max[idx];
188 min[idx] = query.
min[idx];
205 set<0>(q) = query.
point[0];
206 set<1>(q) = query.
point[1];
207 set<2>(q) = query.
point[2];
208 set<3>(q) = query.
point[3];
211 int cnst = constraintIdx;;
212 char spec_0 = cnst % 3; cnst /= 3;
213 char spec_1 = cnst % 3; cnst /= 3;
214 char spec_2 = cnst % 3; cnst /= 3;
215 char spec_3 = cnst % 3; cnst /= 3;
222 den -= max[0]*max[0];
223 else if( spec_0 ==
MIN )
224 den -= min[0]*min[0];
226 num += get<0>(q)*get<0>(q);
229 den -= max[1]*max[1];
230 else if( spec_1 ==
MIN )
231 den -= min[1]*min[1];
233 num += get<1>(q)*get<1>(q);
236 den -= max[2]*max[2];
237 else if( spec_2 ==
MIN )
238 den -= min[2]*min[2];
240 num += get<2>(q)*get<2>(q);
243 den -= max[3]*max[3];
244 else if( spec_3 ==
MIN )
245 den -= min[3]*min[3];
247 num += get<3>(q)*get<3>(q);
249 bool feasible =
true;
254 Scalar lambda2 = num / (4*den);
255 Scalar lambda = std::sqrt(lambda2);
259 lambda = -Scalar(1.0)/(Scalar(2.0)*sign);
267 else if( spec_0 ==
MIN )
270 set<0>(
x) = -get<0>(q) / (2*sign*lambda);
272 if( get<0>(x) > max[0] || get<0>(
x) < min[0] )
277 else if( spec_1 ==
MIN )
280 set<1>(
x) = -get<1>(q) / (2*sign*lambda);
282 if( get<1>(x) > max[1] || get<1>(
x) < min[1] )
287 else if( spec_2 ==
MIN )
290 set<2>(
x) = -get<2>(q) / (2*sign*lambda);
292 if( get<2>(x) > max[2] || get<2>(
x) < min[2] )
297 else if( spec_3 ==
MIN )
300 set<3>(
x) = -get<3>(q) / (2*sign*lambda);
302 if( get<3>(x) > max[3] || get<3>(
x) < min[3] )
319 Scalar dist_i = dist[idx];
320 Scalar dist_j = 1000;
322 for(
int j=1; j < 6; j++)
324 Scalar dist_j = dist[idx + j*32];
325 if( dist_j < dist_i )
333 for(
int j=16; j > 0; j /= 2 )
335 dist_j = dist[idx + j];
336 if( dist_j < dist_i )
342 g_out[childIdx] = dist_i;
__device__ Scalar so3_pseudo_distance(const linalg::Matrix< Scalar, 4, 1 > &q0, const linalg::Matrix< Scalar, 4, 1 > &q1)
static __device__ Scalar compute(const linalg::Matrix< Scalar, 4, 1 > &q0, const linalg::Matrix< Scalar, 4, 1 > &q1)
static __device__ Scalar compute(const linalg::Matrix< Scalar, 4, 1 > &q0, const linalg::Matrix< Scalar, 4, 1 > &q1)
__device__ __host__ Scalar dot(const RValue< Scalar, ROWS, 1, ExpA > &A, const RValue< Scalar, ROWS, 1, ExpB > &B)
compute the DOT
__global__ void so3_distance(QueryPoint< Scalar, NDim > query, Scalar *g_in, unsigned int pitchIn, Scalar *g_out, unsigned int pitchOut, unsigned int n)