Krotos Modules 3
Loading...
Searching...
No Matches
space_l2.h
Go to the documentation of this file.
1#pragma once
2#include "hnswlib.h"
3
4namespace hnswlib {
5
6 static float
7 L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
8 float *pVect1 = (float *) pVect1v;
9 float *pVect2 = (float *) pVect2v;
10 size_t qty = *((size_t *) qty_ptr);
11
12 float res = 0;
13 for (size_t i = 0; i < qty; i++) {
14 float t = *pVect1 - *pVect2;
15 pVect1++;
16 pVect2++;
17 res += t * t;
18 }
19 return (res);
20 }
21
22#if defined(USE_AVX512)
23
24 // Favor using AVX512 if available.
25 static float
26 L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
27 float *pVect1 = (float *) pVect1v;
28 float *pVect2 = (float *) pVect2v;
29 size_t qty = *((size_t *) qty_ptr);
30 float PORTABLE_ALIGN64 TmpRes[16];
31 size_t qty16 = qty >> 4;
32
33 const float *pEnd1 = pVect1 + (qty16 << 4);
34
35 __m512 diff, v1, v2;
36 __m512 sum = _mm512_set1_ps(0);
37
38 while (pVect1 < pEnd1) {
39 v1 = _mm512_loadu_ps(pVect1);
40 pVect1 += 16;
41 v2 = _mm512_loadu_ps(pVect2);
42 pVect2 += 16;
43 diff = _mm512_sub_ps(v1, v2);
44 // sum = _mm512_fmadd_ps(diff, diff, sum);
45 sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff));
46 }
47
48 _mm512_store_ps(TmpRes, sum);
49 float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] +
50 TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] +
51 TmpRes[13] + TmpRes[14] + TmpRes[15];
52
53 return (res);
54}
55#endif
56
57#if defined(USE_AVX)
58
59 // Favor using AVX if available.
60 static float
61 L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
62 float *pVect1 = (float *) pVect1v;
63 float *pVect2 = (float *) pVect2v;
64 size_t qty = *((size_t *) qty_ptr);
65 float PORTABLE_ALIGN32 TmpRes[8];
66 size_t qty16 = qty >> 4;
67
68 const float *pEnd1 = pVect1 + (qty16 << 4);
69
70 __m256 diff, v1, v2;
71 __m256 sum = _mm256_set1_ps(0);
72
73 while (pVect1 < pEnd1) {
74 v1 = _mm256_loadu_ps(pVect1);
75 pVect1 += 8;
76 v2 = _mm256_loadu_ps(pVect2);
77 pVect2 += 8;
78 diff = _mm256_sub_ps(v1, v2);
79 sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
80
81 v1 = _mm256_loadu_ps(pVect1);
82 pVect1 += 8;
83 v2 = _mm256_loadu_ps(pVect2);
84 pVect2 += 8;
85 diff = _mm256_sub_ps(v1, v2);
86 sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
87 }
88
89 _mm256_store_ps(TmpRes, sum);
90 return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
91 }
92
93#endif
94
95#if defined(USE_SSE)
96
97 static float
98 L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
99 float *pVect1 = (float *) pVect1v;
100 float *pVect2 = (float *) pVect2v;
101 size_t qty = *((size_t *) qty_ptr);
102 float PORTABLE_ALIGN32 TmpRes[8];
103 size_t qty16 = qty >> 4;
104
105 const float *pEnd1 = pVect1 + (qty16 << 4);
106
107 __m128 diff, v1, v2;
108 __m128 sum = _mm_set1_ps(0);
109
110 while (pVect1 < pEnd1) {
111 //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
112 v1 = _mm_loadu_ps(pVect1);
113 pVect1 += 4;
114 v2 = _mm_loadu_ps(pVect2);
115 pVect2 += 4;
116 diff = _mm_sub_ps(v1, v2);
117 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
118
119 v1 = _mm_loadu_ps(pVect1);
120 pVect1 += 4;
121 v2 = _mm_loadu_ps(pVect2);
122 pVect2 += 4;
123 diff = _mm_sub_ps(v1, v2);
124 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
125
126 v1 = _mm_loadu_ps(pVect1);
127 pVect1 += 4;
128 v2 = _mm_loadu_ps(pVect2);
129 pVect2 += 4;
130 diff = _mm_sub_ps(v1, v2);
131 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
132
133 v1 = _mm_loadu_ps(pVect1);
134 pVect1 += 4;
135 v2 = _mm_loadu_ps(pVect2);
136 pVect2 += 4;
137 diff = _mm_sub_ps(v1, v2);
138 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
139 }
140
141 _mm_store_ps(TmpRes, sum);
142 return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
143 }
144#endif
145
146#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
147 DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE;
148
149 static float
150 L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
151 size_t qty = *((size_t *) qty_ptr);
152 size_t qty16 = qty >> 4 << 4;
153 float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
154 float *pVect1 = (float *) pVect1v + qty16;
155 float *pVect2 = (float *) pVect2v + qty16;
156
157 size_t qty_left = qty - qty16;
158 float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
159 return (res + res_tail);
160 }
161#endif
162
163
164#if defined(USE_SSE)
165 static float
166 L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
167 float PORTABLE_ALIGN32 TmpRes[8];
168 float *pVect1 = (float *) pVect1v;
169 float *pVect2 = (float *) pVect2v;
170 size_t qty = *((size_t *) qty_ptr);
171
172
173 size_t qty4 = qty >> 2;
174
175 const float *pEnd1 = pVect1 + (qty4 << 2);
176
177 __m128 diff, v1, v2;
178 __m128 sum = _mm_set1_ps(0);
179
180 while (pVect1 < pEnd1) {
181 v1 = _mm_loadu_ps(pVect1);
182 pVect1 += 4;
183 v2 = _mm_loadu_ps(pVect2);
184 pVect2 += 4;
185 diff = _mm_sub_ps(v1, v2);
186 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
187 }
188 _mm_store_ps(TmpRes, sum);
189 return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
190 }
191
192 static float
193 L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
194 size_t qty = *((size_t *) qty_ptr);
195 size_t qty4 = qty >> 2 << 2;
196
197 float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
198 size_t qty_left = qty - qty4;
199
200 float *pVect1 = (float *) pVect1v + qty4;
201 float *pVect2 = (float *) pVect2v + qty4;
202 float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
203
204 return (res + res_tail);
205 }
206#endif
207
208 class L2Space : public SpaceInterface<float> {
209
212 size_t dim_;
213 public:
214 L2Space(size_t dim) {
216 #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
217 #if defined(USE_AVX512)
218 if (AVX512Capable())
219 L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
220 else if (AVXCapable())
221 L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
222 #elif defined(USE_AVX)
223 if (AVXCapable())
224 L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
225 #endif
226
227 if (dim % 16 == 0)
228 fstdistfunc_ = L2SqrSIMD16Ext;
229 else if (dim % 4 == 0)
230 fstdistfunc_ = L2SqrSIMD4Ext;
231 else if (dim > 16)
232 fstdistfunc_ = L2SqrSIMD16ExtResiduals;
233 else if (dim > 4)
234 fstdistfunc_ = L2SqrSIMD4ExtResiduals;
235 #endif
236 dim_ = dim;
237 data_size_ = dim * sizeof(float);
238 }
239
240 size_t get_data_size() {
241 return data_size_;
242 }
243
247
249 return &dim_;
250 }
251
253 };
254
255 static int
256 L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
257
258 size_t qty = *((size_t *) qty_ptr);
259 int res = 0;
260 unsigned char *a = (unsigned char *) pVect1;
261 unsigned char *b = (unsigned char *) pVect2;
262
263 qty = qty >> 2;
264 for (size_t i = 0; i < qty; i++) {
265
266 res += ((*a) - (*b)) * ((*a) - (*b));
267 a++;
268 b++;
269 res += ((*a) - (*b)) * ((*a) - (*b));
270 a++;
271 b++;
272 res += ((*a) - (*b)) * ((*a) - (*b));
273 a++;
274 b++;
275 res += ((*a) - (*b)) * ((*a) - (*b));
276 a++;
277 b++;
278 }
279 return (res);
280 }
281
282 static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
283 size_t qty = *((size_t*)qty_ptr);
284 int res = 0;
285 unsigned char* a = (unsigned char*)pVect1;
286 unsigned char* b = (unsigned char*)pVect2;
287
288 for(size_t i = 0; i < qty; i++)
289 {
290 res += ((*a) - (*b)) * ((*a) - (*b));
291 a++;
292 b++;
293 }
294 return (res);
295 }
296
297 class L2SpaceI : public SpaceInterface<int> {
298
301 size_t dim_;
302 public:
303 L2SpaceI(size_t dim) {
304 if(dim % 4 == 0) {
306 }
307 else {
309 }
310 dim_ = dim;
311 data_size_ = dim * sizeof(unsigned char);
312 }
313
314 size_t get_data_size() {
315 return data_size_;
316 }
317
321
323 return &dim_;
324 }
325
327 };
328
329
330}
Definition space_l2.h:208
DISTFUNC< float > get_dist_func()
Definition space_l2.h:244
size_t dim_
Definition space_l2.h:212
size_t get_data_size()
Definition space_l2.h:240
L2Space(size_t dim)
Definition space_l2.h:214
DISTFUNC< float > fstdistfunc_
Definition space_l2.h:210
size_t data_size_
Definition space_l2.h:211
void * get_dist_func_param()
Definition space_l2.h:248
~L2Space()
Definition space_l2.h:252
Definition space_l2.h:297
DISTFUNC< int > get_dist_func()
Definition space_l2.h:318
~L2SpaceI()
Definition space_l2.h:326
DISTFUNC< int > fstdistfunc_
Definition space_l2.h:299
size_t dim_
Definition space_l2.h:301
L2SpaceI(size_t dim)
Definition space_l2.h:303
void * get_dist_func_param()
Definition space_l2.h:322
size_t data_size_
Definition space_l2.h:300
size_t get_data_size()
Definition space_l2.h:314
Definition hnswlib.h:142
Definition bruteforce.h:7
static float L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr)
Definition space_l2.h:7
static int L2SqrI(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr)
Definition space_l2.h:282
static int L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr)
Definition space_l2.h:256
MTYPE(*)(const void *, const void *, const void *) DISTFUNC
Definition hnswlib.h:138