Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- TestOIDCAuthenticationPingAll
- TestOIDCExpireNodesBasedOnTokenExpiry
- TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin
- TestUserCommand
Expand Down
12 changes: 12 additions & 0 deletions config-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,18 @@ unix_socket_permission: "0770"
# allowed_users:
# - [email protected]
#
# # Optional: PKCE (Proof Key for Code Exchange) configuration
# # PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow
# # by preventing authorization code interception attacks
# # See https://datatracker.ietf.org/doc/html/rfc7636
# pkce:
# # Enable or disable PKCE support (default: false)
# enabled: false
# # PKCE method to use:
# # - plain: Use plain code verifier
# # - S256: Use SHA256 hashed code verifier (default, recommended)
# method: S256
#
# # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users
# # by taking the username from the legacy user and matching it with the username
# # provided by the OIDC. This is useful when migrating from legacy users to OIDC
Expand Down
12 changes: 12 additions & 0 deletions docs/ref/oidc.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ oidc:
allowed_users:
- [email protected]

# Optional: PKCE (Proof Key for Code Exchange) configuration
# PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow
# by preventing authorization code interception attacks
# See https://datatracker.ietf.org/doc/html/rfc7636
pkce:
# Enable or disable PKCE support (default: false)
enabled: false
# PKCE method to use:
# - plain: Use plain code verifier
# - S256: Use SHA256 hashed code verifier (default, recommended)
method: S256

# If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
# This will transform `[email protected]` to the user `first-name.last-name`
# If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following
Expand Down
68 changes: 53 additions & 15 deletions hscontrol/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ import (
)

const (
randomByteSize = 16
randomByteSize = 16
defaultOAuthOptionsCount = 3
)

var (
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache")
errOIDCAllowedDomains = errors.New(
"authenticated principal does not match any allowed domain",
)
Expand All @@ -47,11 +49,17 @@ var (
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
)

// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
MachineKey key.MachinePublic
Verifier *string
}

type AuthProviderOIDC struct {
serverURL string
cfg *types.OIDCConfig
db *db.HSDatabase
registrationCache *zcache.Cache[string, key.MachinePublic]
registrationCache *zcache.Cache[string, RegistrationInfo]
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
polMan policy.PolicyManager
Expand Down Expand Up @@ -87,7 +95,7 @@ func NewAuthProviderOIDC(
Scopes: cfg.Scope,
}

registrationCache := zcache.New[string, key.MachinePublic](
registrationCache := zcache.New[string, RegistrationInfo](
registerCacheExpiration,
registerCacheCleanup,
)
Expand Down Expand Up @@ -157,19 +165,36 @@ func (a *AuthProviderOIDC) RegisterHandler(

stateStr := hex.EncodeToString(randomBlob)[:32]

// place the node key into the state cache, so it can be retrieved later
a.registrationCache.Set(
stateStr,
machineKey,
)
// Initialize registration info with machine key
registrationInfo := RegistrationInfo{
MachineKey: machineKey,
}

// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
// Add PKCE verification if enabled
if a.cfg.PKCE.Enabled {
verifier := oauth2.GenerateVerifier()
registrationInfo.Verifier = &verifier

extras = append(extras, oauth2.AccessTypeOffline)

switch a.cfg.PKCE.Method {
case types.PKCEMethodS256:
extras = append(extras, oauth2.S256ChallengeOption(verifier))
case types.PKCEMethodPlain:
// oauth2 does not have a plain challenge option, so we add it manually
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
}
}

// Add any extra parameters from configuration
for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}

// Cache the registration info
a.registrationCache.Set(stateStr, registrationInfo)

authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)

Expand Down Expand Up @@ -203,7 +228,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}

idToken, err := a.extractIDToken(req.Context(), code)
idToken, err := a.extractIDToken(req.Context(), code, state)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
Expand Down Expand Up @@ -318,8 +343,21 @@ func extractCodeAndStateParamFromRequest(
func (a *AuthProviderOIDC) extractIDToken(
ctx context.Context,
code string,
state string,
) (*oidc.IDToken, error) {
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
var exchangeOpts []oauth2.AuthCodeOption

if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, errNoOIDCRegistrationInfo
}
if regInfo.Verifier != nil {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
}
}

oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
if err != nil {
return nil, fmt.Errorf("could not exchange code for token: %w", err)
}
Expand Down Expand Up @@ -394,7 +432,7 @@ func validateOIDCAllowedUsers(
// cache. If the machine key is found, it will try retrieve the
// node information from the database.
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
machineKey, ok := a.registrationCache.Get(state)
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, nil
}
Expand All @@ -403,9 +441,9 @@ func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *k
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := a.db.GetNodeByMachineKey(machineKey)
node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey)

return node, &machineKey
return node, &regInfo.MachineKey
}

// reauthenticateNode updates the node expiry in the database
Expand Down
28 changes: 28 additions & 0 deletions hscontrol/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ import (
const (
defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days
maxDuration time.Duration = 1<<63 - 1
PKCEMethodPlain string = "plain"
PKCEMethodS256 string = "S256"
)

var (
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
)

type IPAllocationStrategy string
Expand Down Expand Up @@ -162,6 +165,11 @@ type LetsEncryptConfig struct {
ChallengeType string
}

type PKCEConfig struct {
Enabled bool
Method string
}

type OIDCConfig struct {
OnlyStartIfOIDCIsAvailable bool
Issuer string
Expand All @@ -176,6 +184,7 @@ type OIDCConfig struct {
Expiry time.Duration
UseExpiryFromToken bool
MapLegacyUsers bool
PKCE PKCEConfig
}

type DERPConfig struct {
Expand Down Expand Up @@ -226,6 +235,13 @@ type Tuning struct {
NodeMapSessionBufferedChanSize int
}

func validatePKCEMethod(method string) error {
if method != PKCEMethodPlain && method != PKCEMethodS256 {
return errInvalidPKCEMethod
}
return nil
}

// LoadConfig prepares and loads the Headscale configuration into Viper.
// This means it sets the default values, reads the configuration file and
// environment variables, and handles deprecated configuration options.
Expand Down Expand Up @@ -293,6 +309,8 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("oidc.map_legacy_users", true)
viper.SetDefault("oidc.pkce.enabled", false)
viper.SetDefault("oidc.pkce.method", "S256")

viper.SetDefault("logtail.enabled", false)
viper.SetDefault("randomize_client_port", false)
Expand Down Expand Up @@ -340,6 +358,12 @@ func validateServerConfig() error {
// after #2170 is cleaned up
// depr.fatal("oidc.strip_email_domain")

if viper.GetBool("oidc.enabled") {
if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil {
return err
}
}

depr.Log()

for _, removed := range []string{
Expand Down Expand Up @@ -928,6 +952,10 @@ func LoadServerConfig() (*Config, error) {
// after #2170 is cleaned up
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
PKCE: PKCEConfig{
Enabled: viper.GetBool("oidc.pkce.enabled"),
Method: viper.GetString("oidc.pkce.method"),
},
},

LogTail: logTailConfig,
Expand Down
Loading