rand[om]

rand[om]

med ∩ ml

Extending SQLite with Rust

SQLite has a powerful extension mechanism: loadable extensions.

Being an in-process database, SQLite has other extensions mechanisms like application-defined functions (UDF for short). But UDFs have some shortcomings:

  • They’re local to an SQLite connection, not shared for every process connected to the DB
  • They have to be defined in your program. That means that you need to have the function available in the same scope as your application. This is where loadable extensions come in. Loadable extensions can be written in any programming language that can be compiled to a shared library/DLL. Then you can just share the compiled object and load them from any application or programming language. In this post, we’ll see how we can use Rust to write an SQLite loadable extension.

Intro

This post is a simplified version of some techniques I learned from phiresky/sqlite-zstd. That is an SQLite extension that enables zstd compression on SQLite, I highly recommend checking it out if you want to look at more advanced examples than this post.

Cargo.toml

We’ll need a few external dependencies. The most important details to note here are:

  • Using this rusqlite fork from Genomicsplc. For more context why this is needed, see #910.
    • UPDATE: I’ve created a new fork which merges the changes in that PR with the current latest rusqlite release (v0.27.0) here.
  • The crate-type is marked as: ["cdylib"]. This will tell the rust compiler that we are building a shared library.
[package]
name = "sqlite-regex"
version = "0.1.0"
edition = "2021"

[features]

default = []
build_extension = ["rusqlite/bundled", "rusqlite/functions"]

[lib]
crate-type = ["cdylib"]

[dependencies]
regex = "1.5.4"
log = "0.4.14"
env_logger = "0.9.0"
anyhow = "1.0.54"


[dependencies.rusqlite]
package = "rusqlite"
git = "https://github.com/litements/rusqlite/"
branch = "loadable-extensions-release-2"
default-features = false
features = ["loadable_extension", "vtab", "functions", "bundled"]

We need some basic imports and functions. The function ah will convert an anyhow error to a rusqlite error. init_logging will set up env_logger.

#![allow(clippy::missing_safety_doc)]

use crate::ffi::loadable_extension_init;
use anyhow::Context as ACtxt;
use log::LevelFilter;
use regex::bytes::Regex;
use rusqlite::ffi;
use rusqlite::functions::{Context, FunctionFlags};
use rusqlite::types::{ToSqlOutput, Value, ValueRef};
use rusqlite::Connection;
use std::os::raw::c_int;

fn ah(e: anyhow::Error) -> rusqlite::Error {
    rusqlite::Error::UserFunctionError(format!("{:?}", e).into())
}

fn init_logging(default_level: LevelFilter) {
    let lib_log_env = "SQLITE_REGEX_LOG";
    if std::env::var(lib_log_env).is_err() {
        std::env::set_var(lib_log_env, format!("{}", default_level))
    }

    let logger_env = env_logger::Env::new().filter(lib_log_env);

    env_logger::try_init_from_env(logger_env).ok();
}

Extension entry point

When you try to load an extension in SQLite, it needs an entry point function. According to the docs of the sqlite3_load_extension C function, if an entry point is not provided, it will attempt to guess one base on the filename. If we called our compiled extension regex_ext it will try to load an entry point called sqlite3_regex_ext_init because the extension has the filename regex_ext.{so,dll,dylib} If you need more flexibility, there’s also an SQL function to load extensions, and it lets you specify the entry point. With that, you can do something like:

SELECT load_extension('path/to/loadable/extension/regex_ext.[extension]', 'sqlite3_regex_init')

Now it will try to find a function called sqlite3_regex_init as an entry point, instead of sqlite3_regex_ext_init.

Let’s write our entry point function:

#[no_mangle]
pub unsafe extern "C" fn sqlite3_regex_init(
    db: *mut ffi::sqlite3,
    _pz_err_msg: &mut &mut std::os::raw::c_char,
    p_api: *mut ffi::sqlite3_api_routines,
) -> c_int {

    loadable_extension_init(p_api);

    match init(db) {
        Ok(()) => {
            log::info!("[regex-extension] init ok");
            ffi::SQLITE_OK
        }

        Err(e) => {
            log::error!("[regex-extension] init error: {:?}", e);
            ffi::SQLITE_ERROR
        }
    }
}

