﻿using Microsoft.AspNetCore.Http;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Performance.DtoModels;
using Performance.DtoModels.AppSettings;
using Performance.Infrastructure;
using System;
using System.IO;
using System.Linq;
using System.Net;
using System.Threading.Tasks;

namespace Performance.Api
{
    public class RequestRateLimitingMiddleware
    {
        private readonly int Limit = 1;
        private readonly ILogger logger;
        private readonly RequestDelegate next;
        private readonly IMemoryCache requestStore;
        private readonly IHttpContextAccessor httpContextAccessor;
        private readonly RateLimitingConfig options;

        public RequestRateLimitingMiddleware(
            ILogger<RequestRateLimitingMiddleware> logger,
            RequestDelegate next,
            IMemoryCache requestStore,
            IHttpContextAccessor httpContextAccessor,
            IOptions<RateLimitingConfig> options)
        {
            this.logger = logger;
            this.next = next;
            this.requestStore = requestStore;
            this.httpContextAccessor = httpContextAccessor;
            this.options = options.Value;
            if (options != null)
                Limit = options.Value.Limit;
        }

        public async Task Invoke(HttpContext context)
        {
            if (!context.Response.HasStarted && options != null && options.Endpoints != null && options.Endpoints.Any(t => context.Request.Path.ToString().StartsWith(t)))
            {
                var ip = httpContextAccessor.HttpContext.Connection.RemoteIpAddress.ToString();

                var headers = context.Request.Headers;
                if (headers.ContainsKey("X-Forwarded-For"))
                {
                    ip = IPAddress.Parse(headers["X-Forwarded-For"].ToString().Split(',', StringSplitOptions.RemoveEmptyEntries)[0]).ToString();
                }

                var requestKey = $"{ip}-{context.Request.Method}-{context.Request.Path}";
                // logger.LogInformation($"请求地址：{requestKey}");
                var cacheOptions = new MemoryCacheEntryOptions()
                {
                    AbsoluteExpiration = DateTime.Now.AddSeconds(options.Period)
                };

                if (requestStore.TryGetValue(requestKey, out int hitCount))
                {
                    if (hitCount < Limit)
                    {
                        await ProcessRequest(context, requestKey, hitCount, cacheOptions);
                    }
                    else
                    {
                        // X-RateLimit-RetryAfter：超出限制后能够再次正常访问的时间。
                        context.Response.Headers["X-RateLimit-RetryAfter"] = cacheOptions.AbsoluteExpiration?.ToString();
                        context.Response.StatusCode = StatusCodes.Status200OK;
                        context.Response.ContentType = "application/json; charset=utf-8";
                        var response = new ApiResponse
                        {
                            State = ResponseType.TooManyRequests,
                            Message = "您的操作正在处理，请稍等片刻！"
                        };
                        await context.Response.WriteAsync(JsonHelper.Serialize(response));
                    }
                }
                else
                {
                    await ProcessRequest(context, requestKey, hitCount, cacheOptions);
                }
            }
            else
            {
                await next(context);
            }

        }

        private async Task ProcessRequest(HttpContext context, string requestKey, int hitCount, MemoryCacheEntryOptions cacheOptions)
        {
            hitCount++;
            requestStore.Set(requestKey, hitCount, cacheOptions);
            // X-RateLimit-Limit：同一个时间段所允许的请求的最大数目
            context.Response.Headers["X-RateLimit-Limit"] = Limit.ToString();
            // X-RateLimit-Remaining：在当前时间段内剩余的请求的数量。
            context.Response.Headers["X-RateLimit-Remaining"] = (Limit - hitCount).ToString();
            await next(context);
        }
    }
}
