Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
- TestOIDCAuthenticationWithPKCE
- TestOIDCReloginSameNodeNewUser
- TestOIDCFollowUpUrl
- TestOIDCMultipleOpenedLoginUrls
- TestOIDCReloginSameNodeSameUser
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndReloginSameUser
Expand Down
17 changes: 14 additions & 3 deletions hscontrol/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}

cookieState, err := req.Cookie("state")
stateCookieName := getCookieName("state", state)
cookieState, err := req.Cookie(stateCookieName)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
return
Expand All @@ -235,8 +236,13 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
httpError(writer, err)
return
}
if idToken.Nonce == "" {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
return
}

nonce, err := req.Cookie("nonce")
nonceCookieName := getCookieName("nonce", idToken.Nonce)
nonce, err := req.Cookie(nonceCookieName)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
return
Expand Down Expand Up @@ -584,6 +590,11 @@ func renderOIDCCallbackTemplate(
return &content, nil
}

// getCookieName generates a unique cookie name based on a cookie value.
func getCookieName(baseName, value string) string {
return fmt.Sprintf("%s_%s", baseName, value[:6])
}

func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
val, err := util.GenerateRandomStringURLSafe(64)
if err != nil {
Expand All @@ -592,7 +603,7 @@ func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string,

c := &http.Cookie{
Path: "/oidc/callback",
Name: name,
Name: getCookieName(name, val),
Value: val,
MaxAge: int(time.Hour.Seconds()),
Secure: r.TLS != nil,
Expand Down
113 changes: 113 additions & 0 deletions integration/auth_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,119 @@ func TestOIDCFollowUpUrl(t *testing.T) {
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login")
}

// TestOIDCMultipleOpenedLoginUrls tests the scenario:
// - client (mostly Windows) opens multiple browser tabs with different login URLs
// - client performs auth on the first opened browser tab
//
// This test makes sure that cookies are still valid for the first browser tab.
func TestOIDCMultipleOpenedLoginUrls(t *testing.T) {
IntegrationSkip(t)

scenario, err := NewScenario(
ScenarioSpec{
OIDCUsers: []mockoidc.MockUser{
oidcMockUser("user1", true),
},
},
)

require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)

oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
}

err = scenario.CreateHeadscaleEnvWithLoginURL(
nil,
hsic.WithTestName("oidcauthrelog"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
hsic.WithEmbeddedDERPServerOnly(),
)
require.NoError(t, err)

headscale, err := scenario.Headscale()
require.NoError(t, err)

listUsers, err := headscale.ListUsers()
require.NoError(t, err)
assert.Empty(t, listUsers)

ts, err := scenario.CreateTailscaleNode(
"unstable",
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
)
require.NoError(t, err)

u1, err := ts.LoginWithURL(headscale.GetEndpoint())
require.NoError(t, err)

u2, err := ts.LoginWithURL(headscale.GetEndpoint())
require.NoError(t, err)

// make sure login URLs are different
require.NotEqual(t, u1.String(), u2.String())

loginClient, err := newLoginHTTPClient(ts.Hostname())
require.NoError(t, err)

// open the first login URL "in browser"
_, redirect1, err := doLoginURLWithClient(ts.Hostname(), u1, loginClient, false)
require.NoError(t, err)
// open the second login URL "in browser"
_, redirect2, err := doLoginURLWithClient(ts.Hostname(), u2, loginClient, false)
require.NoError(t, err)

// two valid redirects with different state/nonce params
require.NotEqual(t, redirect1.String(), redirect2.String())

// complete auth with the first opened "browser tab"
_, redirect1, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true)
require.NoError(t, err)

listUsers, err = headscale.ListUsers()
require.NoError(t, err)
assert.Len(t, listUsers, 1)

wantUsers := []*v1.User{
{
Id: 1,
Name: "user1",
Email: "[email protected]",
Provider: "oidc",
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
},
}

sort.Slice(
listUsers, func(i, j int) bool {
return listUsers[i].GetId() < listUsers[j].GetId()
},
)

if diff := cmp.Diff(
wantUsers,
listUsers,
cmpopts.IgnoreUnexported(v1.User{}),
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
); diff != "" {
t.Fatalf("unexpected users: %s", diff)
}

assert.EventuallyWithT(
t, func(c *assert.CollectT) {
listNodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, 1)
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login",
)
}

// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client
// authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user.
//
Expand Down
174 changes: 155 additions & 19 deletions integration/scenario.go
Original file line number Diff line number Diff line change
Expand Up @@ -860,47 +860,183 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error {
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}