Some details of this function:

  • It needs #[no_mangle]: The Rust compiler mangles symbol names differently than native code linkers expect. As such, any function that Rust exports to be used outside of Rust needs to be told not to be mangled by the compiler [source].
  • The function is unsafe
  • The function returns ffi:SQLITE_OK. In a few paragraphs, we will see how we can change this return code to make the extension persistent across connections.
  • The extension is loaded in the init() function, which we’ll go through right now.

The init function

The main thing to consider here is that the function will receive a *mut ffi:sqlite3, which is a raw sqlite3 handle. We can use from rusqlite using from_handle. Notice that this function is unsafe too.

fn init(db_handle: *mut ffi::sqlite3) -> anyhow::Result<()> {
    let db = unsafe { rusqlite::Connection::from_handle(db_handle)? };
    load(&db)?;
    Ok(())
}

db is now a “normal” rusqlite::Connection object.

The load function will initialize the logger we defined before and load the rust functions to SQLite:

fn load(c: &Connection) -> anyhow::Result<()> {
    load_with_loglevel(c, LevelFilter::Info)
}

fn load_with_loglevel(c: &Connection, default_log_level: LevelFilter) -> anyhow::Result<()> {
    init_logging(default_log_level);
    add_functions(c)
}

Loading rust functions to SQLite

Thanks to rusqlite and the flexibility of SQLite, this is straightforward:

fn add_functions(c: &Connection) -> anyhow::Result<()> {
    let deterministic = FunctionFlags::SQLITE_DETERMINISTIC | FunctionFlags::SQLITE_UTF8;

    c.create_scalar_function("regex_extract", 2, deterministic, |ctx: &Context| {
        regex_extract(ctx).map_err(ah)
    })?;

    c.create_scalar_function("regex_extract", 3, deterministic, |ctx: &Context| {
        regex_extract(ctx).map_err(ah)
    })?;

    Ok(())
}

We are using the rusqlite method create_scalar_function. If you read the SQLite docs, you’ll see that sqlite3_create_function() receives 5 parameters, the first one, db, is already implicit in our rust code since create_scalar_function is a method on the Connection object, so db is already available on &self. That means that the rust implementation uses 4 parameters.

The first parameter is the name we want to use to register the function in SQLite, if we pass the value "regex_extract", we will be able to use this function as regex_extract() in our SQL queries.

The second parameter is the number of arguments that the function accepts.

Registering multiple versions of the same function

From the SQLite docs: It is common for an application to invoke sqlite3_create_function() multiple times for the same SQL function. For example, if an SQL function can take either 2 or 3 arguments, then sqlite3_create_function() would be invoked once for the 2-argument version and a second time for the 3-argument version. The underlying implementation (the callbacks) can be different for both variants.

In rust, this translates to multiple calls to Connection::create_scalar_function. Then we need to make sure our rust function knows how to handle the varying number of arguments.

The third parameter includes our function flags. In our case, we defined them as a separate variable:

let deterministic = FunctionFlags::SQLITE_DETERMINISTIC | FunctionFlags::SQLITE_UTF8;

You can add or remove flags here, the FunctionFlags::SQLITE_INNOCUOUS flag may be useful too.

The fourth parameter is the function we will register. The function will be passed the Context object. Here we are using a closure which then calls our main rust function. We also call .map_err(ah) to convert anyhow errors to rusqlite errors. Wrapping the function in a new closure may be useful if we want to use the function in other parts of our rust program, keeping the anyhow errors, and mapping them to rusqlite errors only in this context. We could also use the closure to split the rusqlite::functions::Context object (more on this in the next section).

Our rust function

Finally, this is the rust function we will be using. The function receives a raw rusqlite::functions::Context as input and splits it inside the function’s body.

fn regex_extract<'a>(ctx: &Context) -> anyhow::Result<ToSqlOutput<'a>> {
    let arg_pat = 0;
    let arg_input_data = 1;
    let arg_cap_group = 2;

    let empty_return = Ok(ToSqlOutput::Owned(Value::Null));

    let pattern = match ctx.get_raw(arg_pat) {
        ValueRef::Text(t) => t,
        e => anyhow::bail!("regex pattern must be text, got {}", e.data_type()),
    };

    let re = Regex::new(std::str::from_utf8(pattern)?)?;

    let input_value = match ctx.get_raw(arg_input_data) {
        ValueRef::Text(t) => t,
        ValueRef::Null => return empty_return,
        e => anyhow::bail!("regex expects text as input, got {}", e.data_type()),
    };

    let cap_group: usize = if ctx.len() <= arg_cap_group {
        // no capture group, use default
        0
    } else {
        ctx.get(arg_cap_group).context("capture group")?
    };

    // let mut caploc = re.capture_locations();
    // re.captures_read(&mut caploc, input_value);
    if let Some(cap) = re.captures(input_value) {
        match cap.get(cap_group) {
            None => empty_return,
            // String::from_utf8_lossy
            Some(t) => {
                let value = String::from_utf8_lossy(t.as_bytes());
                return Ok(ToSqlOutput::Owned(Value::Text(value.to_string())));
            }
        }
    } else {
        empty_return
    }
}

