Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hscontrol/policy/v2/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ func TestParsing(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pol, err := policyFromBytes([]byte(tt.acl))
pol, err := unmarshalPolicy([]byte(tt.acl))
if tt.wantErr && err == nil {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)

Expand Down
4 changes: 2 additions & 2 deletions hscontrol/policy/v2/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type PolicyManager struct {
// It returns an error if the policy file is invalid.
// The policy manager will update the filter rules based on the users and nodes.
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
policy, err := policyFromBytes(b)
policy, err := unmarshalPolicy(b)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
}
Expand Down Expand Up @@ -131,7 +131,7 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
return false, nil
}

pol, err := policyFromBytes(polB)
pol, err := unmarshalPolicy(polB)
if err != nil {
return false, fmt.Errorf("parsing policy: %w", err)
}
Expand Down
56 changes: 50 additions & 6 deletions hscontrol/policy/v2/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ Please check the format and try again.`, vs)
type AliasEnc struct{ Alias }

func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
ptr, err := unmarshalPointer[Alias](
ptr, err := unmarshalPointer(
b,
parseAlias,
)
Expand Down Expand Up @@ -639,7 +639,7 @@ Please check the format and try again.`, s)
type AutoApproverEnc struct{ AutoApprover }

func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
ptr, err := unmarshalPointer[AutoApprover](
ptr, err := unmarshalPointer(
b,
parseAutoApprover,
)
Expand All @@ -659,7 +659,7 @@ type Owner interface {
type OwnerEnc struct{ Owner }

func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
ptr, err := unmarshalPointer[Owner](
ptr, err := unmarshalPointer(
b,
parseOwner,
)
Expand Down Expand Up @@ -769,6 +769,11 @@ func (h *Hosts) UnmarshalJSON(b []byte) error {
return nil
}

func (h Hosts) exist(name Host) bool {
_, ok := h[name]
return ok
}

// TagOwners are a map of Tag to a list of the UserEntities that own the tag.
type TagOwners map[Tag]Owners

Expand Down Expand Up @@ -902,6 +907,39 @@ type Policy struct {
SSHs []SSH `json:"ssh"`
}

// validate reports if there are any errors in a policy after
// the unmarshaling process.
// It runs through all rules and checks if there are any inconsistencies
// in the policy that needs to be addressed before it can be used.
func (p *Policy) validate() error {
if p == nil {
panic("passed nil policy")
}

// All errors are collected and presented to the user,
// when adding more validation, please add to the list of errors.
var errs []error

for _, acl := range p.ACLs {
for _, src := range acl.Sources {
switch src.(type) {
case *Host:
h := src.(*Host)
if !p.Hosts.exist(*h) {
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
}
}
}
}

if len(errs) > 0 {
return multierr.New(errs...)
}

p.validated = true
return nil
}

// SSH controls who can ssh into which machines.
type SSH struct {
Action string `json:"action"` // TODO(kradalby): add strict type
Expand Down Expand Up @@ -986,7 +1024,10 @@ func (u SSHUser) String() string {
return string(u)
}

func policyFromBytes(b []byte) (*Policy, error) {
// unmarshalPolicy takes a byte slice and unmarshals it into a Policy struct.
// In addition to unmarshalling, it will also validate the policy.
// This is the only entrypoint of reading a policy from a file or other source.
func unmarshalPolicy(b []byte) (*Policy, error) {
if b == nil || len(b) == 0 {
return nil, nil
}
Expand All @@ -1000,11 +1041,14 @@ func policyFromBytes(b []byte) (*Policy, error) {
ast.Standardize()
acl := ast.Pack()

err = json.Unmarshal(acl, &policy)
if err != nil {
if err = json.Unmarshal(acl, &policy); err != nil {
return nil, fmt.Errorf("parsing policy from bytes: %w", err)
}

if err := policy.validate(); err != nil {
return nil, err
}

return &policy, nil
}

Expand Down
61 changes: 60 additions & 1 deletion hscontrol/policy/v2/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,65 @@ func TestUnmarshalPolicy(t *testing.T) {
`,
wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet]`,
},
{
name: "undefined-hostname-errors-2490",
input: `
{
"acls": [
{
"action": "accept",
"src": [
"user1"
],
"dst": [
"user1:*"
]
}
]
}
`,
wantErr: `Host "user1" is not defined in the Policy, please define or remove the reference to it`,
},
{
name: "defined-hostname-does-not-err-2490",
input: `
{
"hosts": {
"user1": "100.100.100.100",
},
"acls": [
{
"action": "accept",
"src": [
"user1"
],
"dst": [
"user1:*"
]
}
]
}
`,
want: &Policy{
Hosts: Hosts{
"user1": Prefix(mp("100.100.100.100/32")),
},
ACLs: []ACL{
{
Action: "accept",
Sources: Aliases{
hp("user1"),
},
Destinations: []AliasWithPorts{
{
Alias: hp("user1"),
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
},
},
},
}

cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool {
Expand All @@ -370,7 +429,7 @@ func TestUnmarshalPolicy(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policy, err := policyFromBytes([]byte(tt.input))
policy, err := unmarshalPolicy([]byte(tt.input))
if tt.wantErr == "" {
if err != nil {
t.Fatalf("got %v; want no error", err)
Expand Down
Loading