Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 154 additions & 31 deletions packages/stores-macro/src/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn derive_store_struct(

let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();
let (extension_impl_generics, extension_generics, extension_where_clause) =
let (extension_impl_generics, extension_ty_generics, extension_where_clause) =
extension_generics.split_for_impl();

// We collect the definitions and implementations for the extension trait methods along with the types of the fields in the transposed struct
Expand All @@ -85,7 +85,7 @@ fn derive_store_struct(
let definition = quote! {
fn transpose(
self,
) -> #transposed_name #extension_generics where Self: ::std::marker::Copy;
) -> #transposed_name #extension_ty_generics where Self: ::std::marker::Copy;
};
definitions.push(definition);
let field_names = fields
Expand All @@ -108,7 +108,7 @@ fn derive_store_struct(
let implementation = quote! {
fn transpose(
self,
) -> #transposed_name #extension_generics where Self: ::std::marker::Copy {
) -> #transposed_name #extension_ty_generics where Self: ::std::marker::Copy {
// Convert each field into the corresponding store
#(
let #field_names = self.#field_names();
Expand All @@ -119,17 +119,15 @@ fn derive_store_struct(
implementations.push(implementation);

// Generate the transposed struct definition
let transposed_struct = match &structure.fields {
Fields::Named(_) => {
quote! { #visibility struct #transposed_name #extension_impl_generics #extension_where_clause {#(#transposed_fields),*} }
}
Fields::Unnamed(_) => {
quote! { #visibility struct #transposed_name #extension_impl_generics (#(#transposed_fields),*) #extension_where_clause; }
}
Fields::Unit => {
quote! {#visibility struct #transposed_name #extension_impl_generics #extension_where_clause;}
}
};
let transposed_struct = transposed_struct(
visibility,
struct_name,
&transposed_name,
structure,
generics,
&extension_generics,
&transposed_fields,
);

// Expand to the extension trait and its implementation for the store alongside the transposed struct
Ok(quote! {
Expand All @@ -139,7 +137,7 @@ fn derive_store_struct(
)*
}

impl #extension_impl_generics #extension_trait_name #extension_generics for dioxus_stores::Store<#struct_name #ty_generics, __Lens> #extension_where_clause {
impl #extension_impl_generics #extension_trait_name #extension_ty_generics for dioxus_stores::Store<#struct_name #ty_generics, __Lens> #extension_where_clause {
#(
#implementations
)*
Expand All @@ -149,6 +147,72 @@ fn derive_store_struct(
})
}

fn field_type_generic(field: &Field, generics: &syn::Generics) -> bool {
generics.type_params().any(|param| {
matches!(&field.ty, syn::Type::Path(type_path) if type_path.path.is_ident(&param.ident))
})
}

fn transposed_struct(
visibility: &syn::Visibility,
struct_name: &Ident,
transposed_name: &Ident,
structure: &DataStruct,
generics: &syn::Generics,
extension_generics: &syn::Generics,
transposed_fields: &[TokenStream2],
) -> TokenStream2 {
let (extension_impl_generics, _, extension_where_clause) = extension_generics.split_for_impl();
// Only use a type alias if:
// - There are no bounds on the type generics
// - All fields are generic types
let use_type_alias = generics.type_params().all(|param| param.bounds.is_empty())
&& structure
.fields
.iter()
.all(|field| field_type_generic(field, generics));
if use_type_alias {
let generics = transpose_generics(struct_name, generics);
return quote! {#visibility type #transposed_name #extension_impl_generics = #struct_name #generics;};
}
match &structure.fields {
Fields::Named(fields) => {
let fields = fields.named.iter();
let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| {
let vis = &f.vis;
let ident = &f.ident;
let colon = f.colon_token.as_ref();
quote! { #vis #ident #colon #t }
});
quote! {
#visibility struct #transposed_name #extension_impl_generics #extension_where_clause {
#(
#fields
),*
}
}
}
Fields::Unnamed(fields) => {
let fields = fields.unnamed.iter();
let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| {
let vis = &f.vis;
quote! { #vis #t }
});
quote! {
#visibility struct #transposed_name #extension_impl_generics (
#(
#fields
),*
)
#extension_where_clause;
}
}
Fields::Unit => {
quote! {#visibility struct #transposed_name #extension_impl_generics #extension_where_clause}
}
}
}

fn generate_field_methods(
field_index: usize,
field: &syn::Field,
Expand All @@ -158,9 +222,7 @@ fn generate_field_methods(
definitions: &mut Vec<TokenStream2>,
implementations: &mut Vec<TokenStream2>,
) {
let vis = &field.vis;
let field_name = &field.ident;
let colon = field.colon_token.as_ref();

// When we map the field, we need to use either the field name for named fields or the index for unnamed fields.
let field_accessor = field_name.as_ref().map_or_else(
Expand All @@ -171,7 +233,7 @@ fn generate_field_methods(
let field_type = &field.ty;
let store_type = mapped_type(struct_name, ty_generics, field_type);

transposed_fields.push(quote! { #vis #field_name #colon #store_type });
transposed_fields.push(store_type.clone());

// Each field gets its own reactive scope within the child based on the field's index
let ordinal = LitInt::new(&field_index.to_string(), field.span());
Expand Down Expand Up @@ -218,7 +280,7 @@ fn derive_store_enum(

let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();
let (extension_impl_generics, extension_generics, extension_where_clause) =
let (extension_impl_generics, extension_ty_generics, extension_where_clause) =
extension_generics.split_for_impl();

// We collect the definitions and implementations for the extension trait methods along with the types of the fields in the transposed enum
Expand Down Expand Up @@ -249,14 +311,11 @@ fn derive_store_enum(
let mut transposed_field_selectors = Vec::new();
let fields = &variant.fields;
for (i, field) in fields.iter().enumerate() {
let vis = &field.vis;
let field_name = &field.ident;
let colon = field.colon_token.as_ref();
let field_type = &field.ty;
let store_type = mapped_type(enum_name, &ty_generics, field_type);

// Push the field for the transposed enum
transposed_fields.push(quote! { #vis #field_name #colon #store_type });
transposed_fields.push(store_type.clone());

// Generate the code to get Store<Field, W> from the enum
let select_field = select_enum_variant_field(
Expand Down Expand Up @@ -321,11 +380,31 @@ fn derive_store_enum(

// Push the type definition of the variant to the transposed enum
let transposed_variant = match &fields {
Fields::Named(_) => {
quote! { #variant_name {#(#transposed_fields),*} }
Fields::Named(named) => {
let fields = named.named.iter();
let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| {
let vis = &f.vis;
let ident = &f.ident;
let colon = f.colon_token.as_ref();
quote! { #vis #ident #colon #t }
});
quote! { #variant_name {
#(
#fields
),*
} }
}
Fields::Unnamed(_) => {
quote! { #variant_name (#(#transposed_fields),*) }
Fields::Unnamed(unnamed) => {
let fields = unnamed.unnamed.iter();
let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| {
let vis = &f.vis;
quote! { #vis #t }
});
quote! { #variant_name (
#(
#fields
),*
) }
}
Fields::Unit => {
quote! { #variant_name }
Expand All @@ -337,13 +416,13 @@ fn derive_store_enum(
let definition = quote! {
fn transpose(
self,
) -> #transposed_name #extension_generics where #readable_bounds, Self: ::std::marker::Copy;
) -> #transposed_name #extension_ty_generics where #readable_bounds, Self: ::std::marker::Copy;
};
definitions.push(definition);
let implementation = quote! {
fn transpose(
self,
) -> #transposed_name #extension_generics where #readable_bounds, Self: ::std::marker::Copy {
) -> #transposed_name #extension_ty_generics where #readable_bounds, Self: ::std::marker::Copy {
// We only do a shallow read of the store to get the current variant. We only need to rerun
// this match when the variant changes, not when the fields change
self.selector().track_shallow();
Expand All @@ -358,7 +437,23 @@ fn derive_store_enum(
};
implementations.push(implementation);

let transposed_enum = quote! { #visibility enum #transposed_name #extension_impl_generics #extension_where_clause {#(#transposed_variants),*} };
// Only use a type alias if:
// - There are no bounds on the type generics
// - All fields are generic types
let use_type_alias = generics.type_params().all(|param| param.bounds.is_empty())
&& structure
.variants
.iter()
.flat_map(|variant| variant.fields.iter())
.all(|field| field_type_generic(field, generics));

let transposed_enum = if use_type_alias {
let generics = transpose_generics(enum_name, generics);

quote! {#visibility type #transposed_name #extension_generics = #enum_name #generics;}
} else {
quote! { #visibility enum #transposed_name #extension_impl_generics #extension_where_clause {#(#transposed_variants),*} }
};

// Expand to the extension trait and its implementation for the store alongside the transposed enum
Ok(quote! {
Expand All @@ -368,7 +463,7 @@ fn derive_store_enum(
)*
}

impl #extension_impl_generics #extension_trait_name #extension_generics for dioxus_stores::Store<#enum_name #ty_generics, __Lens> #extension_where_clause {
impl #extension_impl_generics #extension_trait_name #extension_ty_generics for dioxus_stores::Store<#enum_name #ty_generics, __Lens> #extension_where_clause {
#(
#implementations
)*
Expand Down Expand Up @@ -491,3 +586,31 @@ fn mapped_type(
let write_type = quote! { dioxus_stores::macro_helpers::dioxus_signals::MappedMutSignal<#field_type, __Lens, fn(&#item #ty_generics) -> &#field_type, fn(&mut #item #ty_generics) -> &mut #field_type> };
quote! { dioxus_stores::Store<#field_type, #write_type> }
}

/// Take the generics from the original type with only generic fields into the generics for the transposed type
fn transpose_generics(name: &Ident, generics: &syn::Generics) -> TokenStream2 {
let (_, ty_generics, _) = generics.split_for_impl();
let mut transposed_generics = generics.clone();
let mut generics = Vec::new();
for gen in transposed_generics.params.iter_mut() {
match gen {
// Map type generics into Store<Type, MappedMutSignal<...>>
syn::GenericParam::Type(type_param) => {
let ident = &type_param.ident;
let ty = mapped_type(name, &ty_generics, &parse_quote!(#ident));
generics.push(ty);
}
// Forward const and lifetime generics as-is
syn::GenericParam::Const(const_param) => {
let ident = &const_param.ident;
generics.push(quote! { #ident });
}
syn::GenericParam::Lifetime(lt_param) => {
let ident = &lt_param.lifetime;
generics.push(quote! { #ident });
}
}
}

quote!(<#(#generics),*> )
}
38 changes: 38 additions & 0 deletions packages/stores/tests/marco.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,20 @@ mod macro_tests {
store.check();
}

fn derive_generic_struct_transposed_passthrough() {
#[derive(Store)]
struct Item<const COUNT: usize, T> {
contents: T,
}

let mut store = use_store(|| Item::<0, _> {
contents: "Learn about stores".to_string(),
});

let Item { contents } = store.transpose();
let contents: String = contents();
}

fn derive_tuple() {
#[derive(Store, PartialEq, Clone, Debug)]
struct Item(bool, String);
Expand Down Expand Up @@ -342,4 +356,28 @@ mod macro_tests {
}
}
}

fn derive_generic_enum_transpose_passthrough() {
#[derive(Store, PartialEq, Clone, Debug)]
#[non_exhaustive]
enum Enum<const COUNT: usize, T> {
Foo,
Bar(T),
BarFoo { foo: T },
}

let mut store = use_store(|| Enum::<0, _>::Bar("Hello".to_string()));

let transposed = store.transpose();
use Enum::*;
match transposed {
Enum::Foo => {}
Bar(bar) => {
let bar: String = bar();
}
BarFoo { foo } => {
let foo: String = foo();
}
}
}
}
Loading