Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
"""Callbacks library."""
import collections
from operator import gt, lt
from typing import Any, Callable, Dict, List, Union

from .basic import _ConfigAliases, _log_info, _log_warning


class EarlyStopException(Exception):
"""Exception of early stopping."""

def __init__(self, best_iteration, best_score):
def __init__(self, best_iteration: int, best_score: float) -> None:
"""Create early stopping exception.

Parameters
Expand All @@ -35,7 +36,7 @@ def __init__(self, best_iteration, best_score):
"evaluation_result_list"])


def _format_eval_result(value, show_stdv=True):
def _format_eval_result(value: list, show_stdv: bool = True) -> str:
"""Format metric string."""
if len(value) == 4:
return '%s\'s %s: %g' % (value[0], value[1], value[2])
Expand All @@ -48,7 +49,7 @@ def _format_eval_result(value, show_stdv=True):
raise ValueError("Wrong metric value")


def print_evaluation(period=1, show_stdv=True):
def print_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
"""Create a callback that prints the evaluation results.

Parameters
Expand All @@ -63,15 +64,15 @@ def print_evaluation(period=1, show_stdv=True):
callback : function
The callback that prints the evaluation results every ``period`` iteration(s).
"""
def _callback(env):
def _callback(env: CallbackEnv) -> None:
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
_log_info('[%d]\t%s' % (env.iteration + 1, result))
_callback.order = 10
_callback.order = 10 # type: ignore
return _callback


def record_evaluation(eval_result):
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
"""Create a callback that records the evaluation history into ``eval_result``.

Parameters
Expand All @@ -88,21 +89,21 @@ def record_evaluation(eval_result):
raise TypeError('eval_result should be a dictionary')
eval_result.clear()

def _init(env):
def _init(env: CallbackEnv) -> None:
for data_name, eval_name, _, _ in env.evaluation_result_list:
eval_result.setdefault(data_name, collections.OrderedDict())
eval_result[data_name].setdefault(eval_name, [])

def _callback(env):
def _callback(env: CallbackEnv) -> None:
if not eval_result:
_init(env)
for data_name, eval_name, result, _ in env.evaluation_result_list:
eval_result[data_name][eval_name].append(result)
_callback.order = 20
_callback.order = 20 # type: ignore
return _callback


def reset_parameter(**kwargs):
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
"""Create a callback that resets the parameter after the first iteration.

.. note::
Expand All @@ -123,7 +124,7 @@ def reset_parameter(**kwargs):
callback : function
The callback that resets the parameter after the first iteration.
"""
def _callback(env):
def _callback(env: CallbackEnv) -> None:
new_parameters = {}
for key, value in kwargs.items():
if isinstance(value, list):
Expand All @@ -138,12 +139,12 @@ def _callback(env):
if new_parameters:
env.model.reset_parameter(new_parameters)
env.params.update(new_parameters)
_callback.before_iteration = True
_callback.order = 10
_callback.before_iteration = True # type: ignore
_callback.order = 10 # type: ignore
return _callback


def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable:
"""Create a callback that activates early stopping.

Activates early stopping.
Expand All @@ -170,12 +171,12 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
"""
best_score = []
best_iter = []
best_score_list = []
best_score_list: list = []
cmp_op = []
enabled = [True]
first_metric = ['']

def _init(env):
def _init(env: CallbackEnv) -> None:
enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
if not enabled[0]:
Expand All @@ -200,7 +201,7 @@ def _init(env):
best_score.append(float('inf'))
cmp_op.append(lt)

def _final_iteration_check(env, eval_name_splitted, i):
def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
if env.iteration == env.end_iteration - 1:
if verbose:
_log_info('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
Expand All @@ -209,7 +210,7 @@ def _final_iteration_check(env, eval_name_splitted, i):
_log_info("Evaluated only: {}".format(eval_name_splitted[-1]))
raise EarlyStopException(best_iter[i], best_score_list[i])

def _callback(env):
def _callback(env: CallbackEnv) -> None:
if not cmp_op:
_init(env)
if not enabled[0]:
Expand All @@ -236,5 +237,5 @@ def _callback(env):
_log_info("Evaluated only: {}".format(eval_name_splitted[-1]))
raise EarlyStopException(best_iter[i], best_score_list[i])
_final_iteration_check(env, eval_name_splitted, i)
_callback.order = 30
_callback.order = 30 # type: ignore
return _callback