Let’s break down the function. First, we defined the meaning of each rusqlite::functions::Context index.

let arg_pat = 0;
let arg_input_data = 1;
let arg_cap_group = 2;

Define an empty return (not pattern match found):

let empty_return = Ok(ToSqlOutput::Owned(Value::Null));

Check that the pattern passed to the function is of type TEXT and create a regex from it:

let pattern = match ctx.get_raw(arg_pat) {
    ValueRef::Text(t) => t,
    e => anyhow::bail!("regex pattern must be text, got {}", e.data_type()),
};

let re = Regex::new(std::str::from_utf8(pattern)?)?;

Handle empty inputs and return Null early

let input_value = match ctx.get_raw(arg_input_data) {
    ValueRef::Text(t) => t,
    ValueRef::Null => return empty_return,
    e => anyhow::bail!("regex expects text as input, got {}", e.data_type()),
};

If you remember from above, we can define the same SQLite function with a different number of parameters, we need to handle both cases inside our function:

let cap_group: usize = if ctx.len() <= arg_cap_group {
    // no capture group, use default
    0
} else {
    ctx.get(arg_cap_group).context("capture group")?
};

Finally, we run the regex and return the relevant capture group number:

if let Some(cap) = re.captures(input_value) {
    match cap.get(cap_group) {
        None => empty_return,
        // String::from_utf8_lossy
        Some(t) => {
            let value = String::from_utf8_lossy(t.as_bytes());
            return Ok(ToSqlOutput::Owned(Value::Text(value.to_string())));
        }
    }
} else {
    empty_return
}

Testing our new extension

First, we need to build the project.

cargo build --release

Now we can load the extension in the SQLite REPL with (the dylib extension may be different depending on your OS, I’m using macOS):

SELECT load_extension('target/release/libsqlite_regex.dylib', 'sqlite3_regex_init');

Our logger will output a message like:

[2022-05-14T23:22:09Z INFO  sqlite_regex] [regex-extension] init ok

We will write some tests, I will keep them in a Python script.

#!/usr/bin/env python3

import sqlite3

conn = sqlite3.connect("test.db", isolation_level=None)

print(f"Loading SQLite extension in connection: {conn}")
conn.enable_load_extension(True)
conn.execute(
    "SELECT load_extension('target/release/libsqlite_regex.dylib', 'sqlite3_regex_init');"
)

conn.enable_load_extension(False)

print("Running tests...")

print("Testing pattern 'x(ab)' WITHOUT capture group")
row = conn.execute("SELECT regex_extract('x(ab)', 'xxabaa')").fetchone()
assert row[0] == "xab", row[0]

print("Testing pattern 'x(ab)' WITH capture group = 1")
row = conn.execute("SELECT regex_extract('x(ab)', 'xxabaa', 1)").fetchone()
assert row[0] == "ab", row[0]

print("Testing pattern 'x(ab)' WITH capture group = 0")
row = conn.execute("SELECT regex_extract('x(ab)', 'xxabaa', 0)").fetchone()
assert row[0] == "xab", row[0]

print("Testing pattern 'g(oog)+le' WITHOUT capture group")
row = conn.execute("SELECT regex_extract('g(oog)+le', 'googoogoogle')").fetchone()
assert row[0] == "googoogoogle", row[0]

print("Testing pattern 'g(oog)+le' WITH capture group = 1")
row = conn.execute("SELECT regex_extract('g(oog)+le', 'googoogoogle', 1)").fetchone()
assert row[0] == "oog", row[0]

print("Testing pattern '[Cc]at' WITHOUT capture group")
row = conn.execute("SELECT regex_extract('[Cc]at', 'cat')").fetchone()
assert row[0] == "cat", row[0]

print("Testing pattern '[Cc]at' WITHOUT capture group, expecting empty return")
row = conn.execute("SELECT regex_extract('[Cc]at', 'hello')").fetchone()
assert row[0] is None, row[0]

