Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 3 additions & 18 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,32 +84,17 @@ var registerWebAPITemplate = template.Must(
`))

// RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register/:nkey.
// Listens in /register/:nodeKey.
//
// This is not part of the Tailscale control API, as we could send whatever URL
// in the RegisterResponse.AuthURL field.

func (h *Headscale) RegisterWebAPI(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]

if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}

return
}
nodeKeyStr, ok := vars["nodeKey"]

// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
Expand Down
8 changes: 7 additions & 1 deletion app.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ import (
"sync"
"syscall"
"time"
re "regexp"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
httpu "github.com/juanfont/headscale/http_utils"
"github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog"
"github.com/prometheus/client_golang/prometheus/promhttp"
Expand Down Expand Up @@ -447,7 +449,6 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {

router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost)
router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost)
router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet)
Expand All @@ -465,6 +466,11 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {
router.HandleFunc("/bootstrap-dns", h.DERPBootstrapDNSHandler)
}

regRouter := router.PathPrefix("/register").Subrouter()
//regRouter.Use(h.MachineKeySanitizeMiddleware)
regRouter.Use(httpu.CharWhitelistMiddlewareGenerator(re.MustCompile("[a-fA-F0-9]+"), "nodeKey", "invalid registration characters"))
regRouter.HandleFunc("/{nodeKey}", h.RegisterWebAPI).Methods(http.MethodGet)

apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(h.httpAuthenticationMiddleware)
apiRouter.PathPrefix("/v1/").HandlerFunc(grpcMux.ServeHTTP)
Expand Down
42 changes: 42 additions & 0 deletions http_utils/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package http_utils

import (
"net/http"
re "regexp"

"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
)

// CharWhitelistMiddlewareGenerator is an attempt to make it easier to add character whitelist checks to gorilla mux routes.
// Usage pattern is Route().Use(httpu.CharWhitelistMiddlewareGenerator(re.MustCompile(`re_str`), `keyName`, `logStr`))
// re_str: regular expression to compile and evaluate against
// keyName: name of the key that gorilla was told to capture during URL parsing
// logStr: message to print to log in the event of a whitelist failure

func CharWhitelistMiddlewareGenerator(matchExp *re.Regexp, keyName string, logStr string) mux.MiddlewareFunc {
return mux.MiddlewareFunc(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)
toValidate := vars[keyName]

if !matchExp.Match([]byte(toValidate)) {
// Characters that are outside of the required set have been supplied, do not serve content.
log.Warn().Str("WhitelistValidateFail", toValidate).Msg("Failed whitelist validation: " + logStr)

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
} else {
// Allow processing of content to continue.
next.ServeHTTP(writer, req)
}
})
})
}
3 changes: 0 additions & 3 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"

Expand Down Expand Up @@ -65,8 +64,6 @@ const (
ZstdCompression = "zstd"
)

var NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+")

func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix)
}
Expand Down