Skip to content

Commit 5cd15c3

Browse files
bobelevkradalby
authored andcommitted
fix: make state cookies valid when client uses multiple login URLs
On Windows, if the user clicks the Tailscale icon in the system tray, it opens a login URL in the browser. When the login URL is opened, `state/nonce` cookies are set for that particular URL. If the user clicks the icon again, a new login URL is opened in the browser, and new cookies are set. If the user proceeds with auth in the first tab, the redirect results in a "state did not match" error. This patch ensures that each opened login URL sets an individual cookie that remains valid on the `/oidc/callback` page. `TestOIDCMultipleOpenedLoginUrls` illustrates and tests this behavior.
1 parent 2024219 commit 5cd15c3

File tree

4 files changed

+283
-22
lines changed

4 files changed

+283
-22
lines changed

.github/workflows/test-integration.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ jobs:
3838
- TestOIDCAuthenticationWithPKCE
3939
- TestOIDCReloginSameNodeNewUser
4040
- TestOIDCFollowUpUrl
41+
- TestOIDCMultipleOpenedLoginUrls
4142
- TestOIDCReloginSameNodeSameUser
4243
- TestAuthWebFlowAuthenticationPingAll
4344
- TestAuthWebFlowLogoutAndReloginSameUser

hscontrol/oidc.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
213213
return
214214
}
215215

216-
cookieState, err := req.Cookie("state")
216+
stateCookieName := getCookieName("state", state)
217+
cookieState, err := req.Cookie(stateCookieName)
217218
if err != nil {
218219
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
219220
return
@@ -235,8 +236,13 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
235236
httpError(writer, err)
236237
return
237238
}
239+
if idToken.Nonce == "" {
240+
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
241+
return
242+
}
238243

239-
nonce, err := req.Cookie("nonce")
244+
nonceCookieName := getCookieName("nonce", idToken.Nonce)
245+
nonce, err := req.Cookie(nonceCookieName)
240246
if err != nil {
241247
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
242248
return
@@ -584,6 +590,11 @@ func renderOIDCCallbackTemplate(
584590
return &content, nil
585591
}
586592

593+
// getCookieName generates a unique cookie name based on a cookie value.
594+
func getCookieName(baseName, value string) string {
595+
return fmt.Sprintf("%s_%s", baseName, value[:6])
596+
}
597+
587598
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
588599
val, err := util.GenerateRandomStringURLSafe(64)
589600
if err != nil {
@@ -592,7 +603,7 @@ func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string,
592603

593604
c := &http.Cookie{
594605
Path: "/oidc/callback",
595-
Name: name,
606+
Name: getCookieName(name, val),
596607
Value: val,
597608
MaxAge: int(time.Hour.Seconds()),
598609
Secure: r.TLS != nil,

integration/auth_oidc_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,119 @@ func TestOIDCFollowUpUrl(t *testing.T) {
953953
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login")
954954
}
955955

956+
// TestOIDCMultipleOpenedLoginUrls tests the scenario:
957+
// - client (mostly Windows) opens multiple browser tabs with different login URLs
958+
// - client performs auth on the first opened browser tab
959+
//
960+
// This test makes sure that cookies are still valid for the first browser tab.
961+
func TestOIDCMultipleOpenedLoginUrls(t *testing.T) {
962+
IntegrationSkip(t)
963+
964+
scenario, err := NewScenario(
965+
ScenarioSpec{
966+
OIDCUsers: []mockoidc.MockUser{
967+
oidcMockUser("user1", true),
968+
},
969+
},
970+
)
971+
972+
require.NoError(t, err)
973+
defer scenario.ShutdownAssertNoPanics(t)
974+
975+
oidcMap := map[string]string{
976+
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
977+
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
978+
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
979+
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
980+
}
981+
982+
err = scenario.CreateHeadscaleEnvWithLoginURL(
983+
nil,
984+
hsic.WithTestName("oidcauthrelog"),
985+
hsic.WithConfigEnv(oidcMap),
986+
hsic.WithTLS(),
987+
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
988+
hsic.WithEmbeddedDERPServerOnly(),
989+
)
990+
require.NoError(t, err)
991+
992+
headscale, err := scenario.Headscale()
993+
require.NoError(t, err)
994+
995+
listUsers, err := headscale.ListUsers()
996+
require.NoError(t, err)
997+
assert.Empty(t, listUsers)
998+
999+
ts, err := scenario.CreateTailscaleNode(
1000+
"unstable",
1001+
tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]),
1002+
)
1003+
require.NoError(t, err)
1004+
1005+
u1, err := ts.LoginWithURL(headscale.GetEndpoint())
1006+
require.NoError(t, err)
1007+
1008+
u2, err := ts.LoginWithURL(headscale.GetEndpoint())
1009+
require.NoError(t, err)
1010+
1011+
// make sure login URLs are different
1012+
require.NotEqual(t, u1.String(), u2.String())
1013+
1014+
loginClient, err := newLoginHTTPClient(ts.Hostname())
1015+
require.NoError(t, err)
1016+
1017+
// open the first login URL "in browser"
1018+
_, redirect1, err := doLoginURLWithClient(ts.Hostname(), u1, loginClient, false)
1019+
require.NoError(t, err)
1020+
// open the second login URL "in browser"
1021+
_, redirect2, err := doLoginURLWithClient(ts.Hostname(), u2, loginClient, false)
1022+
require.NoError(t, err)
1023+
1024+
// two valid redirects with different state/nonce params
1025+
require.NotEqual(t, redirect1.String(), redirect2.String())
1026+
1027+
// complete auth with the first opened "browser tab"
1028+
_, redirect1, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true)
1029+
require.NoError(t, err)
1030+
1031+
listUsers, err = headscale.ListUsers()
1032+
require.NoError(t, err)
1033+
assert.Len(t, listUsers, 1)
1034+
1035+
wantUsers := []*v1.User{
1036+
{
1037+
Id: 1,
1038+
Name: "user1",
1039+
1040+
Provider: "oidc",
1041+
ProviderId: scenario.mockOIDC.Issuer() + "/user1",
1042+
},
1043+
}
1044+
1045+
sort.Slice(
1046+
listUsers, func(i, j int) bool {
1047+
return listUsers[i].GetId() < listUsers[j].GetId()
1048+
},
1049+
)
1050+
1051+
if diff := cmp.Diff(
1052+
wantUsers,
1053+
listUsers,
1054+
cmpopts.IgnoreUnexported(v1.User{}),
1055+
cmpopts.IgnoreFields(v1.User{}, "CreatedAt"),
1056+
); diff != "" {
1057+
t.Fatalf("unexpected users: %s", diff)
1058+
}
1059+
1060+
assert.EventuallyWithT(
1061+
t, func(c *assert.CollectT) {
1062+
listNodes, err := headscale.ListNodes()
1063+
assert.NoError(c, err)
1064+
assert.Len(c, listNodes, 1)
1065+
}, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login",
1066+
)
1067+
}
1068+
9561069
// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client
9571070
// authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user.
9581071
//

