Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions packages/fullstack-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ inventory = { workspace = true }
serde_json = { workspace = true }
generational-box = { workspace = true }
futures-util = { workspace = true, features = ["std"] }
tokio = { workspace = true, features = ["rt"] }
parking_lot = { workspace = true }

[features]
web = []
Expand Down
112 changes: 84 additions & 28 deletions packages/fullstack-core/src/streaming.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,56 @@
use crate::{HttpError, ServerFnError};
use axum_core::extract::FromRequestParts;
use axum_core::{extract::FromRequest, response::IntoResponse};
use dioxus_core::{try_consume_context, CapturedError};
use dioxus_signals::{ReadableExt, Signal, WritableExt};
use dioxus_core::{try_consume_context, CapturedError, ReactiveContext};
use http::StatusCode;
use http::{request::Parts, HeaderMap};
use std::{cell::RefCell, rc::Rc};
use parking_lot::RwLock;
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;

tokio::task_local! {
static FULLSTACK_CONTEXT: FullstackContext;
}

/// The context provided by dioxus fullstack for server-side rendering.
///
/// This context will only be set on the server during a streaming response.
#[derive(Clone, Debug)]
/// This context will only be set on the server during the initial streaming response
/// and inside server functions.
#[derive(Clone)]
pub struct FullstackContext {
current_status: Signal<StreamingStatus>,
request_headers: Rc<RefCell<http::request::Parts>>,
response_headers: Rc<RefCell<Option<HeaderMap>>>,
route_http_status: Signal<HttpError>,
current_status: Arc<RwLock<StreamingStatus>>,
current_status_subscribers: Arc<RwLock<HashSet<ReactiveContext>>>,
request_headers: Arc<RwLock<http::request::Parts>>,
response_headers: Arc<RwLock<Option<HeaderMap>>>,
route_http_status: Arc<RwLock<HttpError>>,
route_http_status_subscribers: Arc<RwLock<HashSet<ReactiveContext>>>,
}

impl Debug for FullstackContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FullstackContext")
.field("current_status", &self.current_status)
.field("request_headers", &self.request_headers)
.field("route_http_status", &self.route_http_status)
.finish()
}
}

impl PartialEq for FullstackContext {
fn eq(&self, other: &Self) -> bool {
self.current_status == other.current_status
&& Rc::ptr_eq(&self.request_headers, &other.request_headers)
Arc::ptr_eq(&self.current_status, &other.current_status)
&& Arc::ptr_eq(&self.request_headers, &other.request_headers)
&& Arc::ptr_eq(&self.route_http_status, &other.route_http_status)
&& Arc::ptr_eq(&self.response_headers, &other.response_headers)
&& Arc::ptr_eq(
&self.current_status_subscribers,
&other.current_status_subscribers,
)
&& Arc::ptr_eq(
&self.route_http_status_subscribers,
&other.route_http_status_subscribers,
)
}
}

