merge: Prevent streaming API denial-of-service (resolves #1019) (!951)

View MR for information: https://activitypub.software/TransFem-org/Sharkey/-/merge_requests/951

Closes #1019

Approved-by: dakkar <dakkar@thenautilus.net>
Approved-by: Marie <github@yuugi.dev>
This commit is contained in:
Marie 2025-03-30 10:40:56 +00:00
commit 865a9c4906
4 changed files with 280 additions and 113 deletions

View file

@ -34,13 +34,18 @@ Header meanings and usage have been devised by adapting common patterns to work
## Performance ## Performance
SkRateLimiterService makes between 1 and 4 redis transactions per rate limit check. SkRateLimiterService makes between 0 and 4 redis transactions per rate limit check.
The first call is read-only, while the others perform at least one write operation. The first call is read-only, while the others perform at least one write operation.
No calls are made if a client has already been blocked at least once, as the block status is stored in a short-term memory cache.
Two integer keys are stored per client/subject, and both expire together after the maximum duration of the limit. Two integer keys are stored per client/subject, and both expire together after the maximum duration of the limit.
While performance has not been formally tested, it's expected that SkRateLimiterService has an impact roughly on par with the legacy RateLimiterService. While performance has not been formally tested, it's expected that SkRateLimiterService has an impact roughly on par with the legacy RateLimiterService.
Redis memory usage should be notably lower due to the reduced number of keys and avoidance of set / array constructions. Redis memory usage should be notably lower due to the reduced number of keys and avoidance of set / array constructions.
If redis load does become a concern, then a dedicated node can be assigned via the `redisForRateLimit` config setting. If redis load does become a concern, then a dedicated node can be assigned via the `redisForRateLimit` config setting.
To prevent Redis DoS, SkRateLimiterService internally tracks the number of concurrent requests for each unique client/endpoint combination.
If the number of requests exceeds the limit's maximum value, then any further requests are automatically rejected.
The lockout will automatically end when the number of active requests drops to within the limit value.
## Concurrency and Multi-Node Correctness ## Concurrency and Multi-Node Correctness
To provide consistency across multi-node environments, leaky bucket is implemented with only atomic operations (`Increment`, `Decrement`, `Add`, and `Subtract`). To provide consistency across multi-node environments, leaky bucket is implemented with only atomic operations (`Increment`, `Decrement`, `Add`, and `Subtract`).
@ -54,6 +59,12 @@ Any possible conflict would have to occur within a few-milliseconds window, whic
This error does not compound, as all further operations are relative (Increment and Add). This error does not compound, as all further operations are relative (Increment and Add).
Thus, it's considered an acceptable tradeoff given the limitations imposed by Redis and ioredis. Thus, it's considered an acceptable tradeoff given the limitations imposed by Redis and ioredis.
In-process memory caches are used sparingly to avoid consistency problems.
Besides the role factor cache, there is one "important" cache which directly impacts limit calculations: the lockout cache.
This cache stores response data for blocked limits, preventing repeated calls to redis if a client ignores the 429 errors and continues making requests.
Consistency is guaranteed by only caching blocked limits (allowances are not cached), and by limiting cached data to the duration of the block.
This ensures that stale limit info is never used.
## Algorithm Pseudocode ## Algorithm Pseudocode
The Atomic Leaky Bucket algorithm is described here, in pseudocode: The Atomic Leaky Bucket algorithm is described here, in pseudocode:

View file

@ -17,10 +17,23 @@ import type { RoleService } from '@/core/RoleService.js';
// Required because MemoryKVCache doesn't support null keys. // Required because MemoryKVCache doesn't support null keys.
const defaultUserKey = ''; const defaultUserKey = '';
interface ParsedLimit {
key: string;
now: number;
bucketSize: number;
dripRate: number;
dripSize: number;
fullResetMs: number;
fullResetSec: number;
}
@Injectable() @Injectable()
export class SkRateLimiterService { export class SkRateLimiterService {
// 1-minute cache interval // 1-minute cache interval
private readonly factorCache = new MemoryKVCache<number>(1000 * 60); private readonly factorCache = new MemoryKVCache<number>(1000 * 60);
// 10-second cache interval
private readonly lockoutCache = new MemoryKVCache<number>(1000 * 10);
private readonly requestCounts = new Map<string, number>();
private readonly disabled: boolean; private readonly disabled: boolean;
constructor( constructor(
@ -58,6 +71,8 @@ export class SkRateLimiterService {
} }
const actor = typeof(actorOrUser) === 'object' ? actorOrUser.id : actorOrUser; const actor = typeof(actorOrUser) === 'object' ? actorOrUser.id : actorOrUser;
const actorKey = `@${actor}#${limit.key}`;
const userCacheKey = typeof(actorOrUser) === 'object' ? actorOrUser.id : defaultUserKey; const userCacheKey = typeof(actorOrUser) === 'object' ? actorOrUser.id : defaultUserKey;
const userRoleKey = typeof(actorOrUser) === 'object' ? actorOrUser.id : null; const userRoleKey = typeof(actorOrUser) === 'object' ? actorOrUser.id : null;
const factor = this.factorCache.get(userCacheKey) ?? await this.factorCache.fetch(userCacheKey, async () => { const factor = this.factorCache.get(userCacheKey) ?? await this.factorCache.fetch(userCacheKey, async () => {
@ -73,25 +88,81 @@ export class SkRateLimiterService {
throw new Error(`Rate limit factor is zero or negative: ${factor}`); throw new Error(`Rate limit factor is zero or negative: ${factor}`);
} }
if (isLegacyRateLimit(limit)) { const parsedLimit = this.parseLimit(limit, factor);
return await this.limitLegacy(limit, actor, factor); if (parsedLimit == null) {
} else {
return await this.limitBucket(limit, actor, factor);
}
}
private async limitLegacy(limit: Keyed<LegacyRateLimit>, actor: string, factor: number): Promise<LimitInfo> {
if (hasMaxLimit(limit)) {
return await this.limitLegacyMinMax(limit, actor, factor);
} else if (hasMinLimit(limit)) {
return await this.limitLegacyMinOnly(limit, actor, factor);
} else {
return disabledLimitInfo; return disabledLimitInfo;
} }
// Fast-path to avoid extra redis calls for blocked clients
const lockout = this.getLockout(actorKey, parsedLimit);
if (lockout) {
return lockout;
} }
private async limitLegacyMinMax(limit: Keyed<MaxLegacyLimit>, actor: string, factor: number): Promise<LimitInfo> { // Fast-path to avoid queuing requests that are guaranteed to fail
if (limit.duration === 0) return disabledLimitInfo; const overflow = this.incrementOverflow(actorKey, parsedLimit);
if (overflow) {
return overflow;
}
try {
const info = await this.limitBucket(parsedLimit, actor);
// Store blocked status to avoid hammering redis
if (info.blocked) {
this.lockoutCache.set(actorKey, info.resetMs);
}
return info;
} finally {
this.decrementOverflow(actorKey);
}
}
private getLockout(lockoutKey: string, limit: ParsedLimit): LimitInfo | null {
const lockoutReset = this.lockoutCache.get(lockoutKey);
if (!lockoutReset) {
// Not blocked, proceed with redis check
return null;
}
if (limit.now >= lockoutReset) {
// Block expired, clear and proceed with redis check
this.lockoutCache.delete(lockoutKey);
return null;
}
// Lockout is still active, pre-emptively reject the request
return {
blocked: true,
remaining: 0,
resetMs: limit.fullResetMs,
resetSec: limit.fullResetSec,
fullResetMs: limit.fullResetMs,
fullResetSec: limit.fullResetSec,
};
}
private parseLimit(limit: Keyed<RateLimit>, factor: number): ParsedLimit | null {
if (isLegacyRateLimit(limit)) {
return this.parseLegacyLimit(limit, factor);
} else {
return this.parseBucketLimit(limit, factor);
}
}
private parseLegacyLimit(limit: Keyed<LegacyRateLimit>, factor: number): ParsedLimit | null {
if (hasMaxLimit(limit)) {
return this.parseLegacyMinMax(limit, factor);
} else if (hasMinLimit(limit)) {
return this.parseLegacyMinOnly(limit, factor);
} else {
return null;
}
}
private parseLegacyMinMax(limit: Keyed<MaxLegacyLimit>, factor: number): ParsedLimit | null {
if (limit.duration === 0) return null;
if (limit.duration < 0) throw new Error(`Invalid rate limit ${limit.key}: duration is negative (${limit.duration})`); if (limit.duration < 0) throw new Error(`Invalid rate limit ${limit.key}: duration is negative (${limit.duration})`);
if (limit.max < 1) throw new Error(`Invalid rate limit ${limit.key}: max is less than 1 (${limit.max})`); if (limit.max < 1) throw new Error(`Invalid rate limit ${limit.key}: max is less than 1 (${limit.max})`);
@ -104,35 +175,30 @@ export class SkRateLimiterService {
// Calculate final dripRate from dripSize and duration/max // Calculate final dripRate from dripSize and duration/max
const dripRate = Math.max(Math.round(limit.duration / (limit.max / dripSize)), 1); const dripRate = Math.max(Math.round(limit.duration / (limit.max / dripSize)), 1);
const bucketLimit: Keyed<BucketRateLimit> = { return this.parseBucketLimit({
type: 'bucket', type: 'bucket',
key: limit.key, key: limit.key,
size: limit.max, size: limit.max,
dripRate, dripRate,
dripSize, dripSize,
}; }, factor);
return await this.limitBucket(bucketLimit, actor, factor);
} }
private async limitLegacyMinOnly(limit: Keyed<MinLegacyLimit>, actor: string, factor: number): Promise<LimitInfo> { private parseLegacyMinOnly(limit: Keyed<MinLegacyLimit>, factor: number): ParsedLimit | null {
if (limit.minInterval === 0) return disabledLimitInfo; if (limit.minInterval === 0) return null;
if (limit.minInterval < 0) throw new Error(`Invalid rate limit ${limit.key}: minInterval is negative (${limit.minInterval})`); if (limit.minInterval < 0) throw new Error(`Invalid rate limit ${limit.key}: minInterval is negative (${limit.minInterval})`);
const dripRate = Math.max(Math.round(limit.minInterval), 1); const dripRate = Math.max(Math.round(limit.minInterval), 1);
const bucketLimit: Keyed<BucketRateLimit> = { return this.parseBucketLimit({
type: 'bucket', type: 'bucket',
key: limit.key, key: limit.key,
size: 1, size: 1,
dripRate, dripRate,
dripSize: 1, dripSize: 1,
}; }, factor);
return await this.limitBucket(bucketLimit, actor, factor);
} }
/** private parseBucketLimit(limit: Keyed<BucketRateLimit>, factor: number): ParsedLimit {
* Implementation of Leaky Bucket rate limiting - see SkRateLimiterService.md for details.
*/
private async limitBucket(limit: Keyed<BucketRateLimit>, actor: string, factor: number): Promise<LimitInfo> {
if (limit.size < 1) throw new Error(`Invalid rate limit ${limit.key}: size is less than 1 (${limit.size})`); if (limit.size < 1) throw new Error(`Invalid rate limit ${limit.key}: size is less than 1 (${limit.size})`);
if (limit.dripRate != null && limit.dripRate < 1) throw new Error(`Invalid rate limit ${limit.key}: dripRate is less than 1 (${limit.dripRate})`); if (limit.dripRate != null && limit.dripRate < 1) throw new Error(`Invalid rate limit ${limit.key}: dripRate is less than 1 (${limit.dripRate})`);
if (limit.dripSize != null && limit.dripSize < 1) throw new Error(`Invalid rate limit ${limit.key}: dripSize is less than 1 (${limit.dripSize})`); if (limit.dripSize != null && limit.dripSize < 1) throw new Error(`Invalid rate limit ${limit.key}: dripSize is less than 1 (${limit.dripSize})`);
@ -142,7 +208,27 @@ export class SkRateLimiterService {
const bucketSize = Math.max(Math.ceil(limit.size / factor), 1); const bucketSize = Math.max(Math.ceil(limit.size / factor), 1);
const dripRate = Math.ceil(limit.dripRate ?? 1000); const dripRate = Math.ceil(limit.dripRate ?? 1000);
const dripSize = Math.ceil(limit.dripSize ?? 1); const dripSize = Math.ceil(limit.dripSize ?? 1);
const expirationSec = Math.max(Math.ceil((dripRate * Math.ceil(bucketSize / dripSize)) / 1000), 1); const fullResetMs = dripRate * Math.ceil(bucketSize / dripSize);
const fullResetSec = Math.max(Math.ceil(fullResetMs / 1000), 1);
return {
key: limit.key,
now,
bucketSize,
dripRate,
dripSize,
fullResetMs,
fullResetSec,
};
}
/**
* Implementation of Leaky Bucket rate limiting - see SkRateLimiterService.md for details.
*/
private async limitBucket(limit: ParsedLimit, actor: string): Promise<LimitInfo> {
// 0 - Calculate (extracted to other function)
const { now, bucketSize, dripRate, dripSize } = limit;
const expirationSec = limit.fullResetSec;
// 1 - Read // 1 - Read
const counterKey = createLimitKey(limit, actor, 'c'); const counterKey = createLimitKey(limit, actor, 'c');
@ -262,13 +348,44 @@ export class SkRateLimiterService {
return responses; return responses;
} }
private incrementOverflow(actorKey: string, limit: ParsedLimit): LimitInfo | null {
const oldCount = this.requestCounts.get(actorKey) ?? 0;
if (oldCount >= limit.bucketSize) {
// Overflow, pre-emptively reject the request
return {
blocked: true,
remaining: 0,
resetMs: limit.fullResetMs,
resetSec: limit.fullResetSec,
fullResetMs: limit.fullResetMs,
fullResetSec: limit.fullResetSec,
};
}
// No overflow, increment and continue to redis
this.requestCounts.set(actorKey, oldCount + 1);
return null;
}
private decrementOverflow(actorKey: string): void {
const count = this.requestCounts.get(actorKey);
if (count) {
if (count > 1) {
this.requestCounts.set(actorKey, count - 1);
} else {
this.requestCounts.delete(actorKey);
}
}
}
} }
// Not correct, but good enough for the basic commands we use. // Not correct, but good enough for the basic commands we use.
type RedisResult = string | null; type RedisResult = string | null;
type RedisCommand = [command: string, ...args: unknown[]]; type RedisCommand = [command: string, ...args: unknown[]];
function createLimitKey(limit: Keyed<RateLimit>, actor: string, value: string): string { function createLimitKey(limit: ParsedLimit, actor: string, value: string): string {
return `rl_${actor}_${limit.key}_${value}`; return `rl_${actor}_${limit.key}_${value}`;
} }

View file

@ -10,7 +10,9 @@ import * as WebSocket from 'ws';
import proxyAddr from 'proxy-addr'; import proxyAddr from 'proxy-addr';
import ms from 'ms'; import ms from 'ms';
import { DI } from '@/di-symbols.js'; import { DI } from '@/di-symbols.js';
import type { UsersRepository, MiAccessToken } from '@/models/_.js'; import type { UsersRepository, MiAccessToken, MiUser } from '@/models/_.js';
import type { Config } from '@/config.js';
import type { Keyed, RateLimit } from '@/misc/rate-limit-utils.js';
import { NoteReadService } from '@/core/NoteReadService.js'; import { NoteReadService } from '@/core/NoteReadService.js';
import { NotificationService } from '@/core/NotificationService.js'; import { NotificationService } from '@/core/NotificationService.js';
import { bindThis } from '@/decorators.js'; import { bindThis } from '@/decorators.js';
@ -25,13 +27,16 @@ import { AuthenticateService, AuthenticationError } from './AuthenticateService.
import MainStreamConnection from './stream/Connection.js'; import MainStreamConnection from './stream/Connection.js';
import { ChannelsService } from './stream/ChannelsService.js'; import { ChannelsService } from './stream/ChannelsService.js';
import type * as http from 'node:http'; import type * as http from 'node:http';
import type { IEndpointMeta } from './endpoints.js';
import type { Config } from "@/config.js"; // Maximum number of simultaneous connections by client (user ID or IP address).
// Excess connections will be closed automatically.
const MAX_CONNECTIONS_PER_CLIENT = 32;
@Injectable() @Injectable()
export class StreamingApiServerService { export class StreamingApiServerService {
#wss: WebSocket.WebSocketServer; #wss: WebSocket.WebSocketServer;
#connections = new Map<WebSocket.WebSocket, number>(); #connections = new Map<WebSocket.WebSocket, number>();
#connectionsByClient = new Map<string, Set<WebSocket.WebSocket>>(); // key: IP / user ID -> value: connection
#cleanConnectionsIntervalId: NodeJS.Timeout | null = null; #cleanConnectionsIntervalId: NodeJS.Timeout | null = null;
constructor( constructor(
@ -58,17 +63,9 @@ export class StreamingApiServerService {
@bindThis @bindThis
private async rateLimitThis( private async rateLimitThis(
user: MiLocalUser | null | undefined, limitActor: MiUser | string,
requestIp: string, limit: Keyed<RateLimit>,
limit: IEndpointMeta['limit'] & { key: NonNullable<string> },
) : Promise<boolean> { ) : Promise<boolean> {
let limitActor: string | MiLocalUser;
if (user) {
limitActor = user;
} else {
limitActor = getIpHash(requestIp);
}
// Rate limit // Rate limit
const rateLimit = await this.rateLimiterService.limit(limit, limitActor); const rateLimit = await this.rateLimiterService.limit(limit, limitActor);
return rateLimit.blocked; return rateLimit.blocked;
@ -88,21 +85,6 @@ export class StreamingApiServerService {
return; return;
} }
// ServerServices sets `trustProxy: true`, which inside
// fastify/request.js ends up calling `proxyAddr` in this way,
// so we do the same
const requestIp = proxyAddr(request, () => { return true; } );
if (await this.rateLimitThis(null, requestIp, {
key: 'wsconnect',
duration: ms('5min'),
max: 32,
})) {
socket.write('HTTP/1.1 429 Rate Limit Exceeded\r\n\r\n');
socket.destroy();
return;
}
const q = new URL(request.url, `http://${request.headers.host}`).searchParams; const q = new URL(request.url, `http://${request.headers.host}`).searchParams;
let user: MiLocalUser | null = null; let user: MiLocalUser | null = null;
@ -140,15 +122,48 @@ export class StreamingApiServerService {
return; return;
} }
// ServerServices sets `trustProxy: true`, which inside fastify/request.js ends up calling `proxyAddr` in this way, so we do the same.
const requestIp = proxyAddr(request, () => true );
const limitActor = user?.id ?? getIpHash(requestIp);
if (await this.rateLimitThis(limitActor, {
key: 'wsconnect',
duration: ms('5min'),
max: 32,
})) {
socket.write('HTTP/1.1 429 Rate Limit Exceeded\r\n\r\n');
socket.destroy();
return;
}
// For performance and code simplicity, obtain and hold this reference for the lifetime of the connection.
// This should be safe because the map entry should only be deleted after *all* connections close.
let connectionsForClient = this.#connectionsByClient.get(limitActor);
if (!connectionsForClient) {
connectionsForClient = new Set();
this.#connectionsByClient.set(limitActor, connectionsForClient);
}
// Close excess connections
while (connectionsForClient.size >= MAX_CONNECTIONS_PER_CLIENT) {
// Set maintains insertion order, so first entry is the oldest.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const oldestConnection = connectionsForClient.values().next().value!;
// Technically, the close() handler should remove this entry.
// But if that ever fails, then we could enter an infinite loop.
// We manually remove the connection here just in case.
oldestConnection.close(1008, 'Disconnected - too many simultaneous connections');
connectionsForClient.delete(oldestConnection);
}
const rateLimiter = () => { const rateLimiter = () => {
// rather high limit, because when catching up at the top of a // Rather high limit because when catching up at the top of a timeline, the frontend may render many many notes.
// timeline, the frontend may render many many notes, each of // Each of which causes a message via `useNoteCapture` to ask for realtime updates of that note.
// which causes a message via `useNoteCapture` to ask for return this.rateLimitThis(limitActor, {
// realtime updates of that note type: 'bucket',
return this.rateLimitThis(user, requestIp, {
key: 'wsmessage', key: 'wsmessage',
duration: ms('2sec'), size: 4096, // Allow spikes of up to 4096
max: 4096, dripRate: 50, // Then once every 50ms (20/second rate)
}); });
}; };
@ -166,6 +181,19 @@ export class StreamingApiServerService {
await stream.init(); await stream.init();
this.#wss.handleUpgrade(request, socket, head, (ws) => { this.#wss.handleUpgrade(request, socket, head, (ws) => {
connectionsForClient.add(ws);
// Call before emit() in case it throws an error.
// We don't want to leave dangling references!
ws.once('close', () => {
connectionsForClient.delete(ws);
// Make sure we don't leak the Set objects!
if (connectionsForClient.size < 1) {
this.#connectionsByClient.delete(limitActor);
}
});
this.#wss.emit('connection', ws, request, { this.#wss.emit('connection', ws, request, {
stream, user, app, stream, user, app,
}); });

View file

@ -23,6 +23,8 @@ import type { EventEmitter } from 'events';
import type Channel from './channel.js'; import type Channel from './channel.js';
const MAX_CHANNELS_PER_CONNECTION = 32; const MAX_CHANNELS_PER_CONNECTION = 32;
const MAX_SUBSCRIPTIONS_PER_CONNECTION = 512;
const MAX_CACHED_NOTES_PER_CONNECTION = 64;
/** /**
* Main stream connection * Main stream connection
@ -31,12 +33,11 @@ const MAX_CHANNELS_PER_CONNECTION = 32;
export default class Connection { export default class Connection {
public user?: MiUser; public user?: MiUser;
public token?: MiAccessToken; public token?: MiAccessToken;
private rateLimiter?: () => Promise<boolean>;
private wsConnection: WebSocket.WebSocket; private wsConnection: WebSocket.WebSocket;
public subscriber: StreamEventEmitter; public subscriber: StreamEventEmitter;
private channels: Channel[] = []; private channels = new Map<string, Channel>();
private subscribingNotes: Partial<Record<string, number>> = {}; private subscribingNotes = new Map<string, number>();
private cachedNotes: Packed<'Note'>[] = []; private cachedNotes = new Map<string, Packed<'Note'>>();
public userProfile: MiUserProfile | null = null; public userProfile: MiUserProfile | null = null;
public following: Record<string, Pick<MiFollowing, 'withReplies'> | undefined> = {}; public following: Record<string, Pick<MiFollowing, 'withReplies'> | undefined> = {};
public followingChannels: Set<string> = new Set(); public followingChannels: Set<string> = new Set();
@ -45,7 +46,6 @@ export default class Connection {
public userIdsWhoMeMutingRenotes: Set<string> = new Set(); public userIdsWhoMeMutingRenotes: Set<string> = new Set();
public userMutedInstances: Set<string> = new Set(); public userMutedInstances: Set<string> = new Set();
private fetchIntervalId: NodeJS.Timeout | null = null; private fetchIntervalId: NodeJS.Timeout | null = null;
private activeRateLimitRequests = 0;
private closingConnection = false; private closingConnection = false;
private logger: Logger; private logger: Logger;
@ -60,11 +60,10 @@ export default class Connection {
user: MiUser | null | undefined, user: MiUser | null | undefined,
token: MiAccessToken | null | undefined, token: MiAccessToken | null | undefined,
private ip: string, private ip: string,
rateLimiter: () => Promise<boolean>, private readonly rateLimiter: () => Promise<boolean>,
) { ) {
if (user) this.user = user; if (user) this.user = user;
if (token) this.token = token; if (token) this.token = token;
if (rateLimiter) this.rateLimiter = rateLimiter;
this.logger = loggerService.getLogger('streaming', 'coral'); this.logger = loggerService.getLogger('streaming', 'coral');
} }
@ -121,26 +120,14 @@ export default class Connection {
if (this.closingConnection) return; if (this.closingConnection) return;
if (this.rateLimiter) { // The rate limit is very high, so we can safely disconnect any client that hits it.
// this 4096 should match the `max` of the `rateLimiter`, see if (await this.rateLimiter()) {
// StreamingApiServerService this.logger.warn(`Closing a connection from ${this.ip} (user=${this.user?.id}}) due to an excessive influx of messages.`);
if (this.activeRateLimitRequests <= 4096) {
this.activeRateLimitRequests++;
const shouldRateLimit = await this.rateLimiter();
this.activeRateLimitRequests--;
if (shouldRateLimit) return;
if (this.closingConnection) return;
} else {
let connectionInfo = `IP ${this.ip}`;
if (this.user) connectionInfo += `, user ID ${this.user.id}`;
this.logger.warn(`Closing a connection (${connectionInfo}) due to an excessive influx of messages.`);
this.closingConnection = true; this.closingConnection = true;
this.wsConnection.close(1008, 'Please stop spamming the streaming API.'); this.wsConnection.close(1008, 'Disconnected - too many requests');
return; return;
} }
}
try { try {
obj = JSON.parse(data.toString()); obj = JSON.parse(data.toString());
@ -172,15 +159,13 @@ export default class Connection {
@bindThis @bindThis
public cacheNote(note: Packed<'Note'>) { public cacheNote(note: Packed<'Note'>) {
const add = (note: Packed<'Note'>) => { const add = (note: Packed<'Note'>) => {
const existIndex = this.cachedNotes.findIndex(n => n.id === note.id); this.cachedNotes.set(note.id, note);
if (existIndex > -1) {
this.cachedNotes[existIndex] = note;
return;
}
this.cachedNotes.unshift(note); while (this.cachedNotes.size > MAX_CACHED_NOTES_PER_CONNECTION) {
if (this.cachedNotes.length > 32) { // Map maintains insertion order, so first key is always the oldest
this.cachedNotes.splice(32); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const oldestKey = this.cachedNotes.keys().next().value!;
this.cachedNotes.delete(oldestKey);
} }
}; };
@ -192,9 +177,9 @@ export default class Connection {
@bindThis @bindThis
private readNote(body: JsonValue | undefined) { private readNote(body: JsonValue | undefined) {
if (!isJsonObject(body)) return; if (!isJsonObject(body)) return;
const id = body.id; const id = body.id as string;
const note = this.cachedNotes.find(n => n.id === id); const note = this.cachedNotes.get(id);
if (note == null) return; if (note == null) return;
if (this.user && (note.userId !== this.user.id)) { if (this.user && (note.userId !== this.user.id)) {
@ -215,9 +200,19 @@ export default class Connection {
if (!isJsonObject(payload)) return; if (!isJsonObject(payload)) return;
if (!payload.id || typeof payload.id !== 'string') return; if (!payload.id || typeof payload.id !== 'string') return;
const current = this.subscribingNotes[payload.id] ?? 0; const current = this.subscribingNotes.get(payload.id) ?? 0;
const updated = current + 1; const updated = current + 1;
this.subscribingNotes[payload.id] = updated; this.subscribingNotes.set(payload.id, updated);
// Limit the number of distinct notes that can be subscribed to.
while (this.subscribingNotes.size > MAX_SUBSCRIPTIONS_PER_CONNECTION) {
// Map maintains insertion order, so first key is always the oldest
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const oldestKey = this.subscribingNotes.keys().next().value!;
this.subscribingNotes.delete(oldestKey);
this.subscriber.off(`noteStream:${oldestKey}`, this.onNoteStreamMessage);
}
if (updated === 1) { if (updated === 1) {
this.subscriber.on(`noteStream:${payload.id}`, this.onNoteStreamMessage); this.subscriber.on(`noteStream:${payload.id}`, this.onNoteStreamMessage);
@ -232,12 +227,12 @@ export default class Connection {
if (!isJsonObject(payload)) return; if (!isJsonObject(payload)) return;
if (!payload.id || typeof payload.id !== 'string') return; if (!payload.id || typeof payload.id !== 'string') return;
const current = this.subscribingNotes[payload.id]; const current = this.subscribingNotes.get(payload.id);
if (current == null) return; if (current == null) return;
const updated = current - 1; const updated = current - 1;
this.subscribingNotes[payload.id] = updated; this.subscribingNotes.set(payload.id, updated);
if (updated <= 0) { if (updated <= 0) {
delete this.subscribingNotes[payload.id]; this.subscribingNotes.delete(payload.id);
this.subscriber.off(`noteStream:${payload.id}`, this.onNoteStreamMessage); this.subscriber.off(`noteStream:${payload.id}`, this.onNoteStreamMessage);
} }
} }
@ -304,7 +299,11 @@ export default class Connection {
*/ */
@bindThis @bindThis
public connectChannel(id: string, params: JsonObject | undefined, channel: string, pong = false) { public connectChannel(id: string, params: JsonObject | undefined, channel: string, pong = false) {
if (this.channels.length >= MAX_CHANNELS_PER_CONNECTION) { if (this.channels.has(id)) {
this.disconnectChannel(id);
}
if (this.channels.size >= MAX_CHANNELS_PER_CONNECTION) {
return; return;
} }
@ -320,12 +319,16 @@ export default class Connection {
} }
// 共有可能チャンネルに接続しようとしていて、かつそのチャンネルに既に接続していたら無意味なので無視 // 共有可能チャンネルに接続しようとしていて、かつそのチャンネルに既に接続していたら無意味なので無視
if (channelService.shouldShare && this.channels.some(c => c.chName === channel)) { if (channelService.shouldShare) {
for (const c of this.channels.values()) {
if (c.chName === channel) {
return; return;
} }
}
}
const ch: Channel = channelService.create(id, this); const ch: Channel = channelService.create(id, this);
this.channels.push(ch); this.channels.set(ch.id, ch);
ch.init(params ?? {}); ch.init(params ?? {});
if (pong) { if (pong) {
@ -341,11 +344,11 @@ export default class Connection {
*/ */
@bindThis @bindThis
public disconnectChannel(id: string) { public disconnectChannel(id: string) {
const channel = this.channels.find(c => c.id === id); const channel = this.channels.get(id);
if (channel) { if (channel) {
if (channel.dispose) channel.dispose(); if (channel.dispose) channel.dispose();
this.channels = this.channels.filter(c => c.id !== id); this.channels.delete(id);
} }
} }
@ -360,7 +363,7 @@ export default class Connection {
if (typeof data.type !== 'string') return; if (typeof data.type !== 'string') return;
if (typeof data.body === 'undefined') return; if (typeof data.body === 'undefined') return;
const channel = this.channels.find(c => c.id === data.id); const channel = this.channels.get(data.id);
if (channel != null && channel.onMessage != null) { if (channel != null && channel.onMessage != null) {
channel.onMessage(data.type, data.body); channel.onMessage(data.type, data.body);
} }
@ -372,8 +375,16 @@ export default class Connection {
@bindThis @bindThis
public dispose() { public dispose() {
if (this.fetchIntervalId) clearInterval(this.fetchIntervalId); if (this.fetchIntervalId) clearInterval(this.fetchIntervalId);
for (const c of this.channels.filter(c => c.dispose)) { for (const c of this.channels.values()) {
if (c.dispose) c.dispose(); if (c.dispose) c.dispose();
} }
for (const k of this.subscribingNotes.keys()) {
this.subscriber.off(`noteStream:${k}`, this.onNoteStreamMessage);
}
this.fetchIntervalId = null;
this.channels.clear();
this.subscribingNotes.clear();
this.cachedNotes.clear();
} }
} }