OptunaLearner

class falcon.tabular.learners.OptunaLearner(task: str, model_class: Optional[Type] = None, n_trials: Optional[int] = None, **kwargs: Any)

OptunaLerner select the best hyperparameters for the given model using the Optuna Framework.

__init__(task: str, model_class: Optional[Type] = None, n_trials: Optional[int] = None, **kwargs: Any) None
Parameters
  • task (str) – ‘tabular_classification’ or ‘tabular_regression’

  • model_class (Optional[Type], optional) – the class of the model to train, by default None; if None, HistGradientBoosting

  • n_trials (Optional[int], optional) – number of optimization trials, minimum 20, by default None; if None, the number of trials is chosen dynamically based on the dataset size

fit(X: ndarray[Any, dtype[float32]], y: ndarray[Any, dtype[float32]], *args: Any, **kwargs: Any) None

Fits the model by choosing the best hyperparameters and training the final model using them. For classification tasks, the dataset will be balanced by upsampling the minority class(es).

Parameters
  • X (Float32Array) – features

  • y (Float32Array) – targets

fit_pipe(X: ndarray[Any, dtype[float32]], y: ndarray[Any, dtype[float32]], *args: Any, **kwargs: Any) None

Equivalent to .fit(X, y)

Parameters
  • X (Float32Array) – features

  • y (Float32Array) – targets

forward(X: ndarray[Any, dtype[float32]], *args: Any, **kwargs: Any) Union[ndarray[Any, dtype[float32]], ndarray[Any, dtype[int64]]]

Equivalent to .predict(X)

Parameters

X (Float32Array) – features

Returns

predictions

Return type

Union[Float32Array, Int64Array]

get_input_type() Type
Returns

Float32Array

Return type

Type

get_output_type() Type
Returns

Float32Array for regression, Int64Array for classification

Return type

Type

predict(X: ndarray[Any, dtype[float32]], *args: Any, **kwargs: Any) Union[ndarray[Any, dtype[float32]], ndarray[Any, dtype[int64]]]
Parameters

X (npt.NDArray) – features

Returns

predictions

Return type

npt.NDArray

to_onnx() SerializedModelRepr

Serializes the underlying model to onnx by calling its .to_onnx() method.

Return type

SerializedModelRepr