dotnet query provider

or how to create your own entity framework like library

many thanks to this tutorial notes - https://github.com/muhamad/iqueryable/blob/main/tut/02-Where_and_reusable_Expression_tree_visitor.md

overall idea is that we will have piece of code that will recognize incomming expression tree and transform it to actual query that needs to be send to storage

so whenever you have IQueriable<T> and write something like users.Where(u => u.Name.StartsWith("A")) underneath nothing happens till you ask to materialize results

when you ask for that u => u.Name.StartsWith("A") expression will be converted to somethin like select * from users where name like 'A%'

for this to work we need to implement IQueryProvider method Execute<TResult>(Expression expression)

which will be wired up with our IQueryable<T>

lets try to mimique entity framework

generaly ef core context is something like:

public class ApplicationContext: DbContext {
  // ...
  public DbSet<Customers> Customers { get; set; }
}

inside, that DbSet implements IQueryable, aka:

public class DbSet<T> : IQueryable<T>
{
    public DbSet(IQueryProvider provider)
    {
        Provider = provider;
        Expression = Expression.Constant(this);
    }

    public DbSet(IQueryProvider provider, Expression expression)
    {
        Provider = provider;
        Expression = expression;
    }

    public IEnumerator<T> GetEnumerator() => Provider.Execute<IEnumerable<T>>(Expression).GetEnumerator();

    IEnumerator IEnumerable.GetEnumerator() { throw new NotImplementedException(); }

    public Type ElementType { get; } = typeof(T);
    public Expression Expression { get; }
    public IQueryProvider Provider { get; }
}

seems like in majority of the cases this one will be the same for any provider

then we gonna need our provider and it will be something like this, also in majority of the cases:

public class DbContext : IQueryProvider
{
    // ...

    public TResult Execute<TResult>(Expression expression)
    {
        var sql = new QueryTranslator().Translate(expression); // key piece here, look below
        Console.WriteLine(sql);
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression) => new DbSet<TElement>(this, expression);

    public IQueryable CreateQuery(Expression expression) { throw new NotImplementedException(); }
    public object? Execute(Expression expression) { throw new NotImplementedException(); }
}

and the final and most complicated and importan piece is the translator, here is an example from notes i have mentioned at the very beggining:

/// <summary>
/// The main meat of provider is here, taken from:
/// https://github.com/muhamad/iqueryable/blob/main/tut/02-Where_and_reusable_Expression_tree_visitor.md
/// </summary>
public class QueryTranslator : ExpressionVisitor
{
    private StringBuilder _sb = null!;

    public string Translate(Expression expression)
    {
        _sb = new StringBuilder();
        Visit(expression);
        return _sb.ToString();
    }

    private static Expression StripQuotes(Expression e)
    {
        while (e.NodeType == ExpressionType.Quote)
        {
            e = ((UnaryExpression)e).Operand;
        }
        return e;
    }

    protected override Expression VisitMethodCall(MethodCallExpression m)
    {
        if (m.Method.DeclaringType == typeof(Queryable) && m.Method.Name == "Where")
        {
            _sb.Append("SELECT * FROM (");
            Visit(m.Arguments[0]);
            _sb.Append(") AS T WHERE ");
            var lambda = (LambdaExpression)StripQuotes(m.Arguments[1]);
            Visit(lambda.Body);
            return m;
        }
        throw new NotSupportedException($"The method '{m.Method.Name}' is not supported");
    }

    protected override Expression VisitBinary(BinaryExpression b)
    {
        _sb.Append("(");
        Visit(b.Left);
        switch (b.NodeType)
        {
            case ExpressionType.And:
                _sb.Append(" AND ");
                break;
            case ExpressionType.Or:
                _sb.Append(" OR ");
                break;
            case ExpressionType.Equal:
                _sb.Append(" = ");
                break;

            case ExpressionType.NotEqual:
                _sb.Append(" <> ");
                break;

            case ExpressionType.LessThan:
                _sb.Append(" < ");
                break;

            case ExpressionType.LessThanOrEqual:
                _sb.Append(" <= ");
                break;

            case ExpressionType.GreaterThan:
                _sb.Append(" > ");
                break;

            case ExpressionType.GreaterThanOrEqual:
                _sb.Append(" >= ");
                break;

            default:
                throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", b.NodeType));
        }

        Visit(b.Right);
        _sb.Append(")");
        return b;
    }