integration/scenario.go

Lines changed: 155 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -860,47 +860,183 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error {
860860
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
861861
}
862862

863-
// doLoginURL visits the given login URL and returns the body as a
864-
// string.
865-
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
866-
log.Printf("%s login url: %s\n", hostname, loginURL.String())
863+
type debugJar struct {
864+
inner *cookiejar.Jar
865+
mu sync.RWMutex
866+
store map[string]map[string]map[string]*http.Cookie // domain -> path -> name -> cookie
867+
}
868+
869+
func newDebugJar() (*debugJar, error) {
870+
jar, err := cookiejar.New(nil)
871+
if err != nil {
872+
return nil, err
873+
}
874+
return &debugJar{
875+
inner: jar,
876+
store: make(map[string]map[string]map[string]*http.Cookie),
877+
}, nil
878+
}
879+
880+
func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) {
881+
j.inner.SetCookies(u, cookies)
867882

868-
var err error
883+
j.mu.Lock()
884+
defer j.mu.Unlock()
885+
886+
for _, c := range cookies {
887+
if c == nil || c.Name == "" {
888+
continue
889+
}
890+
domain := c.Domain
891+
if domain == "" {
892+
domain = u.Hostname()
893+
}
894+
path := c.Path
895+
if path == "" {
896+
path = "/"
897+
}
898+
if _, ok := j.store[domain]; !ok {
899+
j.store[domain] = make(map[string]map[string]*http.Cookie)
900+
}
901+
if _, ok := j.store[domain][path]; !ok {
902+
j.store[domain][path] = make(map[string]*http.Cookie)
903+
}
904+
j.store[domain][path][c.Name] = copyCookie(c)
905+
}
906+
}
907+
908+
func (j *debugJar) Cookies(u *url.URL) []*http.Cookie {
909+
return j.inner.Cookies(u)
910+
}
911+
912+
func (j *debugJar) Dump(w io.Writer) {
913+
j.mu.RLock()
914+
defer j.mu.RUnlock()
915+
916+
for domain, paths := range j.store {
917+
fmt.Fprintf(w, "Domain: %s\n", domain)
918+
for path, byName := range paths {
919+
fmt.Fprintf(w, " Path: %s\n", path)
920+
for _, c := range byName {
921+
fmt.Fprintf(
922+
w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n",
923+
c.Name, c.Value, c.Expires, c.Secure, c.HttpOnly, c.SameSite,
924+
)
925+
}
926+
}
927+
}
928+
}
929+
930+
func copyCookie(c *http.Cookie) *http.Cookie {
931+
cc := *c
932+
return &cc
933+
}
934+
935+
func newLoginHTTPClient(hostname string) (*http.Client, error) {
869936
hc := &http.Client{
870937
Transport: LoggingRoundTripper{Hostname: hostname},
871938
}
872-
hc.Jar, err = cookiejar.New(nil)
939+
940+
jar, err := newDebugJar()
941+
if err != nil {
942+
return nil, fmt.Errorf("%s failed to create cookiejar: %w", hostname, err)
943+
}
944+
945+
hc.Jar = jar
946+
947+
return hc, nil
948+
}
949+
950+
// doLoginURL visits the given login URL and returns the body as a string.
951+
func doLoginURL(hostname string, loginURL *url.URL) (string, error) {
952+
log.Printf("%s login url: %s\n", hostname, loginURL.String())
953+
954+
hc, err := newLoginHTTPClient(hostname)
955+
if err != nil {
956+
return "", err
957+
}
958+
959+
body, _, err := doLoginURLWithClient(hostname, loginURL, hc, true)
873960
if err != nil {
874-
return "", fmt.Errorf("%s failed to create cookiejar : %w", hostname, err)
961+
return "", err
962+
}
963+
964+
return body, nil
965+
}
966+
967+
// doLoginURLWithClient performs the login request using the provided HTTP client.
968+
// When followRedirects is false, it will return the first redirect without following it.
969+
func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, followRedirects bool) (
970+
string,
971+
*url.URL,
972+
error,
973+
) {
974+
if hc == nil {
975+
return "", nil, fmt.Errorf("%s http client is nil", hostname)
976+
}
977+
978+
if loginURL == nil {
979+
return "", nil, fmt.Errorf("%s login url is nil", hostname)
875980
}
876981

877982
log.Printf("%s logging in with url: %s", hostname, loginURL.String())
878983
ctx := context.Background()
879-
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
984+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
985+
if err != nil {
986+
return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err)
987+
}
988+
989+
originalRedirect := hc.CheckRedirect
990+
if !followRedirects {
991+
hc.CheckRedirect = func(req *http.Request, via []*http.Request) error {
992+
return http.ErrUseLastResponse
993+
}
994+
}
995+
defer func() {
996+
hc.CheckRedirect = originalRedirect
997+
}()
998+
880999
resp, err := hc.Do(req)
8811000
if err != nil {
882-
return "", fmt.Errorf("%s failed to send http request: %w", hostname, err)
1001+
return "", nil, fmt.Errorf("%s failed to send http request: %w", hostname, err)
1002+
}
1003+
defer resp.Body.Close()
1004+
1005+
bodyBytes, err := io.ReadAll(resp.Body)
1006+
if err != nil {
1007+
return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err)
8831008
}
1009+
body := string(bodyBytes)
8841010

