Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4b1e3db
Fix lint issues uncovered by pycodestyle 2.11
kbattocchi Aug 2, 2023
96c8507
init
kgao Sep 1, 2023
5b08000
update class name
kgao Sep 6, 2023
7c23d5e
update method
kgao Sep 6, 2023
f0940aa
update matrix A, B calling reference
kgao Sep 15, 2023
edd9e94
cleanup wrapper class
kgao Sep 15, 2023
cfb3e34
add local test files
kgao Sep 20, 2023
0dfa1e3
Added federated learner test
kbattocchi Sep 22, 2023
d97793d
Fix aggregation logic
kbattocchi Sep 25, 2023
fab3531
update the doc
kgao Oct 4, 2023
e652b08
update doc for federated learning
kgao Oct 4, 2023
6c7e485
cleanup branch: test, file structure, docstring and linting
kgao Oct 6, 2023
909eb5c
cleanup federated learning doc
kgao Oct 17, 2023
e187f13
cleaning the doc for federated learning
kgao Oct 20, 2023
5eb64a5
Rename moments
kbattocchi Oct 28, 2023
e00cdd2
Make FederatedEstimator a CateEstimator
kbattocchi Oct 30, 2023
f943966
add arg to allow missing values in W and sometimes X (#791)
fverac Sep 29, 2023
1c0bcf8
Drop support for sklearn<1.0
kbattocchi Mar 28, 2023
99cc11a
Support direct covariance fitting in DRIV
kbattocchi Jun 21, 2023
d90972d
Enable sklearn 1.3
kbattocchi Aug 4, 2023
c46eac8
Ensure groups work with DRIV, DMLIV
kbattocchi Aug 14, 2023
53a4b2a
Update __init__.py to reflect current structure
kbattocchi Oct 6, 2023
9dbf21d
Make changes to support dowhy 0.10.1 in tests
kbattocchi Oct 10, 2023
bef9b4c
Allow newer shap, matlab, and seaborn versions
kbattocchi Oct 10, 2023
db606ee
Make minor CI improvements
kbattocchi Aug 14, 2023
de0c645
Save notebook outputs during CI
kbattocchi Oct 11, 2023
a6691ff
Remove legacy assertWarns hack
kbattocchi Oct 21, 2023
6291f4a
Scaling ortholearners using Ray (#800)
v-shaal Oct 27, 2023
7935df5
Improve FederatedEstimator docs
kbattocchi Oct 31, 2023
ef04ce7
Merge branch 'main' into kgao/federal-learning
kbattocchi Oct 31, 2023
650f28a
Fixup docs
kbattocchi Oct 31, 2023
85c8458
Fixup docs
kbattocchi Oct 31, 2023
a91d8d2
Fixup federated learning doctest
kbattocchi Oct 31, 2023
1e1f67d
Reduce featurization degree
kbattocchi Nov 1, 2023
d9444ba
Add tests for DRLearner
kbattocchi Nov 1, 2023
16b8860
Fixup tests
kbattocchi Nov 1, 2023
c04df65
Fix test failures
kbattocchi Nov 1, 2023
e4415c8
Fix test
kbattocchi Nov 1, 2023
5205791
Fix tests
kbattocchi Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 353 additions & 0 deletions doc/spec/federated_learning.rst

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions doc/spec/spec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ EconML User Guide
references
faq
community
federated_learning

.. todo::
benchmark
Expand Down
1 change: 1 addition & 0 deletions econml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'tree',
'utilities',
'dowhy',
'federated_learning'
'__version__']

from ._version import __version__
4 changes: 2 additions & 2 deletions econml/_cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from .inference import BootstrapInference
from .utilities import (tensordot, ndim, reshape, shape, parse_final_model_params, get_feature_names_or_default,
inverse_onehot, Summary, get_input_columns, check_input_arrays, jacify_featurizer)
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, NormalInferenceResults, GenericSingleTreatmentModelFinalInference,\
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference, \
LinearModelFinalInferenceDiscrete, NormalInferenceResults, GenericSingleTreatmentModelFinalInference, \
GenericModelFinalInferenceDiscrete
from ._shap import _shap_explain_cme, _shap_explain_joint_linear_model_cate
from .dowhy import DoWhyWrapper
Expand Down
29 changes: 29 additions & 0 deletions econml/federated_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) PyWhy contributors. All rights reserved.
# Licensed under the MIT License.

