Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions homeassistant/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ class Platform(StrEnum):
# Contains one string or a list of strings, each being an floor id
ATTR_FLOOR_ID: Final = "floor_id"

# Contains one string or a list of strings, each being an label id
ATTR_LABEL_ID: Final = "label_id"

# String with a friendly name for the entity
ATTR_FRIENDLY_NAME: Final = "friendly_name"

Expand Down
7 changes: 7 additions & 0 deletions homeassistant/helpers/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
CONF_ABOVE,
CONF_ALIAS,
CONF_ATTRIBUTE,
Expand Down Expand Up @@ -1220,6 +1221,9 @@ def platform_only_config_schema(domain: str) -> Callable[[dict], dict]:
vol.Optional(ATTR_FLOOR_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
vol.Optional(ATTR_LABEL_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
}

TARGET_SERVICE_FIELDS = {
Expand All @@ -1240,6 +1244,9 @@ def platform_only_config_schema(domain: str) -> Callable[[dict], dict]:
vol.Optional(ATTR_FLOOR_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
vol.Optional(ATTR_LABEL_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
}


Expand Down
50 changes: 45 additions & 5 deletions homeassistant/helpers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_FLOOR_ID,
ATTR_LABEL_ID,
CONF_ENTITY_ID,
CONF_SERVICE,
CONF_SERVICE_DATA,
Expand Down Expand Up @@ -55,6 +56,7 @@
device_registry,
entity_registry,
floor_registry,
label_registry,
template,
translation,
)
Expand Down Expand Up @@ -196,7 +198,7 @@ class ServiceParams(TypedDict):
class ServiceTargetSelector:
"""Class to hold a target selector for a service."""

__slots__ = ("entity_ids", "device_ids", "area_ids", "floor_ids")
__slots__ = ("entity_ids", "device_ids", "area_ids", "floor_ids", "label_ids")

def __init__(self, service_call: ServiceCall) -> None:
"""Extract ids from service call data."""
Expand All @@ -205,6 +207,7 @@ def __init__(self, service_call: ServiceCall) -> None:
device_ids: str | list | None = service_call_data.get(ATTR_DEVICE_ID)
area_ids: str | list | None = service_call_data.get(ATTR_AREA_ID)
floor_ids: str | list | None = service_call_data.get(ATTR_FLOOR_ID)
label_ids: str | list | None = service_call_data.get(ATTR_LABEL_ID)

self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
Expand All @@ -216,12 +219,19 @@ def __init__(self, service_call: ServiceCall) -> None:
self.floor_ids = (
set(cv.ensure_list(floor_ids)) if _has_match(floor_ids) else set()
)
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)

@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(
self.entity_ids or self.device_ids or self.area_ids or self.floor_ids
self.entity_ids
or self.device_ids
or self.area_ids
or self.floor_ids
or self.label_ids
)


Expand All @@ -232,14 +242,15 @@ class SelectedEntities:
# Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set)

# Entities that were referenced via device/area/floor ID.
# Entities that were referenced via device/area/floor/label ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)

# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_floors: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)

# Referenced devices
referenced_devices: set[str] = dataclasses.field(default_factory=set)
Expand All @@ -253,6 +264,7 @@ def log_missing(self, missing_entities: set[str]) -> None:
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
Expand Down Expand Up @@ -467,7 +479,7 @@ def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:


@bind_hass
def async_extract_referenced_entity_ids(
def async_extract_referenced_entity_ids( # noqa: C901
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a service call."""
Expand All @@ -483,13 +495,19 @@ def async_extract_referenced_entity_ids(

selected.referenced.update(entity_ids)

if not selector.device_ids and not selector.area_ids and not selector.floor_ids:
if (
not selector.device_ids
and not selector.area_ids
and not selector.floor_ids
and not selector.label_ids
):
return selected

ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
floor_reg = floor_registry.async_get(hass)
label_reg = label_registry.async_get(hass)

for floor_id in selector.floor_ids:
if floor_id not in floor_reg.floors:
Expand All @@ -503,6 +521,28 @@ def async_extract_referenced_entity_ids(
if device_id not in dev_reg.devices:
selected.missing_devices.add(device_id)

for label_id in selector.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)

# Find areas, devices & entities for targeted labels
if selector.label_ids:
for area_entry in area_reg.areas.values():
if area_entry.labels.intersection(selector.label_ids):
selected.referenced_areas.add(area_entry.id)

for device_entry in dev_reg.devices.values():
if device_entry.labels.intersection(selector.label_ids):
selected.referenced_devices.add(device_entry.id)

for entity_entry in ent_reg.entities.values():
if (
entity_entry.entity_category is None
and entity_entry.hidden_by is None
and entity_entry.labels.intersection(selector.label_ids)
):
selected.indirectly_referenced.add(entity_entry.entity_id)

# Find areas for targeted floors
if selector.floor_ids:
for area_entry in area_reg.areas.values():
Expand Down
Loading