Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,29 @@ AND auth_key_id NOT IN (
},
Rollback: func(db *gorm.DB) error { return nil },
},
// Fix the provider identifier for users that have a double slash in the
// provider identifier.
{
ID: "202505141324",
Migrate: func(tx *gorm.DB) error {
users, err := ListUsers(tx)
if err != nil {
return fmt.Errorf("listing users: %w", err)
}

for _, user := range users {
user.ProviderIdentifier.String = types.CleanIdentifier(user.ProviderIdentifier.String)

err := tx.Save(user).Error
if err != nil {
return fmt.Errorf("saving user: %w", err)
}
}

return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

Expand Down
113 changes: 108 additions & 5 deletions hscontrol/types/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,110 @@ type OIDCClaims struct {
Username string `json:"preferred_username,omitempty"`
}

// Identifier returns a unique identifier string combining the Iss and Sub claims.
// The format depends on whether Iss is a URL or not:
// - For URLs: Joins the URL and sub path (e.g., "https://example.com/sub")
// - For non-URLs: Joins with a slash (e.g., "oidc/sub")
// - For empty Iss: Returns just "sub"
// - For empty Sub: Returns just the Issuer
// - For both empty: Returns empty string
//
// The result is cleaned using CleanIdentifier() to ensure consistent formatting.
func (c *OIDCClaims) Identifier() string {
if strings.HasPrefix(c.Iss, "http") {
if i, err := url.JoinPath(c.Iss, c.Sub); err == nil {
return i
// Handle empty components special cases
if c.Iss == "" && c.Sub == "" {
return ""
}
if c.Iss == "" {
return CleanIdentifier(c.Sub)
}
if c.Sub == "" {
return CleanIdentifier(c.Iss)
}

// We'll use the raw values and let CleanIdentifier handle all the whitespace
issuer := c.Iss
subject := c.Sub

var result string
// Try to parse as URL to handle URL joining correctly
if u, err := url.Parse(issuer); err == nil && u.Scheme != "" {
// For URLs, use proper URL path joining
if joined, err := url.JoinPath(issuer, subject); err == nil {
result = joined
}
}
return c.Iss + "/" + c.Sub

// If URL joining failed or issuer wasn't a URL, do simple string join
if result == "" {
// Default case: simple string joining with slash
issuer = strings.TrimSuffix(issuer, "/")
subject = strings.TrimPrefix(subject, "/")
result = issuer + "/" + subject
}

// Clean the result and return it
return CleanIdentifier(result)
}

// CleanIdentifier cleans a potentially malformed identifier by removing double slashes
// while preserving protocol specifications like http://. This function will:
// - Trim all whitespace from the beginning and end of the identifier
// - Remove whitespace within path segments
// - Preserve the scheme (http://, https://, etc.) for URLs
// - Remove any duplicate slashes in the path
// - Remove empty path segments
// - For non-URL identifiers, it joins non-empty segments with a single slash
// - Returns empty string for identifiers with only slashes
// - Normalize URL schemes to lowercase
func CleanIdentifier(identifier string) string {
if identifier == "" {
return identifier
}

// Trim leading/trailing whitespace
identifier = strings.TrimSpace(identifier)

// Handle URLs with schemes
u, err := url.Parse(identifier)
if err == nil && u.Scheme != "" {
// Clean path by removing empty segments and whitespace within segments
parts := strings.FieldsFunc(u.Path, func(c rune) bool { return c == '/' })
for i, part := range parts {
parts[i] = strings.TrimSpace(part)
}
// Remove empty parts after trimming
cleanParts := make([]string, 0, len(parts))
for _, part := range parts {
if part != "" {
cleanParts = append(cleanParts, part)
}
}

if len(cleanParts) == 0 {
u.Path = ""
} else {
u.Path = "/" + strings.Join(cleanParts, "/")
}
// Ensure scheme is lowercase
u.Scheme = strings.ToLower(u.Scheme)
return u.String()
}

// Handle non-URL identifiers
parts := strings.FieldsFunc(identifier, func(c rune) bool { return c == '/' })
// Clean whitespace from each part
cleanParts := make([]string, 0, len(parts))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
cleanParts = append(cleanParts, trimmed)
}
}
if len(cleanParts) == 0 {
return ""
}
return strings.Join(cleanParts, "/")
}

type OIDCUserInfo struct {
Expand Down Expand Up @@ -231,7 +328,13 @@ func (u *User) FromClaim(claims *OIDCClaims) {
}
}

u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true}
// Get provider identifier
identifier := claims.Identifier()
// Ensure provider identifier always has a leading slash for backward compatibility
if claims.Iss == "" && !strings.HasPrefix(identifier, "/") {
identifier = "/" + identifier
}
u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true}
u.DisplayName = claims.Name
u.ProfilePicURL = claims.ProfilePictureURL
u.Provider = util.RegisterMethodOIDC
Expand Down
213 changes: 213 additions & 0 deletions hscontrol/types/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
)

