pktools 2.6.7
Processing Kernel for geospatial data
ConfusionMatrix.h
1/**********************************************************************
2ConfusionMatrix.h: class for (classification accuracy) confusion matrix
3Copyright (C) 2008-2012 Pieter Kempeneers
4
5This file is part of pktools
6
7pktools is free software: you can redistribute it and/or modify
8it under the terms of the GNU General Public License as published by
9the Free Software Foundation, either version 3 of the License, or
10(at your option) any later version.
11
12pktools is distributed in the hope that it will be useful,
13but WITHOUT ANY WARRANTY; without even the implied warranty of
14MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15GNU General Public License for more details.
16
17You should have received a copy of the GNU General Public License
18along with pktools. If not, see <http://www.gnu.org/licenses/>.
19***********************************************************************/
20#ifndef _CONFUSIONMATRIX_H_
21#define _CONFUSIONMATRIX_H_
22
23#include <sstream>
24#include <vector>
25#include "base/Vector2d.h"
26#include "base/Optionpk.h"
27
28namespace confusionmatrix
29{
30 enum CM_FORMAT { ASCII = 0, LATEX = 1, HTML = 2 };
31
33public:
35 ConfusionMatrix(short nclass);
36 ConfusionMatrix(const std::vector<std::string>& classNames);
38 ConfusionMatrix& operator=(const ConfusionMatrix& cm);
39 short size() const {return m_results.size();};
40 void resize(short nclass);
41 void setClassNames(const std::vector<std::string>& classNames, bool doSort=false);
42 void pushBackClassName(const std::string& className, bool doSort=false);
43 void setResults(const Vector2d<double>& theResults);
44 void setResult(const std::string& theRef, const std::string& theClass, double theResult);
45 void incrementResult(const std::string& theRef, const std::string& theClass, double theIncrement);
46 void clearResults();
47 double nReference(const std::string& theRef) const;
48 double nReference() const;
49 double nClassified(const std::string& theRef) const;
50 int nClasses() const {return m_classes.size();};
51 std::string getClass(int iclass) const {assert(iclass>=0);assert(iclass<m_classes.size());return m_classes[iclass];};
52 int getClassIndex(std::string className) const {
53 int index=0;
54 for(index=0;index<m_classes.size();++index){
55 if(m_classes[index]==className)
56 break;
57 }
58 if(index>=m_classes.size())
59 index=-1;
60 return index;
61 // int index=distance(m_classes.begin(),find(m_classes.begin(),m_classes.end(),className));
62 // assert(index>=0);
63 // if(index<m_results.size())
64 // return(index);
65 // else
66 // return(-1);
67 }
68 std::vector<std::string> getClassNames() const {return m_classes;};
70 double pa(const std::string& theClass, double* se95=NULL) const;
71 double ua(const std::string& theClass, double* se95=NULL) const;
72 double oa(double* se95=NULL) const;
73 int pa_pct(const std::string& theClass, double* se95=NULL) const;
74 int ua_pct(const std::string& theClass, double* se95=NULL) const;
75 int oa_pct(double* se95=NULL) const;
76 double kappa() const;
77 ConfusionMatrix& operator*=(double weight);
78 ConfusionMatrix operator*(double weight);
79 ConfusionMatrix& operator+=(const ConfusionMatrix &cm);
80 ConfusionMatrix operator+(const ConfusionMatrix &cm){
81 return ConfusionMatrix(*this)+=cm;
82 }
83 void sortClassNames();
84
85 void reportSE95(bool doReport) {m_se95=doReport;};
86 void setFormat(const CM_FORMAT& theFormat) {m_format=theFormat;};
87 void setFormat(const std::string theFormat) {m_format=getFormat(theFormat);};
88 CM_FORMAT getFormat() const {return m_format;};
89
90 static const CM_FORMAT getFormat(const std::string theFormat){
91 if(theFormat=="ascii") return(ASCII);
92 else if(theFormat=="latex") return(LATEX);
93 else{
94 std::string errorString="Format not supported: ";
95 errorString+=theFormat;
96 errorString+=" use ascii or latex";
97 throw(errorString);
98 }
99 };
100
101 friend std::ostream& operator<<(std::ostream& os, const ConfusionMatrix &cm){
102 std::ostringstream streamLine;
103 /* streamosclass << iclass; */
104 /* m_classes[iclass]=osclass.str(); */
105
106 std::string fieldSeparator=" ";
107 std::string lineSeparator="";
108 std::string mathMode="";
109 switch(cm.getFormat()){
110 case(LATEX):
111 fieldSeparator=" & ";
112 lineSeparator="\\\\";
113 mathMode="$";
114 break;
115 case(ASCII):
116 default:
117 fieldSeparator="\t";
118 lineSeparator="";
119 mathMode="";
120 break;
121 }
122
123 double se95_ua=0;
124 double se95_pa=0;
125 double se95_oa=0;
126 double dua=0;
127 double dpa=0;
128 double doa=0;
129
130 doa = cm.oa(&se95_oa);
131
132 if(cm.getFormat()==LATEX){
133 os << "\\documentclass{article}" << std::endl;
134 os << "\\begin{document}" << std::endl;
135 }
136 os << "Kappa = " << mathMode << cm.kappa() << mathMode ;
137 os << ", Overall Acc. = " << mathMode << 100.0*cm.oa() << mathMode ;
138 if(cm.m_se95)
139 os << " (" << mathMode << se95_oa << mathMode << ")";
140 os << std::endl;
141 os << std::endl;
142 if(cm.getFormat()==LATEX){
143 os << "\\begin{tabular}{@{}l";
144 for(int iclass=0;iclass<cm.nClasses();++iclass)
145 os << "l";
146 os << "}" << std::endl;
147 os << "\\hline" << std::endl;
148 }
149
150 os << "Class";
151 for(int iclass=0;iclass<cm.nClasses();++iclass)
152 os << fieldSeparator << cm.m_classes[iclass];
153 os << lineSeparator << std::endl;
154 if(cm.getFormat()==LATEX)
155 os << "\\hline" << std::endl;
156 assert(cm.m_classes.size()==cm.m_results.size());
157 for(int irow=0;irow<cm.m_results.size();++irow){
158 os << cm.m_classes[irow];
159 for(int icol=0;icol<cm.m_results[irow].size();++icol)
160 os << fieldSeparator << cm.m_results[irow][icol];
161 os << lineSeparator<< std::endl;
162 }
163 if(cm.getFormat()==LATEX){
164 os << "\\hline" << std::endl;
165 }
166 else
167 os << std::endl;
168
169 os << "User' Acc.";
170 for(int iclass=0;iclass<cm.nClasses();++iclass){
171 dua=cm.ua_pct(cm.m_classes[iclass],&se95_ua);
172 os << fieldSeparator << dua;
173 if(cm.m_se95)
174 os << " (" << se95_ua << ")";
175 }
176 os << lineSeparator<< std::endl;
177 os << "Prod. Acc.";
178 for(int iclass=0;iclass<cm.nClasses();++iclass){
179 dpa=cm.pa_pct(cm.m_classes[iclass],&se95_ua);
180 os << fieldSeparator << dpa;
181 if(cm.m_se95)
182 os << " (" << se95_pa << ")";
183 }
184 os << lineSeparator<< std::endl;
185 if(cm.getFormat()==LATEX){
186 os << "\\end{tabular}" << std::endl;
187 os << "\\end{document}" << std::endl;
188 }
189 return os;
190 };
191private:
192 std::vector<std::string> m_classes;
193 Vector2d<double> m_results;
194 CM_FORMAT m_format;
195 bool m_se95;
196};
197}
198#endif /* _CONFUSIONMATRIX_H_ */