import numpy as np
from econml.dml import LinearDML
from econml.sklearn_extensions.linear_model import StatsModelsLinearRegression
from typing import List


class FederatedEstimator:
"""
A class for federated learning using LinearDML estimators.

Parameters
----------
estimators : list of LinearDML
List of LinearDML estimators to aggregate.

Attributes
----------
estimators : list of LinearDML
List of LinearDML estimators provided during initialization.

model_final_ : StatsModelsLinearRegression
The aggregated model obtained by aggregating models from `estimators`.
"""
def __init__(self, estimators: List[LinearDML]):
self.estimators = estimators
self.model_final_ = StatsModelsLinearRegression.aggregate([est.model_final_ for est in estimators])
6 changes: 3 additions & 3 deletions econml/inference/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,11 +1067,11 @@ def conf_int(self, alpha=0.05):
if self.stderr is None:
raise AttributeError("Only point estimates are available!")
if np.isscalar(self.point_estimate):
return _safe_norm_ppf(alpha / 2, loc=self.point_estimate, scale=self.stderr),\
return _safe_norm_ppf(alpha / 2, loc=self.point_estimate, scale=self.stderr), \
_safe_norm_ppf(1 - alpha / 2, loc=self.point_estimate, scale=self.stderr)
else:
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
for p, err in zip(self.point_estimate, self.stderr)]),\
for p, err in zip(self.point_estimate, self.stderr)]), \
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
for p, err in zip(self.point_estimate, self.stderr)])

Expand Down Expand Up @@ -1403,7 +1403,7 @@ def conf_int_mean(self, *, alpha=None):
_safe_norm_ppf(1 - alpha / 2, loc=mean_point, scale=stderr_mean))
else:
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
for p, err in zip(mean_point, stderr_mean)]),\
for p, err in zip(mean_point, stderr_mean)]), \
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
for p, err in zip(mean_point, stderr_mean)])

Expand Down
12 changes: 6 additions & 6 deletions econml/orf/_ortho_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,12 @@ def _pw_effect_inputs(self, X_single, stderr=False):
slice_weights_one, slice_weights_two = self._get_weights(X_single, tree_slice=slice_it)
slice_weights_list.append((slice_weights_one[mask_w1], slice_weights_two[mask_w2]))
W_none = self.W_one is None
return np.concatenate((self.Y_one[mask_w1], self.Y_two[mask_w2])),\
np.concatenate((self.T_one[mask_w1], self.T_two[mask_w2])),\
np.concatenate((self.X_one[mask_w1], self.X_two[mask_w2])),\
return np.concatenate((self.Y_one[mask_w1], self.Y_two[mask_w2])), \
np.concatenate((self.T_one[mask_w1], self.T_two[mask_w2])), \
np.concatenate((self.X_one[mask_w1], self.X_two[mask_w2])), \
np.concatenate((self.W_one[mask_w1], self.W_two[mask_w2])
) if not W_none else None,\
w_nonzero,\
) if not W_none else None, \
w_nonzero, \
split_inds, slice_weights_list

def _get_inference_options(self):
Expand Down Expand Up @@ -1255,7 +1255,7 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.05):
param_upper = [param + np.apply_along_axis(lambda s: norm.ppf(upper, scale=s), 0, np.sqrt(np.diag(cov_mat)))
for (param, cov_mat) in params_and_cov]
param_lower, param_upper = np.asarray(param_lower), np.asarray(param_upper)
return param_lower.reshape((-1,) + self._estimator._d_y + self._estimator._d_t),\
return param_lower.reshape((-1,) + self._estimator._d_y + self._estimator._d_t), \
param_upper.reshape((-1,) + self._estimator._d_y + self._estimator._d_t)

