diff --git a/packages/stores-macro/src/derive.rs b/packages/stores-macro/src/derive.rs index d30c757e37..aa79b21747 100644 --- a/packages/stores-macro/src/derive.rs +++ b/packages/stores-macro/src/derive.rs @@ -60,7 +60,7 @@ fn derive_store_struct( let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); - let (extension_impl_generics, extension_generics, extension_where_clause) = + let (extension_impl_generics, extension_ty_generics, extension_where_clause) = extension_generics.split_for_impl(); // We collect the definitions and implementations for the extension trait methods along with the types of the fields in the transposed struct @@ -85,7 +85,7 @@ fn derive_store_struct( let definition = quote! { fn transpose( self, - ) -> #transposed_name #extension_generics where Self: ::std::marker::Copy; + ) -> #transposed_name #extension_ty_generics where Self: ::std::marker::Copy; }; definitions.push(definition); let field_names = fields @@ -108,7 +108,7 @@ fn derive_store_struct( let implementation = quote! { fn transpose( self, - ) -> #transposed_name #extension_generics where Self: ::std::marker::Copy { + ) -> #transposed_name #extension_ty_generics where Self: ::std::marker::Copy { // Convert each field into the corresponding store #( let #field_names = self.#field_names(); @@ -119,17 +119,15 @@ fn derive_store_struct( implementations.push(implementation); // Generate the transposed struct definition - let transposed_struct = match &structure.fields { - Fields::Named(_) => { - quote! { #visibility struct #transposed_name #extension_impl_generics #extension_where_clause {#(#transposed_fields),*} } - } - Fields::Unnamed(_) => { - quote! { #visibility struct #transposed_name #extension_impl_generics (#(#transposed_fields),*) #extension_where_clause; } - } - Fields::Unit => { - quote! {#visibility struct #transposed_name #extension_impl_generics #extension_where_clause;} - } - }; + let transposed_struct = transposed_struct( + visibility, + struct_name, + &transposed_name, + structure, + generics, + &extension_generics, + &transposed_fields, + ); // Expand to the extension trait and its implementation for the store alongside the transposed struct Ok(quote! { @@ -139,7 +137,7 @@ fn derive_store_struct( )* } - impl #extension_impl_generics #extension_trait_name #extension_generics for dioxus_stores::Store<#struct_name #ty_generics, __Lens> #extension_where_clause { + impl #extension_impl_generics #extension_trait_name #extension_ty_generics for dioxus_stores::Store<#struct_name #ty_generics, __Lens> #extension_where_clause { #( #implementations )* @@ -149,6 +147,72 @@ fn derive_store_struct( }) } +fn field_type_generic(field: &Field, generics: &syn::Generics) -> bool { + generics.type_params().any(|param| { + matches!(&field.ty, syn::Type::Path(type_path) if type_path.path.is_ident(¶m.ident)) + }) +} + +fn transposed_struct( + visibility: &syn::Visibility, + struct_name: &Ident, + transposed_name: &Ident, + structure: &DataStruct, + generics: &syn::Generics, + extension_generics: &syn::Generics, + transposed_fields: &[TokenStream2], +) -> TokenStream2 { + let (extension_impl_generics, _, extension_where_clause) = extension_generics.split_for_impl(); + // Only use a type alias if: + // - There are no bounds on the type generics + // - All fields are generic types + let use_type_alias = generics.type_params().all(|param| param.bounds.is_empty()) + && structure + .fields + .iter() + .all(|field| field_type_generic(field, generics)); + if use_type_alias { + let generics = transpose_generics(struct_name, generics); + return quote! {#visibility type #transposed_name #extension_impl_generics = #struct_name #generics;}; + } + match &structure.fields { + Fields::Named(fields) => { + let fields = fields.named.iter(); + let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| { + let vis = &f.vis; + let ident = &f.ident; + let colon = f.colon_token.as_ref(); + quote! { #vis #ident #colon #t } + }); + quote! { + #visibility struct #transposed_name #extension_impl_generics #extension_where_clause { + #( + #fields + ),* + } + } + } + Fields::Unnamed(fields) => { + let fields = fields.unnamed.iter(); + let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| { + let vis = &f.vis; + quote! { #vis #t } + }); + quote! { + #visibility struct #transposed_name #extension_impl_generics ( + #( + #fields + ),* + ) + #extension_where_clause; + } + } + Fields::Unit => { + quote! {#visibility struct #transposed_name #extension_impl_generics #extension_where_clause} + } + } +} + fn generate_field_methods( field_index: usize, field: &syn::Field, @@ -158,9 +222,7 @@ fn generate_field_methods( definitions: &mut Vec, implementations: &mut Vec, ) { - let vis = &field.vis; let field_name = &field.ident; - let colon = field.colon_token.as_ref(); // When we map the field, we need to use either the field name for named fields or the index for unnamed fields. let field_accessor = field_name.as_ref().map_or_else( @@ -171,7 +233,7 @@ fn generate_field_methods( let field_type = &field.ty; let store_type = mapped_type(struct_name, ty_generics, field_type); - transposed_fields.push(quote! { #vis #field_name #colon #store_type }); + transposed_fields.push(store_type.clone()); // Each field gets its own reactive scope within the child based on the field's index let ordinal = LitInt::new(&field_index.to_string(), field.span()); @@ -218,7 +280,7 @@ fn derive_store_enum( let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); - let (extension_impl_generics, extension_generics, extension_where_clause) = + let (extension_impl_generics, extension_ty_generics, extension_where_clause) = extension_generics.split_for_impl(); // We collect the definitions and implementations for the extension trait methods along with the types of the fields in the transposed enum @@ -249,14 +311,11 @@ fn derive_store_enum( let mut transposed_field_selectors = Vec::new(); let fields = &variant.fields; for (i, field) in fields.iter().enumerate() { - let vis = &field.vis; - let field_name = &field.ident; - let colon = field.colon_token.as_ref(); let field_type = &field.ty; let store_type = mapped_type(enum_name, &ty_generics, field_type); // Push the field for the transposed enum - transposed_fields.push(quote! { #vis #field_name #colon #store_type }); + transposed_fields.push(store_type.clone()); // Generate the code to get Store from the enum let select_field = select_enum_variant_field( @@ -321,11 +380,31 @@ fn derive_store_enum( // Push the type definition of the variant to the transposed enum let transposed_variant = match &fields { - Fields::Named(_) => { - quote! { #variant_name {#(#transposed_fields),*} } + Fields::Named(named) => { + let fields = named.named.iter(); + let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| { + let vis = &f.vis; + let ident = &f.ident; + let colon = f.colon_token.as_ref(); + quote! { #vis #ident #colon #t } + }); + quote! { #variant_name { + #( + #fields + ),* + } } } - Fields::Unnamed(_) => { - quote! { #variant_name (#(#transposed_fields),*) } + Fields::Unnamed(unnamed) => { + let fields = unnamed.unnamed.iter(); + let fields = fields.zip(transposed_fields.iter()).map(|(f, t)| { + let vis = &f.vis; + quote! { #vis #t } + }); + quote! { #variant_name ( + #( + #fields + ),* + ) } } Fields::Unit => { quote! { #variant_name } @@ -337,13 +416,13 @@ fn derive_store_enum( let definition = quote! { fn transpose( self, - ) -> #transposed_name #extension_generics where #readable_bounds, Self: ::std::marker::Copy; + ) -> #transposed_name #extension_ty_generics where #readable_bounds, Self: ::std::marker::Copy; }; definitions.push(definition); let implementation = quote! { fn transpose( self, - ) -> #transposed_name #extension_generics where #readable_bounds, Self: ::std::marker::Copy { + ) -> #transposed_name #extension_ty_generics where #readable_bounds, Self: ::std::marker::Copy { // We only do a shallow read of the store to get the current variant. We only need to rerun // this match when the variant changes, not when the fields change self.selector().track_shallow(); @@ -358,7 +437,23 @@ fn derive_store_enum( }; implementations.push(implementation); - let transposed_enum = quote! { #visibility enum #transposed_name #extension_impl_generics #extension_where_clause {#(#transposed_variants),*} }; + // Only use a type alias if: + // - There are no bounds on the type generics + // - All fields are generic types + let use_type_alias = generics.type_params().all(|param| param.bounds.is_empty()) + && structure + .variants + .iter() + .flat_map(|variant| variant.fields.iter()) + .all(|field| field_type_generic(field, generics)); + + let transposed_enum = if use_type_alias { + let generics = transpose_generics(enum_name, generics); + + quote! {#visibility type #transposed_name #extension_generics = #enum_name #generics;} + } else { + quote! { #visibility enum #transposed_name #extension_impl_generics #extension_where_clause {#(#transposed_variants),*} } + }; // Expand to the extension trait and its implementation for the store alongside the transposed enum Ok(quote! { @@ -368,7 +463,7 @@ fn derive_store_enum( )* } - impl #extension_impl_generics #extension_trait_name #extension_generics for dioxus_stores::Store<#enum_name #ty_generics, __Lens> #extension_where_clause { + impl #extension_impl_generics #extension_trait_name #extension_ty_generics for dioxus_stores::Store<#enum_name #ty_generics, __Lens> #extension_where_clause { #( #implementations )* @@ -491,3 +586,31 @@ fn mapped_type( let write_type = quote! { dioxus_stores::macro_helpers::dioxus_signals::MappedMutSignal<#field_type, __Lens, fn(&#item #ty_generics) -> &#field_type, fn(&mut #item #ty_generics) -> &mut #field_type> }; quote! { dioxus_stores::Store<#field_type, #write_type> } } + +/// Take the generics from the original type with only generic fields into the generics for the transposed type +fn transpose_generics(name: &Ident, generics: &syn::Generics) -> TokenStream2 { + let (_, ty_generics, _) = generics.split_for_impl(); + let mut transposed_generics = generics.clone(); + let mut generics = Vec::new(); + for gen in transposed_generics.params.iter_mut() { + match gen { + // Map type generics into Store> + syn::GenericParam::Type(type_param) => { + let ident = &type_param.ident; + let ty = mapped_type(name, &ty_generics, &parse_quote!(#ident)); + generics.push(ty); + } + // Forward const and lifetime generics as-is + syn::GenericParam::Const(const_param) => { + let ident = &const_param.ident; + generics.push(quote! { #ident }); + } + syn::GenericParam::Lifetime(lt_param) => { + let ident = <_param.lifetime; + generics.push(quote! { #ident }); + } + } + } + + quote!(<#(#generics),*> ) +} diff --git a/packages/stores/tests/marco.rs b/packages/stores/tests/marco.rs index e972c43bd6..d177b0023e 100644 --- a/packages/stores/tests/marco.rs +++ b/packages/stores/tests/marco.rs @@ -138,6 +138,20 @@ mod macro_tests { store.check(); } + fn derive_generic_struct_transposed_passthrough() { + #[derive(Store)] + struct Item { + contents: T, + } + + let mut store = use_store(|| Item::<0, _> { + contents: "Learn about stores".to_string(), + }); + + let Item { contents } = store.transpose(); + let contents: String = contents(); + } + fn derive_tuple() { #[derive(Store, PartialEq, Clone, Debug)] struct Item(bool, String); @@ -342,4 +356,28 @@ mod macro_tests { } } } + + fn derive_generic_enum_transpose_passthrough() { + #[derive(Store, PartialEq, Clone, Debug)] + #[non_exhaustive] + enum Enum { + Foo, + Bar(T), + BarFoo { foo: T }, + } + + let mut store = use_store(|| Enum::<0, _>::Bar("Hello".to_string())); + + let transposed = store.transpose(); + use Enum::*; + match transposed { + Enum::Foo => {} + Bar(bar) => { + let bar: String = bar(); + } + BarFoo { foo } => { + let foo: String = foo(); + } + } + } }