Working with Scikit-Learn
Contents
Working with Scikit-Learn¶
This guide goes through how to use this package with the Scikit-Learn package.
Load the train and test datasets¶
We’ll first get the train and test splits for the musk
dataset (completely unrelated to Elon Musk).
from tabben.datasets import OpenTabularDataset
train = OpenTabularDataset('./temp', 'musk') # train split by default
test = OpenTabularDataset('./temp', 'musk', split='test') # should only be used ONCE!
print(f'The {train.name} dataset is a {train.task} task with {train.num_classes} classes.')
X_fulltrain, y_fulltrain = train.numpy()
In order to tune some hyperparameters, we’ll need our own validation split (not the test set). We’ll do an 80-20 split and stratify on the class.
from sklearn.model_selection import train_test_split
X_train, X_valid, y_train, y_valid = train_test_split(
X_fulltrain, y_fulltrain,
train_size=0.8,
stratify=y_fulltrain
)
Create and train a model¶
Next, we’ll create a \(k\)-Nearest Neighbors model and train it on our train split.
from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier()
model.fit(X_train, y_train)
And we’ll evaluate it on our validation set, using a simple accuracy metric.
model.score(X_valid, y_valid)
In a larger data processing pipeline¶
However, it might be the case that we want to use a sklearn pipeline to do some data preprocessing like feature normalization, one-hot encoding, etc. or explore the effect of, say, turning continuous attributes into binary ones.
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import Binarizer, StandardScaler
pipeline = make_pipeline(
StandardScaler(with_std=False),
Binarizer(),
KNeighborsClassifier(),
)
pipeline.fit(X_train, y_train)
pipeline.score(X_valid, y_valid)
This code was last run using the following package versions (if you’re looking at the webpage which doesn’t have the output, see the notebook for versions):
from importlib.metadata import version
packages = ['scikit-learn', 'tabben']
for pkg in packages:
print(f'{pkg}: {version(pkg)}')