Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 8 additions & 20 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,23 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return
}

db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
c.String(http.StatusInternalServerError, ":(")
return
}

var m Machine
if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Println("New Machine!")
m = Machine{
Expiry: &req.Expiry,
MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(),
}
if err := db.Create(&m).Error; err != nil {
if err := h.db.Create(&m).Error; err != nil {
log.Printf("Could not create row: %s", err)
return
}
}

if !m.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(c, db, mKey, req, m)
h.handleAuthKey(c, h.db, mKey, req, m)
return
}

Expand Down Expand Up @@ -138,7 +131,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() {
log.Printf("[%s] We have the OldNodeKey in the database. This is a key refresh", m.Name)
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
db.Save(&m)
h.db.Save(&m)

resp.AuthURL = ""
resp.User = *m.Namespace.toUser()
Expand Down Expand Up @@ -204,13 +197,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
return
}

db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return
}
var m Machine
if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Printf("Ignoring request, cannot find machine with key %s", mKey.HexString())
return
}
Expand All @@ -234,7 +222,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
}
db.Save(&m)
h.db.Save(&m)

pollData := make(chan []byte, 1)
update := make(chan []byte, 1)
Expand Down Expand Up @@ -303,7 +291,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
}
now := time.Now().UTC()
m.LastSeen = &now
db.Save(&m)
h.db.Save(&m)
return true

case <-update:
Expand All @@ -322,7 +310,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Printf("[%s] The client has closed the connection", m.Name)
now := time.Now().UTC()
m.LastSeen = &now
db.Save(&m)
h.db.Save(&m)
h.pollMu.Lock()
cancelKeepAlive <- []byte{}
delete(h.clientsPolling, m.ID)
Expand Down
11 changes: 4 additions & 7 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/gin-gonic/gin"
"golang.org/x/crypto/acme/autocert"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
Expand Down Expand Up @@ -43,6 +44,7 @@ type Config struct {
// Headscale represents the base app of the service
type Headscale struct {
cfg Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
Expand Down Expand Up @@ -87,6 +89,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
if err != nil {
return nil, err
}

h.clientsPolling = make(map[uint64]chan []byte)
return &h, nil
}
Expand All @@ -107,12 +110,6 @@ func (h *Headscale) ExpireEphemeralNodes(milliSeconds int64) {
}

func (h *Headscale) expireEphemeralNodesWorker() {
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return
}

namespaces, err := h.ListNamespaces()
if err != nil {
log.Printf("Error listing namespaces: %s", err)
Expand All @@ -127,7 +124,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
for _, m := range *machines {
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
log.Printf("[%s] Ephemeral client removed from database\n", m.Name)
err = db.Unscoped().Delete(m).Error
err = h.db.Unscoped().Delete(m).Error
if err != nil {
log.Printf("[%s] 🤮 Cannot delete ephemeral machine from the database: %s", m.Name, err)
}
Expand Down
5 changes: 5 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,9 @@ func (s *Suite) ResetDB(c *check.C) {
if err != nil {
c.Fatal(err)
}
db, err := h.openDB()
if err != nil {
c.Fatal(err)
}
h.db = db
}
11 changes: 3 additions & 8 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package headscale

import (
"errors"
"log"

"gorm.io/gorm"
"tailscale.com/types/wgkey"
Expand All @@ -18,13 +17,9 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
if err != nil {
return nil, err
}
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return nil, err
}

m := Machine{}
if result := db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errors.New("Machine not found")
}

Expand All @@ -40,6 +35,6 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "cli"
db.Save(&m)
h.db.Save(&m)
return &m, nil
}
7 changes: 1 addition & 6 deletions cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
n, err := h.CreateNamespace("test")
c.Assert(err, check.IsNil)

db, err := h.db()
if err != nil {
c.Fatal(err)
}

m := Machine{
ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
Expand All @@ -21,7 +16,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
Name: "testmachine",
NamespaceID: n.ID,
}
db.Save(&m)
h.db.Save(&m)

_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
Expand Down
23 changes: 9 additions & 14 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ type KV struct {
}

func (h *Headscale) initDB() error {
db, err := h.db()
db, err := h.openDB()
if err != nil {
return err
}
h.db = db

if h.dbType == "postgres" {
db.Exec("create extension if not exists \"uuid-ossp\";")
}
Expand All @@ -45,7 +47,7 @@ func (h *Headscale) initDB() error {
return err
}

func (h *Headscale) db() (*gorm.DB, error) {
func (h *Headscale) openDB() (*gorm.DB, error) {
var db *gorm.DB
var err error
switch h.dbType {
Expand All @@ -69,12 +71,8 @@ func (h *Headscale) db() (*gorm.DB, error) {
}

func (h *Headscale) getValue(key string) (string, error) {
db, err := h.db()
if err != nil {
return "", err
}
var row KV
if result := db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", errors.New("not found")
}
return row.Value, nil
Expand All @@ -85,16 +83,13 @@ func (h *Headscale) setValue(key string, value string) error {
Key: key,
Value: value,
}
db, err := h.db()
if err != nil {
return err
}
_, err = h.getValue(key)

_, err := h.getValue(key)
if err == nil {
db.Model(&kv).Where("key = ?", key).Update("value", value)
h.db.Model(&kv).Where("key = ?", key).Update("value", value)
return nil
}

db.Create(kv)
h.db.Create(kv)
return nil
}
7 changes: 1 addition & 6 deletions machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,9 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
}

func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return nil, err
}

machines := []Machine{}
if err = db.Where("namespace_id = ? AND machine_key <> ? AND registered",
if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
log.Printf("Error accessing db: %s", err)
return nil, err
Expand Down
7 changes: 1 addition & 6 deletions machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ func (s *Suite) TestGetMachine(c *check.C) {
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)

db, err := h.db()
if err != nil {
c.Fatal(err)
}

_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)

Expand All @@ -30,7 +25,7 @@ func (s *Suite) TestGetMachine(c *check.C) {
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
db.Save(&m)
h.db.Save(&m)

m1, err := h.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
Expand Down
47 changes: 7 additions & 40 deletions namespaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,12 @@ type Namespace struct {
// CreateNamespace creates a new Namespace. Returns error if could not be created
// or another namespace already exists
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return nil, err
}

n := Namespace{}
if err := db.Where("name = ?", name).First(&n).Error; err == nil {
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
return nil, errorNamespaceExists
}
n.Name = name
if err := db.Create(&n).Error; err != nil {
if err := h.db.Create(&n).Error; err != nil {
log.Printf("Could not create row: %s", err)
return nil, err
}
Expand All @@ -46,12 +40,6 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
// not exist or if there are machines associated with it.
func (h *Headscale) DestroyNamespace(name string) error {
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return err
}

n, err := h.GetNamespace(name)
if err != nil {
return errorNamespaceNotFound
Expand All @@ -65,7 +53,7 @@ func (h *Headscale) DestroyNamespace(name string) error {
return errorNamespaceNotEmpty
}

if result := db.Unscoped().Delete(&n); result.Error != nil {
if result := h.db.Unscoped().Delete(&n); result.Error != nil {
return err
}

Expand All @@ -74,28 +62,17 @@ func (h *Headscale) DestroyNamespace(name string) error {

// GetNamespace fetches a namespace by name
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return nil, err
}

n := Namespace{}
if result := db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) {
if result := h.db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errorNamespaceNotFound
}
return &n, nil
}

// ListNamespaces gets all the existing namespaces
func (h *Headscale) ListNamespaces() (*[]Namespace, error) {
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return nil, err
}
namespaces := []Namespace{}
if err := db.Find(&namespaces).Error; err != nil {
if err := h.db.Find(&namespaces).Error; err != nil {
return nil, err
}
return &namespaces, nil
Expand All @@ -107,14 +84,9 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) {
if err != nil {
return nil, err
}
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return nil, err
}

machines := []Machine{}
if err := db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
if err := h.db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
return nil, err
}
return &machines, nil
Expand All @@ -126,13 +98,8 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error
if err != nil {
return err
}
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
return err
}
m.NamespaceID = n.ID
db.Save(&m)
h.db.Save(&m)
return nil
}

Expand Down
Loading