Expand All @@ -29,13 +59,15 @@ impl FullstackContext {
/// provide this context for you.
pub fn new(parts: Parts) -> Self {
Self {
current_status: Signal::new(StreamingStatus::RenderingInitialChunk),
request_headers: Rc::new(RefCell::new(parts)),
route_http_status: Signal::new(HttpError {
current_status: Arc::new(RwLock::new(StreamingStatus::RenderingInitialChunk)),
current_status_subscribers: Default::default(),
request_headers: RwLock::new(parts).into(),
route_http_status: Arc::new(RwLock::new(HttpError {
status: http::StatusCode::OK,
message: None,
}),
response_headers: Rc::new(RefCell::new(Some(HeaderMap::new()))),
})),
route_http_status_subscribers: Default::default(),
response_headers: RwLock::new(Some(HeaderMap::new())).into(),
}
}

Expand All @@ -45,8 +77,11 @@ impl FullstackContext {
///
/// Once this method has been called, the http response parts can no longer be modified.
pub fn commit_initial_chunk(&mut self) {
self.current_status
.set(StreamingStatus::InitialChunkCommitted);
*self.current_status.write() = StreamingStatus::InitialChunkCommitted;
let subscribers = std::mem::take(&mut *self.current_status_subscribers.write());
for subscriber in subscribers {
subscriber.mark_dirty();
}
}

/// Get the current status of the streaming response. This method is reactive and will cause
Expand All @@ -56,17 +91,24 @@ impl FullstackContext {
}

/// Access the http request parts mutably. This will allow you to modify headers and other parts of the request.
pub fn parts_mut(&self) -> std::cell::RefMut<'_, http::request::Parts> {
self.request_headers.borrow_mut()
pub fn parts_mut(&self) -> parking_lot::RwLockWriteGuard<'_, http::request::Parts> {
self.request_headers.write()
}

/// Extract an axum extractor from the current request. This will always use an empty body for the request,
/// since it's assumed that rendering the app is done under a `GET` request.
pub async fn extract<T: FromRequest<(), M>, M>() -> Result<T, ServerFnError> {
/// Run a future within the scope of this FullstackContext.
pub async fn scope<F, R>(self, fut: F) -> R
where
F: std::future::Future<Output = R>,
{
FULLSTACK_CONTEXT.scope(self, fut).await
}

/// Extract an axum extractor from the current request.
pub async fn extract<T: FromRequestParts<()>>() -> Result<T, ServerFnError> {
let this = Self::current()
.ok_or_else(|| ServerFnError::new("No FullstackContext found".to_string()))?;

let parts = this.request_headers.borrow_mut().clone();
let parts = this.request_headers.read().clone();
let request =
axum_core::extract::Request::from_parts(parts, axum_core::body::Body::empty());
match T::from_request(request, &()).await {
Expand All @@ -79,8 +121,12 @@ impl FullstackContext {
}

/// Get the current `FullstackContext` if it exists. This will return `None` if called on the client
/// or outside of a streaming response on the server.
/// or outside of a streaming response on the server or server function.
pub fn current() -> Option<Self> {
if let Ok(context) = FULLSTACK_CONTEXT.try_get() {
return Some(context);
}

if let Some(rt) = dioxus_core::Runtime::try_current() {
let id = rt.try_current_scope_id()?;
if let Some(ctx) = rt.consume_context::<FullstackContext>(id) {
Expand All @@ -98,7 +144,11 @@ impl FullstackContext {
}

pub fn set_current_http_status(&mut self, status: HttpError) {
self.route_http_status.set(status);
*self.route_http_status.write() = status;
let subscribers = std::mem::take(&mut *self.route_http_status_subscribers.write());
for subscriber in subscribers {
subscriber.mark_dirty();
}
}

/// Add a header to the response. This will be sent to the client when the response is committed.
Expand All @@ -107,15 +157,15 @@ impl FullstackContext {
key: impl Into<http::header::HeaderName>,
value: impl Into<http::header::HeaderValue>,
) {
if let Some(headers) = self.response_headers.borrow_mut().as_mut() {
if let Some(headers) = self.response_headers.write().as_mut() {
headers.insert(key.into(), value.into());
}
}

/// Take the response headers out of the context. This will leave the context without any headers,
/// so it should only be called once when the response is being committed.
pub fn take_response_headers(&self) -> Option<HeaderMap> {
self.response_headers.borrow_mut().take()
self.response_headers.write().take()
}

/// Set the current HTTP status for the route. This will be used when committing the response
Expand Down Expand Up @@ -182,6 +232,12 @@ pub fn commit_initial_chunk() {
}
}

/// Extract an axum extractor from the current request.
pub fn extract<T: FromRequestParts<()>>(
) -> impl std::future::Future<Output = Result<T, ServerFnError>> {
FullstackContext::extract::<T>()
}

/// Get the current status of the streaming response. This method is reactive and will cause
/// the current reactive context to rerun when the status changes.
///
Expand Down
35 changes: 33 additions & 2 deletions packages/fullstack-server/src/serverfn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use axum::body::Body;
use axum::routing::MethodRouter;
use axum::Router;
use dashmap::DashMap;
use dioxus_fullstack_core::DioxusServerState;
use dioxus_fullstack_core::{DioxusServerState, FullstackContext};
use dioxus_signals::SyncStorage;
use http::Method;
use std::{marker::PhantomData, sync::LazyLock};

Expand Down Expand Up @@ -76,9 +77,39 @@ impl ServerFunction {
// }
// }

async fn server_context_middleware(
request: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
let (parts, body) = request.into_parts();
let server_context = FullstackContext::new(parts.clone());
let request = axum::extract::Request::from_parts(parts, body);

server_context
.scope(async move {
// Run the next middleware / handler inside the server context
let mut response = next.run(request).await;

let server_context = FullstackContext::current().expect(
"Server context should be available inside the server context scope",
);

// Get the extra response headers set during the handler and add them to the response
let headers = server_context.take_response_headers();
if let Some(headers) = headers {
response.headers_mut().extend(headers);
}

response
})
.await
}

router.route(
self.path(),
((self.handler)()).with_state(DioxusServerState {}),
((self.handler)())
.with_state(DioxusServerState {})
.layer(axum::middleware::from_fn(server_context_middleware)),
)
}
}
Expand Down
28 changes: 16 additions & 12 deletions packages/playwright-tests/fullstack/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@ use dioxus::fullstack::{commit_initial_chunk, Websocket};
use dioxus::{fullstack::WebSocketOptions, prelude::*};

fn main() {
dioxus::LaunchBuilder::new()
.with_cfg(server_only! {
dioxus::server::ServeConfig::builder().enable_out_of_order_streaming()
})
.with_context(1234u32)
.launch(app);
#[cfg(feature = "server")]
dioxus::serve(|| async move {
use dioxus::server::axum::{self, Extension};

let cfg = dioxus::server::ServeConfig::builder().enable_out_of_order_streaming();
let router = axum::Router::new()
.serve_dioxus_application(cfg, app)
.layer(Extension(1234u32));

Ok(router)
});
launch(app);
}

fn app() -> Element {
Expand Down Expand Up @@ -74,13 +80,12 @@ fn DefaultServerFnCodec() -> Element {

#[cfg(feature = "server")]
async fn assert_server_context_provided() {
// todo!("replace server context....")
// use dioxus::server::{extract, FromContext};
// let FromContext(i): FromContext<u32> = extract().await.unwrap();
// assert_eq!(i, 1234u32);
use dioxus::{fullstack::extract, server::axum::Extension};
// Just make sure the server context is provided
let Extension(id): Extension<u32> = extract().await.unwrap();
assert_eq!(id, 1234u32);
}

// #[server(PostServerData)]
#[server]
async fn post_server_data(data: String) -> ServerFnResult {
assert_server_context_provided().await;
Expand All @@ -89,7 +94,6 @@ async fn post_server_data(data: String) -> ServerFnResult {
Ok(())
}

// #[server(GetServerData)]
#[server]
async fn get_server_data() -> ServerFnResult<String> {
assert_server_context_provided().await;
Expand Down