Krotos Modules 3
Loading...
Searching...
No Matches
space_ip.h
Go to the documentation of this file.
1#pragma once
2#include "hnswlib.h"
3
4namespace hnswlib {
5
6 static float
7 InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
8 size_t qty = *((size_t *) qty_ptr);
9 float res = 0;
10 for (unsigned i = 0; i < qty; i++) {
11 res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
12 }
13 return res;
14
15 }
16
17 static float
18 InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) {
19 return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr);
20 }
21
22#if defined(USE_AVX)
23
24// Favor using AVX if available.
25 static float
26 InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
27 float PORTABLE_ALIGN32 TmpRes[8];
28 float *pVect1 = (float *) pVect1v;
29 float *pVect2 = (float *) pVect2v;
30 size_t qty = *((size_t *) qty_ptr);
31
32 size_t qty16 = qty / 16;
33 size_t qty4 = qty / 4;
34
35 const float *pEnd1 = pVect1 + 16 * qty16;
36 const float *pEnd2 = pVect1 + 4 * qty4;
37
38 __m256 sum256 = _mm256_set1_ps(0);
39
40 while (pVect1 < pEnd1) {
41 //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
42
43 __m256 v1 = _mm256_loadu_ps(pVect1);
44 pVect1 += 8;
45 __m256 v2 = _mm256_loadu_ps(pVect2);
46 pVect2 += 8;
47 sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
48
49 v1 = _mm256_loadu_ps(pVect1);
50 pVect1 += 8;
51 v2 = _mm256_loadu_ps(pVect2);
52 pVect2 += 8;
53 sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
54 }
55
56 __m128 v1, v2;
57 __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
58
59 while (pVect1 < pEnd2) {
60 v1 = _mm_loadu_ps(pVect1);
61 pVect1 += 4;
62 v2 = _mm_loadu_ps(pVect2);
63 pVect2 += 4;
64 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
65 }
66
67 _mm_store_ps(TmpRes, sum_prod);
68 float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];;
69 return sum;
70 }
71
72 static float
73 InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
74 return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr);
75 }
76
77#endif
78
79#if defined(USE_SSE)
80
81 static float
82 InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
83 float PORTABLE_ALIGN32 TmpRes[8];
84 float *pVect1 = (float *) pVect1v;
85 float *pVect2 = (float *) pVect2v;
86 size_t qty = *((size_t *) qty_ptr);
87
88 size_t qty16 = qty / 16;
89 size_t qty4 = qty / 4;
90
91 const float *pEnd1 = pVect1 + 16 * qty16;
92 const float *pEnd2 = pVect1 + 4 * qty4;
93
94 __m128 v1, v2;
95 __m128 sum_prod = _mm_set1_ps(0);
96
97 while (pVect1 < pEnd1) {
98 v1 = _mm_loadu_ps(pVect1);
99 pVect1 += 4;
100 v2 = _mm_loadu_ps(pVect2);
101 pVect2 += 4;
102 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
103
104 v1 = _mm_loadu_ps(pVect1);
105 pVect1 += 4;
106 v2 = _mm_loadu_ps(pVect2);
107 pVect2 += 4;
108 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
109
110 v1 = _mm_loadu_ps(pVect1);
111 pVect1 += 4;
112 v2 = _mm_loadu_ps(pVect2);
113 pVect2 += 4;
114 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
115
116 v1 = _mm_loadu_ps(pVect1);
117 pVect1 += 4;
118 v2 = _mm_loadu_ps(pVect2);
119 pVect2 += 4;
120 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
121 }
122
123 while (pVect1 < pEnd2) {
124 v1 = _mm_loadu_ps(pVect1);
125 pVect1 += 4;
126 v2 = _mm_loadu_ps(pVect2);
127 pVect2 += 4;
128 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
129 }
130
131 _mm_store_ps(TmpRes, sum_prod);
132 float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
133
134 return sum;
135 }
136
137 static float
138 InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
139 return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr);
140 }
141
142#endif
143
144
145#if defined(USE_AVX512)
146
147 static float
148 InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
149 float PORTABLE_ALIGN64 TmpRes[16];
150 float *pVect1 = (float *) pVect1v;
151 float *pVect2 = (float *) pVect2v;
152 size_t qty = *((size_t *) qty_ptr);
153
154 size_t qty16 = qty / 16;
155
156
157 const float *pEnd1 = pVect1 + 16 * qty16;
158
159 __m512 sum512 = _mm512_set1_ps(0);
160
161 while (pVect1 < pEnd1) {
162 //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
163
164 __m512 v1 = _mm512_loadu_ps(pVect1);
165 pVect1 += 16;
166 __m512 v2 = _mm512_loadu_ps(pVect2);
167 pVect2 += 16;
168 sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2));
169 }
170
171 _mm512_store_ps(TmpRes, sum512);
172 float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];
173
174 return sum;
175 }
176
177 static float
178 InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
179 return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr);
180 }
181
182#endif
183
184#if defined(USE_AVX)
185
186 static float
187 InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
188 float PORTABLE_ALIGN32 TmpRes[8];
189 float *pVect1 = (float *) pVect1v;
190 float *pVect2 = (float *) pVect2v;
191 size_t qty = *((size_t *) qty_ptr);
192
193 size_t qty16 = qty / 16;
194
195
196 const float *pEnd1 = pVect1 + 16 * qty16;
197
198 __m256 sum256 = _mm256_set1_ps(0);
199
200 while (pVect1 < pEnd1) {
201 //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
202
203 __m256 v1 = _mm256_loadu_ps(pVect1);
204 pVect1 += 8;
205 __m256 v2 = _mm256_loadu_ps(pVect2);
206 pVect2 += 8;
207 sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
208
209 v1 = _mm256_loadu_ps(pVect1);
210 pVect1 += 8;
211 v2 = _mm256_loadu_ps(pVect2);
212 pVect2 += 8;
213 sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
214 }
215
216 _mm256_store_ps(TmpRes, sum256);
217 float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
218
219 return sum;
220 }
221
222 static float
223 InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
224 return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr);
225 }
226
227#endif
228
229#if defined(USE_SSE)
230
231 static float
232 InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
233 float PORTABLE_ALIGN32 TmpRes[8];
234 float *pVect1 = (float *) pVect1v;
235 float *pVect2 = (float *) pVect2v;
236 size_t qty = *((size_t *) qty_ptr);
237
238 size_t qty16 = qty / 16;
239
240 const float *pEnd1 = pVect1 + 16 * qty16;
241
242 __m128 v1, v2;
243 __m128 sum_prod = _mm_set1_ps(0);
244
245 while (pVect1 < pEnd1) {
246 v1 = _mm_loadu_ps(pVect1);
247 pVect1 += 4;
248 v2 = _mm_loadu_ps(pVect2);
249 pVect2 += 4;
250 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
251
252 v1 = _mm_loadu_ps(pVect1);
253 pVect1 += 4;
254 v2 = _mm_loadu_ps(pVect2);
255 pVect2 += 4;
256 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
257
258 v1 = _mm_loadu_ps(pVect1);
259 pVect1 += 4;
260 v2 = _mm_loadu_ps(pVect2);
261 pVect2 += 4;
262 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
263
264 v1 = _mm_loadu_ps(pVect1);
265 pVect1 += 4;
266 v2 = _mm_loadu_ps(pVect2);
267 pVect2 += 4;
268 sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
269 }
270 _mm_store_ps(TmpRes, sum_prod);
271 float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
272
273 return sum;
274 }
275
276 static float
277 InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
278 return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr);
279 }
280
281#endif
282
283#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
284 DISTFUNC<float> InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
285 DISTFUNC<float> InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
286 DISTFUNC<float> InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE;
287 DISTFUNC<float> InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE;
288
289 static float
290 InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
291 size_t qty = *((size_t *) qty_ptr);
292 size_t qty16 = qty >> 4 << 4;
293 float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
294 float *pVect1 = (float *) pVect1v + qty16;
295 float *pVect2 = (float *) pVect2v + qty16;
296
297 size_t qty_left = qty - qty16;
298 float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
299 return 1.0f - (res + res_tail);
300 }
301
302 static float
303 InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
304 size_t qty = *((size_t *) qty_ptr);
305 size_t qty4 = qty >> 2 << 2;
306
307 float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
308 size_t qty_left = qty - qty4;
309
310 float *pVect1 = (float *) pVect1v + qty4;
311 float *pVect2 = (float *) pVect2v + qty4;
312 float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
313
314 return 1.0f - (res + res_tail);
315 }
316#endif
317
318 class InnerProductSpace : public SpaceInterface<float> {
319
322 size_t dim_;
323 public:
324 InnerProductSpace(size_t dim) {
326 #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
327 #if defined(USE_AVX512)
328 if (AVX512Capable()) {
329 InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
330 InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
331 } else if (AVXCapable()) {
332 InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
333 InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
334 }
335 #elif defined(USE_AVX)
336 if (AVXCapable()) {
337 InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
338 InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
339 }
340 #endif
341 #if defined(USE_AVX)
342 if (AVXCapable()) {
343 InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
344 InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
345 }
346 #endif
347
348 if (dim % 16 == 0)
349 fstdistfunc_ = InnerProductDistanceSIMD16Ext;
350 else if (dim % 4 == 0)
351 fstdistfunc_ = InnerProductDistanceSIMD4Ext;
352 else if (dim > 16)
353 fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
354 else if (dim > 4)
355 fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
356 #endif
357 dim_ = dim;
358 data_size_ = dim * sizeof(float);
359 }
360
361 size_t get_data_size() {
362 return data_size_;
363 }
364
368
370 return &dim_;
371 }
372
374 };
375
376}
Definition space_ip.h:318
DISTFUNC< float > fstdistfunc_
Definition space_ip.h:320
size_t dim_
Definition space_ip.h:322
size_t get_data_size()
Definition space_ip.h:361
void * get_dist_func_param()
Definition space_ip.h:369
InnerProductSpace(size_t dim)
Definition space_ip.h:324
~InnerProductSpace()
Definition space_ip.h:373
size_t data_size_
Definition space_ip.h:321
DISTFUNC< float > get_dist_func()
Definition space_ip.h:365
Definition hnswlib.h:142
Definition bruteforce.h:7
MTYPE(*)(const void *, const void *, const void *) DISTFUNC
Definition hnswlib.h:138
static float InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr)
Definition space_ip.h:7
static float InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr)
Definition space_ip.h:18