// doLoginURL visits the given login URL and returns the body as a
// string.
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
log.Printf("%s login url: %s\n", hostname, loginURL.String())
type debugJar struct {
inner *cookiejar.Jar
mu sync.RWMutex
store map[string]map[string]map[string]*http.Cookie // domain -> path -> name -> cookie
}

func newDebugJar() (*debugJar, error) {
jar, err := cookiejar.New(nil)
if err != nil {
return nil, err
}
return &debugJar{
inner: jar,
store: make(map[string]map[string]map[string]*http.Cookie),
}, nil
}

func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) {
j.inner.SetCookies(u, cookies)

var err error
j.mu.Lock()
defer j.mu.Unlock()

for _, c := range cookies {
if c == nil || c.Name == "" {
continue
}
domain := c.Domain
if domain == "" {
domain = u.Hostname()
}
path := c.Path
if path == "" {
path = "/"
}
if _, ok := j.store[domain]; !ok {
j.store[domain] = make(map[string]map[string]*http.Cookie)
}
if _, ok := j.store[domain][path]; !ok {
j.store[domain][path] = make(map[string]*http.Cookie)
}
j.store[domain][path][c.Name] = copyCookie(c)
}
}

func (j *debugJar) Cookies(u *url.URL) []*http.Cookie {
return j.inner.Cookies(u)
}

func (j *debugJar) Dump(w io.Writer) {
j.mu.RLock()
defer j.mu.RUnlock()

for domain, paths := range j.store {
fmt.Fprintf(w, "Domain: %s\n", domain)
for path, byName := range paths {
fmt.Fprintf(w, " Path: %s\n", path)
for _, c := range byName {
fmt.Fprintf(
w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n",
c.Name, c.Value, c.Expires, c.Secure, c.HttpOnly, c.SameSite,
)
}
}
}
}

func copyCookie(c *http.Cookie) *http.Cookie {
cc := *c
return &cc
}

func newLoginHTTPClient(hostname string) (*http.Client, error) {
hc := &http.Client{
Transport: LoggingRoundTripper{Hostname: hostname},
}
hc.Jar, err = cookiejar.New(nil)

jar, err := newDebugJar()
if err != nil {
return nil, fmt.Errorf("%s failed to create cookiejar: %w", hostname, err)
}

hc.Jar = jar

return hc, nil
}

// doLoginURL visits the given login URL and returns the body as a string.
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
log.Printf("%s login url: %s\n", hostname, loginURL.String())

hc, err := newLoginHTTPClient(hostname)
if err != nil {
return "", err
}

body, _, err := doLoginURLWithClient(hostname, loginURL, hc, true)
if err != nil {
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
return "", err
}

return body, nil
}

// doLoginURLWithClient performs the login request using the provided HTTP client.
// When followRedirects is false, it will return the first redirect without following it.
func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, followRedirects bool) (
string,
*url.URL,
error,
) {
if hc == nil {
return "", nil, fmt.Errorf("%s http client is nil", hostname)
}

if loginURL == nil {
return "", nil, fmt.Errorf("%s login url is nil", hostname)
}

log.Printf("%s logging in with url: %s", hostname, loginURL.String())
ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
if err != nil {
return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err)
}

originalRedirect := hc.CheckRedirect
if !followRedirects {
hc.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
}
defer func() {
hc.CheckRedirect = originalRedirect
}()

resp, err := hc.Do(req)
if err != nil {
return "", fmt.Errorf("%s failed to send http request: %w", hostname, err)
return "", nil, fmt.Errorf("%s failed to send http request: %w", hostname, err)
}
defer resp.Body.Close()

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err)
}
body := string(bodyBytes)

log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
var redirectURL *url.URL
if resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest {
redirectURL, err = resp.Location()
if err != nil {
return body, nil, fmt.Errorf("%s failed to resolve redirect location: %w", hostname, err)
}
}

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if followRedirects && resp.StatusCode != http.StatusOK {
log.Printf("body: %s", body)

return "", fmt.Errorf("%s response code of login request was %w", hostname, err)
return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode)
}

defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
log.Printf("body: %s", body)

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("%s failed to read response body: %s", hostname, err)
return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode)
}

return "", fmt.Errorf("%s failed to read response body: %w", hostname, err)
if hc.Jar != nil {
if jar, ok := hc.Jar.(*debugJar); ok {
jar.Dump(os.Stdout)
} else {
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
}
}

return string(body), nil
return body, redirectURL, nil
}

var errParseAuthPage = errors.New("failed to parse auth page")
Expand Down
Loading