diff --git a/.golangci.yaml b/.golangci.yaml index 31a4beb..3898f1c 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,6 +1,7 @@ linters: presets: - bugs + - unused - import - module @@ -13,7 +14,6 @@ linters: disable: - scopelint - - noctx linters-settings: lll: @@ -22,7 +22,3 @@ linters-settings: min-complexity: 10 nestif: min-complexity: 3 - errcheck: - exclude-functions: - - "(*github.com/gin-gonic/gin.Context).Error" - - "(*github.com/gin-gonic/gin.Context).AbortWithError" diff --git a/go.mod b/go.mod index 36fe1e4..5b0a27c 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,10 @@ module code.thetadev.de/TSGRain/SEBRAUC go 1.16 require ( - code.thetadev.de/TSGRain/ginzip v0.1.1 - github.com/fortytw2/leaktest v1.3.0 - github.com/gin-contrib/cors v1.3.1 - github.com/gin-gonic/gin v1.7.7 + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gofiber/fiber/v2 v2.21.0 + github.com/gofiber/websocket/v2 v2.0.12 github.com/google/uuid v1.3.0 - github.com/gorilla/websocket v1.4.2 github.com/stretchr/testify v1.7.0 - golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a // indirect - golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index c2354c4..4188720 100644 --- a/go.sum +++ b/go.sum @@ -1,92 +1,49 @@ -code.thetadev.de/TSGRain/ginzip v0.1.1 h1:+X0L6qumEZiKYSLmM+Q0LqKVHsKvdcg4CVzsEpvM7fk= -code.thetadev.de/TSGRain/ginzip v0.1.1/go.mod h1:BH7VkvpP83vPRyMQ8rLIjKycQwGzF+/mFV0BKzg+BuA= -github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= -github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.0.2 h1:JKnhI/XQ75uFBTiuzXpzFrUriDPiZjlOSzh6wXogP0E= +github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= -github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= -github.com/gin-contrib/cors v1.3.1 h1:doAsuITavI4IOcd0Y19U4B+O0dNWihRyX//nn4sEmgA= -github.com/gin-contrib/cors v1.3.1/go.mod h1:jjEJ4268OPZUcU7k9Pm653S7lXUGcqMADzFA61xsmDk= -github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= -github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do= -github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= -github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= -github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= -github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= -github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/fasthttp/websocket v1.4.3-rc.9 h1:CWJH0vONrOatdKXZgkgbFKWllijD9aY50C5KfbSDcWk= +github.com/fasthttp/websocket v1.4.3-rc.9/go.mod h1:eXL2zqDbexYJxaCw8/PQlm7VcMK6uoGvwbYbTdt4dFo= +github.com/gofiber/fiber/v2 v2.20.1/go.mod h1:/LdZHMUXZvTTo7gU4+b1hclqCAdoQphNQ9bi9gutPyI= +github.com/gofiber/fiber/v2 v2.21.0 h1:tdRNrgqWqcHWBwE3o51oAleEVsil4Ro02zd2vMEuP4Q= +github.com/gofiber/fiber/v2 v2.21.0/go.mod h1:MR1usVH3JHYRyQwMe2eZXRSZHRX38fkV+A7CPB+DlDQ= +github.com/gofiber/websocket/v2 v2.0.12 h1:jKwTrXiOut9UGOGEzFTAD6gq+/78mM3NcrI05VbxjAU= +github.com/gofiber/websocket/v2 v2.0.12/go.mod h1:lQRy0u5ACJfiez/e/bhGeYvM0/M940Y3NFw14U3/otI= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= -github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/klauspost/compress v1.13.4 h1:0zhec2I8zGnjWcKyLl6i3gPqKANCCn5e9xmviEEeX6s= +github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/savsgio/gotils v0.0.0-20210921075833-21a6215cb0e4 h1:ocK/D6lCgLji37Z2so4xhMl46se1ntReQQCUIU4BWI8= +github.com/savsgio/gotils v0.0.0-20210921075833-21a6215cb0e4/go.mod h1:oejLrk1Y/5zOF+c/aHtXqn3TFlzzbAgPWg8zBiAHDas= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= -github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= -github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.29.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD3K/7o2Cus= +github.com/valyala/fasthttp v1.30.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD3K/7o2Cus= +github.com/valyala/fasthttp v1.31.0 h1:lrauRLII19afgCs2fnWRJ4M5IkV0lo2FqA61uGkNBfE= +github.com/valyala/fasthttp v1.31.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD3K/7o2Cus= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 h1:hZR0X1kPW+nwyJ9xRxqZk1vx5RUObAPBdKVvXPDUH/E= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= -gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/rauc/rauc.go b/src/rauc/rauc.go index 696cb25..3e021fe 100644 --- a/src/rauc/rauc.go +++ b/src/rauc/rauc.go @@ -18,7 +18,7 @@ var ( ) type Rauc struct { - bc broadcaster + broadcast chan string status RaucStatus runningMtx sync.Mutex } @@ -31,23 +31,19 @@ type RaucStatus struct { Log string `json:"log"` } -type broadcaster interface { - Broadcast(msg []byte) -} - -func NewRauc(bc broadcaster) *Rauc { +func NewRauc(broadcast chan string) *Rauc { r := &Rauc{ - bc: bc, + broadcast: broadcast, } - r.bc.Broadcast(r.GetStatusJson()) + r.broadcast <- r.GetStatusJson() return r } func (r *Rauc) completed(updateFile string) { r.status.Installing = false - r.bc.Broadcast(r.GetStatusJson()) + r.broadcast <- r.GetStatusJson() _ = os.Remove(updateFile) } @@ -72,7 +68,7 @@ func (r *Rauc) RunRauc(updateFile string) error { r.status = RaucStatus{ Installing: true, } - r.bc.Broadcast(r.GetStatusJson()) + r.broadcast <- r.GetStatusJson() cmd := util.CommandFromString(fmt.Sprintf("%s install %s", util.RaucCmd, updateFile)) @@ -104,7 +100,7 @@ func (r *Rauc) RunRauc(updateFile string) error { } if hasUpdate { - r.bc.Broadcast(r.GetStatusJson()) + r.broadcast <- r.GetStatusJson() } } }() @@ -130,7 +126,7 @@ func (r *Rauc) GetStatus() RaucStatus { return r.status } -func (r *Rauc) GetStatusJson() []byte { +func (r *Rauc) GetStatusJson() string { statusJson, _ := json.Marshal(r.status) - return statusJson + return string(statusJson) } diff --git a/src/server/hub.go b/src/server/hub.go new file mode 100644 index 0000000..77e1dfd --- /dev/null +++ b/src/server/hub.go @@ -0,0 +1,98 @@ +package server + +import ( + "log" + "sync" + + "github.com/gofiber/websocket/v2" +) + +type hubClient struct{} + +type MessageHub struct { + Broadcast chan string + + clients map[*websocket.Conn]hubClient + register chan *websocket.Conn + unregister chan *websocket.Conn + lastMessage string + + running bool + runningMtx sync.Mutex +} + +func NewHub() *MessageHub { + return &MessageHub{ + clients: make(map[*websocket.Conn]hubClient), + register: make(chan *websocket.Conn), + Broadcast: make(chan string, 5), + unregister: make(chan *websocket.Conn), + } +} + +func (hub *MessageHub) sendMessage(conn *websocket.Conn, message string) { + if err := conn.WriteMessage( + websocket.TextMessage, []byte(message)); err != nil { + log.Println("write error:", err) + + _ = conn.WriteMessage(websocket.CloseMessage, []byte{}) + _ = conn.Close() + delete(hub.clients, conn) + } +} + +func (hub *MessageHub) Run() { + hub.runningMtx.Lock() + isRunning := hub.running + hub.running = true + hub.runningMtx.Unlock() + + if isRunning { + return + } + + for { + select { + case conn := <-hub.register: + hub.clients[conn] = hubClient{} + log.Println("connection registered") + + case message := <-hub.Broadcast: + log.Println("message received:", message) + hub.lastMessage = message + + // Send the message to all clients + for conn := range hub.clients { + hub.sendMessage(conn, message) + } + + case conn := <-hub.unregister: + // Remove the client from the hub + delete(hub.clients, conn) + + log.Println("connection unregistered") + } + } +} + +func (hub *MessageHub) Handler(conn *websocket.Conn) { + // When the function returns, unregister the client and close the connection + defer func() { + hub.unregister <- conn + conn.Close() + }() + + // Register the client + hub.register <- conn + + if hub.lastMessage != "" { + hub.sendMessage(conn, hub.lastMessage) + } + + for { + _, _, err := conn.ReadMessage() + if err != nil { + return // Calls the deferred function, i.e. closes the connection on error + } + } +} diff --git a/src/server/server.go b/src/server/server.go index a6f684a..46522a3 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -3,23 +3,27 @@ package server import ( "errors" "fmt" + "net/http" "strings" "time" "code.thetadev.de/TSGRain/SEBRAUC/src/rauc" - "code.thetadev.de/TSGRain/SEBRAUC/src/server/stream" "code.thetadev.de/TSGRain/SEBRAUC/src/sysinfo" "code.thetadev.de/TSGRain/SEBRAUC/src/util" "code.thetadev.de/TSGRain/SEBRAUC/ui" - "github.com/gin-contrib/cors" - "github.com/gin-gonic/gin" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/compress" + "github.com/gofiber/fiber/v2/middleware/cors" + "github.com/gofiber/fiber/v2/middleware/filesystem" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/websocket/v2" "github.com/google/uuid" ) type SEBRAUCServer struct { address string raucUpdater *rauc.Rauc - streamer *stream.API + hub *MessageHub tmpdir string } @@ -29,9 +33,9 @@ type statusMessage struct { } func NewServer(address string) *SEBRAUCServer { - streamer := stream.New(10*time.Second, 1*time.Second, []string{}) + hub := NewHub() - raucUpdater := rauc.NewRauc(streamer) + raucUpdater := rauc.NewRauc(hub.Broadcast) tmpdir, err := util.GetTmpdir() if err != nil { @@ -41,100 +45,127 @@ func NewServer(address string) *SEBRAUCServer { return &SEBRAUCServer{ address: address, raucUpdater: raucUpdater, - // hub: hub, - streamer: streamer, - tmpdir: tmpdir, + hub: hub, + tmpdir: tmpdir, } } func (srv *SEBRAUCServer) Run() error { - router := gin.Default() + app := fiber.New(fiber.Config{ + AppName: "SEBRAUC", + BodyLimit: 1024 * 1024 * 1024, + ErrorHandler: errorHandler, + DisableStartupMessage: true, + }) - // only for testing - router.Use(cors.Default()) + app.Use(logger.New()) + + app.Use(compress.New(compress.Config{ + Next: func(c *fiber.Ctx) bool { + return strings.HasPrefix(c.Path(), "/api") + }, + })) + + // just for testing + app.Use("/api", cors.New()) + + app.Use("/api/ws", func(c *fiber.Ctx) error { + // IsWebSocketUpgrade returns true if the client + // requested upgrade to the WebSocket protocol. + if websocket.IsWebSocketUpgrade(c) { + c.Locals("allowed", true) + return c.Next() + } + return fiber.ErrUpgradeRequired + }) + + app.Use("/", filesystem.New(filesystem.Config{ + Root: http.FS(ui.Assets), + PathPrefix: ui.AssetsDir, + MaxAge: 7200, + })) // ROUTES - router.GET("/api/ws", srv.streamer.Handle) - router.GET("/api/status", srv.controllerStatus) - router.GET("/api/info", srv.controllerInfo) + app.Get("/api/ws", websocket.New(srv.hub.Handler)) + app.Post("/api/update", srv.controllerUpdate) + app.Get("/api/status", srv.controllerStatus) + app.Get("/api/info", srv.controllerInfo) + app.Post("/api/reboot", srv.controllerReboot) - router.POST("/api/update", srv.controllerUpdate) - router.POST("/api/reboot", srv.controllerReboot) + // Start messaging hub + go srv.hub.Run() - // router.StaticFS("/", ui.GetFS()) - ui.Register(router) - - return router.Run(srv.address) + return app.Listen(srv.address) } -func (srv *SEBRAUCServer) controllerUpdate(c *gin.Context) { +func (srv *SEBRAUCServer) controllerUpdate(c *fiber.Ctx) error { file, err := c.FormFile("updateFile") if err != nil { - c.Error(err) - return + return err } uid, err := uuid.NewRandom() if err != nil { - c.Error(err) - return + return err } updateFile := fmt.Sprintf("%s/update_%s.raucb", srv.tmpdir, uid.String()) - err = c.SaveUploadedFile(file, updateFile) + err = c.SaveFile(file, updateFile) if err != nil { - c.Error(err) - return + return err } err = srv.raucUpdater.RunRauc(updateFile) if err == nil { writeStatus(c, true, "Update started") } else if errors.Is(err, util.ErrAlreadyRunning) { - c.AbortWithError(409, errors.New("already running")) + return fiber.NewError(fiber.StatusConflict, "already running") } else { - c.Error(err) - return + return err } + return nil } -func (srv *SEBRAUCServer) controllerStatus(c *gin.Context) { - c.JSON(200, srv.raucUpdater.GetStatus()) +func (srv *SEBRAUCServer) controllerStatus(c *fiber.Ctx) error { + c.Context().SetStatusCode(200) + _ = c.JSON(srv.raucUpdater.GetStatus()) + return nil } -func (srv *SEBRAUCServer) controllerInfo(c *gin.Context) { +func (srv *SEBRAUCServer) controllerInfo(c *fiber.Ctx) error { info, err := sysinfo.GetSysinfo() if err != nil { - c.Error(err) - } else { - c.JSON(200, info) + return err } + + c.Context().SetStatusCode(200) + _ = c.JSON(info) + return nil } -func (srv *SEBRAUCServer) controllerReboot(c *gin.Context) { +func (srv *SEBRAUCServer) controllerReboot(c *fiber.Ctx) error { go util.Reboot(5 * time.Second) writeStatus(c, true, "System is rebooting") + return nil } -func errorHandler(c *gin.Context, err error) error { +func errorHandler(c *fiber.Ctx, err error) error { // API error handling - if strings.HasPrefix(c.FullPath(), "/api") { + if strings.HasPrefix(c.Path(), "/api") { writeStatus(c, false, err.Error()) } return err } -func writeStatus(c *gin.Context, success bool, msg string) { - status := 200 - - if !success { - status = 500 - } - - c.JSON(status, statusMessage{ +func writeStatus(c *fiber.Ctx, success bool, msg string) { + _ = c.JSON(statusMessage{ Success: success, Msg: msg, }) + + if success { + c.Context().SetStatusCode(200) + } } diff --git a/src/server/stream/client.go b/src/server/stream/client.go deleted file mode 100644 index c00249f..0000000 --- a/src/server/stream/client.go +++ /dev/null @@ -1,119 +0,0 @@ -package stream - -import ( - "errors" - "fmt" - "time" - - "github.com/gorilla/websocket" -) - -const ( - writeWait = 2 * time.Second -) - -var ping = func(conn *websocket.Conn) error { - return conn.WriteMessage(websocket.PingMessage, nil) -} - -var writeBytes = func(conn *websocket.Conn, data []byte) error { - return conn.WriteMessage(websocket.TextMessage, data) -} - -type client struct { - conn *websocket.Conn - onClose func(*client) - write chan []byte - id uint - once once -} - -func newClient(conn *websocket.Conn, id uint, onClose func(*client)) *client { - return &client{ - conn: conn, - write: make(chan []byte, 1), - id: id, - onClose: onClose, - } -} - -// Close closes the connection. -func (c *client) Close() { - c.once.Do(func() { - c.conn.Close() - close(c.write) - }) -} - -// NotifyClose closes the connection and notifies that the connection was closed. -func (c *client) NotifyClose() { - c.once.Do(func() { - c.conn.Close() - close(c.write) - c.onClose(c) - }) -} - -// startWriteHandler starts listening on the client connection. -// As we do not need anything from the client, -// we ignore incoming messages. Leaves the loop on errors. -func (c *client) startReading(pongWait time.Duration) { - defer c.NotifyClose() - c.conn.SetReadLimit(64) - _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) - c.conn.SetPongHandler(func(appData string) error { - _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) - return nil - }) - for { - if _, _, err := c.conn.NextReader(); err != nil { - printWebSocketError("ReadError", err) - return - } - } -} - -// startWriteHandler starts the write loop. The method has the following tasks: -// * ping the client in the interval provided as parameter -// * write messages send by the channel to the client -// * on errors exit the loop. -func (c *client) startWriteHandler(pingPeriod time.Duration) { - pingTicker := time.NewTicker(pingPeriod) - defer func() { - c.NotifyClose() - pingTicker.Stop() - }() - - for { - select { - case message, ok := <-c.write: - if !ok { - return - } - - _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := writeBytes(c.conn, message); err != nil { - printWebSocketError("WriteError", err) - return - } - case <-pingTicker.C: - _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := ping(c.conn); err != nil { - printWebSocketError("PingError", err) - return - } - } - } -} - -func printWebSocketError(prefix string, err error) { - var closeError *websocket.CloseError - ok := errors.As(err, &closeError) - - if ok && closeError != nil && (closeError.Code == 1000 || closeError.Code == 1001) { - // normal closure - return - } - - fmt.Println("WebSocket:", prefix, err) -} diff --git a/src/server/stream/hub.go b/src/server/stream/hub.go deleted file mode 100644 index 11541cc..0000000 --- a/src/server/stream/hub.go +++ /dev/null @@ -1 +0,0 @@ -package stream diff --git a/src/server/stream/hub_test.go b/src/server/stream/hub_test.go deleted file mode 100644 index 11541cc..0000000 --- a/src/server/stream/hub_test.go +++ /dev/null @@ -1 +0,0 @@ -package stream diff --git a/src/server/stream/once.go b/src/server/stream/once.go deleted file mode 100644 index 2df2523..0000000 --- a/src/server/stream/once.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stream - -import ( - "sync" - "sync/atomic" -) - -// Modified version of sync.Once -// (https://github.com/golang/go/blob/master/src/sync/once.go) -// This version unlocks the mutex early and therefore doesn't -// hold the lock while executing func f(). -type once struct { - m sync.Mutex - done uint32 -} - -func (o *once) Do(f func()) { - if atomic.LoadUint32(&o.done) == 1 { - return - } - if o.mayExecute() { - f() - } -} - -func (o *once) mayExecute() bool { - o.m.Lock() - defer o.m.Unlock() - if o.done == 0 { - atomic.StoreUint32(&o.done, 1) - return true - } - return false -} diff --git a/src/server/stream/once_test.go b/src/server/stream/once_test.go deleted file mode 100644 index 53ec08d..0000000 --- a/src/server/stream/once_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package stream - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func Test_Execute(t *testing.T) { - executeOnce := once{} - execution := make(chan struct{}) - fExecute := func() { - execution <- struct{}{} - } - go executeOnce.Do(fExecute) - go executeOnce.Do(fExecute) - - select { - case <-execution: - // expected - case <-time.After(100 * time.Millisecond): - t.Fatal("fExecute should be executed once") - } - - select { - case <-execution: - t.Fatal("should only execute once") - case <-time.After(100 * time.Millisecond): - // expected - } - - assert.False(t, executeOnce.mayExecute()) - - go executeOnce.Do(fExecute) - - select { - case <-execution: - t.Fatal("should only execute once") - case <-time.After(100 * time.Millisecond): - // expected - } -} diff --git a/src/server/stream/stream.go b/src/server/stream/stream.go deleted file mode 100644 index 6133eeb..0000000 --- a/src/server/stream/stream.go +++ /dev/null @@ -1,187 +0,0 @@ -package stream - -import ( - "net/http" - "net/url" - "regexp" - "strings" - "sync" - "time" - - "code.thetadev.de/TSGRain/SEBRAUC/src/util" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" -) - -// The API provides a handler for a WebSocket stream API. -type API struct { - clients map[uint]*client - lock sync.RWMutex - pingPeriod time.Duration - pongTimeout time.Duration - upgrader *websocket.Upgrader - counter *util.Counter -} - -// New creates a new instance of API. -// pingPeriod: is the interval, in which is server sends the a ping to the client. -// pongTimeout: is the duration after the connection will be terminated, -// when the client does not respond with the pong command. -func New(pingPeriod, pongTimeout time.Duration, allowedWebSocketOrigins []string) *API { - return &API{ - clients: make(map[uint]*client), - pingPeriod: pingPeriod, - pongTimeout: pingPeriod + pongTimeout, - upgrader: newUpgrader(allowedWebSocketOrigins), - counter: &util.Counter{}, - } -} - -// NotifyDeletedUser closes existing connections for the given user. -func (a *API) NotifyDeletedClient(userID uint) error { - a.lock.Lock() - defer a.lock.Unlock() - if client, ok := a.clients[userID]; ok { - client.Close() - delete(a.clients, userID) - } - return nil -} - -// Notify notifies the clients with the given userID that a new messages was created. -func (a *API) Notify(userID uint, msg []byte) { - a.lock.RLock() - defer a.lock.RUnlock() - if client, ok := a.clients[userID]; ok { - client.write <- msg - } -} - -func (a *API) Broadcast(msg []byte) { - a.lock.RLock() - defer a.lock.RUnlock() - for _, client := range a.clients { - client.write <- msg - } -} - -func (a *API) remove(remove *client) { - a.lock.Lock() - defer a.lock.Unlock() - delete(a.clients, remove.id) -} - -func (a *API) register(client *client) { - a.lock.Lock() - defer a.lock.Unlock() - a.clients[client.id] = client -} - -// Handle handles incoming requests. -// First it upgrades the protocol to the WebSocket protocol and then starts listening -// for read and writes. -// swagger:operation GET /stream message streamMessages -// -// Websocket, return newly created messages. -// -// --- -// schema: ws, wss -// produces: [application/json] -// security: [clientTokenHeader: [], clientTokenQuery: [], basicAuth: []] -// responses: -// 200: -// description: Ok -// schema: -// $ref: "#/definitions/Message" -// 400: -// description: Bad Request -// schema: -// $ref: "#/definitions/Error" -// 401: -// description: Unauthorized -// schema: -// $ref: "#/definitions/Error" -// 403: -// description: Forbidden -// schema: -// $ref: "#/definitions/Error" -// 500: -// description: Server Error -// schema: -// $ref: "#/definitions/Error" -func (a *API) Handle(ctx *gin.Context) { - conn, err := a.upgrader.Upgrade(ctx.Writer, ctx.Request, nil) - if err != nil { - ctx.Error(err) - return - } - - client := newClient(conn, a.counter.Increment(), a.remove) - a.register(client) - go client.startReading(a.pongTimeout) - go client.startWriteHandler(a.pingPeriod) -} - -// Close closes all client connections and stops answering new connections. -func (a *API) Close() { - a.lock.Lock() - defer a.lock.Unlock() - - for _, client := range a.clients { - client.Close() - } - for k := range a.clients { - delete(a.clients, k) - } -} - -func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool { - origin := r.Header.Get("origin") - if origin == "" { - return true - } - - u, err := url.Parse(origin) - if err != nil { - return false - } - - if strings.EqualFold(u.Host, r.Host) { - return true - } - - for _, allowedOrigin := range allowedOrigins { - if allowedOrigin.Match([]byte(strings.ToLower(u.Hostname()))) { - return true - } - } - - return false -} - -func newUpgrader(allowedWebSocketOrigins []string) *websocket.Upgrader { - // compiledAllowedOrigins := compileAllowedWebSocketOrigins(allowedWebSocketOrigins) - return &websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - /* - TODO: implement this - if mode.IsDev() { - return true - } - return isAllowedOrigin(r, compiledAllowedOrigins) - */ - return true - }, - } -} - -func compileAllowedWebSocketOrigins(allowedOrigins []string) []*regexp.Regexp { - var compiledAllowedOrigins []*regexp.Regexp - for _, origin := range allowedOrigins { - compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin)) - } - - return compiledAllowedOrigins -} diff --git a/src/server/stream/stream_test.go b/src/server/stream/stream_test.go deleted file mode 100644 index fd779b1..0000000 --- a/src/server/stream/stream_test.go +++ /dev/null @@ -1,424 +0,0 @@ -package stream - -import ( - "errors" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/fortytw2/leaktest" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" -) - -func TestFailureOnNormalHttpRequest(t *testing.T) { - // mode.Set(mode.TestDev) - - defer leaktest.Check(t)() - - server, api := bootTestServer() - defer server.Close() - defer api.Close() - - resp, err := http.Get(server.URL) - assert.Nil(t, err) - assert.Equal(t, 400, resp.StatusCode) - resp.Body.Close() -} - -func TestWriteMessageFails(t *testing.T) { - // mode.Set(mode.TestDev) - oldWrite := writeBytes - // try emulate an write error, mostly this should kill the ReadMessage - // goroutine first but you'll never know. - writeBytes = func(conn *websocket.Conn, data []byte) error { - return errors.New("asd") - } - defer func() { - writeBytes = oldWrite - }() - defer leaktest.Check(t)() - - server, api := bootTestServer() - defer server.Close() - defer api.Close() - - wsURL := wsURL(server.URL) - user := testClient(t, wsURL) - - // the server may take some time to register the client - time.Sleep(100 * time.Millisecond) - client := getClient(api, 1) - assert.NotNil(t, client) - - api.Notify(1, []byte("HI")) - user.expectNoMessage() -} - -func TestWritePingFails(t *testing.T) { - // mode.Set(mode.TestDev) - oldPing := ping - // try emulate an write error, mostly this should kill the ReadMessage - // gorouting first but you'll never know. - ping = func(conn *websocket.Conn) error { - return errors.New("asd") - } - defer func() { - ping = oldPing - }() - - defer leaktest.CheckTimeout(t, 10*time.Second)() - - server, api := bootTestServer() - defer api.Close() - defer server.Close() - - wsURL := wsURL(server.URL) - user := testClient(t, wsURL) - defer user.conn.Close() - - // the server may take some time to register the client - time.Sleep(100 * time.Millisecond) - client := getClient(api, 1) - - assert.NotNil(t, client) - - time.Sleep(api.pingPeriod) // waiting for ping - - api.Notify(1, []byte("HI")) - user.expectNoMessage() -} - -func TestPing(t *testing.T) { - // mode.Set(mode.TestDev) - - server, api := bootTestServer() - defer server.Close() - defer api.Close() - - wsURL := wsURL(server.URL) - - user := createClient(t, wsURL) - defer user.conn.Close() - - ping := make(chan bool) - oldPingHandler := user.conn.PingHandler() - user.conn.SetPingHandler(func(appData string) error { - err := oldPingHandler(appData) - ping <- true - return err - }) - - startReading(user) - - expectNoMessage(user) - - select { - case <-time.After(2 * time.Second): - assert.Fail(t, "Expected ping but there was one :(") - case <-ping: - // expected - } - - expectNoMessage(user) - api.Notify(1, []byte("HI")) - user.expectMessage([]byte("HI")) -} - -func TestCloseClientOnNotReading(t *testing.T) { - // mode.Set(mode.TestDev) - - server, api := bootTestServer() - defer server.Close() - defer api.Close() - - wsURL := wsURL(server.URL) - - ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) - assert.Nil(t, err) - resp.Body.Close() - defer ws.Close() - - // the server may take some time to register the client - time.Sleep(100 * time.Millisecond) - assert.NotNil(t, getClient(api, 1)) - - time.Sleep(api.pingPeriod + api.pongTimeout) - - assert.Nil(t, getClient(api, 1)) -} - -func TestMessageDirectlyAfterConnect(t *testing.T) { - // mode.Set(mode.Prod) - defer leaktest.Check(t)() - server, api := bootTestServer() - defer server.Close() - defer api.Close() - - wsURL := wsURL(server.URL) - - user := testClient(t, wsURL) - defer user.conn.Close() - // the server may take some time to register the client - time.Sleep(100 * time.Millisecond) - api.Notify(1, []byte("msg")) - user.expectMessage([]byte("msg")) -} - -func TestDeleteClientShouldCloseConnection(t *testing.T) { - // mode.Set(mode.Prod) - defer leaktest.Check(t)() - server, api := bootTestServer() - defer server.Close() - defer api.Close() - - wsURL := wsURL(server.URL) - - user := testClient(t, wsURL) - defer user.conn.Close() - // the server may take some time to register the client - time.Sleep(100 * time.Millisecond) - api.Notify(1, []byte("HI")) - user.expectMessage([]byte("HI")) - - assert.Nil(t, api.NotifyDeletedClient(1)) - - api.Notify(1, []byte("HI")) - user.expectNoMessage() -} - -func TestNotify(t *testing.T) { - // mode.Set(mode.TestDev) - - defer leaktest.Check(t)() - server, api := bootTestServer() - defer server.Close() - - wsURL := wsURL(server.URL) - - client1 := testClient(t, wsURL) - defer client1.conn.Close() - - client2 := testClient(t, wsURL) - defer client2.conn.Close() - - client3 := testClient(t, wsURL) - defer client3.conn.Close() - - // the server may take some time to register the client - time.Sleep(100 * time.Millisecond) - - api.Notify(1, []byte("msg")) - expectMessage([]byte("msg"), client1) - expectNoMessage(client2) - expectNoMessage(client3) - - assert.Nil(t, api.NotifyDeletedClient(1)) - - api.Notify(1, []byte("msg")) - expectNoMessage(client1) - expectNoMessage(client2) - expectNoMessage(client3) - - api.Notify(2, []byte("msg")) - expectNoMessage(client1) - expectMessage([]byte("msg"), client2) - expectNoMessage(client3) - - api.Notify(3, []byte("msg")) - expectNoMessage(client1) - expectNoMessage(client2) - expectMessage([]byte("msg"), client3) - - api.Close() -} - -func TestBroadcast(t *testing.T) { - defer leaktest.Check(t)() - server, api := bootTestServer() - defer server.Close() - - wsURL := wsURL(server.URL) - - client1 := testClient(t, wsURL) - defer client1.conn.Close() - - client2 := testClient(t, wsURL) - defer client2.conn.Close() - - client3 := testClient(t, wsURL) - defer client3.conn.Close() - - // the server may take some time to register the client - time.Sleep(100 * time.Millisecond) - - testMsg1 := []byte("hello1") - api.Broadcast(testMsg1) - expectMessage(testMsg1, client1, client2, client3) - - assert.Nil(t, api.NotifyDeletedClient(1)) - - testMsg2 := []byte("hello2") - api.Broadcast(testMsg2) - expectNoMessage(client1) - expectMessage(testMsg2, client2, client3) -} - -func Test_sameOrigin_returnsTrue(t *testing.T) { - // mode.Set(mode.Prod) - req := httptest.NewRequest("GET", "http://example.com/stream", nil) - req.Header.Set("Origin", "http://example.com") - actual := isAllowedOrigin(req, nil) - assert.True(t, actual) -} - -func Test_sameOrigin_returnsTrue_withCustomPort(t *testing.T) { - // mode.Set(mode.Prod) - req := httptest.NewRequest("GET", "http://example.com:8080/stream", nil) - req.Header.Set("Origin", "http://example.com:8080") - actual := isAllowedOrigin(req, nil) - assert.True(t, actual) -} - -func Test_isAllowedOrigin_withoutAllowedOrigins_failsWhenNotSameOrigin(t *testing.T) { - // mode.Set(mode.Prod) - req := httptest.NewRequest("GET", "http://example.com/stream", nil) - req.Header.Set("Origin", "http://gorify.example.com") - actual := isAllowedOrigin(req, nil) - assert.False(t, actual) -} - -func Test_isAllowedOriginMatching(t *testing.T) { - // mode.Set(mode.Prod) - compiledAllowedOrigins := compileAllowedWebSocketOrigins( - []string{"go.{4}\\.example\\.com", "go\\.example\\.com"}, - ) - - req := httptest.NewRequest("GET", "http://example.me/stream", nil) - req.Header.Set("Origin", "http://gorify.example.com") - assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins)) - - req.Header.Set("Origin", "http://go.example.com") - assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins)) - - req.Header.Set("Origin", "http://hello.example.com") - assert.False(t, isAllowedOrigin(req, compiledAllowedOrigins)) -} - -func Test_emptyOrigin_returnsTrue(t *testing.T) { - // mode.Set(mode.Prod) - req := httptest.NewRequest("GET", "http://example.com/stream", nil) - actual := isAllowedOrigin(req, nil) - assert.True(t, actual) -} - -func Test_otherOrigin_returnsFalse(t *testing.T) { - // mode.Set(mode.Prod) - req := httptest.NewRequest("GET", "http://example.com/stream", nil) - req.Header.Set("Origin", "http://otherexample.de") - actual := isAllowedOrigin(req, nil) - assert.False(t, actual) -} - -func Test_invalidOrigin_returnsFalse(t *testing.T) { - // mode.Set(mode.Prod) - req := httptest.NewRequest("GET", "http://example.com/stream", nil) - req.Header.Set("Origin", "http\\://otherexample.de") - actual := isAllowedOrigin(req, nil) - assert.False(t, actual) -} - -func Test_compileAllowedWebSocketOrigins(t *testing.T) { - assert.Equal(t, 0, len(compileAllowedWebSocketOrigins([]string{}))) - assert.Equal(t, 3, len(compileAllowedWebSocketOrigins([]string{"^.*$", "", "abc"}))) -} - -func getClient(api *API, user uint) *client { - api.lock.RLock() - defer api.lock.RUnlock() - - return api.clients[user] -} - -func testClient(t *testing.T, url string) *testingClient { - client := createClient(t, url) - startReading(client) - return client -} - -func startReading(client *testingClient) { - go func() { - for { - _, payload, err := client.conn.ReadMessage() - if err != nil { - return - } - - client.readMessage <- payload - } - }() -} - -func createClient(t *testing.T, url string) *testingClient { - ws, resp, err := websocket.DefaultDialer.Dial(url, nil) - assert.Nil(t, err) - resp.Body.Close() - - readMessages := make(chan []byte) - - return &testingClient{conn: ws, readMessage: readMessages, t: t} -} - -type testingClient struct { - conn *websocket.Conn - readMessage chan []byte - t *testing.T -} - -func (c *testingClient) expectMessage(expected []byte) { - select { - case <-time.After(50 * time.Millisecond): - assert.Fail(c.t, "Expected message but none was send :(") - case actual := <-c.readMessage: - assert.Equal(c.t, expected, actual) - } -} - -func expectMessage(expected []byte, clients ...*testingClient) { - for _, client := range clients { - client.expectMessage(expected) - } -} - -func expectNoMessage(clients ...*testingClient) { - for _, client := range clients { - client.expectNoMessage() - } -} - -func (c *testingClient) expectNoMessage() { - select { - case <-time.After(50 * time.Millisecond): - // no message == as expected - case msg := <-c.readMessage: - assert.Fail(c.t, "Expected NO message but there was one :(", fmt.Sprint(msg)) - } -} - -func bootTestServer() (*httptest.Server, *API) { - r := gin.New() - // ping every 500 ms, and the client has 500 ms to respond - api := New(500*time.Millisecond, 500*time.Millisecond, []string{}) - - r.GET("/", api.Handle) - server := httptest.NewServer(r) - return server, api -} - -func wsURL(httpURL string) string { - return "ws" + strings.TrimPrefix(httpURL, "http") -} diff --git a/src/util/counter.go b/src/util/counter.go deleted file mode 100644 index 9e264ee..0000000 --- a/src/util/counter.go +++ /dev/null @@ -1,30 +0,0 @@ -package util - -import "sync" - -type Counter struct { - count uint - mutex sync.RWMutex -} - -func (c *Counter) Get() uint { - c.mutex.RLock() - defer c.mutex.RUnlock() - - return c.count -} - -func (c *Counter) Reset() { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.count = 0 -} - -func (c *Counter) Increment() uint { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.count++ - return c.count -} diff --git a/src/util/counter_test.go b/src/util/counter_test.go deleted file mode 100644 index bfc3e9d..0000000 --- a/src/util/counter_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package util - -import ( - "sync" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCounter(t *testing.T) { - counter := Counter{} - - var wg sync.WaitGroup - - incrementer := func() { - for i := 0; i < 50; i++ { - counter.Increment() - } - wg.Done() - } - - for i := 0; i < 100; i++ { - wg.Add(1) - go incrementer() - } - - wg.Wait() - - assert.EqualValues(t, 5000, counter.Get()) -} diff --git a/ui/index.html b/ui/index.html index 426ad0f..cf876c6 100644 --- a/ui/index.html +++ b/ui/index.html @@ -7,11 +7,7 @@