]> AND Private Git Repository - predictops.git/blobdiff - predictops/learn/learning.py
Logo AND Algorithmique Numérique Distribuée

Private GIT Repository
Reducing the computation time and adding holidays features
[predictops.git] / predictops / learn / learning.py
index 416450047e8cddda860da6be9296e11c17e0004c..9a5860afaed8657890140c6dbffa79073bbe6787 100644 (file)
@@ -15,8 +15,6 @@ class Learning:
         df = X
         df['cible'] = y
 
         df = X
         df['cible'] = y
 
-        print(df.head())
-
         train_val_set, test_set = train_test_split(df, test_size = 0.2, random_state = 42)
         train_set, val_set = train_test_split(train_val_set, test_size = 0.2, random_state = 42)
 
         train_val_set, test_set = train_test_split(df, test_size = 0.2, random_state = 42)
         train_set, val_set = train_test_split(train_val_set, test_size = 0.2, random_state = 42)
 
@@ -30,12 +28,13 @@ class Learning:
 
 
         if self._config['MODEL']['method'] == 'xgboost':
 
 
         if self._config['MODEL']['method'] == 'xgboost':
-            xgb_reg = xgboost.XGBRegressor(learning_rate = 0.01,
-                                                   max_depth = 10,
-                                                   random_state=42,
-                                                   n_estimators = 173,
-                                                   n_jobs=-1,
-                                                   objective = 'count:poisson')
+
+            xgb_reg = xgboost.XGBRegressor(learning_rate = self._config['HYPERPARAMETERS'].getfloat('learning_rate'),
+                                           max_depth     = self._config['HYPERPARAMETERS'].getint('max_depth'),
+                                           random_state  = self._config['HYPERPARAMETERS'].getint('random_state'),
+                                           n_estimators  = self._config['HYPERPARAMETERS'].getint('n_estimators'),
+                                           n_jobs        = self._config['HYPERPARAMETERS'].getint('n_jobs'),
+                                           objective     = 'count:poisson')
 
             xgb_reg.fit(X_train, y_train,
                         eval_set=[(X_val, y_val)],
 
             xgb_reg.fit(X_train, y_train,
                         eval_set=[(X_val, y_val)],