FreeLing
4.0
|
00001 00002 // 00003 // Omlet - Open Machine Learning Enhanced Toolkit 00004 // 00005 // Copyright (C) 2014 TALP Research Center 00006 // Universitat Politecnica de Catalunya 00007 // 00008 // This file is part of the Omlet library 00009 // 00010 // The Omlet library is free software; you can redistribute it 00011 // and/or modify it under the terms of the GNU Affero General Public 00012 // License as published by the Free Software Foundation; either 00013 // version 3 of the License, or (at your option) any later version. 00014 // 00015 // This library is distributed in the hope that it will be useful, 00016 // but WITHOUT ANY WARRANTY; without even the implied warranty of 00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00018 // Affero General Public License for more details. 00019 // 00020 // You should have received a copy of the GNU Affero General Public 00021 // License along with this library; if not, write to the Free Software 00022 // Foundation, Inc., 51 Franklin St, 5th Floor, Boston, MA 02110-1301 USA 00023 // 00024 // contact: Lluis Padro (padro@lsi.upc.es) 00025 // TALP Research Center 00026 // despatx Omega.S112 - Campus Nord UPC 00027 // 08034 Barcelona. SPAIN 00028 // 00030 00031 // 00032 // Author: Xavier Carreras 00033 // 00034 00035 #ifndef _WEAKRULE 00036 #define _WEAKRULE 00037 00038 #include <iostream> 00039 #include <string> 00040 #include <vector> 00041 #include <map> 00042 #include <set> 00043 #include "freeling/tree.h" 00044 #include "freeling/safe_map.h" 00045 #include "freeling/omlet/dataset.h" 00046 00047 namespace freeling { 00048 00054 00055 class wr_params { 00056 public: 00057 int nlabels; 00058 double epsilon; 00059 00061 wr_params (int nl, double e); 00062 }; 00063 00068 00069 class weak_rule { 00070 00071 public: 00073 virtual ~weak_rule() {}; 00074 00077 virtual void classify(const example &i,double pred[]) = 0; 00078 00080 virtual void read_from_stream(std::wistream *is) = 0; 00081 virtual void write_to_stream(std::wostream *os) = 0; 00082 00084 virtual void learn(const dataset &ds, double &Z) = 0; 00085 00089 virtual double Zcalculus(const dataset &ds) const; 00090 }; 00091 00092 00093 00094 00100 00101 class wr_factory { 00102 00103 public: 00104 typedef weak_rule* (*WR_constructor)(wr_params*); 00105 static void initialize(); 00106 static bool register_weak_rule_type(const std::wstring &type, WR_constructor builder); 00107 static bool unregister_weak_rule_type(const std::wstring &type); 00108 static weak_rule* create_weak_rule(const std::wstring &type, wr_params *p); 00109 static weak_rule* create_weak_rule(const std::wstring &type, int nlabels); 00110 00111 private: 00112 // store weakrule types registered by user apps 00113 static safe_map<std::wstring, WR_constructor> wr_types; 00114 00115 }; 00116 00117 00122 00123 class mlDTree_params : public wr_params { 00124 public: 00126 int max_depth; 00127 00129 mlDTree_params (int nl, double e, int mxd); 00130 }; 00131 00135 00136 class dt_node { 00137 friend class mlDTree; 00138 //protected: 00139 public: 00140 int feature; // 0 when leaf 00141 std::vector<double> predictions; // empty when not leaf (when leaf, array of predictions, one for each class) 00142 00143 public: 00144 // empty constructor 00145 dt_node(); 00147 dt_node(int f); 00149 dt_node(int nl, double pr[]); 00151 dt_node(const dt_node &n); 00152 }; 00153 00158 00159 class mlDTree : public weak_rule { 00160 00161 private: 00162 // learning parameters for the specific type of weak rule 00163 mlDTree_params params; 00164 00165 // decision tree itself 00166 tree<dt_node> rule; 00167 // learning auxiliary list. 00168 std::set<int> used_features; 00169 00171 void classify (const example &i, double pred[], tree<dt_node>::iterator t); 00172 00174 void write_to_stream(tree<dt_node>::iterator t, std::wostream *os); 00175 tree<dt_node> read_dt(std::wistream *is); 00176 00178 tree<dt_node> learn (const dataset &ds, double &Z, int depth); 00179 00180 bool stopping_criterion(const dataset &ds, int depth); 00182 int best_feature(const dataset &ds, double *W); 00184 void Cprediction(int v, double *W, double result[]); 00188 double Zcalculus(double *W, int ndim); 00189 00191 mlDTree(const mlDTree &wr0); 00192 00193 public: 00194 00195 // Constructor 00196 mlDTree(mlDTree_params *p); 00197 00201 void classify(const example &i, double pred[]); 00202 00204 void write_to_stream(std::wostream *os); 00205 void read_from_stream(std::wistream *is); 00206 00208 void learn(const dataset &ds, double &Z); 00209 }; 00210 00211 } // namespace 00212 00213 #endif 00214