DotNet HttpClient Cache

Everywhere in our code we may see something like (pseudo):

var cached = _cache.Get("foo");
if (cached) {
    return cached;
}
var response = _http.GetAsync("/");
_cache.Set("foo", response);
return response;

Wouldn't it be nice if HttpClient itself will be respecing cache control headers from responses and hide all that

To do so we may utilize DelegatingHandler which will act like an middleware but for outgoing requests

We need to build something like:

public class MyHandler: DelegatingHandler {
    protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
    {
        var cached = await _cache.GetAsync(request, cancellationToken);
        if (cached != null)
        {
            return cached;
        }

        var response = await base.SendAsync(request, cancellationToken);

        if (response.IsCacheable())
        {
            await _cache.SetAsync(response, cancellationToken);
        }

        return response;
    }
}

Notes:

  • IsCacheable is positive only for get requests which response has cache control with max age in this example, it is not always right and good solution but left so to keep it simple
  • Cache time to live will be set to same value as we have in cache control, e.g. respect cache control header, if server asks to cache for 5min or 5days - ok
  • Distributed cache used in demo just for example, probably in this case it will be even better to avoid serrialization and deserialization and store everything in memory but it depeneds if cold start is fine or not
  • Neither HttpResponseMessage nor HttpRequestMessage can not be serialized and deserialized, thats why we are going to create wrappers around them

HttpClient respect Cache-Control response header

Here is the code for handler:

using HttpCache.Extensions;

using Microsoft.Extensions.Caching.Distributed;

namespace HttpCache;

public class CacheControlDelegatingHandler: DelegatingHandler
{
    private readonly IDistributedCache _distributedCache;

    public CacheControlDelegatingHandler(IDistributedCache distributedCache)
    {
        _distributedCache = distributedCache;
    }

    public CacheControlDelegatingHandler(HttpMessageHandler innerHandler, IDistributedCache distributedCache) : base(innerHandler)
    {
        _distributedCache = distributedCache;
    }

    protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
    {
        var cached = await _distributedCache.GetAsync(request, cancellationToken);
        if (cached != null)
        {
            return cached;
        }

        var response = await base.SendAsync(request, cancellationToken);

        if (response.IsCacheable())
        {
            await _distributedCache.SetAsync(response, cancellationToken);
        }

        return response;
    }
}

and all related stuff which is not really important:

all related non important implementation stuff

Extensions

using HttpCache.Serializers;

using Microsoft.Extensions.Caching.Distributed;

namespace HttpCache.Extensions;

internal static class DistributedCacheExtensions
{
    // public static async Task SetAsync<T>(this IDistributedCache distributedCache, string key, T value, DistributedCacheEntryOptions options, CancellationToken token = default) where T: HttpResponseMessage
    // {
    //     var bytes = HttpResponseMessageSerializer.Serialize(value);
    //     await distributedCache.SetAsync(key, bytes, options, token);
    // }
    
    // public static async Task<HttpResponseMessage?> GetAsync<T>(this IDistributedCache distributedCache, string key, CancellationToken token = default) where T: HttpResponseMessage  
    // {  
    //     var bytes = await distributedCache.GetAsync(key, token);
    //     return bytes == null ? null : HttpResponseMessageSerializer.Deserialize(bytes);
    // }
    
    public static async Task<HttpResponseMessage?> GetAsync(this IDistributedCache distributedCache, HttpRequestMessage request, CancellationToken token = default)
    {
        if (!request.IsCacheable())
        {
            return null;
        }
        var key = request.GetCacheKey();
        if (key == null)
        {
            return null;
        }
        var bytes = await distributedCache.GetAsync(key, token);
        if (bytes == null)
        {
            return null;
        }
        return HttpResponseMessageSerializer.Deserialize(bytes);
    }
    
    public static async Task SetAsync(this IDistributedCache distributedCache, HttpResponseMessage response, CancellationToken token = default)
    {
        if (response.RequestMessage == null)
        {
            throw new ArgumentNullException(nameof(response.RequestMessage), "unexpected usage, request message was null");
        }

        if (response.Headers.CacheControl == null)
        {
            throw new ArgumentNullException(nameof(response.Headers.CacheControl), "unexpected usage, response has no cache control");
        }
        
        if (response.Headers.CacheControl.MaxAge == null)
        {
            throw new ArgumentNullException(nameof(response.Headers.CacheControl.MaxAge), "unexpected usage, response cache control has no max age");
        }

        if (!response.RequestMessage.IsCacheable())
        {
            throw new ArgumentOutOfRangeException(nameof(response.RequestMessage), "unexpected usage, given response is not cacheable");
        }

        var key = response.RequestMessage.GetCacheKey();
        var bytes = HttpResponseMessageSerializer.Serialize(response);
        var options = new DistributedCacheEntryOptions
        {
            AbsoluteExpirationRelativeToNow = response.Headers.CacheControl.MaxAge
        };

        await distributedCache.SetAsync(key, bytes, options, token);
    }
}

