// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package websocket import ( "bytes" "crypto/rand" "fmt" "io" "log" "net" "net/http" "net/http/httptest" "net/url" "reflect" "runtime" "strings" "sync" "testing" "time" ) var serverAddr string var once sync.Once func echoServer(ws *Conn) { defer ws.Close() io.Copy(ws, ws) } type Count struct { S string N int } func countServer(ws *Conn) { defer ws.Close() for { var count Count err := JSON.Receive(ws, &count) if err != nil { return } count.N++ count.S = strings.Repeat(count.S, count.N) err = JSON.Send(ws, count) if err != nil { return } } } type testCtrlAndDataHandler struct { hybiFrameHandler } func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) { h.hybiFrameHandler.conn.wio.Lock() defer h.hybiFrameHandler.conn.wio.Unlock() w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame) if err != nil { return 0, err } n, err := w.Write(b) w.Close() return n, err } func ctrlAndDataServer(ws *Conn) { defer ws.Close() h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}} ws.frameHandler = h go func() { for i := 0; ; i++ { var b []byte if i%2 != 0 { // with or without payload b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i)) } if _, err := h.WritePing(b); err != nil { break } if _, err := h.WritePong(b); err != nil { // unsolicited pong break } time.Sleep(10 * time.Millisecond) } }() b := make([]byte, 128) for { n, err := ws.Read(b) if err != nil { break } if _, err := ws.Write(b[:n]); err != nil { break } } } func subProtocolHandshake(config *Config, req *http.Request) error { for _, proto := range config.Protocol { if proto == "chat" { config.Protocol = []string{proto} return nil } } return ErrBadWebSocketProtocol } func subProtoServer(ws *Conn) { for _, proto := range ws.Config().Protocol { io.WriteString(ws, proto) } } func startServer() { http.Handle("/echo", Handler(echoServer)) http.Handle("/count", Handler(countServer)) http.Handle("/ctrldata", Handler(ctrlAndDataServer)) subproto := Server{ Handshake: subProtocolHandshake, Handler: Handler(subProtoServer), } http.Handle("/subproto", subproto) server := httptest.NewServer(nil) serverAddr = server.Listener.Addr().String() log.Print("Test WebSocket server listening on ", serverAddr) } func newConfig(t *testing.T, path string) *Config { config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost") return config } func TestEcho(t *testing.T) { once.Do(startServer) // websocket.Dial() client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } conn, err := NewClient(newConfig(t, "/echo"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } msg := []byte("hello, world\n") if _, err := conn.Write(msg); err != nil { t.Errorf("Write: %v", err) } var actual_msg = make([]byte, 512) n, err := conn.Read(actual_msg) if err != nil { t.Errorf("Read: %v", err) } actual_msg = actual_msg[0:n] if !bytes.Equal(msg, actual_msg) { t.Errorf("Echo: expected %q got %q", msg, actual_msg) } conn.Close() } func TestAddr(t *testing.T) { once.Do(startServer) // websocket.Dial() client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } conn, err := NewClient(newConfig(t, "/echo"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } ra := conn.RemoteAddr().String() if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") { t.Errorf("Bad remote addr: %v", ra) } la := conn.LocalAddr().String() if !strings.HasPrefix(la, "http://") { t.Errorf("Bad local addr: %v", la) } conn.Close() } func TestCount(t *testing.T) { once.Do(startServer) // websocket.Dial() client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } conn, err := NewClient(newConfig(t, "/count"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } var count Count count.S = "hello" if err := JSON.Send(conn, count); err != nil { t.Errorf("Write: %v", err) } if err := JSON.Receive(conn, &count); err != nil { t.Errorf("Read: %v", err) } if count.N != 1 { t.Errorf("count: expected %d got %d", 1, count.N) } if count.S != "hello" { t.Errorf("count: expected %q got %q", "hello", count.S) } if err := JSON.Send(conn, count); err != nil { t.Errorf("Write: %v", err) } if err := JSON.Receive(conn, &count); err != nil { t.Errorf("Read: %v", err) } if count.N != 2 { t.Errorf("count: expected %d got %d", 2, count.N) } if count.S != "hellohello" { t.Errorf("count: expected %q got %q", "hellohello", count.S) } conn.Close() } func TestWithQuery(t *testing.T) { once.Do(startServer) client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } config := newConfig(t, "/echo") config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr)) if err != nil { t.Fatal("location url", err) } ws, err := NewClient(config, client) if err != nil { t.Errorf("WebSocket handshake: %v", err) return } ws.Close() } func testWithProtocol(t *testing.T, subproto []string) (string, error) { once.Do(startServer) client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } config := newConfig(t, "/subproto") config.Protocol = subproto ws, err := NewClient(config, client) if err != nil { return "", err } msg := make([]byte, 16) n, err := ws.Read(msg) if err != nil { return "", err } ws.Close() return string(msg[:n]), nil } func TestWithProtocol(t *testing.T) { proto, err := testWithProtocol(t, []string{"chat"}) if err != nil { t.Errorf("SubProto: unexpected error: %v", err) } if proto != "chat" { t.Errorf("SubProto: expected %q, got %q", "chat", proto) } } func TestWithTwoProtocol(t *testing.T) { proto, err := testWithProtocol(t, []string{"test", "chat"}) if err != nil { t.Errorf("SubProto: unexpected error: %v", err) } if proto != "chat" { t.Errorf("SubProto: expected %q, got %q", "chat", proto) } } func TestWithBadProtocol(t *testing.T) { _, err := testWithProtocol(t, []string{"test"}) if err != ErrBadStatus { t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err) } } func TestHTTP(t *testing.T) { once.Do(startServer) // If the client did not send a handshake that matches the protocol // specification, the server MUST return an HTTP response with an // appropriate error code (such as 400 Bad Request) resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) if err != nil { t.Errorf("Get: error %#v", err) return } if resp == nil { t.Error("Get: resp is null") return } if resp.StatusCode != http.StatusBadRequest { t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode) } } func TestTrailingSpaces(t *testing.T) { // http://code.google.com/p/go/issues/detail?id=955 // The last runs of this create keys with trailing spaces that should not be // generated by the client. once.Do(startServer) config := newConfig(t, "/echo") for i := 0; i < 30; i++ { // body ws, err := DialConfig(config) if err != nil { t.Errorf("Dial #%d failed: %v", i, err) break } ws.Close() } } func TestDialConfigBadVersion(t *testing.T) { once.Do(startServer) config := newConfig(t, "/echo") config.Version = 1234 _, err := DialConfig(config) if dialerr, ok := err.(*DialError); ok { if dialerr.Err != ErrBadProtocolVersion { t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err) } } } func TestDialConfigWithDialer(t *testing.T) { once.Do(startServer) config := newConfig(t, "/echo") config.Dialer = &net.Dialer{ Deadline: time.Now().Add(-time.Minute), } _, err := DialConfig(config) dialerr, ok := err.(*DialError) if !ok { t.Fatalf("DialError expected, got %#v", err) } neterr, ok := dialerr.Err.(*net.OpError) if !ok { t.Fatalf("net.OpError error expected, got %#v", dialerr.Err) } if !neterr.Timeout() { t.Fatalf("expected timeout error, got %#v", neterr) } } func TestSmallBuffer(t *testing.T) { // http://code.google.com/p/go/issues/detail?id=1145 // Read should be able to handle reading a fragment of a frame. once.Do(startServer) // websocket.Dial() client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } conn, err := NewClient(newConfig(t, "/echo"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } msg := []byte("hello, world\n") if _, err := conn.Write(msg); err != nil { t.Errorf("Write: %v", err) } var small_msg = make([]byte, 8) n, err := conn.Read(small_msg) if err != nil { t.Errorf("Read: %v", err) } if !bytes.Equal(msg[:len(small_msg)], small_msg) { t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg) } var second_msg = make([]byte, len(msg)) n, err = conn.Read(second_msg) if err != nil { t.Errorf("Read: %v", err) } second_msg = second_msg[0:n] if !bytes.Equal(msg[len(small_msg):], second_msg) { t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg) } conn.Close() } var parseAuthorityTests = []struct { in *url.URL out string }{ { &url.URL{ Scheme: "ws", Host: "www.google.com", }, "www.google.com:80", }, { &url.URL{ Scheme: "wss", Host: "www.google.com", }, "www.google.com:443", }, { &url.URL{ Scheme: "ws", Host: "www.google.com:80", }, "www.google.com:80", }, { &url.URL{ Scheme: "wss", Host: "www.google.com:443", }, "www.google.com:443", }, // some invalid ones for parseAuthority. parseAuthority doesn't // concern itself with the scheme unless it actually knows about it { &url.URL{ Scheme: "http", Host: "www.google.com", }, "www.google.com", }, { &url.URL{ Scheme: "http", Host: "www.google.com:80", }, "www.google.com:80", }, { &url.URL{ Scheme: "asdf", Host: "127.0.0.1", }, "127.0.0.1", }, { &url.URL{ Scheme: "asdf", Host: "www.google.com", }, "www.google.com", }, } func TestParseAuthority(t *testing.T) { for _, tt := range parseAuthorityTests { out := parseAuthority(tt.in) if out != tt.out { t.Errorf("got %v; want %v", out, tt.out) } } } type closerConn struct { net.Conn closed int // count of the number of times Close was called } func (c *closerConn) Close() error { c.closed++ return c.Conn.Close() } func TestClose(t *testing.T) { if runtime.GOOS == "plan9" { t.Skip("see golang.org/issue/11454") } once.Do(startServer) conn, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } cc := closerConn{Conn: conn} client, err := NewClient(newConfig(t, "/echo"), &cc) if err != nil { t.Fatalf("WebSocket handshake: %v", err) } // set the deadline to ten minutes ago, which will have expired by the time // client.Close sends the close status frame. conn.SetDeadline(time.Now().Add(-10 * time.Minute)) if err := client.Close(); err == nil { t.Errorf("ws.Close(): expected error, got %v", err) } if cc.closed < 1 { t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed) } } var originTests = []struct { req *http.Request origin *url.URL }{ { req: &http.Request{ Header: http.Header{ "Origin": []string{"http://www.example.com"}, }, }, origin: &url.URL{ Scheme: "http", Host: "www.example.com", }, }, { req: &http.Request{}, }, } func TestOrigin(t *testing.T) { conf := newConfig(t, "/echo") conf.Version = ProtocolVersionHybi13 for i, tt := range originTests { origin, err := Origin(conf, tt.req) if err != nil { t.Error(err) continue } if !reflect.DeepEqual(origin, tt.origin) { t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin) continue } } } func TestCtrlAndData(t *testing.T) { once.Do(startServer) c, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal(err) } ws, err := NewClient(newConfig(t, "/ctrldata"), c) if err != nil { t.Fatal(err) } defer ws.Close() h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}} ws.frameHandler = h b := make([]byte, 128) for i := 0; i < 2; i++ { data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i)) if _, err := ws.Write(data); err != nil { t.Fatalf("#%d: %v", i, err) } var ctrl []byte if i%2 != 0 { // with or without payload ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i)) } if _, err := h.WritePing(ctrl); err != nil { t.Fatalf("#%d: %v", i, err) } n, err := ws.Read(b) if err != nil { t.Fatalf("#%d: %v", i, err) } if !bytes.Equal(b[:n], data) { t.Fatalf("#%d: got %v; want %v", i, b[:n], data) } } } func TestCodec_ReceiveLimited(t *testing.T) { const limit = 2048 var payloads [][]byte for _, size := range []int{ 1024, 2048, 4096, // receive of this message would be interrupted due to limit 2048, // this one is to make sure next receive recovers discarding leftovers } { b := make([]byte, size) rand.Read(b) payloads = append(payloads, b) } handlerDone := make(chan struct{}) limitedHandler := func(ws *Conn) { defer close(handlerDone) ws.MaxPayloadBytes = limit defer ws.Close() for i, p := range payloads { t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit) var recv []byte err := Message.Receive(ws, &recv) switch err { case nil: case ErrFrameTooLarge: if len(p) <= limit { t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit) } continue default: t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err) } if len(recv) > limit { t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit) } if !bytes.Equal(p, recv) { t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p) } } } server := httptest.NewServer(Handler(limitedHandler)) defer server.CloseClientConnections() defer server.Close() addr := server.Listener.Addr().String() ws, err := Dial("ws://"+addr+"/", "", "http://localhost/") if err != nil { t.Fatal(err) } defer ws.Close() for i, p := range payloads { if err := Message.Send(ws, p); err != nil { t.Fatalf("payload #%d (size %d): %v", i, len(p), err) } } <-handlerDone }