diff --git a/apps/api/.env.example b/apps/api/.env.example index b025326..d91799a 100644 --- a/apps/api/.env.example +++ b/apps/api/.env.example @@ -27,3 +27,5 @@ SLACK_WEBHOOK_URL= # set if you'd like to send slack server health status messag POSTHOG_API_KEY= # set if you'd like to send posthog events like job logs POSTHOG_HOST= # set if you'd like to send posthog events like job logs +STRIPE_PRICE_ID_STANDARD= +STRIPE_PRICE_ID_SCALE= diff --git a/apps/api/src/controllers/auth.ts b/apps/api/src/controllers/auth.ts index 77aa52f..fb3a813 100644 --- a/apps/api/src/controllers/auth.ts +++ b/apps/api/src/controllers/auth.ts @@ -1,9 +1,9 @@ import { parseApi } from "../../src/lib/parseApi"; -import { getRateLimiter } from "../../src/services/rate-limiter"; +import { getRateLimiter, crawlRateLimit, scrapeRateLimit } from "../../src/services/rate-limiter"; import { AuthResponse, RateLimiterMode } from "../../src/types"; import { supabase_service } from "../../src/services/supabase"; import { withAuth } from "../../src/lib/withAuth"; - +import { RateLimiterRedis } from "rate-limiter-flexible"; export async function authenticateUser(req, res, mode?: RateLimiterMode) : Promise { return withAuth(supaAuthenticateUser)(req, res, mode); @@ -19,7 +19,6 @@ export async function supaAuthenticateUser( error?: string; status?: number; }> { - const authHeader = req.headers.authorization; if (!authHeader) { return { success: false, error: "Unauthorized", status: 401 }; @@ -33,13 +32,55 @@ export async function supaAuthenticateUser( }; } + const incomingIP = (req.headers["x-forwarded-for"] || + req.socket.remoteAddress) as string; + const iptoken = incomingIP + token; + + let rateLimiter: RateLimiterRedis; + let subscriptionData: { team_id: string, plan: string } | null = null; + let normalizedApi: string; + + if (token == "this_is_just_a_preview_token") { + rateLimiter = await getRateLimiter(RateLimiterMode.Preview, token); + } else { + normalizedApi = parseApi(token); + + const { data, error } = await supabase_service.rpc( + 'get_key_and_price_id', { api_key: normalizedApi }); + + if (error) { + console.error('Error fetching key and price_id:', error); + } else { + console.log('Key and Price ID:', data); + } + + if (error || !data || data.length === 0) { + return { + success: false, + error: "Unauthorized: Invalid token", + status: 401, + }; + } + + subscriptionData = { + team_id: data[0].team_id, + plan: getPlanByPriceId(data[0].price_id) + } + switch (mode) { + case RateLimiterMode.Crawl: + rateLimiter = crawlRateLimit(subscriptionData.plan); + break; + case RateLimiterMode.Scrape: + rateLimiter = scrapeRateLimit(subscriptionData.plan); + break; + // case RateLimiterMode.Search: + // rateLimiter = await searchRateLimiter(RateLimiterMode.Search, token); + // break; + } + } + try { - const incomingIP = (req.headers["x-forwarded-for"] || - req.socket.remoteAddress) as string; - const iptoken = incomingIP + token; - await getRateLimiter( - token === "this_is_just_a_preview_token" ? RateLimiterMode.Preview : mode, token - ).consume(iptoken); + rateLimiter.consume(iptoken); } catch (rateLimiterRes) { console.error(rateLimiterRes); return { @@ -66,19 +107,36 @@ export async function supaAuthenticateUser( // return { success: false, error: "Unauthorized: Invalid token", status: 401 }; } - const normalizedApi = parseApi(token); // make sure api key is valid, based on the api_keys table in supabase - const { data, error } = await supabase_service + if (!subscriptionData) { + normalizedApi = parseApi(token); + + const { data, error } = await supabase_service .from("api_keys") .select("*") .eq("key", normalizedApi); - if (error || !data || data.length === 0) { - return { - success: false, - error: "Unauthorized: Invalid token", - status: 401, - }; + + if (error || !data || data.length === 0) { + return { + success: false, + error: "Unauthorized: Invalid token", + status: 401, + }; + } + + subscriptionData = data[0]; } - return { success: true, team_id: data[0].team_id }; + return { success: true, team_id: subscriptionData.team_id }; } + +function getPlanByPriceId(price_id: string) { + switch (price_id) { + case process.env.STRIPE_PRICE_ID_STANDARD: + return 'standard'; + case process.env.STRIPE_PRICE_ID_SCALE: + return 'scale'; + default: + return 'starter'; + } +} \ No newline at end of file diff --git a/apps/api/src/services/rate-limiter.ts b/apps/api/src/services/rate-limiter.ts index 5bc9acb..c20f67a 100644 --- a/apps/api/src/services/rate-limiter.ts +++ b/apps/api/src/services/rate-limiter.ts @@ -2,18 +2,18 @@ import { RateLimiterRedis } from "rate-limiter-flexible"; import * as redis from "redis"; import { RateLimiterMode } from "../../src/types"; -const MAX_REQUESTS_PER_MINUTE_PREVIEW = 5; const MAX_CRAWLS_PER_MINUTE_STARTER = 2; const MAX_CRAWLS_PER_MINUTE_STANDARD = 4; const MAX_CRAWLS_PER_MINUTE_SCALE = 20; +const MAX_SCRAPES_PER_MINUTE_STARTER = 10; +const MAX_SCRAPES_PER_MINUTE_STANDARD = 15; +const MAX_SCRAPES_PER_MINUTE_SCALE = 30; + +const MAX_REQUESTS_PER_MINUTE_PREVIEW = 5; const MAX_REQUESTS_PER_MINUTE_ACCOUNT = 20; - const MAX_REQUESTS_PER_MINUTE_CRAWL_STATUS = 120; - - - export const redisClient = redis.createClient({ url: process.env.REDIS_URL, legacyMode: true, @@ -48,15 +48,15 @@ export const testSuiteRateLimiter = new RateLimiterRedis({ }); -export function crawlRateLimit(plan: string){ - if(plan === "standard"){ +export function crawlRateLimit (plan: string){ + if (plan === "standard"){ return new RateLimiterRedis({ storeClient: redisClient, keyPrefix: "middleware", points: MAX_CRAWLS_PER_MINUTE_STANDARD, duration: 60, // Duration in seconds }); - }else if(plan === "scale"){ + } else if (plan === "scale"){ return new RateLimiterRedis({ storeClient: redisClient, keyPrefix: "middleware", @@ -70,18 +70,38 @@ export function crawlRateLimit(plan: string){ points: MAX_CRAWLS_PER_MINUTE_STARTER, duration: 60, // Duration in seconds }); - } - - +export function scrapeRateLimit (plan: string){ + if (plan === "standard"){ + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "middleware", + points: MAX_SCRAPES_PER_MINUTE_STANDARD, + duration: 60, // Duration in seconds + }); + } else if (plan === "scale"){ + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "middleware", + points: MAX_SCRAPES_PER_MINUTE_SCALE, + duration: 60, // Duration in seconds + }); + } + return new RateLimiterRedis({ + storeClient: redisClient, + keyPrefix: "middleware", + points: MAX_SCRAPES_PER_MINUTE_STARTER, + duration: 60, // Duration in seconds + }); +} export function getRateLimiter(mode: RateLimiterMode, token: string){ // Special test suite case. TODO: Change this later. - if(token.includes("5089cefa58")){ + if (token.includes("5089cefa58")){ return testSuiteRateLimiter; } - switch(mode) { + switch (mode) { case RateLimiterMode.Preview: return previewRateLimiter; case RateLimiterMode.CrawlStatus: