Skip to content

Commit ddd31ba

Browse files
committed
hscontrol: use Updates() instead of Save() for partial updates
Changed UpdateUser and re-registration flows to use Updates() which only writes modified fields, preventing unintended overwrites of unchanged fields. Also updated UsePreAuthKey to use Model().Update() for single field updates and removed unused NodeSave wrapper.
1 parent 4a8dc2d commit ddd31ba

File tree

4 files changed

+146
-14
lines changed

4 files changed

+146
-14
lines changed

hscontrol/db/node.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,6 @@ func NodeSetMachineKey(
452452
}).Error
453453
}
454454

455-
// NodeSave saves a node object to the database, prefer to use a specific save method rather
456-
// than this. It is intended to be used when we are changing or.
457-
// TODO(kradalby): Remove this func, just use Save.
458-
func NodeSave(tx *gorm.DB, node *types.Node) error {
459-
return tx.Save(node).Error
460-
}
461-
462455
func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
463456
// Strip invalid DNS characters for givenName
464457
suppliedName = strings.ToLower(suppliedName)

hscontrol/db/preauth_keys.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,12 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
145145

146146
// UsePreAuthKey marks a PreAuthKey as used.
147147
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
148-
k.Used = true
149-
if err := tx.Save(k).Error; err != nil {
148+
err := tx.Model(k).Update("used", true).Error
149+
if err != nil {
150150
return fmt.Errorf("failed to update key used status in the database: %w", err)
151151
}
152152

153+
k.Used = true
153154
return nil
154155
}
155156

hscontrol/db/user_update_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package db
2+
3+
import (
4+
"database/sql"
5+
"testing"
6+
7+
"github.com/juanfont/headscale/hscontrol/types"
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
"gorm.io/gorm"
11+
)
12+
13+
// TestUserUpdatePreservesUnchangedFields verifies that updating a user
14+
// preserves fields that aren't modified. This test validates the fix
15+
// for using Updates() instead of Save() in UpdateUser-like operations.
16+
func TestUserUpdatePreservesUnchangedFields(t *testing.T) {
17+
database := dbForTest(t)
18+
19+
// Create a user with all fields set
20+
initialUser := types.User{
21+
Name: "testuser",
22+
DisplayName: "Test User Display",
23+
24+
ProviderIdentifier: sql.NullString{
25+
String: "provider-123",
26+
Valid: true,
27+
},
28+
}
29+
30+
createdUser, err := database.CreateUser(initialUser)
31+
require.NoError(t, err)
32+
require.NotNil(t, createdUser)
33+
34+
// Verify initial state
35+
assert.Equal(t, "testuser", createdUser.Name)
36+
assert.Equal(t, "Test User Display", createdUser.DisplayName)
37+
assert.Equal(t, "[email protected]", createdUser.Email)
38+
assert.True(t, createdUser.ProviderIdentifier.Valid)
39+
assert.Equal(t, "provider-123", createdUser.ProviderIdentifier.String)
40+
41+
// Simulate what UpdateUser does: load user, modify one field, save
42+
_, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) {
43+
user, err := GetUserByID(tx, types.UserID(createdUser.ID))
44+
if err != nil {
45+
return nil, err
46+
}
47+
48+
// Modify ONLY DisplayName
49+
user.DisplayName = "Updated Display Name"
50+
51+
// This is the line being tested - currently uses Save() which writes ALL fields, potentially overwriting unchanged ones
52+
err = tx.Save(user).Error
53+
if err != nil {
54+
return nil, err
55+
}
56+
57+
return user, nil
58+
})
59+
require.NoError(t, err)
60+
61+
// Read user back from database
62+
updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) {
63+
return GetUserByID(rx, types.UserID(createdUser.ID))
64+
})
65+
require.NoError(t, err)
66+
67+
// Verify that DisplayName was updated
68+
assert.Equal(t, "Updated Display Name", updatedUser.DisplayName)
69+
70+
// CRITICAL: Verify that other fields were NOT overwritten
71+
// With Save(), these assertions should pass because the user object
72+
// was loaded from DB and has all fields populated.
73+
// But if Updates() is used, these will also pass (and it's safer).
74+
assert.Equal(t, "testuser", updatedUser.Name, "Name should be preserved")
75+
assert.Equal(t, "[email protected]", updatedUser.Email, "Email should be preserved")
76+
assert.True(t, updatedUser.ProviderIdentifier.Valid, "ProviderIdentifier should be preserved")
77+
assert.Equal(t, "provider-123", updatedUser.ProviderIdentifier.String, "ProviderIdentifier value should be preserved")
78+
}
79+
80+
// TestUserUpdateWithUpdatesMethod tests that using Updates() instead of Save()
81+
// works correctly and only updates modified fields.
82+
func TestUserUpdateWithUpdatesMethod(t *testing.T) {
83+
database := dbForTest(t)
84+
85+
// Create a user
86+
initialUser := types.User{
87+
Name: "testuser",
88+
DisplayName: "Original Display",
89+
90+
ProviderIdentifier: sql.NullString{
91+
String: "provider-abc",
92+
Valid: true,
93+
},
94+
}
95+
96+
createdUser, err := database.CreateUser(initialUser)
97+
require.NoError(t, err)
98+
99+
// Update using Updates() method
100+
_, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) {
101+
user, err := GetUserByID(tx, types.UserID(createdUser.ID))
102+
if err != nil {
103+
return nil, err
104+
}
105+
106+
// Modify multiple fields
107+
user.DisplayName = "New Display"
108+
user.Email = "[email protected]"
109+
110+
// Use Updates() instead of Save()
111+
err = tx.Updates(user).Error
112+
if err != nil {
113+
return nil, err
114+
}
115+
116+
return user, nil
117+
})
118+
require.NoError(t, err)
119+
120+
// Verify changes
121+
updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) {
122+
return GetUserByID(rx, types.UserID(createdUser.ID))
123+
})
124+
require.NoError(t, err)
125+
126+
// Verify updated fields
127+
assert.Equal(t, "New Display", updatedUser.DisplayName)
128+
assert.Equal(t, "[email protected]", updatedUser.Email)
129+
130+
// Verify preserved fields
131+
assert.Equal(t, "testuser", updatedUser.Name)
132+
assert.True(t, updatedUser.ProviderIdentifier.Valid)
133+
assert.Equal(t, "provider-abc", updatedUser.ProviderIdentifier.String)
134+
}

hscontrol/state/state.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,9 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error
300300
return nil, err
301301
}
302302

303-
if err := tx.Save(user).Error; err != nil {
303+
// Use Updates() to only update modified fields, preserving unchanged values.
304+
err = tx.Updates(user).Error
305+
if err != nil {
304306
return nil, fmt.Errorf("updating user: %w", err)
305307
}
306308

@@ -1191,9 +1193,10 @@ func (s *State) HandleNodeFromAuthPath(
11911193
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID())
11921194
}
11931195

1194-
// Use the node from UpdateNode to save to database
11951196
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
1196-
if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil {
1197+
// Use Updates() to preserve fields not modified by UpdateNode.
1198+
err := tx.Updates(updatedNodeView.AsStruct()).Error
1199+
if err != nil {
11971200
return nil, fmt.Errorf("failed to save node: %w", err)
11981201
}
11991202
return nil, nil
@@ -1410,9 +1413,10 @@ func (s *State) HandleNodeFromPreAuthKey(
14101413
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID())
14111414
}
14121415

1413-
// Use the node from UpdateNode to save to database
14141416
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
1415-
if err := tx.Save(updatedNodeView.AsStruct()).Error; err != nil {
1417+
// Use Updates() to preserve fields not modified by UpdateNode.
1418+
err := tx.Updates(updatedNodeView.AsStruct()).Error
1419+
if err != nil {
14161420
return nil, fmt.Errorf("failed to save node: %w", err)
14171421
}
14181422

0 commit comments

Comments
 (0)