namespace HttpCache.Extensions;

internal static class HttpRequestMessageExtensions
{
    public static bool IsCacheable(this HttpRequestMessage request)
    {
        if (request.RequestUri == null)
        {
            return false;
        }

        if (request.Method != HttpMethod.Get && request.Method != HttpMethod.Head)
        {
            return false;
        }

        return true;
    }

    public static string? GetCacheKey(this HttpRequestMessage request)
    {
        if (request.RequestUri == null)
        {
            return null;
        }

        if (!request.IsCacheable())
        {
            return null;
        }

        return $"{request.Method} {request.RequestUri.ToString()}";
    }
}

namespace HttpCache.Extensions;

internal static class HttpResponseMessageExtensions
{
    public static bool IsCacheable(this HttpResponseMessage response)
    {
        if (response.RequestMessage == null || !response.RequestMessage.IsCacheable())
        {
            return false;
        }

        if (response.Headers.CacheControl?.MaxAge == null || response.Headers.CacheControl?.MaxAge == TimeSpan.Zero)
        {
            return false;
        }
        
        return true;
    }
}

Models

namespace HttpCache.Models;

internal record CacheableHttpRequestMessage
{
    public HttpMethod Method { get; init; } = HttpMethod.Get;
    public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; init; } = Array.Empty<KeyValuePair<string, IEnumerable<string>>>();
    public IDictionary<string, object?> Options { get; init; } = new Dictionary<string, object?>();
    public Version Version { get; init; } = new();
    public Uri? RequestUri { get; init; }
    public HttpVersionPolicy VersionPolicy { get; init; }

    public HttpRequestMessage Convert()
    {
        var request = new HttpRequestMessage();

        request.Method = Method;
        request.RequestUri = RequestUri;
        
        request.Version = Version;
        request.VersionPolicy = VersionPolicy;

        foreach (var header in Headers)
        {
            request.Headers.Add(header.Key, header.Value);
        }

        foreach (var option in Options)
        {
            var key = new HttpRequestOptionsKey<object?>(option.Key);
            request.Options.Set(key, option.Value);
        }
        
        return request;
    }
    
    
}
using System.Net;

namespace HttpCache.Models;

internal record CacheableHttpResponseMessage
{
    public HttpStatusCode StatusCode { get; init; }
    public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; init; } = Array.Empty<KeyValuePair<string, IEnumerable<string>>>();
    public Version Version { get; init; } = new();
    public string? ReasonPhrase { get; init; }
    public IEnumerable<KeyValuePair<string, IEnumerable<string>>> TrailingHeaders { get; init; } = Array.Empty<KeyValuePair<string, IEnumerable<string>>>();
    
    public byte[] ContentBytes { get; init; } = Array.Empty<byte>();
    public IEnumerable<KeyValuePair<string, IEnumerable<string>>> ContentHeaders { get; init; } = Array.Empty<KeyValuePair<string, IEnumerable<string>>>();
    
    public CacheableHttpRequestMessage? RequestMessage { get; init; }
}

There is also converters and serializers but they are really tedious

With this in place we can test if thats work:

using System;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;

using FluentAssertions;

using HttpCache;

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting.Server.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Net.Http.Headers;

using Xunit;

namespace Tests;

