// Copyright 2013-2015 The Gorilla WebSocket Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package websocket import ( "bufio" "errors" "net" "net/http" "net/url" "strings" "time" ) const ( originHeader = "Origin" protocolHeader = "Sec-Websocket-Protocol" websocketVersion = "13" ) // HandshakeError describes an error with the handshake from the peer. type HandshakeError struct { message string } func (e HandshakeError) Error() string { return e.message } // Upgrader specifies parameters for upgrading an HTTP connection to a // WebSocket connection. type Upgrader struct { // HandshakeTimeout specifies the duration for the handshake to complete. HandshakeTimeout time.Duration // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer // size is zero, then a default value of 4096 is used. The I/O buffer sizes // do not limit the size of the messages that can be sent or received. ReadBufferSize, WriteBufferSize int // Subprotocols specifies the server's supported protocols in order of // preference. If this field is set, then the Upgrade method negotiates a // subprotocol by selecting the first match in this list with a protocol // requested by the client. Subprotocols []string // Error specifies the function for generating HTTP error responses. If Error // is nil, then http.Error is used to generate the HTTP response. Error func(w http.ResponseWriter, r *http.Request, status int, reason error) // CheckOrigin returns true if the request Origin header is acceptable. If // CheckOrigin is nil, the host in the Origin header must not be set or // must match the host of the request. CheckOrigin func(r *http.Request) bool // EnableCompression specify if the server should attempt to negotiate per // message compression (RFC 7692). Setting this value to true does not // guarantee that compression will be supported. Currently only "no context // takeover" modes are supported. EnableCompression bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { err := HandshakeError{reason} if u.Error != nil { u.Error(w, r, status, err) } else { w.Header().Set("Sec-Websocket-Version", "13") http.Error(w, http.StatusText(status), status) } return nil, err } // checkSameOrigin returns true if the origin is not set or is equal to the request host. func checkSameOrigin(r *http.Request) bool { origin := r.Header[originHeader] if len(origin) == 0 { return true } return checkSameOriginFromHeaderAndHost(origin[0], r.Host) } func checkSameOriginFromHeaderAndHost(origin, reqHost string) bool { if len(origin) == 0 { return true } u, err := url.Parse(origin) if err != nil { return false } return u.Host == reqHost } func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { if u.Subprotocols != nil { return matchSubprotocol(Subprotocols(r), u.Subprotocols) } else if responseHeader != nil { return responseHeader.Get(protocolHeader) } return "" } func matchSubprotocol(clientProtocols, serverProtocols []string) string { for _, serverProtocol := range serverProtocols { for _, clientProtocol := range clientProtocols { if clientProtocol == serverProtocol { return clientProtocol } } } return "" } // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // The responseHeader is included in the response to the client's upgrade // request. Use the responseHeader to specify cookies (Set-Cookie) and the // application negotiated subprotocol (Sec-Websocket-Protocol). // // If the upgrade fails, then Upgrade replies to the client with an HTTP error // response. func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { if r.Method != "GET" { return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") } if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific Sec-Websocket-Extensions headers are unsupported") } if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") } if !tokenListContainsValue(r.Header, "Connection", "upgrade") { return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'") } if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'") } checkOrigin := u.CheckOrigin if checkOrigin == nil { checkOrigin = checkSameOrigin } if !checkOrigin(r) { return u.returnError(w, r, http.StatusForbidden, "websocket: origin not allowed") } challengeKey := r.Header.Get("Sec-Websocket-Key") if challengeKey == "" { return u.returnError(w, r, http.StatusBadRequest, "websocket: key missing or blank") } subprotocol := u.selectSubprotocol(r, responseHeader) // Negotiate PMCE var compress bool if u.EnableCompression { for _, ext := range parseExtensions(r.Header) { if ext[""] != "permessage-deflate" { continue } compress = true break } } var ( netConn net.Conn br *bufio.Reader err error ) h, ok := w.(http.Hijacker) if !ok { return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") } var rw *bufio.ReadWriter netConn, rw, err = h.Hijack() if err != nil { return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } br = rw.Reader if br.Buffered() > 0 { netConn.Close() return nil, errors.New("websocket: client sent data before handshake is complete") } c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c.subprotocol = subprotocol if compress { c.newCompressionWriter = compressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover } p := c.writeBuf[:0] p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, computeAcceptKey(challengeKey)...) p = append(p, "\r\n"...) if c.subprotocol != "" { p = append(p, "Sec-Websocket-Protocol: "...) p = append(p, c.subprotocol...) p = append(p, "\r\n"...) } if compress { p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) } for k, vs := range responseHeader { if k == protocolHeader { continue } for _, v := range vs { p = append(p, k...) p = append(p, ": "...) for i := 0; i < len(v); i++ { b := v[i] if b <= 31 { // prevent response splitting. b = ' ' } p = append(p, b) } p = append(p, "\r\n"...) } } p = append(p, "\r\n"...) // Clear deadlines set by HTTP server. netConn.SetDeadline(time.Time{}) if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) } if _, err = netConn.Write(p); err != nil { netConn.Close() return nil, err } if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Time{}) } return c, nil } // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // This function is deprecated, use websocket.Upgrader instead. // // The application is responsible for checking the request origin before // calling Upgrade. An example implementation of the same origin policy is: // // if req.Header.Get("Origin") != "http://"+req.Host { // http.Error(w, "Origin not allowed", 403) // return // } // // If the endpoint supports subprotocols, then the application is responsible // for negotiating the protocol used on the connection. Use the Subprotocols() // function to get the subprotocols requested by the client. Use the // Sec-Websocket-Protocol response header to specify the subprotocol selected // by the application. // // The responseHeader is included in the response to the client's upgrade // request. Use the responseHeader to specify cookies (Set-Cookie) and the // negotiated subprotocol (Sec-Websocket-Protocol). // // The connection buffers IO to the underlying network connection. The // readBufSize and writeBufSize parameters specify the size of the buffers to // use. Messages can be larger than the buffers. // // If the request is not a valid WebSocket handshake, then Upgrade returns an // error of type HandshakeError. Applications should handle this error by // replying to the client with an HTTP error response. func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { // don't return errors to maintain backwards compatibility } u.CheckOrigin = func(r *http.Request) bool { // allow all connections by default return true } return u.Upgrade(w, r, responseHeader) } // Subprotocols returns the subprotocols requested by the client in the // Sec-Websocket-Protocol header. func Subprotocols(r *http.Request) []string { return subprotocolsFromHeader(r.Header.Get(protocolHeader)) } func subprotocolsFromHeader(header string) []string { h := strings.TrimSpace(header) if h == "" { return nil } protocols := strings.Split(h, ",") for i := range protocols { protocols[i] = strings.TrimSpace(protocols[i]) } return protocols } // IsWebSocketUpgrade returns true if the client requested upgrade to the // WebSocket protocol. func IsWebSocketUpgrade(r *http.Request) bool { return tokenListContainsValue(r.Header, "Connection", "upgrade") && tokenListContainsValue(r.Header, "Upgrade", "websocket") }