]> AND Private Git Repository - predictops.git/blob - predictops/learn/learning.py
Logo AND Algorithmique Numérique Distribuée

Private GIT Repository
Reducing the computation time and adding holidays features
[predictops.git] / predictops / learn / learning.py
1 from configparser import ConfigParser
2 from math import sqrt
3 from sklearn.metrics import mean_squared_error, mean_absolute_error
4 from sklearn.model_selection import train_test_split
5
6 import xgboost
7
8 class Learning:
9
10     def __init__(self, config_file = None,
11                  X = None, y = None):
12         self._config = ConfigParser()
13         self._config.read(config_file)
14
15         df = X
16         df['cible'] = y
17
18         train_val_set, test_set = train_test_split(df, test_size = 0.2, random_state = 42)
19         train_set, val_set = train_test_split(train_val_set, test_size = 0.2, random_state = 42)
20
21         X_test = test_set.drop('cible', axis = 1)
22         y_test = test_set['cible'].copy()
23
24         X_train = train_set.drop('cible', axis=1)
25         y_train = train_set['cible'].copy()
26         X_val = val_set.drop('cible', axis=1)
27         y_val = val_set['cible'].copy()
28
29
30         if self._config['MODEL']['method'] == 'xgboost':
31
32             xgb_reg = xgboost.XGBRegressor(learning_rate = self._config['HYPERPARAMETERS'].getfloat('learning_rate'),
33                                            max_depth     = self._config['HYPERPARAMETERS'].getint('max_depth'),
34                                            random_state  = self._config['HYPERPARAMETERS'].getint('random_state'),
35                                            n_estimators  = self._config['HYPERPARAMETERS'].getint('n_estimators'),
36                                            n_jobs        = self._config['HYPERPARAMETERS'].getint('n_jobs'),
37                                            objective     = 'count:poisson')
38
39             xgb_reg.fit(X_train, y_train,
40                         eval_set=[(X_val, y_val)],
41                         early_stopping_rounds=10)
42
43             y_test_pred = xgb_reg.predict(X_test)
44             print(sqrt(mean_squared_error(y_test_pred, y_test)), mean_absolute_error(y_test_pred,y_test))