diff --git a/Cargo.lock b/Cargo.lock index 0bfce99b8d..86eac0ebfe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5850,9 +5850,11 @@ dependencies = [ "generational-box", "http 1.3.1", "inventory", + "parking_lot", "serde", "serde_json", "thiserror 2.0.17", + "tokio", "tracing", ] diff --git a/packages/fullstack-core/Cargo.toml b/packages/fullstack-core/Cargo.toml index 9fc6fae631..1947bacde8 100644 --- a/packages/fullstack-core/Cargo.toml +++ b/packages/fullstack-core/Cargo.toml @@ -29,6 +29,8 @@ inventory = { workspace = true } serde_json = { workspace = true } generational-box = { workspace = true } futures-util = { workspace = true, features = ["std"] } +tokio = { workspace = true, features = ["rt"] } +parking_lot = { workspace = true } [features] web = [] diff --git a/packages/fullstack-core/src/streaming.rs b/packages/fullstack-core/src/streaming.rs index 4689f5b88e..481db186e1 100644 --- a/packages/fullstack-core/src/streaming.rs +++ b/packages/fullstack-core/src/streaming.rs @@ -1,26 +1,52 @@ use crate::{HttpError, ServerFnError}; -use axum_core::{extract::FromRequest, response::IntoResponse}; -use dioxus_core::{try_consume_context, CapturedError}; -use dioxus_signals::{ReadableExt, Signal, WritableExt}; +use axum_core::extract::FromRequest; +use axum_core::response::IntoResponse; +use dioxus_core::{CapturedError, ReactiveContext}; use http::StatusCode; use http::{request::Parts, HeaderMap}; -use std::{cell::RefCell, rc::Rc}; +use parking_lot::RwLock; +use std::collections::HashSet; +use std::fmt::Debug; +use std::sync::Arc; + +tokio::task_local! { + static FULLSTACK_CONTEXT: FullstackContext; +} /// The context provided by dioxus fullstack for server-side rendering. /// -/// This context will only be set on the server during a streaming response. +/// This context will only be set on the server during the initial streaming response +/// and inside server functions. #[derive(Clone, Debug)] pub struct FullstackContext { - current_status: Signal, - request_headers: Rc>, - response_headers: Rc>>, - route_http_status: Signal, + // We expose the lock for request headers directly so it needs to be in a separate lock + request_headers: Arc>, + // The rest of the fields are only held internally, so we can group them together + lock: Arc>, +} + +pub struct FullstackContextInner { + current_status: StreamingStatus, + current_status_subscribers: HashSet, + response_headers: Option, + route_http_status: HttpError, + route_http_status_subscribers: HashSet, +} + +impl Debug for FullstackContextInner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FullstackContextInner") + .field("current_status", &self.current_status) + .field("response_headers", &self.response_headers) + .field("route_http_status", &self.route_http_status) + .finish() + } } impl PartialEq for FullstackContext { fn eq(&self, other: &Self) -> bool { - self.current_status == other.current_status - && Rc::ptr_eq(&self.request_headers, &other.request_headers) + Arc::ptr_eq(&self.lock, &other.lock) + && Arc::ptr_eq(&self.request_headers, &other.request_headers) } } @@ -29,13 +55,18 @@ impl FullstackContext { /// provide this context for you. pub fn new(parts: Parts) -> Self { Self { - current_status: Signal::new(StreamingStatus::RenderingInitialChunk), - request_headers: Rc::new(RefCell::new(parts)), - route_http_status: Signal::new(HttpError { - status: http::StatusCode::OK, - message: None, - }), - response_headers: Rc::new(RefCell::new(Some(HeaderMap::new()))), + request_headers: RwLock::new(parts).into(), + lock: RwLock::new(FullstackContextInner { + current_status: StreamingStatus::RenderingInitialChunk, + current_status_subscribers: Default::default(), + route_http_status: HttpError { + status: http::StatusCode::OK, + message: None, + }, + route_http_status_subscribers: Default::default(), + response_headers: Some(HeaderMap::new()), + }) + .into(), } } @@ -45,30 +76,50 @@ impl FullstackContext { /// /// Once this method has been called, the http response parts can no longer be modified. pub fn commit_initial_chunk(&mut self) { - self.current_status - .set(StreamingStatus::InitialChunkCommitted); + let mut lock = self.lock.write(); + lock.current_status = StreamingStatus::InitialChunkCommitted; + // The key type is mutable, but the hash is stable through mutations because we hash by pointer + #[allow(clippy::mutable_key_type)] + let subscribers = std::mem::take(&mut lock.current_status_subscribers); + for subscriber in subscribers { + subscriber.mark_dirty(); + } } /// Get the current status of the streaming response. This method is reactive and will cause /// the current reactive context to rerun when the status changes. pub fn streaming_state(&self) -> StreamingStatus { - *self.current_status.read() + let mut lock = self.lock.write(); + // Register the current reactive context as a subscriber to changes in the streaming status + if let Some(ctx) = ReactiveContext::current() { + lock.current_status_subscribers.insert(ctx); + } + lock.current_status } /// Access the http request parts mutably. This will allow you to modify headers and other parts of the request. - pub fn parts_mut(&self) -> std::cell::RefMut<'_, http::request::Parts> { - self.request_headers.borrow_mut() + pub fn parts_mut(&self) -> parking_lot::RwLockWriteGuard<'_, http::request::Parts> { + self.request_headers.write() } - /// Extract an axum extractor from the current request. This will always use an empty body for the request, - /// since it's assumed that rendering the app is done under a `GET` request. + /// Run a future within the scope of this FullstackContext. + pub async fn scope(self, fut: F) -> R + where + F: std::future::Future, + { + FULLSTACK_CONTEXT.scope(self, fut).await + } + + /// Extract an axum extractor from the current request. + /// + /// The body of the request is always empty when using this method, as the body can only be consumed once in the server + /// function extractors. pub async fn extract, M>() -> Result { let this = Self::current() .ok_or_else(|| ServerFnError::new("No FullstackContext found".to_string()))?; - let parts = this.request_headers.borrow_mut().clone(); - let request = - axum_core::extract::Request::from_parts(parts, axum_core::body::Body::empty()); + let parts = this.request_headers.read().clone(); + let request = axum_core::extract::Request::from_parts(parts, Default::default()); match T::from_request(request, &()).await { Ok(res) => Ok(res), Err(err) => { @@ -79,8 +130,14 @@ impl FullstackContext { } /// Get the current `FullstackContext` if it exists. This will return `None` if called on the client - /// or outside of a streaming response on the server. + /// or outside of a streaming response on the server or server function. pub fn current() -> Option { + // Try to get the context from the task local (for server functions) + if let Ok(context) = FULLSTACK_CONTEXT.try_get() { + return Some(context); + } + + // Otherwise, try to get it from the dioxus runtime context (for streaming SSR) if let Some(rt) = dioxus_core::Runtime::try_current() { let id = rt.try_current_scope_id()?; if let Some(ctx) = rt.consume_context::(id) { @@ -94,11 +151,23 @@ impl FullstackContext { /// Get the current HTTP status for the route. This will default to 200 OK, but can be modified /// by calling `FullstackContext::commit_error_status` with an error. pub fn current_http_status(&self) -> HttpError { - self.route_http_status.read().clone() + let mut lock = self.lock.write(); + // Register the current reactive context as a subscriber to changes in the http status + if let Some(ctx) = ReactiveContext::current() { + lock.route_http_status_subscribers.insert(ctx); + } + lock.route_http_status.clone() } pub fn set_current_http_status(&mut self, status: HttpError) { - self.route_http_status.set(status); + let mut lock = self.lock.write(); + lock.route_http_status = status; + // The key type is mutable, but the hash is stable through mutations because we hash by pointer + #[allow(clippy::mutable_key_type)] + let subscribers = std::mem::take(&mut lock.route_http_status_subscribers); + for subscriber in subscribers { + subscriber.mark_dirty(); + } } /// Add a header to the response. This will be sent to the client when the response is committed. @@ -107,7 +176,8 @@ impl FullstackContext { key: impl Into, value: impl Into, ) { - if let Some(headers) = self.response_headers.borrow_mut().as_mut() { + let mut lock = self.lock.write(); + if let Some(headers) = lock.response_headers.as_mut() { headers.insert(key.into(), value.into()); } } @@ -115,7 +185,8 @@ impl FullstackContext { /// Take the response headers out of the context. This will leave the context without any headers, /// so it should only be called once when the response is being committed. pub fn take_response_headers(&self) -> Option { - self.response_headers.borrow_mut().take() + let mut lock = self.lock.write(); + lock.response_headers.take() } /// Set the current HTTP status for the route. This will be used when committing the response @@ -177,11 +248,18 @@ pub enum StreamingStatus { /// ``` pub fn commit_initial_chunk() { crate::history::finalize_route(); - if let Some(mut streaming) = try_consume_context::() { + if let Some(mut streaming) = FullstackContext::current() { streaming.commit_initial_chunk(); } } +/// Extract an axum extractor from the current request. +#[deprecated(note = "Use FullstackContext::extract instead", since = "0.7.0")] +pub fn extract, M>( +) -> impl std::future::Future> { + FullstackContext::extract::() +} + /// Get the current status of the streaming response. This method is reactive and will cause /// the current reactive context to rerun when the status changes. /// @@ -205,7 +283,7 @@ pub fn commit_initial_chunk() { /// } /// ``` pub fn current_status() -> StreamingStatus { - if let Some(streaming) = try_consume_context::() { + if let Some(streaming) = FullstackContext::current() { streaming.streaming_state() } else { StreamingStatus::InitialChunkCommitted diff --git a/packages/fullstack-server/src/serverfn.rs b/packages/fullstack-server/src/serverfn.rs index abfc306be6..fe9f8a5795 100644 --- a/packages/fullstack-server/src/serverfn.rs +++ b/packages/fullstack-server/src/serverfn.rs @@ -2,7 +2,7 @@ use axum::body::Body; use axum::routing::MethodRouter; use axum::Router; use dashmap::DashMap; -use dioxus_fullstack_core::DioxusServerState; +use dioxus_fullstack_core::{DioxusServerState, FullstackContext}; use http::Method; use std::{marker::PhantomData, sync::LazyLock}; @@ -76,9 +76,39 @@ impl ServerFunction { // } // } + async fn server_context_middleware( + request: axum::extract::Request, + next: axum::middleware::Next, + ) -> axum::response::Response { + let (parts, body) = request.into_parts(); + let server_context = FullstackContext::new(parts.clone()); + let request = axum::extract::Request::from_parts(parts, body); + + server_context + .scope(async move { + // Run the next middleware / handler inside the server context + let mut response = next.run(request).await; + + let server_context = FullstackContext::current().expect( + "Server context should be available inside the server context scope", + ); + + // Get the extra response headers set during the handler and add them to the response + let headers = server_context.take_response_headers(); + if let Some(headers) = headers { + response.headers_mut().extend(headers); + } + + response + }) + .await + } + router.route( self.path(), - ((self.handler)()).with_state(DioxusServerState {}), + ((self.handler)()) + .with_state(DioxusServerState {}) + .layer(axum::middleware::from_fn(server_context_middleware)), ) } } diff --git a/packages/fullstack/src/payloads/cbor.rs b/packages/fullstack/src/payloads/cbor.rs index 62d0b50d16..bd013ba583 100644 --- a/packages/fullstack/src/payloads/cbor.rs +++ b/packages/fullstack/src/payloads/cbor.rs @@ -21,7 +21,7 @@ use serde::{de::DeserializeOwned, Serialize}; /// *last* if there are multiple extractors in a handler. /// See ["the order of extractors"][order-of-extractors] /// -/// [order-of-extractors]: crate::extract#the-order-of-extractors +/// [order-of-extractors]: mod@crate::extract#the-order-of-extractors #[must_use] pub struct Cbor(pub T); diff --git a/packages/fullstack/src/payloads/multipart.rs b/packages/fullstack/src/payloads/multipart.rs index 97ee837013..018d9aa578 100644 --- a/packages/fullstack/src/payloads/multipart.rs +++ b/packages/fullstack/src/payloads/multipart.rs @@ -32,7 +32,7 @@ use axum::extract::multipart::{Field, MultipartError}; /// `Multipart` extractor must be *last* if there are multiple extractors in a handler. /// See ["the order of extractors"][order-of-extractors] /// -/// [order-of-extractors]: crate::extract#the-order-of-extractors +/// [order-of-extractors]: mod@crate::extract#the-order-of-extractors /// /// # Large Files /// diff --git a/packages/fullstack/src/payloads/postcard.rs b/packages/fullstack/src/payloads/postcard.rs index cbe8d9e6a9..70f25a7f7c 100644 --- a/packages/fullstack/src/payloads/postcard.rs +++ b/packages/fullstack/src/payloads/postcard.rs @@ -23,7 +23,7 @@ use std::future::Future; /// *last* if there are multiple extractors in a handler. /// See ["the order of extractors"][order-of-extractors] /// -/// [order-of-extractors]: crate::extract#the-order-of-extractors +/// [order-of-extractors]: mod@crate::extract#the-order-of-extractors pub struct Postcard(pub T); #[derive(thiserror::Error, Debug)] diff --git a/packages/playwright-tests/fullstack/src/main.rs b/packages/playwright-tests/fullstack/src/main.rs index b36fc8e028..f4619d8c42 100644 --- a/packages/playwright-tests/fullstack/src/main.rs +++ b/packages/playwright-tests/fullstack/src/main.rs @@ -9,12 +9,19 @@ use dioxus::fullstack::{commit_initial_chunk, Websocket}; use dioxus::{fullstack::WebSocketOptions, prelude::*}; fn main() { - dioxus::LaunchBuilder::new() - .with_cfg(server_only! { - dioxus::server::ServeConfig::builder().enable_out_of_order_streaming() - }) - .with_context(1234u32) - .launch(app); + #[cfg(feature = "server")] + dioxus::serve(|| async move { + use dioxus::server::axum::{self, Extension}; + + let cfg = dioxus::server::ServeConfig::builder().enable_out_of_order_streaming(); + let router = axum::Router::new() + .serve_dioxus_application(cfg, app) + .layer(Extension(1234u32)); + + Ok(router) + }); + #[cfg(not(feature = "server"))] + launch(app); } fn app() -> Element { @@ -74,13 +81,12 @@ fn DefaultServerFnCodec() -> Element { #[cfg(feature = "server")] async fn assert_server_context_provided() { - // todo!("replace server context....") - // use dioxus::server::{extract, FromContext}; - // let FromContext(i): FromContext = extract().await.unwrap(); - // assert_eq!(i, 1234u32); + use dioxus::{fullstack::FullstackContext, server::axum::Extension}; + // Just make sure the server context is provided + let Extension(id): Extension = FullstackContext::extract().await.unwrap(); + assert_eq!(id, 1234u32); } -// #[server(PostServerData)] #[server] async fn post_server_data(data: String) -> ServerFnResult { assert_server_context_provided().await; @@ -89,7 +95,6 @@ async fn post_server_data(data: String) -> ServerFnResult { Ok(()) } -// #[server(GetServerData)] #[server] async fn get_server_data() -> ServerFnResult { assert_server_context_provided().await;