Skip to content
26 changes: 22 additions & 4 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,22 @@ def _check_sample_weight(sample_weight, X, dtype=None):
_LGBMComputeSampleWeight = compute_sample_weight
except ImportError:
SKLEARN_INSTALLED = False
_LGBMModelBase = object
_LGBMClassifierBase = object
_LGBMRegressorBase = object

class _LGBMModelBase: # type: ignore
"""Dummy class for sklearn.base.BaseEstimator."""

pass

class _LGBMClassifierBase: # type: ignore
"""Dummy class for sklearn.base.ClassifierMixin."""

pass

class _LGBMRegressorBase: # type: ignore
"""Dummy class for sklearn.base.RegressorMixin."""

pass

_LGBMLabelEncoder = None
LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
Expand All @@ -118,11 +131,16 @@ def _check_sample_weight(sample_weight, X, dtype=None):
DASK_INSTALLED = True
except ImportError:
DASK_INSTALLED = False

delayed = None
Client = object
default_client = None
wait = None

class Client: # type: ignore
"""Dummy class for dask.distributed.Client."""

pass

class dask_Array: # type: ignore
"""Dummy class for dask.array.Array."""

Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def feature_name_(self):
return self._Booster.feature_name()


class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
"""LightGBM regressor."""

def fit(self, X, y,
Expand All @@ -830,7 +830,7 @@ def fit(self, X, y,
+ _base_doc[_base_doc.find('eval_metric :'):])


class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
"""LightGBM classifier."""

def fit(self, X, y,
Expand Down