diff --git a/.golangci.yaml b/.golangci.yaml index 3898f1c..31a4beb 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,7 +1,6 @@ linters: presets: - bugs - - unused - import - module @@ -14,6 +13,7 @@ linters: disable: - scopelint + - noctx linters-settings: lll: @@ -22,3 +22,7 @@ linters-settings: min-complexity: 10 nestif: min-complexity: 3 + errcheck: + exclude-functions: + - "(*github.com/gin-gonic/gin.Context).Error" + - "(*github.com/gin-gonic/gin.Context).AbortWithError" diff --git a/go.mod b/go.mod index 5b0a27c..36fe1e4 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,14 @@ module code.thetadev.de/TSGRain/SEBRAUC go 1.16 require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gofiber/fiber/v2 v2.21.0 - github.com/gofiber/websocket/v2 v2.0.12 + code.thetadev.de/TSGRain/ginzip v0.1.1 + github.com/fortytw2/leaktest v1.3.0 + github.com/gin-contrib/cors v1.3.1 + github.com/gin-gonic/gin v1.7.7 github.com/google/uuid v1.3.0 + github.com/gorilla/websocket v1.4.2 github.com/stretchr/testify v1.7.0 + golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a // indirect + golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 4188720..c2354c4 100644 --- a/go.sum +++ b/go.sum @@ -1,49 +1,92 @@ -github.com/andybalholm/brotli v1.0.2 h1:JKnhI/XQ75uFBTiuzXpzFrUriDPiZjlOSzh6wXogP0E= -github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= +code.thetadev.de/TSGRain/ginzip v0.1.1 h1:+X0L6qumEZiKYSLmM+Q0LqKVHsKvdcg4CVzsEpvM7fk= +code.thetadev.de/TSGRain/ginzip v0.1.1/go.mod h1:BH7VkvpP83vPRyMQ8rLIjKycQwGzF+/mFV0BKzg+BuA= +github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= +github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fasthttp/websocket v1.4.3-rc.9 h1:CWJH0vONrOatdKXZgkgbFKWllijD9aY50C5KfbSDcWk= -github.com/fasthttp/websocket v1.4.3-rc.9/go.mod h1:eXL2zqDbexYJxaCw8/PQlm7VcMK6uoGvwbYbTdt4dFo= -github.com/gofiber/fiber/v2 v2.20.1/go.mod h1:/LdZHMUXZvTTo7gU4+b1hclqCAdoQphNQ9bi9gutPyI= -github.com/gofiber/fiber/v2 v2.21.0 h1:tdRNrgqWqcHWBwE3o51oAleEVsil4Ro02zd2vMEuP4Q= -github.com/gofiber/fiber/v2 v2.21.0/go.mod h1:MR1usVH3JHYRyQwMe2eZXRSZHRX38fkV+A7CPB+DlDQ= -github.com/gofiber/websocket/v2 v2.0.12 h1:jKwTrXiOut9UGOGEzFTAD6gq+/78mM3NcrI05VbxjAU= -github.com/gofiber/websocket/v2 v2.0.12/go.mod h1:lQRy0u5ACJfiez/e/bhGeYvM0/M940Y3NFw14U3/otI= -github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= +github.com/gin-contrib/cors v1.3.1 h1:doAsuITavI4IOcd0Y19U4B+O0dNWihRyX//nn4sEmgA= +github.com/gin-contrib/cors v1.3.1/go.mod h1:jjEJ4268OPZUcU7k9Pm653S7lXUGcqMADzFA61xsmDk= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do= +github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs= +github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= +github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= +github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= +github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/klauspost/compress v1.13.4 h1:0zhec2I8zGnjWcKyLl6i3gPqKANCCn5e9xmviEEeX6s= -github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= +github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/savsgio/gotils v0.0.0-20210921075833-21a6215cb0e4 h1:ocK/D6lCgLji37Z2so4xhMl46se1ntReQQCUIU4BWI8= -github.com/savsgio/gotils v0.0.0-20210921075833-21a6215cb0e4/go.mod h1:oejLrk1Y/5zOF+c/aHtXqn3TFlzzbAgPWg8zBiAHDas= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.29.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD3K/7o2Cus= -github.com/valyala/fasthttp v1.30.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD3K/7o2Cus= -github.com/valyala/fasthttp v1.31.0 h1:lrauRLII19afgCs2fnWRJ4M5IkV0lo2FqA61uGkNBfE= -github.com/valyala/fasthttp v1.31.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD3K/7o2Cus= -github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= -github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 h1:hZR0X1kPW+nwyJ9xRxqZk1vx5RUObAPBdKVvXPDUH/E= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= +gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/rauc/rauc.go b/src/rauc/rauc.go index 3e021fe..696cb25 100644 --- a/src/rauc/rauc.go +++ b/src/rauc/rauc.go @@ -18,7 +18,7 @@ var ( ) type Rauc struct { - broadcast chan string + bc broadcaster status RaucStatus runningMtx sync.Mutex } @@ -31,19 +31,23 @@ type RaucStatus struct { Log string `json:"log"` } -func NewRauc(broadcast chan string) *Rauc { +type broadcaster interface { + Broadcast(msg []byte) +} + +func NewRauc(bc broadcaster) *Rauc { r := &Rauc{ - broadcast: broadcast, + bc: bc, } - r.broadcast <- r.GetStatusJson() + r.bc.Broadcast(r.GetStatusJson()) return r } func (r *Rauc) completed(updateFile string) { r.status.Installing = false - r.broadcast <- r.GetStatusJson() + r.bc.Broadcast(r.GetStatusJson()) _ = os.Remove(updateFile) } @@ -68,7 +72,7 @@ func (r *Rauc) RunRauc(updateFile string) error { r.status = RaucStatus{ Installing: true, } - r.broadcast <- r.GetStatusJson() + r.bc.Broadcast(r.GetStatusJson()) cmd := util.CommandFromString(fmt.Sprintf("%s install %s", util.RaucCmd, updateFile)) @@ -100,7 +104,7 @@ func (r *Rauc) RunRauc(updateFile string) error { } if hasUpdate { - r.broadcast <- r.GetStatusJson() + r.bc.Broadcast(r.GetStatusJson()) } } }() @@ -126,7 +130,7 @@ func (r *Rauc) GetStatus() RaucStatus { return r.status } -func (r *Rauc) GetStatusJson() string { +func (r *Rauc) GetStatusJson() []byte { statusJson, _ := json.Marshal(r.status) - return string(statusJson) + return statusJson } diff --git a/src/server/hub.go b/src/server/hub.go deleted file mode 100644 index 77e1dfd..0000000 --- a/src/server/hub.go +++ /dev/null @@ -1,98 +0,0 @@ -package server - -import ( - "log" - "sync" - - "github.com/gofiber/websocket/v2" -) - -type hubClient struct{} - -type MessageHub struct { - Broadcast chan string - - clients map[*websocket.Conn]hubClient - register chan *websocket.Conn - unregister chan *websocket.Conn - lastMessage string - - running bool - runningMtx sync.Mutex -} - -func NewHub() *MessageHub { - return &MessageHub{ - clients: make(map[*websocket.Conn]hubClient), - register: make(chan *websocket.Conn), - Broadcast: make(chan string, 5), - unregister: make(chan *websocket.Conn), - } -} - -func (hub *MessageHub) sendMessage(conn *websocket.Conn, message string) { - if err := conn.WriteMessage( - websocket.TextMessage, []byte(message)); err != nil { - log.Println("write error:", err) - - _ = conn.WriteMessage(websocket.CloseMessage, []byte{}) - _ = conn.Close() - delete(hub.clients, conn) - } -} - -func (hub *MessageHub) Run() { - hub.runningMtx.Lock() - isRunning := hub.running - hub.running = true - hub.runningMtx.Unlock() - - if isRunning { - return - } - - for { - select { - case conn := <-hub.register: - hub.clients[conn] = hubClient{} - log.Println("connection registered") - - case message := <-hub.Broadcast: - log.Println("message received:", message) - hub.lastMessage = message - - // Send the message to all clients - for conn := range hub.clients { - hub.sendMessage(conn, message) - } - - case conn := <-hub.unregister: - // Remove the client from the hub - delete(hub.clients, conn) - - log.Println("connection unregistered") - } - } -} - -func (hub *MessageHub) Handler(conn *websocket.Conn) { - // When the function returns, unregister the client and close the connection - defer func() { - hub.unregister <- conn - conn.Close() - }() - - // Register the client - hub.register <- conn - - if hub.lastMessage != "" { - hub.sendMessage(conn, hub.lastMessage) - } - - for { - _, _, err := conn.ReadMessage() - if err != nil { - return // Calls the deferred function, i.e. closes the connection on error - } - } -} diff --git a/src/server/server.go b/src/server/server.go index 46522a3..a6f684a 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -3,27 +3,23 @@ package server import ( "errors" "fmt" - "net/http" "strings" "time" "code.thetadev.de/TSGRain/SEBRAUC/src/rauc" + "code.thetadev.de/TSGRain/SEBRAUC/src/server/stream" "code.thetadev.de/TSGRain/SEBRAUC/src/sysinfo" "code.thetadev.de/TSGRain/SEBRAUC/src/util" "code.thetadev.de/TSGRain/SEBRAUC/ui" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/compress" - "github.com/gofiber/fiber/v2/middleware/cors" - "github.com/gofiber/fiber/v2/middleware/filesystem" - "github.com/gofiber/fiber/v2/middleware/logger" - "github.com/gofiber/websocket/v2" + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" "github.com/google/uuid" ) type SEBRAUCServer struct { address string raucUpdater *rauc.Rauc - hub *MessageHub + streamer *stream.API tmpdir string } @@ -33,9 +29,9 @@ type statusMessage struct { } func NewServer(address string) *SEBRAUCServer { - hub := NewHub() + streamer := stream.New(10*time.Second, 1*time.Second, []string{}) - raucUpdater := rauc.NewRauc(hub.Broadcast) + raucUpdater := rauc.NewRauc(streamer) tmpdir, err := util.GetTmpdir() if err != nil { @@ -45,127 +41,100 @@ func NewServer(address string) *SEBRAUCServer { return &SEBRAUCServer{ address: address, raucUpdater: raucUpdater, - hub: hub, - tmpdir: tmpdir, + // hub: hub, + streamer: streamer, + tmpdir: tmpdir, } } func (srv *SEBRAUCServer) Run() error { - app := fiber.New(fiber.Config{ - AppName: "SEBRAUC", - BodyLimit: 1024 * 1024 * 1024, - ErrorHandler: errorHandler, - DisableStartupMessage: true, - }) + router := gin.Default() - app.Use(logger.New()) - - app.Use(compress.New(compress.Config{ - Next: func(c *fiber.Ctx) bool { - return strings.HasPrefix(c.Path(), "/api") - }, - })) - - // just for testing - app.Use("/api", cors.New()) - - app.Use("/api/ws", func(c *fiber.Ctx) error { - // IsWebSocketUpgrade returns true if the client - // requested upgrade to the WebSocket protocol. - if websocket.IsWebSocketUpgrade(c) { - c.Locals("allowed", true) - return c.Next() - } - return fiber.ErrUpgradeRequired - }) - - app.Use("/", filesystem.New(filesystem.Config{ - Root: http.FS(ui.Assets), - PathPrefix: ui.AssetsDir, - MaxAge: 7200, - })) + // only for testing + router.Use(cors.Default()) // ROUTES - app.Get("/api/ws", websocket.New(srv.hub.Handler)) - app.Post("/api/update", srv.controllerUpdate) - app.Get("/api/status", srv.controllerStatus) - app.Get("/api/info", srv.controllerInfo) - app.Post("/api/reboot", srv.controllerReboot) + router.GET("/api/ws", srv.streamer.Handle) + router.GET("/api/status", srv.controllerStatus) + router.GET("/api/info", srv.controllerInfo) - // Start messaging hub - go srv.hub.Run() + router.POST("/api/update", srv.controllerUpdate) + router.POST("/api/reboot", srv.controllerReboot) - return app.Listen(srv.address) + // router.StaticFS("/", ui.GetFS()) + ui.Register(router) + + return router.Run(srv.address) } -func (srv *SEBRAUCServer) controllerUpdate(c *fiber.Ctx) error { +func (srv *SEBRAUCServer) controllerUpdate(c *gin.Context) { file, err := c.FormFile("updateFile") if err != nil { - return err + c.Error(err) + return } uid, err := uuid.NewRandom() if err != nil { - return err + c.Error(err) + return } updateFile := fmt.Sprintf("%s/update_%s.raucb", srv.tmpdir, uid.String()) - err = c.SaveFile(file, updateFile) + err = c.SaveUploadedFile(file, updateFile) if err != nil { - return err + c.Error(err) + return } err = srv.raucUpdater.RunRauc(updateFile) if err == nil { writeStatus(c, true, "Update started") } else if errors.Is(err, util.ErrAlreadyRunning) { - return fiber.NewError(fiber.StatusConflict, "already running") + c.AbortWithError(409, errors.New("already running")) } else { - return err + c.Error(err) + return } - return nil } -func (srv *SEBRAUCServer) controllerStatus(c *fiber.Ctx) error { - c.Context().SetStatusCode(200) - _ = c.JSON(srv.raucUpdater.GetStatus()) - return nil +func (srv *SEBRAUCServer) controllerStatus(c *gin.Context) { + c.JSON(200, srv.raucUpdater.GetStatus()) } -func (srv *SEBRAUCServer) controllerInfo(c *fiber.Ctx) error { +func (srv *SEBRAUCServer) controllerInfo(c *gin.Context) { info, err := sysinfo.GetSysinfo() if err != nil { - return err + c.Error(err) + } else { + c.JSON(200, info) } - - c.Context().SetStatusCode(200) - _ = c.JSON(info) - return nil } -func (srv *SEBRAUCServer) controllerReboot(c *fiber.Ctx) error { +func (srv *SEBRAUCServer) controllerReboot(c *gin.Context) { go util.Reboot(5 * time.Second) writeStatus(c, true, "System is rebooting") - return nil } -func errorHandler(c *fiber.Ctx, err error) error { +func errorHandler(c *gin.Context, err error) error { // API error handling - if strings.HasPrefix(c.Path(), "/api") { + if strings.HasPrefix(c.FullPath(), "/api") { writeStatus(c, false, err.Error()) } return err } -func writeStatus(c *fiber.Ctx, success bool, msg string) { - _ = c.JSON(statusMessage{ +func writeStatus(c *gin.Context, success bool, msg string) { + status := 200 + + if !success { + status = 500 + } + + c.JSON(status, statusMessage{ Success: success, Msg: msg, }) - - if success { - c.Context().SetStatusCode(200) - } } diff --git a/src/server/stream/client.go b/src/server/stream/client.go new file mode 100644 index 0000000..c00249f --- /dev/null +++ b/src/server/stream/client.go @@ -0,0 +1,119 @@ +package stream + +import ( + "errors" + "fmt" + "time" + + "github.com/gorilla/websocket" +) + +const ( + writeWait = 2 * time.Second +) + +var ping = func(conn *websocket.Conn) error { + return conn.WriteMessage(websocket.PingMessage, nil) +} + +var writeBytes = func(conn *websocket.Conn, data []byte) error { + return conn.WriteMessage(websocket.TextMessage, data) +} + +type client struct { + conn *websocket.Conn + onClose func(*client) + write chan []byte + id uint + once once +} + +func newClient(conn *websocket.Conn, id uint, onClose func(*client)) *client { + return &client{ + conn: conn, + write: make(chan []byte, 1), + id: id, + onClose: onClose, + } +} + +// Close closes the connection. +func (c *client) Close() { + c.once.Do(func() { + c.conn.Close() + close(c.write) + }) +} + +// NotifyClose closes the connection and notifies that the connection was closed. +func (c *client) NotifyClose() { + c.once.Do(func() { + c.conn.Close() + close(c.write) + c.onClose(c) + }) +} + +// startWriteHandler starts listening on the client connection. +// As we do not need anything from the client, +// we ignore incoming messages. Leaves the loop on errors. +func (c *client) startReading(pongWait time.Duration) { + defer c.NotifyClose() + c.conn.SetReadLimit(64) + _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(appData string) error { + _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + for { + if _, _, err := c.conn.NextReader(); err != nil { + printWebSocketError("ReadError", err) + return + } + } +} + +// startWriteHandler starts the write loop. The method has the following tasks: +// * ping the client in the interval provided as parameter +// * write messages send by the channel to the client +// * on errors exit the loop. +func (c *client) startWriteHandler(pingPeriod time.Duration) { + pingTicker := time.NewTicker(pingPeriod) + defer func() { + c.NotifyClose() + pingTicker.Stop() + }() + + for { + select { + case message, ok := <-c.write: + if !ok { + return + } + + _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := writeBytes(c.conn, message); err != nil { + printWebSocketError("WriteError", err) + return + } + case <-pingTicker.C: + _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ping(c.conn); err != nil { + printWebSocketError("PingError", err) + return + } + } + } +} + +func printWebSocketError(prefix string, err error) { + var closeError *websocket.CloseError + ok := errors.As(err, &closeError) + + if ok && closeError != nil && (closeError.Code == 1000 || closeError.Code == 1001) { + // normal closure + return + } + + fmt.Println("WebSocket:", prefix, err) +} diff --git a/src/server/stream/hub.go b/src/server/stream/hub.go new file mode 100644 index 0000000..11541cc --- /dev/null +++ b/src/server/stream/hub.go @@ -0,0 +1 @@ +package stream diff --git a/src/server/stream/hub_test.go b/src/server/stream/hub_test.go new file mode 100644 index 0000000..11541cc --- /dev/null +++ b/src/server/stream/hub_test.go @@ -0,0 +1 @@ +package stream diff --git a/src/server/stream/once.go b/src/server/stream/once.go new file mode 100644 index 0000000..2df2523 --- /dev/null +++ b/src/server/stream/once.go @@ -0,0 +1,38 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stream + +import ( + "sync" + "sync/atomic" +) + +// Modified version of sync.Once +// (https://github.com/golang/go/blob/master/src/sync/once.go) +// This version unlocks the mutex early and therefore doesn't +// hold the lock while executing func f(). +type once struct { + m sync.Mutex + done uint32 +} + +func (o *once) Do(f func()) { + if atomic.LoadUint32(&o.done) == 1 { + return + } + if o.mayExecute() { + f() + } +} + +func (o *once) mayExecute() bool { + o.m.Lock() + defer o.m.Unlock() + if o.done == 0 { + atomic.StoreUint32(&o.done, 1) + return true + } + return false +} diff --git a/src/server/stream/once_test.go b/src/server/stream/once_test.go new file mode 100644 index 0000000..53ec08d --- /dev/null +++ b/src/server/stream/once_test.go @@ -0,0 +1,43 @@ +package stream + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_Execute(t *testing.T) { + executeOnce := once{} + execution := make(chan struct{}) + fExecute := func() { + execution <- struct{}{} + } + go executeOnce.Do(fExecute) + go executeOnce.Do(fExecute) + + select { + case <-execution: + // expected + case <-time.After(100 * time.Millisecond): + t.Fatal("fExecute should be executed once") + } + + select { + case <-execution: + t.Fatal("should only execute once") + case <-time.After(100 * time.Millisecond): + // expected + } + + assert.False(t, executeOnce.mayExecute()) + + go executeOnce.Do(fExecute) + + select { + case <-execution: + t.Fatal("should only execute once") + case <-time.After(100 * time.Millisecond): + // expected + } +} diff --git a/src/server/stream/stream.go b/src/server/stream/stream.go new file mode 100644 index 0000000..6133eeb --- /dev/null +++ b/src/server/stream/stream.go @@ -0,0 +1,187 @@ +package stream + +import ( + "net/http" + "net/url" + "regexp" + "strings" + "sync" + "time" + + "code.thetadev.de/TSGRain/SEBRAUC/src/util" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +// The API provides a handler for a WebSocket stream API. +type API struct { + clients map[uint]*client + lock sync.RWMutex + pingPeriod time.Duration + pongTimeout time.Duration + upgrader *websocket.Upgrader + counter *util.Counter +} + +// New creates a new instance of API. +// pingPeriod: is the interval, in which is server sends the a ping to the client. +// pongTimeout: is the duration after the connection will be terminated, +// when the client does not respond with the pong command. +func New(pingPeriod, pongTimeout time.Duration, allowedWebSocketOrigins []string) *API { + return &API{ + clients: make(map[uint]*client), + pingPeriod: pingPeriod, + pongTimeout: pingPeriod + pongTimeout, + upgrader: newUpgrader(allowedWebSocketOrigins), + counter: &util.Counter{}, + } +} + +// NotifyDeletedUser closes existing connections for the given user. +func (a *API) NotifyDeletedClient(userID uint) error { + a.lock.Lock() + defer a.lock.Unlock() + if client, ok := a.clients[userID]; ok { + client.Close() + delete(a.clients, userID) + } + return nil +} + +// Notify notifies the clients with the given userID that a new messages was created. +func (a *API) Notify(userID uint, msg []byte) { + a.lock.RLock() + defer a.lock.RUnlock() + if client, ok := a.clients[userID]; ok { + client.write <- msg + } +} + +func (a *API) Broadcast(msg []byte) { + a.lock.RLock() + defer a.lock.RUnlock() + for _, client := range a.clients { + client.write <- msg + } +} + +func (a *API) remove(remove *client) { + a.lock.Lock() + defer a.lock.Unlock() + delete(a.clients, remove.id) +} + +func (a *API) register(client *client) { + a.lock.Lock() + defer a.lock.Unlock() + a.clients[client.id] = client +} + +// Handle handles incoming requests. +// First it upgrades the protocol to the WebSocket protocol and then starts listening +// for read and writes. +// swagger:operation GET /stream message streamMessages +// +// Websocket, return newly created messages. +// +// --- +// schema: ws, wss +// produces: [application/json] +// security: [clientTokenHeader: [], clientTokenQuery: [], basicAuth: []] +// responses: +// 200: +// description: Ok +// schema: +// $ref: "#/definitions/Message" +// 400: +// description: Bad Request +// schema: +// $ref: "#/definitions/Error" +// 401: +// description: Unauthorized +// schema: +// $ref: "#/definitions/Error" +// 403: +// description: Forbidden +// schema: +// $ref: "#/definitions/Error" +// 500: +// description: Server Error +// schema: +// $ref: "#/definitions/Error" +func (a *API) Handle(ctx *gin.Context) { + conn, err := a.upgrader.Upgrade(ctx.Writer, ctx.Request, nil) + if err != nil { + ctx.Error(err) + return + } + + client := newClient(conn, a.counter.Increment(), a.remove) + a.register(client) + go client.startReading(a.pongTimeout) + go client.startWriteHandler(a.pingPeriod) +} + +// Close closes all client connections and stops answering new connections. +func (a *API) Close() { + a.lock.Lock() + defer a.lock.Unlock() + + for _, client := range a.clients { + client.Close() + } + for k := range a.clients { + delete(a.clients, k) + } +} + +func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool { + origin := r.Header.Get("origin") + if origin == "" { + return true + } + + u, err := url.Parse(origin) + if err != nil { + return false + } + + if strings.EqualFold(u.Host, r.Host) { + return true + } + + for _, allowedOrigin := range allowedOrigins { + if allowedOrigin.Match([]byte(strings.ToLower(u.Hostname()))) { + return true + } + } + + return false +} + +func newUpgrader(allowedWebSocketOrigins []string) *websocket.Upgrader { + // compiledAllowedOrigins := compileAllowedWebSocketOrigins(allowedWebSocketOrigins) + return &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + /* + TODO: implement this + if mode.IsDev() { + return true + } + return isAllowedOrigin(r, compiledAllowedOrigins) + */ + return true + }, + } +} + +func compileAllowedWebSocketOrigins(allowedOrigins []string) []*regexp.Regexp { + var compiledAllowedOrigins []*regexp.Regexp + for _, origin := range allowedOrigins { + compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin)) + } + + return compiledAllowedOrigins +} diff --git a/src/server/stream/stream_test.go b/src/server/stream/stream_test.go new file mode 100644 index 0000000..fd779b1 --- /dev/null +++ b/src/server/stream/stream_test.go @@ -0,0 +1,424 @@ +package stream + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/fortytw2/leaktest" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" +) + +func TestFailureOnNormalHttpRequest(t *testing.T) { + // mode.Set(mode.TestDev) + + defer leaktest.Check(t)() + + server, api := bootTestServer() + defer server.Close() + defer api.Close() + + resp, err := http.Get(server.URL) + assert.Nil(t, err) + assert.Equal(t, 400, resp.StatusCode) + resp.Body.Close() +} + +func TestWriteMessageFails(t *testing.T) { + // mode.Set(mode.TestDev) + oldWrite := writeBytes + // try emulate an write error, mostly this should kill the ReadMessage + // goroutine first but you'll never know. + writeBytes = func(conn *websocket.Conn, data []byte) error { + return errors.New("asd") + } + defer func() { + writeBytes = oldWrite + }() + defer leaktest.Check(t)() + + server, api := bootTestServer() + defer server.Close() + defer api.Close() + + wsURL := wsURL(server.URL) + user := testClient(t, wsURL) + + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + client := getClient(api, 1) + assert.NotNil(t, client) + + api.Notify(1, []byte("HI")) + user.expectNoMessage() +} + +func TestWritePingFails(t *testing.T) { + // mode.Set(mode.TestDev) + oldPing := ping + // try emulate an write error, mostly this should kill the ReadMessage + // gorouting first but you'll never know. + ping = func(conn *websocket.Conn) error { + return errors.New("asd") + } + defer func() { + ping = oldPing + }() + + defer leaktest.CheckTimeout(t, 10*time.Second)() + + server, api := bootTestServer() + defer api.Close() + defer server.Close() + + wsURL := wsURL(server.URL) + user := testClient(t, wsURL) + defer user.conn.Close() + + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + client := getClient(api, 1) + + assert.NotNil(t, client) + + time.Sleep(api.pingPeriod) // waiting for ping + + api.Notify(1, []byte("HI")) + user.expectNoMessage() +} + +func TestPing(t *testing.T) { + // mode.Set(mode.TestDev) + + server, api := bootTestServer() + defer server.Close() + defer api.Close() + + wsURL := wsURL(server.URL) + + user := createClient(t, wsURL) + defer user.conn.Close() + + ping := make(chan bool) + oldPingHandler := user.conn.PingHandler() + user.conn.SetPingHandler(func(appData string) error { + err := oldPingHandler(appData) + ping <- true + return err + }) + + startReading(user) + + expectNoMessage(user) + + select { + case <-time.After(2 * time.Second): + assert.Fail(t, "Expected ping but there was one :(") + case <-ping: + // expected + } + + expectNoMessage(user) + api.Notify(1, []byte("HI")) + user.expectMessage([]byte("HI")) +} + +func TestCloseClientOnNotReading(t *testing.T) { + // mode.Set(mode.TestDev) + + server, api := bootTestServer() + defer server.Close() + defer api.Close() + + wsURL := wsURL(server.URL) + + ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + assert.Nil(t, err) + resp.Body.Close() + defer ws.Close() + + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + assert.NotNil(t, getClient(api, 1)) + + time.Sleep(api.pingPeriod + api.pongTimeout) + + assert.Nil(t, getClient(api, 1)) +} + +func TestMessageDirectlyAfterConnect(t *testing.T) { + // mode.Set(mode.Prod) + defer leaktest.Check(t)() + server, api := bootTestServer() + defer server.Close() + defer api.Close() + + wsURL := wsURL(server.URL) + + user := testClient(t, wsURL) + defer user.conn.Close() + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + api.Notify(1, []byte("msg")) + user.expectMessage([]byte("msg")) +} + +func TestDeleteClientShouldCloseConnection(t *testing.T) { + // mode.Set(mode.Prod) + defer leaktest.Check(t)() + server, api := bootTestServer() + defer server.Close() + defer api.Close() + + wsURL := wsURL(server.URL) + + user := testClient(t, wsURL) + defer user.conn.Close() + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + api.Notify(1, []byte("HI")) + user.expectMessage([]byte("HI")) + + assert.Nil(t, api.NotifyDeletedClient(1)) + + api.Notify(1, []byte("HI")) + user.expectNoMessage() +} + +func TestNotify(t *testing.T) { + // mode.Set(mode.TestDev) + + defer leaktest.Check(t)() + server, api := bootTestServer() + defer server.Close() + + wsURL := wsURL(server.URL) + + client1 := testClient(t, wsURL) + defer client1.conn.Close() + + client2 := testClient(t, wsURL) + defer client2.conn.Close() + + client3 := testClient(t, wsURL) + defer client3.conn.Close() + + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + + api.Notify(1, []byte("msg")) + expectMessage([]byte("msg"), client1) + expectNoMessage(client2) + expectNoMessage(client3) + + assert.Nil(t, api.NotifyDeletedClient(1)) + + api.Notify(1, []byte("msg")) + expectNoMessage(client1) + expectNoMessage(client2) + expectNoMessage(client3) + + api.Notify(2, []byte("msg")) + expectNoMessage(client1) + expectMessage([]byte("msg"), client2) + expectNoMessage(client3) + + api.Notify(3, []byte("msg")) + expectNoMessage(client1) + expectNoMessage(client2) + expectMessage([]byte("msg"), client3) + + api.Close() +} + +func TestBroadcast(t *testing.T) { + defer leaktest.Check(t)() + server, api := bootTestServer() + defer server.Close() + + wsURL := wsURL(server.URL) + + client1 := testClient(t, wsURL) + defer client1.conn.Close() + + client2 := testClient(t, wsURL) + defer client2.conn.Close() + + client3 := testClient(t, wsURL) + defer client3.conn.Close() + + // the server may take some time to register the client + time.Sleep(100 * time.Millisecond) + + testMsg1 := []byte("hello1") + api.Broadcast(testMsg1) + expectMessage(testMsg1, client1, client2, client3) + + assert.Nil(t, api.NotifyDeletedClient(1)) + + testMsg2 := []byte("hello2") + api.Broadcast(testMsg2) + expectNoMessage(client1) + expectMessage(testMsg2, client2, client3) +} + +func Test_sameOrigin_returnsTrue(t *testing.T) { + // mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + req.Header.Set("Origin", "http://example.com") + actual := isAllowedOrigin(req, nil) + assert.True(t, actual) +} + +func Test_sameOrigin_returnsTrue_withCustomPort(t *testing.T) { + // mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com:8080/stream", nil) + req.Header.Set("Origin", "http://example.com:8080") + actual := isAllowedOrigin(req, nil) + assert.True(t, actual) +} + +func Test_isAllowedOrigin_withoutAllowedOrigins_failsWhenNotSameOrigin(t *testing.T) { + // mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + req.Header.Set("Origin", "http://gorify.example.com") + actual := isAllowedOrigin(req, nil) + assert.False(t, actual) +} + +func Test_isAllowedOriginMatching(t *testing.T) { + // mode.Set(mode.Prod) + compiledAllowedOrigins := compileAllowedWebSocketOrigins( + []string{"go.{4}\\.example\\.com", "go\\.example\\.com"}, + ) + + req := httptest.NewRequest("GET", "http://example.me/stream", nil) + req.Header.Set("Origin", "http://gorify.example.com") + assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins)) + + req.Header.Set("Origin", "http://go.example.com") + assert.True(t, isAllowedOrigin(req, compiledAllowedOrigins)) + + req.Header.Set("Origin", "http://hello.example.com") + assert.False(t, isAllowedOrigin(req, compiledAllowedOrigins)) +} + +func Test_emptyOrigin_returnsTrue(t *testing.T) { + // mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + actual := isAllowedOrigin(req, nil) + assert.True(t, actual) +} + +func Test_otherOrigin_returnsFalse(t *testing.T) { + // mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + req.Header.Set("Origin", "http://otherexample.de") + actual := isAllowedOrigin(req, nil) + assert.False(t, actual) +} + +func Test_invalidOrigin_returnsFalse(t *testing.T) { + // mode.Set(mode.Prod) + req := httptest.NewRequest("GET", "http://example.com/stream", nil) + req.Header.Set("Origin", "http\\://otherexample.de") + actual := isAllowedOrigin(req, nil) + assert.False(t, actual) +} + +func Test_compileAllowedWebSocketOrigins(t *testing.T) { + assert.Equal(t, 0, len(compileAllowedWebSocketOrigins([]string{}))) + assert.Equal(t, 3, len(compileAllowedWebSocketOrigins([]string{"^.*$", "", "abc"}))) +} + +func getClient(api *API, user uint) *client { + api.lock.RLock() + defer api.lock.RUnlock() + + return api.clients[user] +} + +func testClient(t *testing.T, url string) *testingClient { + client := createClient(t, url) + startReading(client) + return client +} + +func startReading(client *testingClient) { + go func() { + for { + _, payload, err := client.conn.ReadMessage() + if err != nil { + return + } + + client.readMessage <- payload + } + }() +} + +func createClient(t *testing.T, url string) *testingClient { + ws, resp, err := websocket.DefaultDialer.Dial(url, nil) + assert.Nil(t, err) + resp.Body.Close() + + readMessages := make(chan []byte) + + return &testingClient{conn: ws, readMessage: readMessages, t: t} +} + +type testingClient struct { + conn *websocket.Conn + readMessage chan []byte + t *testing.T +} + +func (c *testingClient) expectMessage(expected []byte) { + select { + case <-time.After(50 * time.Millisecond): + assert.Fail(c.t, "Expected message but none was send :(") + case actual := <-c.readMessage: + assert.Equal(c.t, expected, actual) + } +} + +func expectMessage(expected []byte, clients ...*testingClient) { + for _, client := range clients { + client.expectMessage(expected) + } +} + +func expectNoMessage(clients ...*testingClient) { + for _, client := range clients { + client.expectNoMessage() + } +} + +func (c *testingClient) expectNoMessage() { + select { + case <-time.After(50 * time.Millisecond): + // no message == as expected + case msg := <-c.readMessage: + assert.Fail(c.t, "Expected NO message but there was one :(", fmt.Sprint(msg)) + } +} + +func bootTestServer() (*httptest.Server, *API) { + r := gin.New() + // ping every 500 ms, and the client has 500 ms to respond + api := New(500*time.Millisecond, 500*time.Millisecond, []string{}) + + r.GET("/", api.Handle) + server := httptest.NewServer(r) + return server, api +} + +func wsURL(httpURL string) string { + return "ws" + strings.TrimPrefix(httpURL, "http") +} diff --git a/src/util/counter.go b/src/util/counter.go new file mode 100644 index 0000000..9e264ee --- /dev/null +++ b/src/util/counter.go @@ -0,0 +1,30 @@ +package util + +import "sync" + +type Counter struct { + count uint + mutex sync.RWMutex +} + +func (c *Counter) Get() uint { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.count +} + +func (c *Counter) Reset() { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.count = 0 +} + +func (c *Counter) Increment() uint { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.count++ + return c.count +} diff --git a/src/util/counter_test.go b/src/util/counter_test.go new file mode 100644 index 0000000..bfc3e9d --- /dev/null +++ b/src/util/counter_test.go @@ -0,0 +1,30 @@ +package util + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCounter(t *testing.T) { + counter := Counter{} + + var wg sync.WaitGroup + + incrementer := func() { + for i := 0; i < 50; i++ { + counter.Increment() + } + wg.Done() + } + + for i := 0; i < 100; i++ { + wg.Add(1) + go incrementer() + } + + wg.Wait() + + assert.EqualValues(t, 5000, counter.Get()) +} diff --git a/ui/index.html b/ui/index.html index cf876c6..426ad0f 100644 --- a/ui/index.html +++ b/ui/index.html @@ -7,7 +7,11 @@ SEBRAUC +
+ diff --git a/ui/src/components/Updater/UpdaterCard.tsx b/ui/src/components/Updater/UpdaterCard.tsx index 6873897..9a38b95 100644 --- a/ui/src/components/Updater/UpdaterCard.tsx +++ b/ui/src/components/Updater/UpdaterCard.tsx @@ -7,8 +7,9 @@ import Icon from "../Icon/Icon" import "./Updater.scss" import Alert from "./Alert" import Reboot from "./Reboot" -import {sebraucApi, wsUrl} from "../../util/apiUrls" +import {sebraucApi} from "../../util/apiUrls" import colors from "../../util/colors" +import WebsocketClient from "../../util/websocket" class UploadStatus { uploading = false @@ -52,19 +53,19 @@ type State = { export default class UpdaterCard extends Component { private dropzoneRef = createRef() - private conn: WebSocket | undefined + private ws: WebsocketClient constructor(props?: Props | undefined, context?: any) { super(props, context) + this.ws = new WebsocketClient(this.onWsStatusUpdate, this.onWsMessage) + this.state = { uploadStatus: new UploadStatus(false), uploadFilename: "", raucStatus: new RaucStatus(), - wsConnected: false, + wsConnected: this.ws.api().isConnected(), } - - this.connectWebsocket() } private buttonClick = () => { @@ -105,31 +106,16 @@ export default class UpdaterCard extends Component { this.dropzoneRef.current?.reset() } - private connectWebsocket = () => { - if (window.WebSocket) { - this.conn = new WebSocket(wsUrl) - this.conn.onopen = () => { - this.setState({wsConnected: true}) - console.log("WS connected") - } - this.conn.onclose = () => { - this.setState({wsConnected: false}) - console.log("WS connection closed") - window.setTimeout(this.connectWebsocket, 3000) - } - this.conn.onmessage = (evt) => { - var messages = evt.data.split("\n") - for (var i = 0; i < messages.length; i++) { - this.setState({ - raucStatus: Object.assign( - new RaucStatus(), - JSON.parse(messages[i]) - ), - }) - } - } - } else { - console.log("Your browser does not support WebSockets") + private onWsStatusUpdate = (wsConnected: boolean) => { + this.setState({wsConnected: wsConnected}) + } + + private onWsMessage = (evt: MessageEvent) => { + var messages = evt.data.split("\n") + for (var i = 0; i < messages.length; i++) { + this.setState({ + raucStatus: Object.assign(new RaucStatus(), JSON.parse(messages[i])), + }) } } @@ -167,6 +153,10 @@ export default class UpdaterCard extends Component { return 0 } + componentWillUnmount() { + this.ws.destroy() + } + render() { const acceptUploads = this.acceptUploads() const circleColor = this.circleColor() diff --git a/ui/src/components/app.tsx b/ui/src/components/app.tsx index b068bc3..c2bd70c 100644 --- a/ui/src/components/app.tsx +++ b/ui/src/components/app.tsx @@ -1,14 +1,14 @@ import {Component} from "preact" import UpdaterView from "./Updater/UpdaterView" import logo from "../assets/logo.svg" -import {version} from "../util/version" +import {getConfig} from "../util/config" export default class App extends Component { render() { return (
SEBRAUC - {version} + {getConfig().version}
) diff --git a/ui/src/util/config.ts b/ui/src/util/config.ts new file mode 100644 index 0000000..1e0e124 --- /dev/null +++ b/ui/src/util/config.ts @@ -0,0 +1,23 @@ +export interface Config { + version: string +} + +// eslint-disable-next-line @typescript-eslint/no-unused-vars +declare global { + interface Window { + config?: any + } +} + +function isConfig(object: any): object is Config { + return typeof object === "object" && "version" in object +} + +export function getConfig(): Config { + if (isConfig(window.config)) { + return window.config + } + return { + version: "dev", + } +} diff --git a/ui/src/util/version.ts b/ui/src/util/version.ts deleted file mode 100644 index 9991968..0000000 --- a/ui/src/util/version.ts +++ /dev/null @@ -1,7 +0,0 @@ -let version = import.meta.env.VITE_VERSION - -if (version === undefined) { - version = "unknown" -} - -export {version} diff --git a/ui/src/util/websocket.ts b/ui/src/util/websocket.ts new file mode 100644 index 0000000..8d75673 --- /dev/null +++ b/ui/src/util/websocket.ts @@ -0,0 +1,92 @@ +import {wsUrl} from "./apiUrls" + +class WebsocketAPI { + private static ws: WebsocketAPI | undefined + + private conn: WebSocket | undefined + private wsConnected: boolean + + private clients: Set + + private constructor() { + this.clients = new Set() + this.wsConnected = false + + if (window.WebSocket) { + this.connect() + } else { + console.log("Your browser does not support WebSockets") + } + } + + private setStatus(wsConnected: boolean) { + if (wsConnected !== this.wsConnected) { + this.wsConnected = wsConnected + this.clients.forEach((client) => { + client.statusCallback(this.wsConnected) + }) + } + } + + private connect() { + this.conn = new WebSocket(wsUrl) + this.conn.onopen = () => { + this.setStatus(true) + console.log("WS connected") + } + this.conn.onclose = () => { + this.setStatus(false) + console.log("WS connection closed") + window.setTimeout(() => this.connect(), 3000) + } + this.conn.onmessage = (evt) => { + this.clients.forEach((client) => { + client.msgCallback(evt) + }) + } + } + + static Get(): WebsocketAPI { + if (this.ws === undefined) { + this.ws = new WebsocketAPI() + } + return this.ws + } + + isConnected(): boolean { + return this.wsConnected + } + + addClient(client: WebsocketClient) { + console.log("added client", client) + this.clients.add(client) + } + + removeClient(client: WebsocketClient) { + console.log("removed client", client) + this.clients.delete(client) + } +} + +export default class WebsocketClient { + statusCallback: (wsConnected: boolean) => void + msgCallback: (evt: MessageEvent) => void + + constructor( + statusCallback: (wsConnected: boolean) => void, + msgCallback: (evt: MessageEvent) => void + ) { + this.statusCallback = statusCallback + this.msgCallback = msgCallback + + this.api().addClient(this) + } + + api(): WebsocketAPI { + return WebsocketAPI.Get() + } + + destroy() { + this.api().removeClient(this) + } +} diff --git a/ui/ui.go b/ui/ui.go index 1d28f93..a10fda6 100644 --- a/ui/ui.go +++ b/ui/ui.go @@ -1,10 +1,64 @@ package ui import ( + "bytes" "embed" + "encoding/json" + "io/fs" + "net/http" + + "code.thetadev.de/TSGRain/SEBRAUC/src/util" + "code.thetadev.de/TSGRain/ginzip" + "github.com/gin-gonic/gin" ) -const AssetsDir = "dist" +const distDir = "dist" //go:embed dist/** -var Assets embed.FS +var assets embed.FS + +type uiConfig struct { + Version string `json:"version"` +} + +func subFS(fsys fs.FS, dir string) fs.FS { + sub, err := fs.Sub(fsys, dir) + if err != nil { + panic(err) + } + return sub +} + +func distFS() fs.FS { + return subFS(assets, distDir) +} + +func Register(r *gin.Engine) { + indexHandler := getIndexHandler() + + ui := r.Group("/", ginzip.New(ginzip.DefaultOptions())) + + ui.GET("/", indexHandler) + ui.GET("/index.html", indexHandler) + + ui.StaticFS("/assets", http.FS(subFS(distFS(), "assets"))) +} + +func getIndexHandler() gin.HandlerFunc { + content, err := fs.ReadFile(distFS(), "index.html") + if err != nil { + panic(err) + } + + uiConfigBytes, err := json.Marshal(uiConfig{ + Version: util.Version(), + }) + if err != nil { + panic(err) + } + content = bytes.ReplaceAll(content, []byte("\"%CONFIG%\""), uiConfigBytes) + + return func(c *gin.Context) { + c.Data(200, "text/html", content) + } +}