def const_marginal_effect_inference(self, X=None):
Expand Down
130 changes: 116 additions & 14 deletions econml/sklearn_extensions/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from sklearn.linear_model import lasso_path
"""

from __future__ import annotations # needed to allow type signature to refer to containing type

import numbers
import numpy as np
import warnings
Expand All @@ -36,9 +38,11 @@
from statsmodels.api import RLM
import statsmodels
from joblib import Parallel, delayed

from typing import List

# TODO: once we drop support for sklearn < 1.0, we can remove this


def _add_normalize(to_wrap):
"""
Add a fictitious "normalize" argument to linear model initializer signatures.
Expand Down Expand Up @@ -1657,7 +1661,7 @@ def coef__interval(self, alpha=0.05):
The lower and upper bounds of the confidence interval of the coefficients
"""
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
for p, err in zip(self.coef_, self.coef_stderr_)]),\
for p, err in zip(self.coef_, self.coef_stderr_)]), \
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
for p, err in zip(self.coef_, self.coef_stderr_)])

Expand All @@ -1677,15 +1681,15 @@ def intercept__interval(self, alpha=0.05):
The lower and upper bounds of the confidence interval of the intercept(s)
"""
if not self.fit_intercept:
return (0 if self._n_out == 0 else np.zeros(self._n_out)),\
return (0 if self._n_out == 0 else np.zeros(self._n_out)), \
(0 if self._n_out == 0 else np.zeros(self._n_out))

if self._n_out == 0:
return _safe_norm_ppf(alpha / 2, loc=self.intercept_, scale=self.intercept_stderr_),\
return _safe_norm_ppf(alpha / 2, loc=self.intercept_, scale=self.intercept_stderr_), \
_safe_norm_ppf(1 - alpha / 2, loc=self.intercept_, scale=self.intercept_stderr_)
else:
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
for p, err in zip(self.intercept_, self.intercept_stderr_)]),\
for p, err in zip(self.intercept_, self.intercept_stderr_)]), \
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
for p, err in zip(self.intercept_, self.intercept_stderr_)])

Expand All @@ -1707,7 +1711,7 @@ def predict_interval(self, X, alpha=0.05):
The lower and upper bounds of the confidence intervals of the predicted mean outcomes
"""
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
for p, err in zip(self.predict(X), self.prediction_stderr(X))]),\
for p, err in zip(self.predict(X), self.prediction_stderr(X))]), \
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
for p, err in zip(self.predict(X), self.prediction_stderr(X))])

Expand All @@ -1730,7 +1734,6 @@ class StatsModelsLinearRegression(_StatsModelsWrapper):
def __init__(self, fit_intercept=True, cov_type="HC0"):
self.cov_type = cov_type
self.fit_intercept = fit_intercept
return

def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
"""Check dimensions and other assertions."""
Expand Down Expand Up @@ -1835,22 +1838,46 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
wy = y * np.sqrt(freq_weight).reshape(-1, 1)

param, _, rank, _ = np.linalg.lstsq(WX, wy, rcond=None)
n_obs = np.sum(freq_weight)
self._n_obs = n_obs

if rank < param.shape[0]:
warnings.warn("Co-variance matrix is underdetermined. Inference will be invalid!")
df = param.shape[0]

sigma_inv = np.linalg.pinv(np.matmul(WX.T, WX))
self._param = param
var_i = sample_var + (y - np.matmul(X, param))**2
n_obs = np.sum(freq_weight)
df = len(param) if self._n_out == 0 else param.shape[0]
if rank < df:
warnings.warn("Co-variance matrix is underdetermined. Inference will be invalid!")

if n_obs <= df:
warnings.warn("Number of observations <= than number of parameters. Using biased variance calculation!")
correction = 1
else:
correction = (n_obs / (n_obs - df))

# For aggregation calculations, always treat wy as an array so that einsum expressions don't need to change
# We'll collapse results back down afterwards if necessary
wy = wy.reshape(-1, 1) if y.ndim < 2 else wy
sv = sample_var.reshape(-1, 1) if y.ndim < 2 else sample_var
self.A = np.matmul(WX.T, WX)
self.B = np.matmul(WX.T, wy)

# for federation, we need to store these 5 arrays when using heteroskedasticity-robust inference
if (self.cov_type in ['HC0', 'HC1']):
# y dimension is always first in the output when present so that broadcasting works correctly
self.C = np.einsum('nw,nx,ny,ny->ywx', X, X, wy, wy)
self.D = np.einsum('nv,nw,nx,ny->yvwx', X, X, WX, wy)
self.E = np.einsum('nu,nv,nw,nx->uvwx', X, X, WX, WX)
self.sample_var = np.einsum('nw,nx,ny->ywx', WX, WX, sv)
elif (self.cov_type is None) or (self.cov_type == 'nonrobust'):
self.C = np.einsum('ny,ny->y', wy, wy)
self.D = np.einsum('nx,ny->yx', WX, wy)
self.E = np.einsum('nw,nx->wx', WX, WX)
self.sample_var = np.average(sv, weights=freq_weight, axis=0) * n_obs

sigma_inv = np.linalg.pinv(self.A)

var_i = sample_var + (y - np.matmul(X, param))**2

self._param = param

if (self.cov_type is None) or (self.cov_type == 'nonrobust'):
if y.ndim < 2:
self._var = correction * np.average(var_i, weights=freq_weight) * sigma_inv
Expand Down Expand Up @@ -1879,8 +1906,83 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1.")

self._param_var = np.array(self._var)

return self

@staticmethod
def aggregate(models: List[StatsModelsLinearRegression]):
"""
Aggregate multiple models into one.

