如何在GraphQL中根据查询参数解析特定的DbContext?

人气:620 发布:2022-10-16 标签: dependency-injection c# entity-framework-core graphql

问题描述

我想使用GraphQL和实体框架核心来查询多个数据库。每个数据库都链接到一个被许可方,因此所有查询都会收到一个查询参数licenseeId。现在,我需要指示DI在服务请求DbContext(例如,通过构造函数参数或服务定位器)时,以某种方式基于licenseeId解析DbContext。这真的可能吗?

以下是目前执行的相关部分:

存储库类

public MyRepository
{
    public MyDbContext DbContext { get; set; }

    public MyRepository(MyDbContext dbContext)
    {
    }
}

查询类

public class MainQuery : ObjectGraphType
{
    public MainQuery()
    {
        objectGraph.FieldAsync<ListGraphType<MyModel>>("items",
            arguments: new QueryArguments(
                new QueryArgument<NonNullGraphType<GuidGraphType>> { Name = "licenseeId" }
            ),
            resolve: async context => {
                var licenseeId = resolveFieldContext.GetArgument<Guid>("licenseeId");

                // *1, create dbContext based on licenseeId manually via factory
                var dbContext = ...;

                var repository = resolveFieldContext.ResolveServices.GetRequiredService<CucumberRepository>();

                // *2, assign context manually
                repository.DbContext = dbContext;

                return await repository.GetAllAsync();
            });
    }
}

如您所见,我当前需要使用工厂(*1)手动创建DbContext,然后将此实例分配给存储库属性(*2)。

我想在这里使用纯DI。我的想法是以某种方式使用Startup中的服务工厂

services.AddDbContext<MyDbContext>((serviceProvider, dbContextOptionsBuilder) => {
  var query = serviceProvider.GetRequiredService<?>();
  
  var connectionString = $"...;Catalog=MyDatabase_{query.GetLicenseeId()}";

  dbContextOptionsBuilder.UseSqlServer(connectionString, ...);
});

这将使我能够像

一样定义存储库类
public MyRepository
{
    public MyRepository(MyDbContext dbContext)
    {
    }
}

在解析回调中,我只需编写

objectGraph.FieldAsync<ListGraphType<MyModel>>(...,
    resolve: async context => {
        var repository = resolveFieldContext.ResolveServices.GetRequiredService<MyRepository>();

        return await repository.GetAllAsync();
    });

这是个好主意吗?

推荐答案

我提出了以下方法(假设许可为INT)。请注意,缓存DbContextOptions很重要,因为EF Core基于此对象缓存LINQ查询。

public interface ILicenseOptionFactory
{
    public DbContextOptions GetOptions(int licenseId);
}

public class LicenseOptionFactory : ILicenseOptionFactory
{
    private ConcurrentDictionary<int, DbContextOptions> _options = new ConcurrentDictionary<int, DbContextOptions>();

    public DbContextOptions GetOptions(int licenseId)
    {
        var options = _options.GetOrAdd(licenseId, lid =>
        {
            // any other way how to retrieve connections string based on licenseId
            string cs;
            switch (lid)
            {
                case 0:
                    cs = "connectionString0";
                    break;
                case 1:
                    cs = "connectionString1";
                    break;
                default:
                    throw new Exception($"Invalid licenseId: {lid}");
            }

            return new DbContextOptionsBuilder().UseSqlServer(cs).Options;
        });

        return options;
    }
}
public interface ILicenseConnectionFactory<TContext> : IDisposable, IAsyncDisposable
    where TContext : DbContext
{
    TContext GetContext(int licenseId);
}

public class LicenseConnectionFactory<TContext> : ILicenseConnectionFactory<TContext> where TContext : DbContext
{
    private readonly ILicenseOptionFactory _optionFactory;
    private static Dictionary<int, TContext> _contexts;

    public LicenseConnectionFactory(ILicenseOptionFactory optionFactory)
    {
        _optionFactory = optionFactory;
    }

    public TContext GetContext(int licenseId)
    {
        _contexts ??= new Dictionary<int, TContext>();
        if (_contexts.TryGetValue(licenseId, out var ctx))
            return ctx;

        var options = _optionFactory.GetOptions(licenseId);
        ctx = (TContext)Activator.CreateInstance(typeof(TContext), options);
        _contexts.Add(licenseId, ctx);
        return ctx;
    }

    public void Dispose()
    {
        if (_contexts == null)
            return;

        foreach (var dbContext in _contexts.Values)
        {
            dbContext.Dispose();   
        }

        _contexts = null;
    }

    public async ValueTask DisposeAsync()
    {
        if (_contexts == null)
            return;

        foreach (var dbContext in _contexts.Values)
        {
            await dbContext.DisposeAsync();   
        }

        _contexts = null;
    }
}

注册示例,请注意这些服务需要SingletonScoped

var serviceCollection = new ServiceCollection();

serviceCollection.AddSingleton<ILicenseOptionFactory, LicenseOptionFactory>();
serviceCollection
    .AddScoped<ILicenseConnectionFactory<MyDbContext>, LicenseConnectionFactory<MyDbContext>>();

存储库示例(但最好完全删除此抽象)

public class MyRepository
{
    private readonly ILicenseConnectionFactory<MyDbContext> _factory;
    private MyDbContext _dbContext;

    public MyDbContext DbContext
    {
        get => _dbContext ?? throw new Exception("Repository is not initialized.");
    }

    public MyRepository(ILicenseConnectionFactory<MyDbContext> factory)
    {
        _factory = factory;
    }

    public void SetLicenseId(int licnseId)
    {
        _dbContext = _factory.GetContext(licnseId);
    }
}

和最终用法。我不知道什么是resolveFieldContext是否可以通过DI解决-您可以简化存储库初始化,而不使用SetLicenseId

public class MainQuery : ObjectGraphType
{
    public MainQuery()
    {
        objectGraph.FieldAsync<ListGraphType<MyModel>>("items",
            arguments: new QueryArguments(
                new QueryArgument<NonNullGraphType<GuidGraphType>> { Name = "licenseeId" }
            ),
            resolve: async context => {
                var licenseeId = resolveFieldContext.GetArgument<Guid>("licenseeId");

                // *1, create dbContext based on licenseeId manually via factory
                var repository = resolveFieldContext.ResolveServices.GetRequiredService<CucumberRepository>();

                // *2, assign context manually
                repository.SetLicenseId(licenseeId);

                return await repository.GetAllAsync();
            });
    }
}

678