    protected override Expression VisitConstant(ConstantExpression c)
    {
        if (c.Value is IQueryable q)
        {
            // assume constant nodes within IQueryable<T> are table references, e.g. T is a table name
            _sb.Append($"SELECT * FROM {q.ElementType.Name}");
        }
        else if (c.Value == null)
        {
            _sb.Append("NULL");
        }
        else
        {
            switch (Type.GetTypeCode(c.Value.GetType()))
            {
                case TypeCode.Boolean:
                    _sb.Append(((bool) c.Value) ? 1 : 0);
                    break;
                case TypeCode.String:
                    _sb.Append($"'{c.Value}'");
                    break;

                case TypeCode.Object:
                    throw new NotSupportedException($"The constant for '{c.Value}' is not supported");

                default:
                    _sb.Append(c.Value);
                    break;
            }
        }

        return c;
    }

    protected override Expression VisitMember(MemberExpression node)
    {
        if (node.Expression != null && node.Expression.NodeType == ExpressionType.Parameter)
        {
            _sb.Append(node.Member.Name);
            return node;
        }
        throw new NotSupportedException($"The member '{node.Member.Name}' is not supported");
    }
}

so now it is time to wire everything up:

using System.Collections;
using System.Data;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using Dapper;
using Microsoft.Data.Sqlite;

/*
dotnet add package Microsoft.Data.Sqlite
dotnet add package Dapper
*/

using var connection = new SqliteConnection("Data Source=:memory:");
connection.Open(); // if connection is not opened dapper will manually open/close it each time, and because database is in memory it will be wiped out each time
// initialize sample database
connection.Execute("CREATE TABLE Customers ( CustomerId INTEGER PRIMARY KEY, ContactName NVARCHAR, City NVARCHAR )");
connection.Execute("INSERT INTO Customers (CustomerId, ContactName, City) VALUES (1, 'Michael', 'Kiev'), (2, 'Kira', 'Kiev')");

// the main demo - look ma entity framework without entity framework
var db = new DbContext(connection);
var customers = new DbSet<Customers>(db);
foreach (var customer in customers.Where(c => c.City == "Kiev"))
{
    Console.WriteLine(customer.ContactName);
}

public class Customers
{
    public int CustomerId { get; init; }
    public string ContactName { get; init; } = "";
    public string City { get; init; } = "";
}

public class DbContext : IQueryProvider
{
    private readonly IDbConnection _connection;

    // Done once at start - find `IDBConnection.Query<T>("SELECT * FROM T", null, ...)` method
    private static readonly MethodInfo QueryMethod = typeof(SqlMapper)
        .GetMethods(BindingFlags.Public | BindingFlags.Static)
        .FirstOrDefault(m => m.Name == "Query" && m.IsGenericMethod && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 7) ?? throw new ArgumentNullException(nameof(QueryMethod));

    public DbContext(IDbConnection connection)
    {
        _connection = connection;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        var sql = new QueryTranslator().Translate(expression);
        Console.WriteLine(sql);
        // Dynamically call `_connection.Query<Customers>("SELECT * FROM Customers", null, null, true, null, CommandType.Text)`
        return (TResult) QueryMethod
            .MakeGenericMethod(typeof(TResult).GenericTypeArguments[0]) // `TResult` is `IQueriable<T>`, so `typeof(TResult).GenericTypeArguments[0].Name` will be name of T
            .Invoke(null, new object?[]{ _connection, sql, null, null, true, null, CommandType.Text })!;
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression) => new DbSet<TElement>(this, expression);

    public IQueryable CreateQuery(Expression expression) { throw new NotImplementedException(); }
    public object? Execute(Expression expression) { throw new NotImplementedException(); }
}

public class DbSet<T> : IQueryable<T>
{
    public DbSet(IQueryProvider provider)
    {
        Provider = provider;
        Expression = Expression.Constant(this);
    }