func TestUnmarshallOIDCClaims(t *testing.T) {
Expand Down Expand Up @@ -76,6 +77,218 @@ func TestUnmarshallOIDCClaims(t *testing.T) {
}
}

func TestOIDCClaimsIdentifier(t *testing.T) {
tests := []struct {
name string
iss string
sub string
expected string
}{
{
name: "standard URL with trailing slash",
iss: "https://oidc.example.com/",
sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
},
{
name: "standard URL without trailing slash",
iss: "https://oidc.example.com",
sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
},
{
name: "standard URL with uppercase protocol",
iss: "HTTPS://oidc.example.com/",
sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
},
{
name: "standard URL with path and trailing slash",
iss: "https://login.microsoftonline.com/v2.0/",
sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
},
{
name: "standard URL with path without trailing slash",
iss: "https://login.microsoftonline.com/v2.0",
sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
},
{
name: "non-URL identifier with slash",
iss: "oidc",
sub: "sub",
expected: "oidc/sub",
},
{
name: "non-URL identifier with trailing slash",
iss: "oidc/",
sub: "sub",
expected: "oidc/sub",
},
{
name: "subject with slash",
iss: "oidc/",
sub: "sub/",
expected: "oidc/sub",
},
{
name: "whitespace",
iss: " oidc/ ",
sub: " sub ",
expected: "oidc/sub",
},
{
name: "newline",
iss: "\noidc/\n",
sub: "\nsub\n",
expected: "oidc/sub",
},
{
name: "tab",
iss: "\toidc/\t",
sub: "\tsub\t",
expected: "oidc/sub",
},
{
name: "empty issuer",
iss: "",
sub: "sub",
expected: "sub",
},
{
name: "empty subject",
iss: "https://oidc.example.com",
sub: "",
expected: "https://oidc.example.com",
},
{
name: "both empty",
iss: "",
sub: "",
expected: "",
},
{
name: "URL with double slash",
iss: "https://login.microsoftonline.com//v2.0",
sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
},
{
name: "FTP URL protocol",
iss: "ftp://example.com/directory",
sub: "resource",
expected: "ftp://example.com/directory/resource",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims := OIDCClaims{
Iss: tt.iss,
Sub: tt.sub,
}
result := claims.Identifier()
assert.Equal(t, tt.expected, result)
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("Identifier() mismatch (-want +got):\n%s", diff)
}

// Now clean the identifier and verify it's still the same
cleaned := CleanIdentifier(result)

// Double-check with cmp.Diff for better error messages
if diff := cmp.Diff(tt.expected, cleaned); diff != "" {
t.Errorf("CleanIdentifier(Identifier()) mismatch (-want +got):\n%s", diff)
}
})
}
}

func TestCleanIdentifier(t *testing.T) {
tests := []struct {
name string
identifier string
expected string
}{
{
name: "empty identifier",
identifier: "",
expected: "",
},
{
name: "simple identifier",
identifier: "oidc/sub",
expected: "oidc/sub",
},
{
name: "double slashes in the middle",
identifier: "oidc//sub",
expected: "oidc/sub",
},
{
name: "trailing slash",
identifier: "oidc/sub/",
expected: "oidc/sub",
},
{
name: "multiple double slashes",
identifier: "oidc//sub///id//",
expected: "oidc/sub/id",
},
{
name: "HTTP URL with proper scheme",
identifier: "http://example.com/path",
expected: "http://example.com/path",
},
{
name: "HTTP URL with double slashes in path",
identifier: "http://example.com//path///resource",
expected: "http://example.com/path/resource",
},
{
name: "HTTPS URL with empty segments",
identifier: "https://example.com///path//",
expected: "https://example.com/path",
},
{
name: "URL with double slashes in domain",
identifier: "https://login.microsoftonline.com//v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
},
{
name: "FTP URL with double slashes",
identifier: "ftp://example.com//resource//",
expected: "ftp://example.com/resource",
},
{
name: "Just slashes",
identifier: "///",
expected: "",
},
{
name: "Leading slash without URL",
identifier: "/path//to///resource",
expected: "path/to/resource",
},
{
name: "Non-standard protocol",
identifier: "ldap://example.org//path//to//resource",
expected: "ldap://example.org/path/to/resource",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := CleanIdentifier(tt.identifier)
assert.Equal(t, tt.expected, result)
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff)
}
})
}
}

func TestOIDCClaimsJSONToUser(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading