diff --git a/econml/dml/causal_forest.py b/econml/dml/causal_forest.py index e54c6a845..ab353e9fc 100644 --- a/econml/dml/causal_forest.py +++ b/econml/dml/causal_forest.py @@ -794,6 +794,7 @@ def tune(self, Y, T, *, X=None, W=None, est.inference = False scorer = RScorer(model_y=est.model_y, model_t=est.model_t, + discrete_outcome=est.discrete_outcome, discrete_treatment=est.discrete_treatment, categories=est.categories, cv=est.cv, mc_iters=est.mc_iters, mc_agg=est.mc_agg, random_state=est.random_state) diff --git a/econml/tests/test_dml.py b/econml/tests/test_dml.py index 5e52f4cda..9e6806192 100644 --- a/econml/tests/test_dml.py +++ b/econml/tests/test_dml.py @@ -1300,3 +1300,22 @@ def test_treatment_names(self): expected_prefix = str(new_treatment_name[0]) if new_treatment_name is not None else t_name assert (est.cate_treatment_names(new_treatment_name) == [ expected_prefix + postfix for postfix in postfixes]) + + def test_causal_forest_tune_with_discrete_outcome_and_treatment(self): + np.random.seed(1234) + n = 1000 + treatment = np.repeat([0, 1], n // 2) + covariate = np.resize([0, 1], n) + outcome = ((treatment == 1) & (covariate == 1)).astype(int) + X = covariate.reshape(-1, 1) + Y = outcome + T = treatment + + est = CausalForestDML( + model_y=GradientBoostingClassifier(), + model_t=GradientBoostingClassifier(), + discrete_outcome=True, + discrete_treatment=True + ) + est.tune(Y=Y, T=T, X=X) + est.fit(Y=Y, T=T, X=X)