ws

Log

Files

Refs

LICENSE

ws.c (11719B)

     1 #include "ws.h"
     2 #include "libsheepyObject.h"
     3 // from libsheepy.h:
     4 // internal
     5 // u8, u16, u32, u64
     6 // randomUrandomOpen, randomWord, randomUrandomClose
     7 // bCatS, eqS
     8 // logE
     9 // range
    10 // MIN, MIN3
    11 //
    12 // from libsheepyObject.h:
    13 // findG, indexOfG
    14 // setG
    15 
    16 #include <arpa/inet.h>
    17 
    18 /* enable/disable logging */
    19 #undef pLog
    20 #define pLog(...)
    21 
    22 #define BASE64_ENCODED_SIZE(len) (len+2)/3*4+1
    23 
    24 const char *BASE64_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
    25 
    26 void _base64_encode_triple(u8 triple[3], char result[4]) {
    27     u32 tripleValue;
    28 
    29     tripleValue = triple[0];
    30     tripleValue *= 256;
    31     tripleValue += triple[1];
    32     tripleValue *= 256;
    33     tripleValue += triple[2];
    34 
    35     range(i,4) {
    36         result[3-i] = BASE64_CHARS[tripleValue%64];
    37         tripleValue /= 64;
    38     }
    39 }
    40 
    41 bool base64_encode(u8 *source, size_t sourcelen, char *target, size_t targetlen) {
    42 
    43     if ((sourcelen+2)/3*4 > targetlen-1)
    44        return false;
    45 
    46     while (sourcelen >= 3) {
    47         _base64_encode_triple(source, target);
    48         sourcelen -= 3;
    49         source += 3;
    50         target += 4;
    51     }
    52 
    53     if (sourcelen > 0) {
    54         u8 temp[3];
    55         memset(temp, 0, sizeof(temp));
    56         memcpy(temp, source, sourcelen);
    57         _base64_encode_triple(temp, target);
    58         target[3] = '=';
    59         if (sourcelen == 1)
    60             target[2] = '=';
    61 
    62         target += 4;
    63     }
    64 
    65     target[0] = 0;
    66 
    67     return true;
    68 }
    69 
    70 struct sha1 {
    71   u64 len;    /* processed message length */
    72   u32 h[5];   /* hash state */
    73   u8 buf[64]; /* message block buffer */
    74 };
    75 
    76 #define SHA1_DIGEST_LEN 20
    77 
    78 /* reset state */
    79 internal void sha1Init(struct sha1 *ctx);
    80 
    81 /* process message */
    82 internal void sha1Update(struct sha1 *ctx, const void *m, u64 len);
    83 
    84 /* get message digest
    85  * state is ruined after sum, keep a copy if multiple sum is needed
    86  * part of the message might be left in s, zero it if secrecy is needed
    87  */
    88 internal void sha1Final(struct sha1 *ctx, u8 md[SHA1_DIGEST_LEN]);
    89 
    90 internal u32 rol(u32 n, u8 k) { return (n << k) | (n >> (32-k)); }
    91 
    92 #define F0(b,c,d) (d ^ (b & (c ^ d)))
    93 #define F1(b,c,d) (b ^ c ^ d)
    94 #define F2(b,c,d) ((b & c) | (d & (b | c)))
    95 #define F3(b,c,d) (b ^ c ^ d)
    96 #define G0(a,b,c,d,e,i) e += rol(a,5)+F0(b,c,d)+W[i]+0x5A827999; b = rol(b,30)
    97 #define G1(a,b,c,d,e,i) e += rol(a,5)+F1(b,c,d)+W[i]+0x6ED9EBA1; b = rol(b,30)
    98 #define G2(a,b,c,d,e,i) e += rol(a,5)+F2(b,c,d)+W[i]+0x8F1BBCDC; b = rol(b,30)
    99 #define G3(a,b,c,d,e,i) e += rol(a,5)+F3(b,c,d)+W[i]+0xCA62C1D6; b = rol(b,30)
   100 
   101 internal void processblock(struct sha1 *s, const u8 *buf) {
   102   u32 W[80], a, b, c, d, e;
   103   u8 i;
   104 
   105   for (i = 0; i < 16; i++) {
   106     W[i] = (u32)buf[4*i]<<24;
   107     W[i] |= (u32)buf[4*i+1]<<16;
   108     W[i] |= (u32)buf[4*i+2]<<8;
   109     W[i] |= buf[4*i+3];
   110   }
   111   for (; i < 80; i++)
   112     W[i] = rol(W[i-3] ^ W[i-8] ^ W[i-14] ^ W[i-16], 1);
   113   a = s->h[0];
   114   b = s->h[1];
   115   c = s->h[2];
   116   d = s->h[3];
   117   e = s->h[4];
   118   for (i = 0; i < 20; ) {
   119     G0(a,b,c,d,e,i++);
   120     G0(e,a,b,c,d,i++);
   121     G0(d,e,a,b,c,i++);
   122     G0(c,d,e,a,b,i++);
   123     G0(b,c,d,e,a,i++);
   124   }
   125   while (i < 40) {
   126     G1(a,b,c,d,e,i++);
   127     G1(e,a,b,c,d,i++);
   128     G1(d,e,a,b,c,i++);
   129     G1(c,d,e,a,b,i++);
   130     G1(b,c,d,e,a,i++);
   131   }
   132   while (i < 60) {
   133     G2(a,b,c,d,e,i++);
   134     G2(e,a,b,c,d,i++);
   135     G2(d,e,a,b,c,i++);
   136     G2(c,d,e,a,b,i++);
   137     G2(b,c,d,e,a,i++);
   138   }
   139   while (i < 80) {
   140     G3(a,b,c,d,e,i++);
   141     G3(e,a,b,c,d,i++);
   142     G3(d,e,a,b,c,i++);
   143     G3(c,d,e,a,b,i++);
   144     G3(b,c,d,e,a,i++);
   145   }
   146   s->h[0] += a;
   147   s->h[1] += b;
   148   s->h[2] += c;
   149   s->h[3] += d;
   150   s->h[4] += e;
   151 }
   152 
   153 internal void pad(struct sha1 *s) {
   154   u8 r = s->len % 64;
   155 
   156   s->buf[r++] = 0x80;
   157   if (r > 56) {
   158     memset(s->buf + r, 0, 64 - r);
   159     r = 0;
   160     processblock(s, s->buf);
   161   }
   162   memset(s->buf + r, 0, 56 - r);
   163   s->len *= 8;
   164   s->buf[56] = s->len >> 56;
   165   s->buf[57] = s->len >> 48;
   166   s->buf[58] = s->len >> 40;
   167   s->buf[59] = s->len >> 32;
   168   s->buf[60] = s->len >> 24;
   169   s->buf[61] = s->len >> 16;
   170   s->buf[62] = s->len >> 8;
   171   s->buf[63] = s->len;
   172   processblock(s, s->buf);
   173 }
   174 
   175 internal void sha1Init(struct sha1 *s) {
   176   s->len = 0;
   177   s->h[0] = 0x67452301;
   178   s->h[1] = 0xEFCDAB89;
   179   s->h[2] = 0x98BADCFE;
   180   s->h[3] = 0x10325476;
   181   s->h[4] = 0xC3D2E1F0;
   182 }
   183 
   184 internal void sha1Final(struct sha1 *s,  u8 md[SHA1_DIGEST_LEN]) {
   185 
   186   pad(s);
   187   range(i,5) {
   188     md[4*i] = s->h[i] >> 24;
   189     md[4*i+1] = s->h[i] >> 16;
   190     md[4*i+2] = s->h[i] >> 8;
   191     md[4*i+3] = s->h[i];
   192   }
   193 }
   194 
   195 internal void sha1Update(struct sha1 *s,  const void *m, u64 len) {
   196   const u8 *p = m;
   197   u8 r = s->len % 64;
   198 
   199   s->len += len;
   200   if (r) {
   201     if (len < 64 - r) {
   202       memcpy(s->buf + r, p, len);
   203       return;
   204     }
   205     memcpy(s->buf + r, p, 64 - r);
   206     len -= 64 - r;
   207     p += 64 - r;
   208     processblock(s, s->buf);
   209   }
   210   for (; len >= 64; len -= 64, p += 64)
   211     processblock(s, p);
   212   memcpy(s->buf, p, len);
   213 }
   214 
   215 internal u64 keyRnd[2];
   216 internal char clientKey[BASE64_ENCODED_SIZE(sizeof(keyRnd))];
   217 
   218 internal const char handshake[] =
   219     "GET /connect HTTP/1.1\r\n"
   220     "Host: %s\r\n"
   221     "Upgrade: websocket\r\n"
   222     "Connection: Upgrade\r\n"
   223     "Sec-WebSocket-Key: %s\r\n"
   224     "Sec-WebSocket-Version: 13\r\n"
   225     "\r\n";
   226 
   227 internal const char serverResponse[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
   228 
   229 size_t wsHandshakeSize(char *hostname) {
   230   return (sizeof(handshake) -1 -4) + BASE64_ENCODED_SIZE(sizeof(keyRnd))-1 + strlen(hostname) +1;
   231 }
   232 
   233 int wsHandshake(char *frame, size_t size, char *hostname) {
   234 
   235   randomUrandomOpen();
   236   keyRnd[0] = randomWord();
   237   keyRnd[1] = randomWord();
   238   randomUrandomClose();
   239 
   240   base64_encode((u8*)keyRnd, sizeof(keyRnd), clientKey, sizeof(clientKey));
   241 
   242   snprintf(frame, size, "GET /connect HTTP/1.1\r\n"
   243     "Host: %s\r\n"
   244     "Upgrade: websocket\r\n"
   245     "Connection: Upgrade\r\n"
   246     "Sec-WebSocket-Key: %s\r\n"
   247     "Sec-WebSocket-Version: 13\r\n"
   248     "\r\n", hostname, clientKey);
   249   return 0;
   250 }
   251 
   252 bool wsHanskakeCheck(char *frame, size_t size) {
   253   char *srvRes;
   254 
   255   if (srvRes = findG(frame, "sec-websocket-accept: ")) {
   256     srvRes += 22;
   257     int idx = indexOfG(srvRes, "\r\n");
   258 
   259     if ((srvRes - frame + idx) > size)
   260       return false;
   261 
   262     setG(srvRes, idx, 0);
   263 
   264     // createAcceptKey
   265     char s[128];
   266     bCatS(s, clientKey, serverResponse);
   267 
   268     struct sha1 shc;
   269     sha1Init(&shc);
   270     sha1Update(&shc, s, BASE64_ENCODED_SIZE(sizeof(keyRnd)) -1 + sizeof(serverResponse) -1);
   271     u8 digest[SHA1_DIGEST_LEN];
   272     sha1Final(&shc, digest);
   273 
   274     char base64Digest[800];
   275     base64_encode(digest, sizeof(digest), base64Digest, sizeof(base64Digest));
   276     if (!eqS(srvRes, base64Digest)) {
   277       logE("Bad server response!");
   278       return false;
   279     }
   280   }
   281   else {
   282     return false;
   283   }
   284   return true;
   285 }
   286 
   287 #define WS_FINAL_FRAME  1 << 7
   288 #define WS_OP_MASK  0xF
   289 #define WS_MASK 1 << 7
   290 #define WS_HEADER_SIZE 2
   291 #define WS_MASK_SIZE 4
   292 #define WS_LEN_MASK 0x7F
   293 
   294 #define WS_PAYLOAD_EXTEND_1 126
   295 #define WS_PAYLOAD_EXTEND_2 127
   296 
   297 #define WS_LEN_SIZE_1 2
   298 #define WS_LEN_SIZE_2 8
   299 
   300 // 1 if opcode is control frame opcode, otherwise 0
   301 #define isCtlFrame(opcode) ((opcode >> 3) & 1)
   302 
   303 ssize_t wsMaskSize(wsOpt op, size_t pSize) {
   304   if (isCtlFrame(op) && pSize >= WS_PAYLOAD_EXTEND_1) return -1;
   305   if (pSize < WS_PAYLOAD_EXTEND_1)                    return pSize + WS_HEADER_SIZE + WS_MASK_SIZE;
   306   if (pSize < 1 << 16)                                return pSize + WS_HEADER_SIZE + WS_MASK_SIZE + WS_LEN_SIZE_1;
   307   return pSize + WS_HEADER_SIZE + WS_MASK_SIZE + WS_LEN_SIZE_2;
   308 }
   309 
   310 ssize_t wsNoMaskSize(wsOpt op, size_t pSize) {
   311   if (isCtlFrame(op) && pSize >= WS_PAYLOAD_EXTEND_1) return -1;
   312   if (pSize < WS_PAYLOAD_EXTEND_1)                    return pSize + WS_HEADER_SIZE;
   313   if (pSize < 1 << 16)                                return pSize + WS_HEADER_SIZE + WS_LEN_SIZE_1;
   314   return pSize + WS_HEADER_SIZE + WS_LEN_SIZE_2;
   315 }
   316 
   317 ssize_t wsControlMaskSize(void) {
   318   return WS_HEADER_SIZE + WS_MASK_SIZE;
   319 }
   320 
   321 ssize_t wsControlNoMaskSize(void) {
   322   return WS_HEADER_SIZE;
   323 }
   324 
   325 bool wsMask(char *frame, size_t size, wsOpt op, char *payload, size_t pSize, bool final) {
   326   #define WS_STEP1\
   327   ssize_t frameSize = wsMaskSize(op, pSize);\
   328   \
   329   if (frameSize == -1 || frameSize > size) return false;\
   330   \
   331   zeroBuf(frame, frameSize);\
   332   \
   333   if (final) frame[0] = WS_FINAL_FRAME | op;\
   334   else       frame[0] = op;
   335   WS_STEP1;
   336 
   337   u8 wsPLen, maskOffset;
   338   if (pSize < WS_PAYLOAD_EXTEND_1) {
   339     wsPLen     = pSize;
   340     maskOffset = WS_HEADER_SIZE;
   341   }
   342   else if (pSize < 1 << 16) {
   343     wsPLen     = WS_PAYLOAD_EXTEND_1;
   344     maskOffset = WS_HEADER_SIZE + WS_LEN_SIZE_1;
   345     u16 *len   = (u16*)&frame[2];
   346     *len       = htons((u16)pSize);
   347   }
   348   else {
   349     wsPLen     = WS_PAYLOAD_EXTEND_2;
   350     maskOffset = WS_HEADER_SIZE + WS_LEN_SIZE_2;
   351     u64 *len   = (u64*)&frame[2];
   352     *len       = htobe64(pSize);
   353 
   354   }
   355   frame[1] = WS_MASK | wsPLen;
   356   // mask key 4 bytes, random = always 0
   357 
   358   u8 payloadOffset = maskOffset + WS_MASK_SIZE;
   359 
   360   strncpy(frame+payloadOffset, payload, pSize);
   361 
   362   range(i, pSize) {
   363     frame[i + payloadOffset] ^= frame[maskOffset + (i % 4)];
   364   }
   365 
   366   return true;
   367 }
   368 
   369 bool wsNoMask(char *frame, size_t size, wsOpt op, char *payload, size_t pSize, bool final) {
   370   WS_STEP1;
   371 
   372   u8 wsPLen, payloadOffset;
   373   if (pSize < WS_PAYLOAD_EXTEND_1) {
   374     wsPLen        = pSize;
   375     payloadOffset = WS_HEADER_SIZE;
   376   }
   377   else if (pSize < 1 << 16) {
   378     wsPLen        = WS_PAYLOAD_EXTEND_1;
   379     payloadOffset = WS_HEADER_SIZE + WS_LEN_SIZE_1;
   380     u16 *len   = (u16*)&frame[2];
   381     *len       = htons((u16)pSize);
   382   }
   383   else {
   384     wsPLen        = WS_PAYLOAD_EXTEND_2;
   385     payloadOffset = WS_HEADER_SIZE + WS_LEN_SIZE_2;
   386     u64 *len   = (u64*)&frame[2];
   387     *len       = htobe64(pSize);
   388 
   389   }
   390   frame[1] = wsPLen;
   391 
   392   strncpy(frame+payloadOffset, payload, pSize);
   393 
   394   return true;
   395 }
   396 
   397 #define genControlFrame(SIZE_FUNC) \
   398   ssize_t frameSize = SIZE_FUNC(op, 0);\
   399   zeroBuf(frame, frameSize);\
   400   frame[0] = WS_FINAL_FRAME | op
   401 
   402 bool wsControlMask(char *frame, wsOpt op) {
   403 
   404   if (!isCtlFrame(op)) return false;
   405   genControlFrame(wsMaskSize);
   406   frame[1] = WS_MASK;
   407 
   408   return true;
   409 }
   410 
   411 bool wsControlNoMask(char *frame, wsOpt op) {
   412 
   413   if (!isCtlFrame(op)) return false;
   414   genControlFrame(wsNoMaskSize);
   415 
   416   return true;
   417 }
   418 
   419 
   420 size_t wsDecodeSize(char *frame) {
   421   u8 len = frame[1] & WS_LEN_MASK;
   422   switch(len) {
   423     case WS_PAYLOAD_EXTEND_1:;
   424       u16 len1 = *(u16*)&frame[2];
   425       return ntohs(len1);
   426     case WS_PAYLOAD_EXTEND_2:;
   427       u64 len2 = *(u64*)&frame[2];
   428       return be64toh(len2);
   429     default:
   430       return len;
   431   }
   432 }
   433 
   434 void wsDecode(char *data, size_t size, char *frame, size_t fSize) {
   435 
   436   size_t sz     = MIN3(wsDecodeSize(frame), size, fSize);
   437   char *payload = wsDecodePayOffset(frame);
   438   strncpy(data, payload, sz);
   439 
   440   if (wsIsMasked(frame)) {
   441     char *mask = payload -4;
   442     range(i, sz) {
   443       data[i] ^= *(mask + (i % 4));
   444     }
   445   }
   446 }
   447 
   448 char *wsDecodePayOffset(char *frame) {
   449   char *r;
   450   r = frame + WS_HEADER_SIZE;
   451   u8 len = frame[1] & WS_LEN_MASK;
   452   if (len == WS_PAYLOAD_EXTEND_1) {
   453     r += WS_LEN_SIZE_1;
   454   }
   455   else if (len == WS_PAYLOAD_EXTEND_2) {
   456     r += WS_LEN_SIZE_2;
   457   }
   458   if (wsIsMasked(frame)) r += WS_MASK_SIZE;
   459   return r;
   460 }
   461 
   462 char *wsDecodeInPlace(char *frame, size_t size) {
   463 
   464   char *r = wsDecodePayOffset(frame);
   465   if (wsIsMasked(frame)) {
   466     char *mask = r -4;
   467     size_t sz  = MIN(wsDecodeSize(frame), size);
   468     range(i, sz) {
   469       *(r+i) ^= *(mask + (i % 4));
   470     }
   471   }
   472   return r;
   473 }
   474 
   475 bool wsIsFinal(char *frame) {
   476   return (frame[0] & WS_FINAL_FRAME) == WS_FINAL_FRAME;
   477 }
   478 
   479 bool wsIsMasked(char *frame) {
   480   return (frame[1] & WS_MASK) == WS_MASK;
   481 }
   482 
   483 wsOpt wsDecodeOp(char *frame) {
   484   return (wsOpt) frame[0] & WS_OP_MASK;
   485 }
   486 
   487 bool checkLibsheepyVersionWs(const char *currentLibsheepyVersion) {
   488   return eqG(currentLibsheepyVersion, LIBSHEEPY_VERSION);
   489 }
   490