diff --git a/apps/api/.env.example b/apps/api/.env.example index 55271ec..3bdfcf1 100644 --- a/apps/api/.env.example +++ b/apps/api/.env.example @@ -17,6 +17,8 @@ SUPABASE_SERVICE_TOKEN= # Other Optionals TEST_API_KEY= # use if you've set up authentication and want to test with a real API key +RATE_LIMIT_TEST_API_KEY_SCRAPE= # set if you'd like to test the scraping rate limit +RATE_LIMIT_TEST_API_KEY_CRAWL= # set if you'd like to test the crawling rate limit SCRAPING_BEE_API_KEY= #Set if you'd like to use scraping Be to handle JS blocking OPENAI_API_KEY= # add for LLM dependednt features (image alt generation, etc.) BULL_AUTH_KEY= # @@ -27,3 +29,5 @@ SLACK_WEBHOOK_URL= # set if you'd like to send slack server health status messag POSTHOG_API_KEY= # set if you'd like to send posthog events like job logs POSTHOG_HOST= # set if you'd like to send posthog events like job logs +STRIPE_PRICE_ID_STANDARD= +STRIPE_PRICE_ID_SCALE= diff --git a/apps/api/src/__tests__/e2e_withAuth/index.test.ts b/apps/api/src/__tests__/e2e_withAuth/index.test.ts index abe5c58..e9082ca 100644 --- a/apps/api/src/__tests__/e2e_withAuth/index.test.ts +++ b/apps/api/src/__tests__/e2e_withAuth/index.test.ts @@ -955,4 +955,65 @@ describe("E2E Tests for API Routes", () => { expect(response.body).toHaveProperty("isProduction"); }); }); + + describe("Rate Limiter", () => { + it("should return 429 when rate limit is exceeded for preview token", async () => { + for (let i = 0; i < 5; i++) { + const response = await request(TEST_URL) + .post("/v0/scrape") + .set("Authorization", `Bearer this_is_just_a_preview_token`) + .set("Content-Type", "application/json") + .send({ url: "https://www.scrapethissite.com" }); + + expect(response.statusCode).toBe(200); + } + const response = await request(TEST_URL) + .post("/v0/scrape") + .set("Authorization", `Bearer this_is_just_a_preview_token`) + .set("Content-Type", "application/json") + .send({ url: "https://www.scrapethissite.com" }); + + expect(response.statusCode).toBe(429); + }, 60000); + }); + + // it("should return 429 when rate limit is exceeded for API key", async () => { + // for (let i = 0; i < parseInt(process.env.RATE_LIMIT_TEST_API_KEY_SCRAPE); i++) { + // const response = await request(TEST_URL) + // .post("/v0/scrape") + // .set("Authorization", `Bearer ${process.env.TEST_API_KEY}`) + // .set("Content-Type", "application/json") + // .send({ url: "https://www.scrapethissite.com" }); + + // expect(response.statusCode).toBe(200); + // } + + // const response = await request(TEST_URL) + // .post("/v0/scrape") + // .set("Authorization", `Bearer ${process.env.TEST_API_KEY}`) + // .set("Content-Type", "application/json") + // .send({ url: "https://www.scrapethissite.com" }); + + // expect(response.statusCode).toBe(429); + // }, 60000); + + // it("should return 429 when rate limit is exceeded for API key", async () => { + // for (let i = 0; i < parseInt(process.env.RATE_LIMIT_TEST_API_KEY_CRAWL); i++) { + // const response = await request(TEST_URL) + // .post("/v0/crawl") + // .set("Authorization", `Bearer ${process.env.TEST_API_KEY}`) + // .set("Content-Type", "application/json") + // .send({ url: "https://www.scrapethissite.com" }); + + // expect(response.statusCode).toBe(200); + // } + + // const response = await request(TEST_URL) + // .post("/v0/crawl") + // .set("Authorization", `Bearer ${process.env.TEST_API_KEY}`) + // .set("Content-Type", "application/json") + // .send({ url: "https://www.scrapethissite.com" }); + + // expect(response.statusCode).toBe(429); + // }, 60000); }); diff --git a/apps/api/src/controllers/auth.ts b/apps/api/src/controllers/auth.ts index 77aa52f..4009d69 100644 --- a/apps/api/src/controllers/auth.ts +++ b/apps/api/src/controllers/auth.ts @@ -1,9 +1,9 @@ import { parseApi } from "../../src/lib/parseApi"; -import { getRateLimiter } from "../../src/services/rate-limiter"; +import { getRateLimiter, } from "../../src/services/rate-limiter"; import { AuthResponse, RateLimiterMode } from "../../src/types"; import { supabase_service } from "../../src/services/supabase"; import { withAuth } from "../../src/lib/withAuth"; - +import { RateLimiterRedis } from "rate-limiter-flexible"; export async function authenticateUser(req, res, mode?: RateLimiterMode) : Promise { return withAuth(supaAuthenticateUser)(req, res, mode); @@ -19,7 +19,6 @@ export async function supaAuthenticateUser( error?: string; status?: number; }> { - const authHeader = req.headers.authorization; if (!authHeader) { return { success: false, error: "Unauthorized", status: 401 }; @@ -33,13 +32,85 @@ export async function supaAuthenticateUser( }; } + const incomingIP = (req.headers["x-forwarded-for"] || + req.socket.remoteAddress) as string; + const iptoken = incomingIP + token; + + let rateLimiter: RateLimiterRedis; + let subscriptionData: { team_id: string, plan: string } | null = null; + let normalizedApi: string; + + if (token == "this_is_just_a_preview_token") { + rateLimiter = getRateLimiter(RateLimiterMode.Preview, token); + } else { + normalizedApi = parseApi(token); + + const { data, error } = await supabase_service.rpc( + 'get_key_and_price_id_2', { api_key: normalizedApi } + ); + // get_key_and_price_id_2 rpc definition: + // create or replace function get_key_and_price_id_2(api_key uuid) + // returns table(key uuid, team_id uuid, price_id text) as $$ + // begin + // if api_key is null then + // return query + // select null::uuid as key, null::uuid as team_id, null::text as price_id; + // end if; + + // return query + // select ak.key, ak.team_id, s.price_id + // from api_keys ak + // left join subscriptions s on ak.team_id = s.team_id + // where ak.key = api_key; + // end; + // $$ language plpgsql; + + if (error) { + console.error('Error fetching key and price_id:', error); + } else { + // console.log('Key and Price ID:', data); + } + + if (error || !data || data.length === 0) { + return { + success: false, + error: "Unauthorized: Invalid token", + status: 401, + }; + } + + + subscriptionData = { + team_id: data[0].team_id, + plan: getPlanByPriceId(data[0].price_id) + } + switch (mode) { + case RateLimiterMode.Crawl: + rateLimiter = getRateLimiter(RateLimiterMode.Crawl, token, subscriptionData.plan); + break; + case RateLimiterMode.Scrape: + rateLimiter = getRateLimiter(RateLimiterMode.Scrape, token, subscriptionData.plan); + break; + case RateLimiterMode.CrawlStatus: + rateLimiter = getRateLimiter(RateLimiterMode.CrawlStatus, token); + break; + case RateLimiterMode.Search: + rateLimiter = getRateLimiter(RateLimiterMode.Search, token); + break; + case RateLimiterMode.Preview: + rateLimiter = getRateLimiter(RateLimiterMode.Preview, token); + break; + default: + rateLimiter = getRateLimiter(RateLimiterMode.Crawl, token); + break; + // case RateLimiterMode.Search: + // rateLimiter = await searchRateLimiter(RateLimiterMode.Search, token); + // break; + } + } + try { - const incomingIP = (req.headers["x-forwarded-for"] || - req.socket.remoteAddress) as string; - const iptoken = incomingIP + token; - await getRateLimiter( - token === "this_is_just_a_preview_token" ? RateLimiterMode.Preview : mode, token - ).consume(iptoken); + await rateLimiter.consume(iptoken); } catch (rateLimiterRes) { console.error(rateLimiterRes); return { @@ -66,19 +137,36 @@ export async function supaAuthenticateUser( // return { success: false, error: "Unauthorized: Invalid token", status: 401 }; } - const normalizedApi = parseApi(token); // make sure api key is valid, based on the api_keys table in supabase - const { data, error } = await supabase_service + if (!subscriptionData) { + normalizedApi = parseApi(token); + + const { data, error } = await supabase_service .from("api_keys") .select("*") .eq("key", normalizedApi); - if (error || !data || data.length === 0) { - return { - success: false, - error: "Unauthorized: Invalid token", - status: 401, - }; + + if (error || !data || data.length === 0) { + return { + success: false, + error: "Unauthorized: Invalid token", + status: 401, + }; + } + + subscriptionData = data[0]; } - return { success: true, team_id: data[0].team_id }; + return { success: true, team_id: subscriptionData.team_id }; } + +function getPlanByPriceId(price_id: string) { + switch (price_id) { + case process.env.STRIPE_PRICE_ID_STANDARD: + return 'standard'; + case process.env.STRIPE_PRICE_ID_SCALE: + return 'scale'; + default: + return 'starter'; + } +} \ No newline at end of file diff --git a/apps/api/src/services/rate-limiter.ts b/apps/api/src/services/rate-limiter.ts index 34c243b..29b14f8 100644 --- a/apps/api/src/services/rate-limiter.ts +++ b/apps/api/src/services/rate-limiter.ts @@ -2,17 +2,21 @@ import { RateLimiterRedis } from "rate-limiter-flexible"; import * as redis from "redis"; import { RateLimiterMode } from "../../src/types"; -const MAX_REQUESTS_PER_MINUTE_PREVIEW = 5; -const MAX_CRAWLS_PER_MINUTE_STARTER = 2; -const MAX_CRAWLS_PER_MINUTE_STANDARD = 4; +const MAX_CRAWLS_PER_MINUTE_STARTER = 3; +const MAX_CRAWLS_PER_MINUTE_STANDARD = 5; const MAX_CRAWLS_PER_MINUTE_SCALE = 20; +const MAX_SCRAPES_PER_MINUTE_STARTER = 20; +const MAX_SCRAPES_PER_MINUTE_STANDARD = 30; +const MAX_SCRAPES_PER_MINUTE_SCALE = 50; + +const MAX_SEARCHES_PER_MINUTE_STARTER = 20; +const MAX_SEARCHES_PER_MINUTE_STANDARD = 30; +const MAX_SEARCHES_PER_MINUTE_SCALE = 50; + +const MAX_REQUESTS_PER_MINUTE_PREVIEW = 5; const MAX_REQUESTS_PER_MINUTE_ACCOUNT = 20; - -const MAX_REQUESTS_PER_MINUTE_CRAWL_STATUS = 120; - - - +const MAX_REQUESTS_PER_MINUTE_CRAWL_STATUS = 150; export const redisClient = redis.createClient({ url: process.env.REDIS_URL, @@ -21,71 +25,109 @@ export const redisClient = redis.createClient({ export const previewRateLimiter = new RateLimiterRedis({ storeClient: redisClient, - keyPrefix: "middleware", + keyPrefix: "preview", points: MAX_REQUESTS_PER_MINUTE_PREVIEW, duration: 60, // Duration in seconds }); export const serverRateLimiter = new RateLimiterRedis({ storeClient: redisClient, - keyPrefix: "middleware", + keyPrefix: "server", points: MAX_REQUESTS_PER_MINUTE_ACCOUNT, duration: 60, // Duration in seconds }); export const crawlStatusRateLimiter = new RateLimiterRedis({ storeClient: redisClient, - keyPrefix: "middleware", + keyPrefix: "crawl-status", points: MAX_REQUESTS_PER_MINUTE_CRAWL_STATUS, duration: 60, // Duration in seconds }); export const testSuiteRateLimiter = new RateLimiterRedis({ storeClient: redisClient, - keyPrefix: "middleware", - points: 100000, + keyPrefix: "test-suite", + points: 10000, duration: 60, // Duration in seconds }); -export function crawlRateLimit(plan: string){ - if(plan === "standard"){ - return new RateLimiterRedis({ - storeClient: redisClient, - keyPrefix: "middleware", - points: MAX_CRAWLS_PER_MINUTE_STANDARD, - duration: 60, // Duration in seconds - }); - }else if(plan === "scale"){ - return new RateLimiterRedis({ - storeClient: redisClient, - keyPrefix: "middleware", - points: MAX_CRAWLS_PER_MINUTE_SCALE, - duration: 60, // Duration in seconds - }); - } - return new RateLimiterRedis({ - storeClient: redisClient, - keyPrefix: "middleware", - points: MAX_CRAWLS_PER_MINUTE_STARTER, - duration: 60, // Duration in seconds - }); - -} - - - - -export function getRateLimiter(mode: RateLimiterMode, token: string){ +export function getRateLimiter(mode: RateLimiterMode, token: string, plan?: string){ // Special test suite case. TODO: Change this later. - if(token.includes("5089cefa58")){ + if (token.includes("5089cefa58")){ return testSuiteRateLimiter; } - switch(mode) { + switch (mode) { case RateLimiterMode.Preview: return previewRateLimiter; case RateLimiterMode.CrawlStatus: return crawlStatusRateLimiter; + case RateLimiterMode.Crawl: + if (plan === "standard"){ + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "crawl-standard", + points: MAX_CRAWLS_PER_MINUTE_STANDARD, + duration: 60, // Duration in seconds + }); + } else if (plan === "scale"){ + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "crawl-scale", + points: MAX_CRAWLS_PER_MINUTE_SCALE, + duration: 60, // Duration in seconds + }); + } + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "crawl-starter", + points: MAX_CRAWLS_PER_MINUTE_STARTER, + duration: 60, // Duration in seconds + }); + case RateLimiterMode.Scrape: + if (plan === "standard"){ + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "scrape-standard", + points: MAX_SCRAPES_PER_MINUTE_STANDARD, + duration: 60, // Duration in seconds + }); + } else if (plan === "scale"){ + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "scrape-scale", + points: MAX_SCRAPES_PER_MINUTE_SCALE, + duration: 60, // Duration in seconds + }); + } + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "scrape-starter", + points: MAX_SCRAPES_PER_MINUTE_STARTER, + duration: 60, // Duration in seconds + }); + case RateLimiterMode.Search: + if (plan === "standard"){ + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "search-standard", + points: MAX_SEARCHES_PER_MINUTE_STANDARD, + duration: 60, // Duration in seconds + }); + } else if (plan === "scale"){ + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "search-scale", + points: MAX_SEARCHES_PER_MINUTE_SCALE, + duration: 60, // Duration in seconds + }); + } + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "search-starter", + points: MAX_SEARCHES_PER_MINUTE_STARTER, + duration: 60, // Duration in seconds + }); default: return serverRateLimiter; }