Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4b1e3db
Fix lint issues uncovered by pycodestyle 2.11
kbattocchi Aug 2, 2023
96c8507
init
kgao Sep 1, 2023
5b08000
update class name
kgao Sep 6, 2023
7c23d5e
update method
kgao Sep 6, 2023
f0940aa
update matrix A, B calling reference
kgao Sep 15, 2023
edd9e94
cleanup wrapper class
kgao Sep 15, 2023
cfb3e34
add local test files
kgao Sep 20, 2023
0dfa1e3
Added federated learner test
kbattocchi Sep 22, 2023
d97793d
Fix aggregation logic
kbattocchi Sep 25, 2023
fab3531
update the doc
kgao Oct 4, 2023
e652b08
update doc for federated learning
kgao Oct 4, 2023
6c7e485
cleanup branch: test, file structure, docstring and linting
kgao Oct 6, 2023
909eb5c
cleanup federated learning doc
kgao Oct 17, 2023
e187f13
cleaning the doc for federated learning
kgao Oct 20, 2023
5eb64a5
Rename moments
kbattocchi Oct 28, 2023
e00cdd2
Make FederatedEstimator a CateEstimator
kbattocchi Oct 30, 2023
f943966
add arg to allow missing values in W and sometimes X (#791)
fverac Sep 29, 2023
1c0bcf8
Drop support for sklearn<1.0
kbattocchi Mar 28, 2023
99cc11a
Support direct covariance fitting in DRIV
kbattocchi Jun 21, 2023
d90972d
Enable sklearn 1.3
kbattocchi Aug 4, 2023
c46eac8
Ensure groups work with DRIV, DMLIV
kbattocchi Aug 14, 2023
53a4b2a
Update __init__.py to reflect current structure
kbattocchi Oct 6, 2023
9dbf21d
Make changes to support dowhy 0.10.1 in tests
kbattocchi Oct 10, 2023
bef9b4c
Allow newer shap, matlab, and seaborn versions
kbattocchi Oct 10, 2023
db606ee
Make minor CI improvements
kbattocchi Aug 14, 2023
de0c645
Save notebook outputs during CI
kbattocchi Oct 11, 2023
a6691ff
Remove legacy assertWarns hack
kbattocchi Oct 21, 2023
6291f4a
Scaling ortholearners using Ray (#800)
v-shaal Oct 27, 2023
7935df5
Improve FederatedEstimator docs
kbattocchi Oct 31, 2023
ef04ce7
Merge branch 'main' into kgao/federal-learning
kbattocchi Oct 31, 2023
650f28a
Fixup docs
kbattocchi Oct 31, 2023
85c8458
Fixup docs
kbattocchi Oct 31, 2023
a91d8d2
Fixup federated learning doctest
kbattocchi Oct 31, 2023
1e1f67d
Reduce featurization degree
kbattocchi Nov 1, 2023
d9444ba
Add tests for DRLearner
kbattocchi Nov 1, 2023
16b8860
Fixup tests
kbattocchi Nov 1, 2023
c04df65
Fix test failures
kbattocchi Nov 1, 2023
e4415c8
Fix test
kbattocchi Nov 1, 2023
5205791
Fix tests
kbattocchi Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ Inference Methods
econml.inference.LinearModelFinalInferenceDiscrete
econml.inference.StatsModelsInferenceDiscrete

.. _federated_api:

Federated Estimation
--------------------

.. autosummary::
:toctree: _autosummary

econml.federated_learning.FederatedEstimator

.. _solutions_api:

Solutions
Expand Down
145 changes: 145 additions & 0 deletions doc/spec/federated_learning.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
Federated Learning in EconML
==============================================
.. contents::
:local:
:depth: 2

Overview
--------

Federated Learning in the EconML Library allows models to be trained on separate data sets and then combined
into a single CATE model afterwards, without ever needing to collect all of the training data on a single machine.

Motivation for Incorporating Federated Learning into the EconML Library
-----------------------------------------------------------------------

1. **Large data sets**: With data sets that are so large that they cannot fit onto a single machine, federated
learning allows you to partition the data, train an individual causal model on each partition, and combine the models
into a single model afterwards.

2. **Privacy Preservation**: Federated learning enables organizations to build machine learning models without
centralizing or sharing sensitive data. This may be important to comply with data privacy regulations by keeping
data localized and reducing exposure to compliance risks.

Federated Learning with EconML
------------------------------

Introducing the `FederatedEstimator`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

We provide the `FederatedEstimator` class to allow aggregating individual estimators which have
been trained on different subsets of data. The individual estimators must all be of the same type,
which must currently be either LinearDML, LinearDRLearner, LinearDRIV, or LinearIntentToTreatDRIV.

Unlike other estimators, you should not call `fit` on an instance of `FederatedEstimator`; instead,
you should train your individual estimators separately and then pass the already trained models to the `FederatedEstimator`
initializer. The `FederatedEstimator` will then aggregate the individual estimators into a single model.


Example Usage
~~~~~~~~~~~~~

.. testsetup::

import numpy as np
from econml.federated_learning import FederatedEstimator
from econml.dml import LinearDML
n = 1000
(X, y, t) = (np.random.normal(size=(n,)+s) for s in [(3,), (), ())

.. testcode::

# Create individual LinearDML estimators
num_partitions = 3
estimators = []
for i in range(num_partitions):
est = LinearDML(random_state=123)
# Get the data for this partition
X_part, y_part, t_part = (arr[i::num_partitions] for arr in (X, y, t))

# In practice, each estimator could be trained in a distributed fashion
# e.g. by using Spark
est.fit(Y=y_part, T=t_part, X=X_part)
estimators.append(est)

# Create a FederatedEstimator by providing a list of estimators
federated_estimator = FederatedEstimator(estimators)

# The federated estimator can now be used like a typical CATE estimator
cme = federated_estimator.const_marginal_effect(X)



Theory
------

Many estimators are solving a moment equation

.. math::

\E[\psi(D; \theta; \eta)] = 0

where :math:`D` is the data, :math:`\theta` is the parameter, and :math:`\eta` is the nuisance parameter. Often, the moment is linear in the parameter, so that it can be rewritten as

.. math::

\E[\psi_a(D; \eta)\theta + \psi_b(D; \eta)] = 0

In this case, solving the equation using the empirical expectations gives

.. math::

\begin{align*}
\hat{\theta} &= -\E_n[\psi_a(D;\hat{\eta})]^{-1} \E_n[\psi_b(D;\hat{\eta})] \\
\sqrt{N}(\theta-\hat{\theta}) &\sim \mathcal{N}\left(0, \E_n[\psi_a(D;\hat{\eta})]^{-1} \E_n[\psi(D;\hat{\theta};\hat{\eta}) \psi(D;\hat{\theta};\hat{\eta})^\top] \E_n[\psi_a(D;\hat{\eta})^\top]^{-1}\right)
\end{align*}

The center term in the variance calculation can be expanded out:

.. math::
:nowrap:

\begin{align*}
\E_n[\psi(D;\hat\theta;\hat\eta) \psi(D;\hat\theta;\hat\eta)^\top] &= \E_n[(\psi_b(D;\hat\eta)+\psi_a(D;\hat\eta)\hat\theta) (\psi_b(D;\hat\eta)+\psi_a(D;\hat\eta)\hat\theta)^\top] \\
&= \E_n[\psi_b(D;\hat\eta) \psi_b(D;\hat\eta)^\top] + \E_n[\psi_a(D;\hat\eta)\hat\theta\psi_b(D;\hat\eta)^\top] \\
&+ \E_n[\psi_b(D;\hat\eta) \hat\theta^\top \psi_a(D;\hat\eta)^\top] + \E_n[\psi_a(D;\hat\eta) \hat\theta\hat\theta^\top\psi_a(D;\hat\eta)^\top ]
\end{align*}

Some of these terms involve products where :math:`\hat\theta` appears in an interior position, but these can equivalently be computed by taking the outer product of the matrices on either side and then contracting with :math:`\hat\theta` afterwards. Thus, we can distribute the computation of the following quantities:

.. math::
:nowrap:

\begin{align*}
& \E_n[\psi_a(D;\hat\eta)] \\
& \E_n[\psi_b(D;\hat\eta)] \\
& \E_n[\psi_b(D;\hat\eta) \psi_b(D;\hat\eta)^\top] \\
& \E_n[\psi_b(D;\hat\eta) \otimes \psi_a(D;\hat\eta)] \\
& \E_n[\psi_a(D;\hat\eta) \otimes \psi_a(D;\hat\eta)] \\
\end{align*}

We can then aggregate these distributed estimates, use the first two to calculate :math:`\hat\theta`, and then use that with the rest to calculate the analytical variance.

As an example, for linear regression of :math:`y` on :math:`X`, we have

.. math::

\psi_a(D;\eta) = X^\top X \\
\psi_b(D;\eta) = X^\top y

And so the additional moments we need to distribute are

\begin{align*}
& \E_n[X^\top y y^\top X] = \E_n[X^\top X y^2] = \E_n[X] \\
& \E_n[X^\top y \otimes X^\top X] = \E_n[X \otimes X \otimes X y]\\
& \E_n[X^\top X \otimes X^\top X] = \E_n[X \otimes X \otimes X \otimes X] \\
\end{align*}

Thus, at the cost of storing these three extra moments, we can distribute the computation of linear regression and recover exactly the same
result we would have gotten by doing this computation on the full data set.

In the context of federated CATE estimation, note that in practice the nuisances are computed on subsets of the data,
so while it is true that the aggregated final linear model is exactly the same as what would be computed with all of the same nuisances locally,
in practice the nuisance estimates would differ if computed on all of the data. In practice, this should not be a significant issue as long as the
nuisance estimators converge at a reasonable rate; for example if the first stage models are accurate enough for the final estimate to converge at a rate of :math:`O(1/\sqrt{n})`,
then splitting the data into :math:`k` partitions should only increase the variance by a factor of :math:`\sqrt{k}`.
1 change: 1 addition & 0 deletions doc/spec/spec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ EconML User Guide
estimation_dynamic
inference
interpretability
federated_learning
references
faq
community
Expand Down
1 change: 1 addition & 0 deletions econml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'tree',
'dowhy',
'utilities',
'federated_learning',
'__version__']

from ._version import __version__
6 changes: 6 additions & 0 deletions econml/_cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Base classes for all CATE estimators."""

import abc
import inspect
import numpy as np
from functools import wraps
from copy import deepcopy
Expand Down Expand Up @@ -329,6 +330,11 @@ def _use_inference_method(self, name, *args, **kwargs):
def _defer_to_inference(m):
@wraps(m)
def call(self, *args, **kwargs):
# apply defaults before calling inference method
bound_args = inspect.signature(m).bind(self, *args, **kwargs)
bound_args.apply_defaults()
args = bound_args.args[1:] # remove self
kwargs = bound_args.kwargs
return self._use_inference_method(m.__name__, *args, **kwargs)
return call

Expand Down
96 changes: 96 additions & 0 deletions econml/federated_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) PyWhy contributors. All rights reserved.
# Licensed under the MIT License.

import numpy as np
from sklearn import clone

from econml.utilities import check_input_arrays
from ._cate_estimator import (LinearCateEstimator, TreatmentExpansionMixin,
StatsModelsCateEstimatorMixin, StatsModelsCateEstimatorDiscreteMixin)
from .dml import LinearDML
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete
from .sklearn_extensions.linear_model import StatsModelsLinearRegression
from typing import List

# TODO: This could be extended to also work with our sparse and 2SLS estimators,
# if we add an aggregate method to them
# Remember to update the docs if this changes


class FederatedEstimator(TreatmentExpansionMixin, LinearCateEstimator):
"""
A class for federated learning using LinearDML, LinearDRIV, and LinearDRLearner estimators.

Parameters
----------
estimators : list of LinearDML, LinearDRIV, or LinearDRLearner
List of estimators to aggregate (all of the same type).

Attributes
----------
estimators : list of LinearDML, LinearDRIV, or LinearDRLearner
List of estimators provided during initialization.

model_final_ : StatsModelsLinearRegression
The aggregated model obtained by aggregating models from `estimators`.

fitted_models_final : list of StatsModelsLinearRegression
The list of fitted models obtained by aggregating models from `estimators`.
"""

def __init__(self, estimators: List[LinearDML]):
self.estimators = estimators
dummy_est = clone(self.estimators[0], safe=False) # used to extract various attributes later
infs = [est._inference for est in self.estimators]
assert (
all(isinstance(inf, StatsModelsInference) for inf in infs) or
all(isinstance(inf, StatsModelsInferenceDiscrete) for inf in infs)
), "All estimators must use either StatsModelsInference or StatsModelsInferenceDiscrete"
cov_types = set(inf.cov_type for inf in infs)
assert len(cov_types) == 1, f"All estimators must use the same covariance type, got {cov_types}"
if isinstance(infs[0], StatsModelsInference):
inf = StatsModelsInference(cov_type=cov_types.pop())
cate_est_type = StatsModelsCateEstimatorMixin
self.model_final_ = StatsModelsLinearRegression.aggregate([est.model_final_ for est in self.estimators])
inf.model_final = self.model_final_
inf.bias_part_of_coef = dummy_est.bias_part_of_coef
else:
inf = StatsModelsInferenceDiscrete(cov_type=cov_types.pop())
cate_est_type = StatsModelsCateEstimatorDiscreteMixin
self.fitted_models_final = [
StatsModelsLinearRegression.aggregate(models)
for models in zip(*[est.fitted_models_final for est in self.estimators],
strict=True)]
inf.fitted_models_final = self.fitted_models_final

# mix in the appropriate inference class
self.__class__ = type("FederatedEstimator", (FederatedEstimator, cate_est_type), {})

# assign all of the attributes from the dummy estimator that would normally be assigned during fitting
# TODO: This seems hacky; is there a better abstraction to maintain these?
# This should also include bias_part_of_coef, model_final_, and fitted_models_final above
inf.featurizer = dummy_est.featurizer_ if hasattr(dummy_est, 'featurizer_') else None
inf._est = self
self._d_t = inf._d_t = dummy_est._d_t
self._d_y = inf._d_y = dummy_est._d_y
self.d_t = inf.d_t = inf._d_t[0] if inf._d_t else 1
self.d_y = inf.d_y = inf._d_y[0] if inf._d_y else 1
self._d_t_in = inf._d_t_in = dummy_est._d_t_in
self.fit_cate_intercept_ = inf.fit_cate_intercept = dummy_est.fit_cate_intercept
self._inference = inf

# Assign treatment expansion attributes
self.transformer = dummy_est.transformer

# Methods needed to implement the LinearCateEstimator interface

def const_marginal_effect(self, X=None):
X, = check_input_arrays(X)
return self._inference.const_marginal_effect_inference(X).point_estimate

def fit(self, *args, **kwargs):
raise NotImplementedError("FederatedEstimator does not support fit")

# Methods needed to implement the LinearFinalModelCateEstimatorMixin
def bias_part_of_coef(self):
return self._inference.bias_part_of_coef
Loading