885-
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
1011+
var redirectURL *url.URL
1012+
if resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest {
1013+
redirectURL, err = resp.Location()
1014+
if err != nil {
1015+
return body, nil, fmt.Errorf("%s failed to resolve redirect location: %w", hostname, err)
1016+
}
1017+
}
8861018

887-
if resp.StatusCode != http.StatusOK {
888-
body, _ := io.ReadAll(resp.Body)
1019+
if followRedirects && resp.StatusCode != http.StatusOK {
8891020
log.Printf("body: %s", body)
8901021

891-
return "", fmt.Errorf("%s response code of login request was %w", hostname, err)
1022+
return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode)
8921023
}
8931024

894-
defer resp.Body.Close()
1025+
if resp.StatusCode >= http.StatusBadRequest {
1026+
log.Printf("body: %s", body)
8951027

896-
body, err := io.ReadAll(resp.Body)
897-
if err != nil {
898-
log.Printf("%s failed to read response body: %s", hostname, err)
1028+
return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode)
1029+
}
8991030

900-
return "", fmt.Errorf("%s failed to read response body: %w", hostname, err)
1031+
if hc.Jar != nil {
1032+
if jar, ok := hc.Jar.(*debugJar); ok {
1033+
jar.Dump(os.Stdout)
1034+
} else {
1035+
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
1036+
}
9011037
}
9021038

903-
return string(body), nil
1039+
return body, redirectURL, nil
9041040
}
9051041

9061042
var errParseAuthPage = errors.New("failed to parse auth page")

0 commit comments

Comments
 (0)