    public DbSet(IQueryProvider provider, Expression expression)
    {
        Provider = provider;
        Expression = expression;
    }

    public IEnumerator<T> GetEnumerator() => Provider.Execute<IEnumerable<T>>(Expression).GetEnumerator();

    IEnumerator IEnumerable.GetEnumerator() { throw new NotImplementedException(); }

    public Type ElementType { get; } = typeof(T);
    public Expression Expression { get; }
    public IQueryProvider Provider { get; }
}

/// <summary>
/// The main meat of provider is here, taken from:
/// https://github.com/muhamad/iqueryable/blob/main/tut/02-Where_and_reusable_Expression_tree_visitor.md
/// </summary>
public class QueryTranslator : ExpressionVisitor
{
    private StringBuilder _sb = null!;

    public string Translate(Expression expression)
    {
        _sb = new StringBuilder();
        Visit(expression);
        return _sb.ToString();
    }

    private static Expression StripQuotes(Expression e)
    {
        while (e.NodeType == ExpressionType.Quote)
        {
            e = ((UnaryExpression)e).Operand;
        }
        return e;
    }

    protected override Expression VisitMethodCall(MethodCallExpression m)
    {
        if (m.Method.DeclaringType == typeof(Queryable) && m.Method.Name == "Where")
        {
            _sb.Append("SELECT * FROM (");
            Visit(m.Arguments[0]);
            _sb.Append(") AS T WHERE ");
            var lambda = (LambdaExpression)StripQuotes(m.Arguments[1]);
            Visit(lambda.Body);
            return m;
        }
        throw new NotSupportedException($"The method '{m.Method.Name}' is not supported");
    }

    protected override Expression VisitBinary(BinaryExpression b)
    {
        _sb.Append("(");
        Visit(b.Left);
        switch (b.NodeType)
        {
            case ExpressionType.And:
                _sb.Append(" AND ");
                break;
            case ExpressionType.Or:
                _sb.Append(" OR ");
                break;
            case ExpressionType.Equal:
                _sb.Append(" = ");
                break;

            case ExpressionType.NotEqual:
                _sb.Append(" <> ");
                break;

            case ExpressionType.LessThan:
                _sb.Append(" < ");
                break;

            case ExpressionType.LessThanOrEqual:
                _sb.Append(" <= ");
                break;

            case ExpressionType.GreaterThan:
                _sb.Append(" > ");
                break;

            case ExpressionType.GreaterThanOrEqual:
                _sb.Append(" >= ");
                break;

            default:
                throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", b.NodeType));
        }

        Visit(b.Right);
        _sb.Append(")");
        return b;
    }

    protected override Expression VisitConstant(ConstantExpression c)
    {
        if (c.Value is IQueryable q)
        {
            // assume constant nodes within IQueryable<T> are table references, e.g. T is a table name
            _sb.Append($"SELECT * FROM {q.ElementType.Name}");
        }
        else if (c.Value == null)
        {
            _sb.Append("NULL");
        }
        else
        {
            switch (Type.GetTypeCode(c.Value.GetType()))
            {
                case TypeCode.Boolean:
                    _sb.Append(((bool) c.Value) ? 1 : 0);
                    break;
                case TypeCode.String:
                    _sb.Append($"'{c.Value}'");
                    break;

                case TypeCode.Object:
                    throw new NotSupportedException($"The constant for '{c.Value}' is not supported");

                default:
                    _sb.Append(c.Value);
                    break;
            }
        }

        return c;
    }

    protected override Expression VisitMember(MemberExpression node)
    {
        if (node.Expression != null && node.Expression.NodeType == ExpressionType.Parameter)
        {
            _sb.Append(node.Member.Name);
            return node;
        }
        throw new NotSupportedException($"The member '{node.Member.Name}' is not supported");
    }
}

in this example we are creating sample database and query it with our custom query provider

idea behind query provider is so awesome, imagine, if you implement such provider for elasticsearch, firestore, mongodb, etc you may just replace implementation without touching the code at all, non of the ORM in existence can do this

but unfortunately that piece with translation is really hard and error proune, thats why we do not see much of them