Parameters
----------
models : list of StatsModelsLinearRegression
The models to aggregate

Returns
-------
agg_model : StatsModelsLinearRegression
The aggregated model
"""
if len(models) == 0:
raise ValueError("Must aggregate at least one model!")
cov_types = set([model.cov_type for model in models])
fit_intercepts = set([model.fit_intercept for model in models])
_n_outs = set([model._n_out for model in models])
assert len(cov_types) == 1, "All models must have the same cov_type!"
assert len(fit_intercepts) == 1, "All models must have the same fit_intercept!"
assert len(_n_outs) == 1, "All models must have the same number of outcomes!"
agg_model = StatsModelsLinearRegression(cov_type=models[0].cov_type, fit_intercept=models[0].fit_intercept)

agg_model._n_out = models[0]._n_out

A = np.sum([model.A for model in models], axis=0)
B = np.sum([model.B for model in models], axis=0)
C = np.sum([model.C for model in models], axis=0)
D = np.sum([model.D for model in models], axis=0)
E = np.sum([model.E for model in models], axis=0)

sample_var = np.sum([model.sample_var for model in models], axis=0)
n_obs = np.sum([model._n_obs for model in models], axis=0)

sigma_inv = np.linalg.pinv(A)
param = sigma_inv @ B
df = np.shape(param)[0]

agg_model._param = param if agg_model._n_out > 0 else param.squeeze(1)

if n_obs <= df:
warnings.warn("Number of observations <= than number of parameters. Using biased variance calculation!")
correction = 1
elif agg_model.cov_type == 'HC0':
correction = 1
else: # both HC1 and nonrobust use the same correction factor
correction = (n_obs / (n_obs - df))

if agg_model.cov_type in ['HC0', 'HC1']:
weighted_sigma = C - 2 * np.einsum('yvwx,vy->ywx', D, param) + \
np.einsum('uvwx,uy,vy->ywx', E, param, param) + sample_var
if agg_model._n_out == 0:
agg_model._var = correction * np.matmul(sigma_inv, np.matmul(weighted_sigma.squeeze(0), sigma_inv))
else:
agg_model._var = [correction * np.matmul(sigma_inv, np.matmul(ws, sigma_inv)) for ws in weighted_sigma]

else:
assert agg_model.cov_type == 'nonrobust' or agg_model.cov_type is None
sigma = C - 2 * np.einsum('yx,xy->y', D, param) + np.einsum('wx,wy,xy->y', E, param, param)
var_i = (sample_var + sigma) / n_obs
if agg_model._n_out == 0:
agg_model._var = correction * var_i * sigma_inv
else:
agg_model._var = [correction * var * sigma_inv for var in var_i]

agg_model._param_var = np.array(agg_model._var)

(agg_model.A, agg_model.B, agg_model.C, agg_model.D, agg_model.E,
agg_model.sample_var, agg_model._n_obs) = A, B, C, D, E, sample_var, n_obs

return agg_model


class StatsModelsRLM(_StatsModelsWrapper):
"""
Expand Down
Loading