It seems to work! We just need to fix one more problem. If we open a new DB connection, the functions will not be available. We can test it by adding the following at the end of our testing script, we will get an error:

conn2 = sqlite3.connect("test.db", isolation_level=None)
print(f"Testing connection 2: {conn2}")
row = conn2.execute("SELECT regex_extract('x(ab)', 'xxabaa')").fetchone()
assert row[0] == "xab", row[0]
# sqlite3.OperationalError: no such function: regex_extract

Loading the extensions persistently

The main “problem” with loadable extensions is that they’re only available for the current SQLite connection. Luckily for us, we can make extension persistent. In our sqlite3_regex_init function, we need to return ffi::SQLITE_OK_LOAD_PERMANENTLY instead of ffi::SQLITE_OK. We also have to call ffi:sqlite3_auto_extension with a function pointer that will run the extension initialization. By doing this, the entry point function will be executed on each new connection in this process, even if the original connection loading the extension is closed now. Our initialization function is the same as before, but now it has the name sqlite3_regex_init_internal

#[no_mangle]
pub unsafe extern "C" fn sqlite3_regex_init_internal(
    db: *mut ffi::sqlite3,
    _pz_err_msg: &mut &mut std::os::raw::c_char,
    p_api: *mut ffi::sqlite3_api_routines,
) -> c_int {
    loadable_extension_init(p_api);

    match init(db) {
        Ok(()) => {
            log::info!("[regex-extension] init ok");
            ffi::SQLITE_OK_LOAD_PERMANENTLY // <== Changed here!
        }

        Err(e) => {
            log::error!("[regex-extension] init error: {:?}", e);
            ffi::SQLITE_ERROR
        }
    }
}

And we define a new entry point function that will call the function above

#[no_mangle]
pub unsafe extern "C" fn sqlite3_regex_init(
    db: *mut ffi::sqlite3,
    _pz_err_msg: &mut &mut std::os::raw::c_char,
    p_api: *mut ffi::sqlite3_api_routines,
) -> c_int {
    loadable_extension_init(p_api);

    // Create pointer to initialization function
    let ptr = sqlite3_regex_init_internal
        as unsafe extern "C" fn(
            *mut ffi::sqlite3,
            &mut &mut std::os::raw::c_char,
            *mut ffi::sqlite3_api_routines,
        ) -> c_int;

		// Pass pointer to sqlite3_auto_extension()
    sqlite3_auto_extension(Some(std::mem::transmute(ptr)));

    match init(db) {
        Ok(()) => {
            log::info!("[regex-extension] init ok");
            ffi::SQLITE_OK_LOAD_PERMANENTLY
        }

        Err(e) => {
            log::error!("[regex-extension] init error: {:?}", e);
            ffi::SQLITE_ERROR
        }
    }
}

The important part is sqlite3_auto_extension(Some(std::mem::transmute(ptr)));, which will make SQLite run the function pointer from ptr on each new connection. If we compile and re-run our test file, we will see this output:

Loading SQLite extension in connection: <sqlite3.Connection object at 0x1034b8300>
[2022-05-15T00:48:33Z INFO  sqlite_regex] [regex-extension] init ok
Running tests...
Testing pattern 'x(ab)' WITHOUT capture group
Testing pattern 'x(ab)' WITH capture group = 1
Testing pattern 'x(ab)' WITH capture group = 0
Testing pattern 'g(oog)+le' WITHOUT capture group
Testing pattern 'g(oog)+le' WITH capture group = 1
Testing pattern '[Cc]at' WITHOUT capture group
Testing pattern '[Cc]at' WITHOUT capture group, expecting empty return
[2022-05-15T00:48:33Z INFO  sqlite_regex] [regex-extension] init ok
Testing connection 2: <sqlite3.Connection object at 0x103563120>
All tests passed

We can see a second logging message [regex-extension] init ok without having called the initialization function a second time. This is because SQLite has automatically called it for us on the new connection.

Note: I’m certain there must be a way to avoid having to run init both in sqlite3_regex_init and sqlite3_regex_init_internal, but my knowledge of C functions and pointers is very limited, and I haven’t managed to make it work better than this.

Final note

Persisting extensions may be useful if we were writing a VFS. But for extensions like the one we created, and considering the nature of SQLite (an in-process DB), I think it’s safer and more convenient to just load the extensions on each new connection and “forgetting” about sqlite3_auto_extension() (just ignore all the changes we made in the “Loading the extensions persistently” section).

You can find the code from this blog post here.