diff --git a/api/src/app.module.ts b/api/src/app.module.ts index e27e2c7..ffa951e 100644 --- a/api/src/app.module.ts +++ b/api/src/app.module.ts @@ -3,8 +3,9 @@ import { MongooseModule } from '@nestjs/mongoose' import { GatewayModule } from './gateway/gateway.module' import { AuthModule } from './auth/auth.module' import { UsersModule } from './users/users.module' -import { ThrottlerGuard, ThrottlerModule } from '@nestjs/throttler' +import { ThrottlerModule } from '@nestjs/throttler' import { APP_GUARD } from '@nestjs/core/constants' +import { ThrottlerByIpGuard } from './auth/guards/throttle-by-ip.guard' @Module({ imports: [ @@ -23,7 +24,7 @@ import { APP_GUARD } from '@nestjs/core/constants' providers: [ { provide: APP_GUARD, - useClass: ThrottlerGuard, + useClass: ThrottlerByIpGuard, }, ], }) diff --git a/api/src/auth/guards/throttle-by-ip.guard.ts b/api/src/auth/guards/throttle-by-ip.guard.ts new file mode 100644 index 0000000..c1134a4 --- /dev/null +++ b/api/src/auth/guards/throttle-by-ip.guard.ts @@ -0,0 +1,19 @@ +import { Injectable } from '@nestjs/common' +import { ThrottlerGuard } from '@nestjs/throttler' + +@Injectable() +export class ThrottlerByIpGuard extends ThrottlerGuard { + protected async getTracker(req: Record): Promise { + return this.extractIP(req) + } + + private extractIP(req: Record): string { + if (req.headers['x-forwarded-for']) { + return req.headers['x-forwarded-for'] + } else if (req.ips.length) { + return req.ips[0] + } else { + return req.ip + } + } +}