From cd46d7bd203b69e6d163fd19e38600d9feae6e56 Mon Sep 17 00:00:00 2001 From: jacqueline Date: Tue, 21 Nov 2023 16:20:01 +1100 Subject: [PATCH] Make lua db iterators async --- src/database/database.cpp | 58 +++++++++++++++++++------------ src/database/include/database.hpp | 14 ++++++-- src/lua/lua_database.cpp | 27 +++++++++----- 3 files changed, 66 insertions(+), 33 deletions(-) diff --git a/src/database/database.cpp b/src/database/database.cpp index 0967eb95..88ae7bbe 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -858,7 +858,7 @@ auto IndexRecord::Expand(std::size_t page_size) const } Iterator::Iterator(std::weak_ptr db, const IndexInfo& idx) - : db_(db), prev_pos_(), current_pos_() { + : db_(db), pos_mutex_(), current_pos_(), prev_pos_() { std::string prefix = EncodeIndexPrefix( IndexKey::Header{.id = idx.id, .depth = 0, .components_hash = 0}); current_pos_ = Continuation{.prefix = {prefix.data(), prefix.size()}, @@ -869,36 +869,50 @@ Iterator::Iterator(std::weak_ptr db, const IndexInfo& idx) } Iterator::Iterator(std::weak_ptr db, const Continuation& c) - : db_(db), prev_pos_(), current_pos_(c) {} + : db_(db), pos_mutex_(), current_pos_(c), prev_pos_() {} -auto Iterator::Prev() -> std::optional { - if (!prev_pos_) { - return {}; - } +auto Iterator::Next(Callback cb) -> void { auto db = db_.lock(); if (!db) { - return {}; + InvokeNull(cb); + return; } - std::unique_ptr> res{ - db->GetPage(&*prev_pos_).get()}; - prev_pos_ = res->prev_page(); - current_pos_ = prev_pos_; - return *res->values()[0]; + db->worker_task_->Dispatch([=]() { + std::lock_guard lock{pos_mutex_}; + if (!current_pos_) { + InvokeNull(cb); + return; + } + std::unique_ptr> res{ + db->dbGetPage(*current_pos_)}; + prev_pos_ = current_pos_; + current_pos_ = res->next_page(); + std::invoke(cb, *res->values()[0]); + }); } -auto Iterator::Next() -> std::optional { - if (!current_pos_) { - return {}; - } +auto Iterator::Prev(Callback cb) -> void { auto db = db_.lock(); if (!db) { - return {}; + InvokeNull(cb); + return; } - std::unique_ptr> res{ - db->GetPage(&*current_pos_).get()}; - prev_pos_ = current_pos_; - current_pos_ = res->next_page(); - return *res->values()[0]; + db->worker_task_->Dispatch([=]() { + std::lock_guard lock{pos_mutex_}; + if (!prev_pos_) { + InvokeNull(cb); + return; + } + std::unique_ptr> res{ + db->dbGetPage(*current_pos_)}; + current_pos_ = prev_pos_; + prev_pos_ = res->prev_page(); + std::invoke(cb, *res->values()[0]); + }); +} + +auto Iterator::InvokeNull(Callback cb) -> void { + std::invoke(cb, std::optional{}); } } // namespace database diff --git a/src/database/include/database.hpp b/src/database/include/database.hpp index 972871db..63014bed 100644 --- a/src/database/include/database.hpp +++ b/src/database/include/database.hpp @@ -129,6 +129,8 @@ class Database { Database& operator=(const Database&) = delete; private: + friend class Iterator; + // Owned. Dumb pointers because destruction needs to be done in an explicit // order. leveldb::DB* db_; @@ -191,13 +193,19 @@ class Iterator { Iterator(std::weak_ptr, const IndexInfo&); Iterator(std::weak_ptr, const Continuation&); - auto Prev() -> std::optional; - auto Next() -> std::optional; + using Callback = std::function)>; + + auto Next(Callback) -> void; + auto Prev(Callback) -> void; private: + auto InvokeNull(Callback) -> void; + std::weak_ptr db_; - std::optional prev_pos_; + + std::mutex pos_mutex_; std::optional current_pos_; + std::optional prev_pos_; }; } // namespace database diff --git a/src/lua/lua_database.cpp b/src/lua/lua_database.cpp index 545dcd31..d8ae86f6 100644 --- a/src/lua/lua_database.cpp +++ b/src/lua/lua_database.cpp @@ -56,17 +56,28 @@ static const struct luaL_Reg kDatabaseFuncs[] = {{"indexes", indexes}, {NULL, NULL}}; static auto db_iterate(lua_State* state) -> int { + luaL_checktype(state, 1, LUA_TFUNCTION); + int callback_ref = luaL_ref(state, LUA_REGISTRYINDEX); + database::Iterator* it = *reinterpret_cast( lua_touserdata(state, lua_upvalueindex(1))); - auto res = it->Next(); - if (res) { - database::IndexRecord** record = reinterpret_cast( - lua_newuserdata(state, sizeof(uintptr_t))); - *record = new database::IndexRecord(*res); - luaL_setmetatable(state, kDbRecordMetatable); - return 1; - } + it->Next([=](std::optional res) { + events::Ui().RunOnTask([=]() { + lua_rawgeti(state, LUA_REGISTRYINDEX, callback_ref); + if (res) { + database::IndexRecord** record = + reinterpret_cast( + lua_newuserdata(state, sizeof(uintptr_t))); + *record = new database::IndexRecord(*res); + luaL_setmetatable(state, kDbRecordMetatable); + } else { + lua_pushnil(state); + } + lua_call(state, 1, 0); + luaL_unref(state, LUA_REGISTRYINDEX, callback_ref); + }); + }); return 0; }