Krotos Modules 3
Loading...
Searching...
No Matches
AudioDataset.cpp
Go to the documentation of this file.
1#include <chrono>
2
3namespace krotos
4{
5
6AudioDataset::AudioDataset() : juce::Thread("AudioDatasetThread")
7{
8 // sort here as using std::binary_search later
9 std::sort(CatIDSet.begin(), CatIDSet.end());
10
11 startThread(); // set a priority?
12}
13
14AudioDataset::~AudioDataset() { stopThread(4000); }
15
16std::vector<File> AudioDataset::findValidFiles() const
17{
18 std::vector<File> paths;
19
20 AudioFormatManager formatManager;
21 formatManager.registerBasicFormats();
22
23 const auto root = getFactoryAssetDirectory();
24 for (const auto& path : File(root).findChildFiles(2, true, "*.wav", juce::File::FollowSymlinks::no))
25 {
26 // extract UCS CatID from filename
27 const auto filename = path.getFileName().toStdString();
28 const auto key = extractCatID(filename);
29
30 // check if we have a valid CatID tag
31 const auto success = std::binary_search(CatIDSet.begin(), CatIDSet.end(), key);
32 if (success)
33 {
34 paths.push_back(path);
35 }
36 }
37 return paths;
38}
39
40AudioBuffer<float> AudioDataset::resampleAudioBuffer(const AudioBuffer<float>& buffer, double sampleRate,
41 double targetRate)
42{
43 if ((int)sampleRate == (int)targetRate)
44 return buffer;
45
46 const double ratio = sampleRate / targetRate;
47 AudioBuffer<float> resampledBuffer;
48 resampledBuffer.setSize(1, (int)(buffer.getNumSamples() / ratio));
49 auto inputData = buffer.getReadPointer(0);
50 auto outputData = resampledBuffer.getWritePointer(0);
51 LagrangeInterpolator resampler;
52 resampler.process(ratio, inputData, outputData, resampledBuffer.getNumSamples());
53 return resampledBuffer;
54}
55
57{
58 if (runAnalysis())
59 {
60 const auto tick = std::chrono::steady_clock::now();
61
62 AudioFormatManager formatManager;
63 formatManager.registerBasicFormats();
64 AudioEmbedding model;
65
66 const auto paths = findValidFiles();
67 const auto file_count = paths.size();
68
69 for (const auto& path : paths)
70 {
71 // extract UCS CatID from filename
72 const auto filename = path.getFileName().toStdString();
73 const auto key = extractCatID(filename);
74
75 std::unique_ptr<juce::AudioFormatReader> reader(formatManager.createReaderFor(path));
76 if (reader.get() != nullptr)
77 {
78 // load up to 3 seconds of audio
79 const auto numSamples =
80 std::min(static_cast<int>(reader->lengthInSamples), static_cast<int>(3 * reader->sampleRate));
81 const auto duration = static_cast<float>(reader->lengthInSamples) / reader->sampleRate;
82 int64 readerStartSample = 0;
83
84 // To-Do: investigate suitable duration threshold
85 if (duration > 10.f)
86 {
87 // analyse the start of short Foley files,
88 // but shift to centre of long files to avoid fade in
89 readerStartSample = static_cast<int64>(reader->lengthInSamples / 2) - (numSamples / 2);
90 }
91
92 // opportunity to exit before audio analysis
93 if (threadShouldExit())
94 break;
95
96 // read audio segment into buffer
97 juce::AudioBuffer<float> buffer(1, numSamples);
98 reader->read(&buffer, 0, numSamples, readerStartSample, true, false);
99
100 // resample to target samplerate of 48 kHz
101 buffer = resampleAudioBuffer(buffer, reader->sampleRate, 48000.0);
102
103 // extract the audio embedding
104 const auto embedding = model.forward(buffer);
105
106 // store the path and embedding together
107 const auto data = std::make_pair(path.getFullPathName().toStdString(), embedding);
108 m_dataset[key].push_back(data);
109 }
110
111 // another chance to exit
112 if (threadShouldExit())
113 break;
114 }
115
116 const auto tock = std::chrono::steady_clock::now();
117 const auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(tock - tick).count();
118 DBG("analysis time: " << elapsed << " ms");
119 DBG("file count: " << file_count);
120
121 const auto success = writeCache();
122 if (success)
123 {
124 const auto outfile = getDatasetFile().getFullPathName().toStdString();
125 DBG("analysis cached to: " << outfile);
126 }
127 else
128 {
129 DBG("caching analysis failed");
130 }
131 }
132}
133
135{
136 const auto loaded_paths = getPathsSet();
137
138 const auto paths = findValidFiles();
139 std::set<std::string> available_paths;
140 for (const auto& path : paths)
141 available_paths.insert(path.getFullPathName().toStdString());
142
143 return (available_paths != loaded_paths);
144}
145
147{
148 if (cacheExists())
149 {
150 DBG("loading previous analysis from cache");
151 bool success = readCache();
152 if (success)
153 {
154 // check if factory assets changed
155 if (assetsChanged())
156 {
157 // factory assets changed since we last cached analysis
158 // rerun the analysis
159 m_dataset.clear();
160 success = deleteCache();
161 if (success)
162 {
163 DBG("rerunning analysis...");
164 return true;
165 }
166 else
167 {
168 // uh-oh
169 return false;
170 }
171 }
172 }
173 else
174 {
175 // cache not loaded
176 // rerun the analysis
177 m_dataset.clear();
178 success = deleteCache();
179 if (success)
180 {
181 DBG("rerunning analysis...");
182 return true;
183 }
184 else
185 {
186 // uh-oh
187 return false;
188 }
189 }
190 }
191 else
192 {
193 DBG("running analysis...");
194 return true;
195 }
196
197 return false;
198}
199
200std::set<std::string> AudioDataset::getPathsSet() const
201{
202 std::set<std::string> paths;
203 for (const auto& key : m_dataset)
204 {
205 for (const auto& element : key.second)
206 {
207 const auto path = element.first;
208 paths.insert(path);
209 }
210 }
211 return paths;
212}
213
214std::string AudioDataset::fixKeyErrors(std::string key) const
215{
216 // workarounds for some of the factory assets
217 if (key == "FOLEYFeet")
218 {
219 key = "FOLYFeet";
220 }
221 else if (key.find("CMPTKey") != std::string::npos)
222 {
223 key = "CMPTKey";
224 }
225 return key;
226}
227
229{
230 const auto file = getDatasetFile();
231
232 DynamicObject* dataset = new DynamicObject();
233 Array<var> rows;
234
235 for (const auto& category : m_dataset)
236 {
237 const auto key = category.first;
238 const auto results = category.second;
239
240 for (const auto& element : results)
241 {
242 DynamicObject* metadata = new DynamicObject();
243 metadata->setProperty("CatID", var(key));
244
245 const auto path = element.first;
246 const auto embedding = element.second;
247 metadata->setProperty("Path", var(path));
248
249 Array<var> features;
250 for (auto value : embedding)
251 features.add(value);
252 metadata->setProperty("Embedding", features);
253
254 rows.add(metadata);
255 metadata = nullptr;
256 }
257 }
258 dataset->setProperty("Version", var(m_version));
259 dataset->setProperty("Dataset", rows);
260
261 FileOutputStream stream(file);
262 if (stream.openedOk())
263 {
264 // overwrite an existing file
265 stream.setPosition(0);
266 stream.truncate();
267 JSON::writeToStream(stream, dataset);
268 return true;
269 }
270
271 return false;
272}
273
275{
276 const auto file = getDatasetFile();
277 assert(file.existsAsFile());
278
279 m_dataset.clear();
280
281 auto json = JSON::parse(file);
282 auto version = json.getProperty(Identifier("Version"), 0).toString().toStdString();
283 if (version != m_version)
284 {
285 // treat read cache as fail to trigger running updated analysis
286 return 0;
287 }
288
289 var result = json.getProperty(Identifier("Dataset"), 0);
290 for (int i = 0; i < result.size(); ++i)
291 {
292 // Get the data from the cache file
293 auto CatID = result[i].getProperty(Identifier("CatID"), 0).toString().toStdString();
294 auto Path = result[i].getProperty(Identifier("Path"), 0).toString().toStdString();
295 var array = result[i].getProperty(Identifier("Embedding"), 0);
296 std::vector<float> embedding;
297 for (int j = 0; j < array.size(); ++j)
298 embedding.push_back(array[j]);
299
300 auto metadata = std::make_pair(Path, embedding);
301 m_dataset[CatID].push_back(metadata);
302 }
303
304 return 1;
305}
306
307bool AudioDataset::cacheExists() { return getDatasetFile().existsAsFile(); }
308
309bool AudioDataset::deleteCache() { return getDatasetFile().deleteFile(); }
310
311std::string AudioDataset::extractCatID(std::string filename) const
312{
313 std::string delimiter = "_";
314 std::string key = filename.substr(0, filename.find(delimiter));
315 key = fixKeyErrors(key);
316 return key;
317}
318
319std::vector<std::pair<std::string, float>> AudioDataset::sample(std::string path, std::size_t k, bool descending) const
320{
321 assert(k > 0);
322
323 std::vector<std::pair<std::string, float>> query_results;
324
325 auto filename = juce::File(path).getFileName().toStdString();
326 const auto key = extractCatID(filename);
327 const auto success = std::binary_search(CatIDSet.begin(), CatIDSet.end(), key);
328 if (success)
329 {
330 if (m_dataset.count(key) > 0)
331 {
332 //-------------------------
333 //-------------------------
334
335 const auto results = m_dataset.at(key);
336 if (results.size() == 1)
337 {
338 // the query path is the only result
339 query_results.push_back(std::make_pair(results.at(0).first, 1.f));
340 return query_results;
341 }
342
343 //-------------------------
344 //-------------------------
345
346 const auto it = std::find_if(results.begin(), results.end(),
347 [path](const std::pair<std::string, std::vector<float>>& element) {
348 return element.first == path;
349 });
350 if (it == results.end())
351 {
352 // valid UCS CatID but path not found in dataset
353 // file dragged in from outside factory assets?
354 query_results.push_back(std::make_pair(path, 1.f));
355 return query_results;
356 }
357 const auto query_index = std::distance(results.begin(), it);
358 const auto query_features = results.at(query_index).second;
359
360 //-------------------------
361 //-------------------------
362
363 const int scale = descending ? 1 : -1;
364 std::vector<float> scores(results.size(), 0.f);
365 for (std::size_t i = 0; i < results.size(); ++i)
366 {
367 const auto features = results.at(i).second;
368 const auto score = std::inner_product(features.begin(), features.end(), query_features.begin(), 0.f);
369 scores.at(i) = scale * score;
370 }
371
372 //-------------------------
373 //-------------------------
374
375 // find top-k matches based on cosine similarity score
376 std::priority_queue<std::pair<float, std::size_t>, std::vector<std::pair<float, std::size_t>>,
377 std::greater<std::pair<float, std::size_t>>>
378 q;
379
380 for (std::size_t i = 0; i < scores.size(); ++i)
381 {
382 if (q.size() < k) // if ((q.size() < k) && (scores[i] >= threshold))
383 {
384 q.push(std::make_pair(scores.at(i), i));
385 }
386 else if (q.top().first < scores.at(i))
387 {
388 q.pop();
389 q.push(std::make_pair(scores.at(i), i));
390 }
391 }
392 assert(q.size() >= 1);
393
394 //-------------------------
395 //-------------------------
396
397 while (!q.empty())
398 {
399 // smallest element at top
400 const auto result = q.top();
401 const auto score = scale * result.first;
402 query_results.push_back(std::make_pair(results.at(result.second).first, score));
403 q.pop();
404 }
405
406 return query_results;
407 }
408 }
409
410 // invalid CatID or not in dataset
411 query_results.push_back(std::make_pair(path, 1.f));
412 return query_results;
413}
414
416{
417#ifdef JUCE_MAC
418 return File::getSpecialLocation(File::commonApplicationDataDirectory).getFullPathName() +
419 "/Application Support/Krotos/" + JucePlugin_Name + "/Factory Assets/";
420#else
421 return File("C:\\ProgramData\\Krotos\\" JucePlugin_Name "\\Factory Assets\\");
422#endif
423}
424
426{
427#ifdef JUCE_MAC
428 return File::getSpecialLocation(File::commonApplicationDataDirectory).getFullPathName() +
429 "/Application Support/Krotos/" + JucePlugin_Name + "/ttpResources/embeddings.json";
430#else
431 return File("C:\\ProgramData\\Krotos\\" JucePlugin_Name "\\ttpResources\\embeddings.json");
432#endif
433}
434
435} // namespace krotos
std::string extractCatID(std::string filename) const
Definition AudioDataset.cpp:311
bool writeCache()
Definition AudioDataset.cpp:228
File getFactoryAssetDirectory() const
Definition AudioDataset.cpp:415
std::map< std::string, std::vector< std::pair< std::string, std::vector< float > > > > m_dataset
Definition AudioDataset.h:96
std::vector< std::pair< std::string, float > > sample(std::string query, std::size_t k=3, bool descending=true) const
Definition AudioDataset.cpp:319
void run() override
Definition AudioDataset.cpp:56
std::string fixKeyErrors(std::string key) const
Definition AudioDataset.cpp:214
AudioBuffer< float > resampleAudioBuffer(const AudioBuffer< float > &buffer, double sampleRate, double targetRate)
Definition AudioDataset.cpp:40
std::array< std::string, 753 > CatIDSet
Definition AudioDataset.h:102
~AudioDataset()
Definition AudioDataset.cpp:14
bool cacheExists()
Definition AudioDataset.cpp:307
std::set< std::string > getPathsSet() const
Definition AudioDataset.cpp:200
AudioDataset()
Definition AudioDataset.cpp:6
File getDatasetFile() const
Definition AudioDataset.cpp:425
bool readCache()
Definition AudioDataset.cpp:274
std::vector< File > findValidFiles() const
Definition AudioDataset.cpp:16
std::string m_version
Definition AudioDataset.h:95
bool assetsChanged()
Definition AudioDataset.cpp:134
bool deleteCache()
Definition AudioDataset.cpp:309
bool runAnalysis()
Definition AudioDataset.cpp:146
Definition AudioEmbedding.h:6
std::vector< float > forward(const AudioSampleBuffer &buffer)
Definition AudioEmbedding.cpp:26
Definition AirAbsorptionFilter.cpp:2