public class CacheControlTests
{
    [Fact]
    public async Task ShouldRespectResponseCacheControl()
    {
        // Arrange
        var counter = 0;
        using var server = new WebHostBuilder().UseKestrel(o => o.Listen(IPAddress.Loopback, 0)).Configure(app =>
        {
            app.Run(async context =>
            {
                counter += 1;
                context.Response.Headers.CacheControl = new CacheControlHeaderValue { MaxAge = TimeSpan.FromSeconds(5) }.ToString();
                await context.Response.WriteAsync("Hello World");
            });
        }).Build();
        await server.StartAsync();
        var uri = new Uri(server.ServerFeatures.Get<IServerAddressesFeature>()!.Addresses.First());

        var services = new ServiceCollection();
        services.AddDistributedMemoryCache();
        services.AddSingleton<CacheControlDelegatingHandler>();
        services
            .AddHttpClient("demo", c => c.BaseAddress = uri)
            .AddHttpMessageHandler<CacheControlDelegatingHandler>();
        var provider = services.BuildServiceProvider();
        var factory = provider.GetRequiredService<IHttpClientFactory>();
        var httpClient = factory.CreateClient("demo");

        // Act & Assert
        
        var response = await httpClient.GetAsync("/");
        response.Should().BeSuccessful();
        counter.Should().Be(1);
        
        response = await httpClient.GetAsync("/");
        response.Should().BeSuccessful();
        counter.Should().Be(1, "second request served from cache");

        await Task.Delay(TimeSpan.FromSeconds(5));
        
        response = await httpClient.GetAsync("/");
        response.Should().BeSuccessful();
        counter.Should().Be(2, "after configured 5sec max age of cache we made another request");
    }
}

HttpClient respect ETag response header

Here is even more, how about entity tags

Behind the scene it is more about traffic savings but still applicable here

Sample implementation

using System.Net;

using HttpCache.Extensions;

using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Net.Http.Headers;

namespace HttpCache;

public class EntityTagDelegatingHandler: DelegatingHandler
{
    private readonly IDistributedCache _distributedCache;

    public EntityTagDelegatingHandler(IDistributedCache distributedCache)
    {
        _distributedCache = distributedCache;
    }

    public EntityTagDelegatingHandler(HttpMessageHandler innerHandler, IDistributedCache distributedCache) : base(innerHandler)
    {
        _distributedCache = distributedCache;
    }

    protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
    {
        var cached = await _distributedCache.GetAsync(request, cancellationToken);

        if (cached?.Headers.ETag != null)
        {
            request.Headers.Add(HeaderNames.IfNoneMatch, cached.Headers.ETag.ToString());
        }

        var response = await base.SendAsync(request, cancellationToken);

        if (cached != null && response.StatusCode == HttpStatusCode.NotModified)
        {
            return cached;
        }

        if (response.Headers.ETag != null)
        {
            response.Headers.Add(HeaderNames.CacheControl, new CacheControlHeaderValue { MaxAge = TimeSpan.MaxValue }.ToString());
            await _distributedCache.SetAsync(response, cancellationToken);
        }

        return response;
    }
}

And its tests

using System;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;

using FluentAssertions;

using HttpCache;

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting.Server.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Net.Http.Headers;

using Xunit;

namespace Tests;

public class EntityTagTests
{
    [Fact]
    public async Task ShouldRespectEntityTag()
    {
        // Arrange
        var counter = 0;
        var responsesWithoutBody = 0;
        using var server = new WebHostBuilder().UseKestrel(o => o.Listen(IPAddress.Loopback, 0)).Configure(app =>
        {
            app.Run(async context =>
            {
                counter += 1;
                if (context.Request.Headers.IfNoneMatch.Contains("\"test\""))
                {
                    context.Response.StatusCode = 304;
                    responsesWithoutBody += 1;
                    return;
                }
                context.Response.Headers.Add(HeaderNames.ETag, new EntityTagHeaderValue("\"test\"").ToString());;
                await context.Response.WriteAsync("Hello World");
            });
        }).Build();
        await server.StartAsync();
        var uri = new Uri(server.ServerFeatures.Get<IServerAddressesFeature>()!.Addresses.First());

        var services = new ServiceCollection();
        services.AddDistributedMemoryCache();
        services.AddSingleton<EntityTagDelegatingHandler>();
        services
            .AddHttpClient("demo", c => c.BaseAddress = uri)
            .AddHttpMessageHandler<EntityTagDelegatingHandler>();
        var provider = services.BuildServiceProvider();
        var factory = provider.GetRequiredService<IHttpClientFactory>();
        var httpClient = factory.CreateClient("demo");

        // Act & Assert
        var response = await httpClient.GetAsync("/");
        response.Should().BeSuccessful();
        counter.Should().Be(1);
        response.Headers.ETag.Should().NotBeNull();
        response.Headers.ETag!.Tag.Should().Be("\"test\"");
        
        response = await httpClient.GetAsync("/");
        response.Should().BeSuccessful();
        counter.Should().Be(2, "request still made");
        responsesWithoutBody.Should().Be(1, "but we did not serve body");
        var str = await response.Content.ReadAsStringAsync();
        str.Should().Be("Hello World", "response is taken from cache even so our backend did not return anything");
    }
}