ws.c (11719B)
1 #include "ws.h" 2 #include "libsheepyObject.h" 3 // from libsheepy.h: 4 // internal 5 // u8, u16, u32, u64 6 // randomUrandomOpen, randomWord, randomUrandomClose 7 // bCatS, eqS 8 // logE 9 // range 10 // MIN, MIN3 11 // 12 // from libsheepyObject.h: 13 // findG, indexOfG 14 // setG 15 16 #include <arpa/inet.h> 17 18 /* enable/disable logging */ 19 #undef pLog 20 #define pLog(...) 21 22 #define BASE64_ENCODED_SIZE(len) (len+2)/3*4+1 23 24 const char *BASE64_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; 25 26 void _base64_encode_triple(u8 triple[3], char result[4]) { 27 u32 tripleValue; 28 29 tripleValue = triple[0]; 30 tripleValue *= 256; 31 tripleValue += triple[1]; 32 tripleValue *= 256; 33 tripleValue += triple[2]; 34 35 range(i,4) { 36 result[3-i] = BASE64_CHARS[tripleValue%64]; 37 tripleValue /= 64; 38 } 39 } 40 41 bool base64_encode(u8 *source, size_t sourcelen, char *target, size_t targetlen) { 42 43 if ((sourcelen+2)/3*4 > targetlen-1) 44 return false; 45 46 while (sourcelen >= 3) { 47 _base64_encode_triple(source, target); 48 sourcelen -= 3; 49 source += 3; 50 target += 4; 51 } 52 53 if (sourcelen > 0) { 54 u8 temp[3]; 55 memset(temp, 0, sizeof(temp)); 56 memcpy(temp, source, sourcelen); 57 _base64_encode_triple(temp, target); 58 target[3] = '='; 59 if (sourcelen == 1) 60 target[2] = '='; 61 62 target += 4; 63 } 64 65 target[0] = 0; 66 67 return true; 68 } 69 70 struct sha1 { 71 u64 len; /* processed message length */ 72 u32 h[5]; /* hash state */ 73 u8 buf[64]; /* message block buffer */ 74 }; 75 76 #define SHA1_DIGEST_LEN 20 77 78 /* reset state */ 79 internal void sha1Init(struct sha1 *ctx); 80 81 /* process message */ 82 internal void sha1Update(struct sha1 *ctx, const void *m, u64 len); 83 84 /* get message digest 85 * state is ruined after sum, keep a copy if multiple sum is needed 86 * part of the message might be left in s, zero it if secrecy is needed 87 */ 88 internal void sha1Final(struct sha1 *ctx, u8 md[SHA1_DIGEST_LEN]); 89 90 internal u32 rol(u32 n, u8 k) { return (n << k) | (n >> (32-k)); } 91 92 #define F0(b,c,d) (d ^ (b & (c ^ d))) 93 #define F1(b,c,d) (b ^ c ^ d) 94 #define F2(b,c,d) ((b & c) | (d & (b | c))) 95 #define F3(b,c,d) (b ^ c ^ d) 96 #define G0(a,b,c,d,e,i) e += rol(a,5)+F0(b,c,d)+W[i]+0x5A827999; b = rol(b,30) 97 #define G1(a,b,c,d,e,i) e += rol(a,5)+F1(b,c,d)+W[i]+0x6ED9EBA1; b = rol(b,30) 98 #define G2(a,b,c,d,e,i) e += rol(a,5)+F2(b,c,d)+W[i]+0x8F1BBCDC; b = rol(b,30) 99 #define G3(a,b,c,d,e,i) e += rol(a,5)+F3(b,c,d)+W[i]+0xCA62C1D6; b = rol(b,30) 100 101 internal void processblock(struct sha1 *s, const u8 *buf) { 102 u32 W[80], a, b, c, d, e; 103 u8 i; 104 105 for (i = 0; i < 16; i++) { 106 W[i] = (u32)buf[4*i]<<24; 107 W[i] |= (u32)buf[4*i+1]<<16; 108 W[i] |= (u32)buf[4*i+2]<<8; 109 W[i] |= buf[4*i+3]; 110 } 111 for (; i < 80; i++) 112 W[i] = rol(W[i-3] ^ W[i-8] ^ W[i-14] ^ W[i-16], 1); 113 a = s->h[0]; 114 b = s->h[1]; 115 c = s->h[2]; 116 d = s->h[3]; 117 e = s->h[4]; 118 for (i = 0; i < 20; ) { 119 G0(a,b,c,d,e,i++); 120 G0(e,a,b,c,d,i++); 121 G0(d,e,a,b,c,i++); 122 G0(c,d,e,a,b,i++); 123 G0(b,c,d,e,a,i++); 124 } 125 while (i < 40) { 126 G1(a,b,c,d,e,i++); 127 G1(e,a,b,c,d,i++); 128 G1(d,e,a,b,c,i++); 129 G1(c,d,e,a,b,i++); 130 G1(b,c,d,e,a,i++); 131 } 132 while (i < 60) { 133 G2(a,b,c,d,e,i++); 134 G2(e,a,b,c,d,i++); 135 G2(d,e,a,b,c,i++); 136 G2(c,d,e,a,b,i++); 137 G2(b,c,d,e,a,i++); 138 } 139 while (i < 80) { 140 G3(a,b,c,d,e,i++); 141 G3(e,a,b,c,d,i++); 142 G3(d,e,a,b,c,i++); 143 G3(c,d,e,a,b,i++); 144 G3(b,c,d,e,a,i++); 145 } 146 s->h[0] += a; 147 s->h[1] += b; 148 s->h[2] += c; 149 s->h[3] += d; 150 s->h[4] += e; 151 } 152 153 internal void pad(struct sha1 *s) { 154 u8 r = s->len % 64; 155 156 s->buf[r++] = 0x80; 157 if (r > 56) { 158 memset(s->buf + r, 0, 64 - r); 159 r = 0; 160 processblock(s, s->buf); 161 } 162 memset(s->buf + r, 0, 56 - r); 163 s->len *= 8; 164 s->buf[56] = s->len >> 56; 165 s->buf[57] = s->len >> 48; 166 s->buf[58] = s->len >> 40; 167 s->buf[59] = s->len >> 32; 168 s->buf[60] = s->len >> 24; 169 s->buf[61] = s->len >> 16; 170 s->buf[62] = s->len >> 8; 171 s->buf[63] = s->len; 172 processblock(s, s->buf); 173 } 174 175 internal void sha1Init(struct sha1 *s) { 176 s->len = 0; 177 s->h[0] = 0x67452301; 178 s->h[1] = 0xEFCDAB89; 179 s->h[2] = 0x98BADCFE; 180 s->h[3] = 0x10325476; 181 s->h[4] = 0xC3D2E1F0; 182 } 183 184 internal void sha1Final(struct sha1 *s, u8 md[SHA1_DIGEST_LEN]) { 185 186 pad(s); 187 range(i,5) { 188 md[4*i] = s->h[i] >> 24; 189 md[4*i+1] = s->h[i] >> 16; 190 md[4*i+2] = s->h[i] >> 8; 191 md[4*i+3] = s->h[i]; 192 } 193 } 194 195 internal void sha1Update(struct sha1 *s, const void *m, u64 len) { 196 const u8 *p = m; 197 u8 r = s->len % 64; 198 199 s->len += len; 200 if (r) { 201 if (len < 64 - r) { 202 memcpy(s->buf + r, p, len); 203 return; 204 } 205 memcpy(s->buf + r, p, 64 - r); 206 len -= 64 - r; 207 p += 64 - r; 208 processblock(s, s->buf); 209 } 210 for (; len >= 64; len -= 64, p += 64) 211 processblock(s, p); 212 memcpy(s->buf, p, len); 213 } 214 215 internal u64 keyRnd[2]; 216 internal char clientKey[BASE64_ENCODED_SIZE(sizeof(keyRnd))]; 217 218 internal const char handshake[] = 219 "GET /connect HTTP/1.1\r\n" 220 "Host: %s\r\n" 221 "Upgrade: websocket\r\n" 222 "Connection: Upgrade\r\n" 223 "Sec-WebSocket-Key: %s\r\n" 224 "Sec-WebSocket-Version: 13\r\n" 225 "\r\n"; 226 227 internal const char serverResponse[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 228 229 size_t wsHandshakeSize(char *hostname) { 230 return (sizeof(handshake) -1 -4) + BASE64_ENCODED_SIZE(sizeof(keyRnd))-1 + strlen(hostname) +1; 231 } 232 233 int wsHandshake(char *frame, size_t size, char *hostname) { 234 235 randomUrandomOpen(); 236 keyRnd[0] = randomWord(); 237 keyRnd[1] = randomWord(); 238 randomUrandomClose(); 239 240 base64_encode((u8*)keyRnd, sizeof(keyRnd), clientKey, sizeof(clientKey)); 241 242 snprintf(frame, size, "GET /connect HTTP/1.1\r\n" 243 "Host: %s\r\n" 244 "Upgrade: websocket\r\n" 245 "Connection: Upgrade\r\n" 246 "Sec-WebSocket-Key: %s\r\n" 247 "Sec-WebSocket-Version: 13\r\n" 248 "\r\n", hostname, clientKey); 249 return 0; 250 } 251 252 bool wsHanskakeCheck(char *frame, size_t size) { 253 char *srvRes; 254 255 if (srvRes = findG(frame, "sec-websocket-accept: ")) { 256 srvRes += 22; 257 int idx = indexOfG(srvRes, "\r\n"); 258 259 if ((srvRes - frame + idx) > size) 260 return false; 261 262 setG(srvRes, idx, 0); 263 264 // createAcceptKey 265 char s[128]; 266 bCatS(s, clientKey, serverResponse); 267 268 struct sha1 shc; 269 sha1Init(&shc); 270 sha1Update(&shc, s, BASE64_ENCODED_SIZE(sizeof(keyRnd)) -1 + sizeof(serverResponse) -1); 271 u8 digest[SHA1_DIGEST_LEN]; 272 sha1Final(&shc, digest); 273 274 char base64Digest[800]; 275 base64_encode(digest, sizeof(digest), base64Digest, sizeof(base64Digest)); 276 if (!eqS(srvRes, base64Digest)) { 277 logE("Bad server response!"); 278 return false; 279 } 280 } 281 else { 282 return false; 283 } 284 return true; 285 } 286 287 #define WS_FINAL_FRAME 1 << 7 288 #define WS_OP_MASK 0xF 289 #define WS_MASK 1 << 7 290 #define WS_HEADER_SIZE 2 291 #define WS_MASK_SIZE 4 292 #define WS_LEN_MASK 0x7F 293 294 #define WS_PAYLOAD_EXTEND_1 126 295 #define WS_PAYLOAD_EXTEND_2 127 296 297 #define WS_LEN_SIZE_1 2 298 #define WS_LEN_SIZE_2 8 299 300 // 1 if opcode is control frame opcode, otherwise 0 301 #define isCtlFrame(opcode) ((opcode >> 3) & 1) 302 303 ssize_t wsMaskSize(wsOpt op, size_t pSize) { 304 if (isCtlFrame(op) && pSize >= WS_PAYLOAD_EXTEND_1) return -1; 305 if (pSize < WS_PAYLOAD_EXTEND_1) return pSize + WS_HEADER_SIZE + WS_MASK_SIZE; 306 if (pSize < 1 << 16) return pSize + WS_HEADER_SIZE + WS_MASK_SIZE + WS_LEN_SIZE_1; 307 return pSize + WS_HEADER_SIZE + WS_MASK_SIZE + WS_LEN_SIZE_2; 308 } 309 310 ssize_t wsNoMaskSize(wsOpt op, size_t pSize) { 311 if (isCtlFrame(op) && pSize >= WS_PAYLOAD_EXTEND_1) return -1; 312 if (pSize < WS_PAYLOAD_EXTEND_1) return pSize + WS_HEADER_SIZE; 313 if (pSize < 1 << 16) return pSize + WS_HEADER_SIZE + WS_LEN_SIZE_1; 314 return pSize + WS_HEADER_SIZE + WS_LEN_SIZE_2; 315 } 316 317 ssize_t wsControlMaskSize(void) { 318 return WS_HEADER_SIZE + WS_MASK_SIZE; 319 } 320 321 ssize_t wsControlNoMaskSize(void) { 322 return WS_HEADER_SIZE; 323 } 324 325 bool wsMask(char *frame, size_t size, wsOpt op, char *payload, size_t pSize, bool final) { 326 #define WS_STEP1\ 327 ssize_t frameSize = wsMaskSize(op, pSize);\ 328 \ 329 if (frameSize == -1 || frameSize > size) return false;\ 330 \ 331 zeroBuf(frame, frameSize);\ 332 \ 333 if (final) frame[0] = WS_FINAL_FRAME | op;\ 334 else frame[0] = op; 335 WS_STEP1; 336 337 u8 wsPLen, maskOffset; 338 if (pSize < WS_PAYLOAD_EXTEND_1) { 339 wsPLen = pSize; 340 maskOffset = WS_HEADER_SIZE; 341 } 342 else if (pSize < 1 << 16) { 343 wsPLen = WS_PAYLOAD_EXTEND_1; 344 maskOffset = WS_HEADER_SIZE + WS_LEN_SIZE_1; 345 u16 *len = (u16*)&frame[2]; 346 *len = htons((u16)pSize); 347 } 348 else { 349 wsPLen = WS_PAYLOAD_EXTEND_2; 350 maskOffset = WS_HEADER_SIZE + WS_LEN_SIZE_2; 351 u64 *len = (u64*)&frame[2]; 352 *len = htobe64(pSize); 353 354 } 355 frame[1] = WS_MASK | wsPLen; 356 // mask key 4 bytes, random = always 0 357 358 u8 payloadOffset = maskOffset + WS_MASK_SIZE; 359 360 strncpy(frame+payloadOffset, payload, pSize); 361 362 range(i, pSize) { 363 frame[i + payloadOffset] ^= frame[maskOffset + (i % 4)]; 364 } 365 366 return true; 367 } 368 369 bool wsNoMask(char *frame, size_t size, wsOpt op, char *payload, size_t pSize, bool final) { 370 WS_STEP1; 371 372 u8 wsPLen, payloadOffset; 373 if (pSize < WS_PAYLOAD_EXTEND_1) { 374 wsPLen = pSize; 375 payloadOffset = WS_HEADER_SIZE; 376 } 377 else if (pSize < 1 << 16) { 378 wsPLen = WS_PAYLOAD_EXTEND_1; 379 payloadOffset = WS_HEADER_SIZE + WS_LEN_SIZE_1; 380 u16 *len = (u16*)&frame[2]; 381 *len = htons((u16)pSize); 382 } 383 else { 384 wsPLen = WS_PAYLOAD_EXTEND_2; 385 payloadOffset = WS_HEADER_SIZE + WS_LEN_SIZE_2; 386 u64 *len = (u64*)&frame[2]; 387 *len = htobe64(pSize); 388 389 } 390 frame[1] = wsPLen; 391 392 strncpy(frame+payloadOffset, payload, pSize); 393 394 return true; 395 } 396 397 #define genControlFrame(SIZE_FUNC) \ 398 ssize_t frameSize = SIZE_FUNC(op, 0);\ 399 zeroBuf(frame, frameSize);\ 400 frame[0] = WS_FINAL_FRAME | op 401 402 bool wsControlMask(char *frame, wsOpt op) { 403 404 if (!isCtlFrame(op)) return false; 405 genControlFrame(wsMaskSize); 406 frame[1] = WS_MASK; 407 408 return true; 409 } 410 411 bool wsControlNoMask(char *frame, wsOpt op) { 412 413 if (!isCtlFrame(op)) return false; 414 genControlFrame(wsNoMaskSize); 415 416 return true; 417 } 418 419 420 size_t wsDecodeSize(char *frame) { 421 u8 len = frame[1] & WS_LEN_MASK; 422 switch(len) { 423 case WS_PAYLOAD_EXTEND_1:; 424 u16 len1 = *(u16*)&frame[2]; 425 return ntohs(len1); 426 case WS_PAYLOAD_EXTEND_2:; 427 u64 len2 = *(u64*)&frame[2]; 428 return be64toh(len2); 429 default: 430 return len; 431 } 432 } 433 434 void wsDecode(char *data, size_t size, char *frame, size_t fSize) { 435 436 size_t sz = MIN3(wsDecodeSize(frame), size, fSize); 437 char *payload = wsDecodePayOffset(frame); 438 strncpy(data, payload, sz); 439 440 if (wsIsMasked(frame)) { 441 char *mask = payload -4; 442 range(i, sz) { 443 data[i] ^= *(mask + (i % 4)); 444 } 445 } 446 } 447 448 char *wsDecodePayOffset(char *frame) { 449 char *r; 450 r = frame + WS_HEADER_SIZE; 451 u8 len = frame[1] & WS_LEN_MASK; 452 if (len == WS_PAYLOAD_EXTEND_1) { 453 r += WS_LEN_SIZE_1; 454 } 455 else if (len == WS_PAYLOAD_EXTEND_2) { 456 r += WS_LEN_SIZE_2; 457 } 458 if (wsIsMasked(frame)) r += WS_MASK_SIZE; 459 return r; 460 } 461 462 char *wsDecodeInPlace(char *frame, size_t size) { 463 464 char *r = wsDecodePayOffset(frame); 465 if (wsIsMasked(frame)) { 466 char *mask = r -4; 467 size_t sz = MIN(wsDecodeSize(frame), size); 468 range(i, sz) { 469 *(r+i) ^= *(mask + (i % 4)); 470 } 471 } 472 return r; 473 } 474 475 bool wsIsFinal(char *frame) { 476 return (frame[0] & WS_FINAL_FRAME) == WS_FINAL_FRAME; 477 } 478 479 bool wsIsMasked(char *frame) { 480 return (frame[1] & WS_MASK) == WS_MASK; 481 } 482 483 wsOpt wsDecodeOp(char *frame) { 484 return (wsOpt) frame[0] & WS_OP_MASK; 485 } 486 487 bool checkLibsheepyVersionWs(const char *currentLibsheepyVersion) { 488 return eqG(currentLibsheepyVersion, LIBSHEEPY_VERSION); 489 } 490