ml_algo 13.3.7
ml_algo: ^13.3.7 copied to clipboard
Machine learning algorithms written in native dart
Machine learning algorithms with dart #
What is the ml_algo for? #
The main purpose of the library - to give developers, interested both in Dart language and data science, native Dart implementation of machine learning algorithms. This library targeted to dart vm, so, to get smoothest experience with the lib, please, do not use it in a browser.
The library's content #
-
Model selection
- CrossValidator. Factory, that creates instances of cross validators. Cross validation allows researchers to fit different hyperparameters of machine learning algorithms, assessing prediction quality on different parts of a dataset.
-
Classification algorithms
-
LogisticRegressor. A class, that performs linear binary classification of data. To use this kind of classifier your data have to be linearly separable.
-
SoftmaxRegressor. A class, that performs linear multiclass classification of data. To use this kind of classifier your data have to be linearly separable.
-
DecisionTreeClassifier A class, that performs classification, using decision trees. May work with data with non-linear patterns.
-
KnnClassifier A class, that performs classification, using
k nearest neighbours algorithm
- it makes prediction basing on firstk
closest observations to the given one.
-
-
Regression algorithms
-
LinearRegressor. A class, that finds a linear pattern in training data and predicts a real numbers depending on the pattern.
-
KnnRegressor A class, that makes prediction for each new observation basing on first
k
closest observations from training data. It may catch non-linear pattern of the data.
-
Examples #
Logistic regression #
Let's classify records from well-known dataset - Pima Indians Diabets Database via Logistic regressor
Import all necessary packages. First, it's needed to ensure, if you have ml_preprocessing
and ml_dataframe
package
in your dependencies:
dependencies:
ml_dataframe: ^0.0.11
ml_preprocessing: ^5.0.1
We need these repos to parse raw data in order to use it farther. For more details, please, visit ml_preprocessing repository page.
import 'dart:async';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_preprocessing/ml_preprocessing.dart';
Download dataset from Pima Indians Diabets Database and read it (of course, you should provide a proper path to your downloaded file):
final samples = await fromCsv('datasets/pima_indians_diabetes_database.csv', headerExists: true);
Data in this file is represented by 768 records and 8 features. 9th column is a label column, it contains either 0 or 1
on each row. This column is our target - we should predict a class label for each observation. The column's name is
class variable (0 or 1)
. Let's store it:
final targetColumnName = 'class variable (0 or 1)';
Then, we should create an instance of CrossValidator
class for fitting hyperparameters
of our model. We should pass training data (our samples
variable), a list of target column names (in our case it's
just a name stored in targetColumnName
variable) and a number of folds into CrossValidator constructor.
final validator = CrossValidator.KFold(samples, [targetColumnName], numberOfFolds: 5);
All are set, so, we can do our classification.
Evaluate our model via accuracy metric:
final accuracy = validator.evaluate((samples, targetNames) =>
LogisticRegressor(
samples,
targetNames[0], // remember, we provided a list of just a single name
optimizerType: LinearOptimizerType.gradient,
initialLearningRate: .8,
iterationsLimit: 500,
batchSize: samples.rows.length,
fitIntercept: true,
interceptScale: .1,
learningRateType: LearningRateType.constant
), MetricType.accuracy);
Let's print the score:
print('accuracy on classification: ${accuracy.toStringAsFixed(2)}');
We will see something like this:
acuracy on classification: 0.77
All the code above all together:
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_preprocessing/ml_preprocessing.dart';
Future main() async {
final samples = await fromCsv('datasets/pima_indians_diabetes_database.csv', headerExists: true);
final targetColumnName = 'class variable (0 or 1)';
final validator = CrossValidator.KFold(samples, [targetColumnName], numberOfFolds: 5);
final accuracy = validator.evaluate((samples, targetNames) =>
LogisticRegressor(
samples,
targetNames[0], // remember, we provide a list of just a single name
optimizerType: LinearOptimizerType.gradient,
initialLearningRate: .8,
iterationsLimit: 500,
batchSize: 768,
fitIntercept: true,
interceptScale: .1,
learningRateType: LearningRateType.constant
), MetricType.accuracy);
print('accuracy on classification: ${accuracy.toStringFixed(2)}');
}
K nearest neighbour regression #
Let's do some prediction with a well-known non-parametric regression algorithm - k nearest neighbours. Let's take a state of the art dataset - boston housing.
As usual, import all necessary packages
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_preprocessing/ml_preprocessing.dart';
and download and read the data
final samples = await fromCsv('lib/_datasets/housing.csv',
headerExists: false,
fieldDelimiter: ' ',
);
As you can see, the dataset is headless, that means, that there is no a descriptive line in the beginning of the file. So, we may use an autogenerated header in order to point, from what column we should take our target labels:
print(samples.header);
It will output the following:
(col_0, col_1, col_2, col_3, col_4, col_5, col_6, col_7, col_8, col_9, col_10, col_11, col_12, col_13)
Our target is col_13
. Let's store it:
final targetColumnName = 'col_13';
Let's create a cross-validator instance:
final validator = CrossValidator.KFold(samples, [targetColumnName], numberOfFolds: 5);
Let the k
parameter be equal to 4
.
Assess a knn regressor with the chosen k
value using MAPE metric
final error = validator.evaluate((samples, targetNames) =>
KnnRegressor(samples, targetNames[0], 4), MetricType.mape);
Let's print our error
print('MAPE error on k-fold validation: ${error.toStringAsFixed(2)}%'); // it yields approx. 6.18
Contacts #
If you have questions, feel free to write me on