OptunaLearner
- class falcon.tabular.learners.OptunaLearner(task: str, model_class: Optional[Type] = None, n_trials: Optional[int] = None, **kwargs: Any)
OptunaLerner select the best hyperparameters for the given model using the Optuna Framework.
- __init__(task: str, model_class: Optional[Type] = None, n_trials: Optional[int] = None, **kwargs: Any) None
- Parameters
task (str) – ‘tabular_classification’ or ‘tabular_regression’
model_class (Optional[Type], optional) – the class of the model to train, by default None; if None, HistGradientBoosting
n_trials (Optional[int], optional) – number of optimization trials, minimum 20, by default None; if None, the number of trials is chosen dynamically based on the dataset size
- fit(X: ndarray[Any, dtype[float32]], y: ndarray[Any, dtype[float32]], *args: Any, **kwargs: Any) None
Fits the model by choosing the best hyperparameters and training the final model using them. For classification tasks, the dataset will be balanced by upsampling the minority class(es).
- Parameters
X (Float32Array) – features
y (Float32Array) – targets
- fit_pipe(X: ndarray[Any, dtype[float32]], y: ndarray[Any, dtype[float32]], *args: Any, **kwargs: Any) None
Equivalent to .fit(X, y)
- Parameters
X (Float32Array) – features
y (Float32Array) – targets
- forward(X: ndarray[Any, dtype[float32]], *args: Any, **kwargs: Any) Union[ndarray[Any, dtype[float32]], ndarray[Any, dtype[int64]]]
Equivalent to .predict(X)
- Parameters
X (Float32Array) – features
- Returns
predictions
- Return type
Union[Float32Array, Int64Array]
- get_input_type() Type
- Returns
Float32Array
- Return type
Type
- get_output_type() Type
- Returns
Float32Array for regression, Int64Array for classification
- Return type
Type
- predict(X: ndarray[Any, dtype[float32]], *args: Any, **kwargs: Any) Union[ndarray[Any, dtype[float32]], ndarray[Any, dtype[int64]]]
- Parameters
X (npt.NDArray) – features
- Returns
predictions
- Return type
npt.NDArray
- to_onnx() SerializedModelRepr
Serializes the underlying model to onnx by calling its .to_onnx() method.
- Return type
SerializedModelRepr