28#include "imageclasses/ImgReaderGdal.h"
29#include "imageclasses/ImgWriterOgr.h"
30#include "base/Optionpk.h"
31#include "algorithms/StatFactory.h"
34#define PI 3.1415926535897932384626433832795
90int main(
int argc,
char *argv[])
92 Optionpk<string> image_opt(
"i",
"input",
"Raster input dataset containing band information");
93 Optionpk<string> sample_opt(
"s",
"sample",
"Raster dataset with features to be extracted from input data. Output will contain features with input band information included.");
95 Optionpk<int> class_opt(
"c",
"class",
"Class(es) to extract from input sample image. Leave empty to extract all valid data pixels from sample dataset");
96 Optionpk<float> threshold_opt(
"t",
"threshold",
"Probability threshold for selecting samples (randomly). Provide probability in percentage (>0) or absolute (<0). Use a single threshold per vector sample layer. If using raster land cover maps as a sample dataset, you can provide a threshold value for each class (e.g. -t 80 -t 60). Use value 100 to select all pixels for selected class(es)", 100);
97 Optionpk<string> ogrformat_opt(
"f",
"f",
"Output sample dataset format",
"SQLite");
98 Optionpk<string> ftype_opt(
"ft",
"ftype",
"Field type (only Real or Integer)",
"Real");
99 Optionpk<string> ltype_opt(
"lt",
"ltype",
"Label type: In16 or String",
"Integer");
100 Optionpk<int> band_opt(
"b",
"band",
"Band index(es) to extract (0 based). Leave empty to use all bands");
103 Optionpk<double> srcnodata_opt(
"srcnodata",
"srcnodata",
"Invalid value(s) for input image");
104 Optionpk<int> bndnodata_opt(
"bndnodata",
"bndnodata",
"Band in input image to check if pixel is valid (used for srcnodata)", 0);
105 Optionpk<string> fieldname_opt(
"bn",
"bname",
"For single band input data, this extra attribute name will correspond to the raster values. For multi-band input data, multiple attributes with this prefix will be added (e.g. b0, b1, b2, etc.)",
"b");
106 Optionpk<string> label_opt(
"cn",
"cname",
"Name of the class label in the output vector dataset",
"label");
108 Optionpk<short> verbose_opt(
"v",
"verbose",
"Verbose mode if > 0", 0,2);
110 bstart_opt.setHide(1);
112 bndnodata_opt.setHide(1);
113 srcnodata_opt.setHide(1);
114 fieldname_opt.setHide(1);
115 label_opt.setHide(1);
120 doProcess=image_opt.retrieveOption(argc,argv);
121 sample_opt.retrieveOption(argc,argv);
122 output_opt.retrieveOption(argc,argv);
123 class_opt.retrieveOption(argc,argv);
124 threshold_opt.retrieveOption(argc,argv);
125 ogrformat_opt.retrieveOption(argc,argv);
126 ftype_opt.retrieveOption(argc,argv);
127 ltype_opt.retrieveOption(argc,argv);
128 band_opt.retrieveOption(argc,argv);
129 bstart_opt.retrieveOption(argc,argv);
130 bend_opt.retrieveOption(argc,argv);
131 bndnodata_opt.retrieveOption(argc,argv);
132 srcnodata_opt.retrieveOption(argc,argv);
133 fieldname_opt.retrieveOption(argc,argv);
134 label_opt.retrieveOption(argc,argv);
135 down_opt.retrieveOption(argc,argv);
136 verbose_opt.retrieveOption(argc,argv);
138 catch(
string predefinedString){
139 std::cout << predefinedString << std::endl;
144 cout <<
"Usage: pkextractimg -i input -s sample -o output" << endl;
146 std::cout <<
"short option -h shows basic options only, use long option --help to show all options" << std::endl;
156 std::cout << class_opt << std::endl;
158 stat.setNoDataValues(srcnodata_opt);
160 unsigned long int nsample=0;
161 unsigned long int ntotalvalid=0;
162 unsigned long int ntotalinvalid=0;
164 map<int,unsigned long int> nvalid;
165 map<int,unsigned long int> ninvalid;
177 map <int,short> classmap;
178 for(
int iclass=0;iclass<class_opt.size();++iclass){
179 nvalid[class_opt[iclass]]=0;
180 ninvalid[class_opt[iclass]]=0;
181 classmap[class_opt[iclass]]=iclass;
185 if(image_opt.empty()){
186 std::cerr <<
"No image dataset provided (use option -i). Use --help for help information";
189 if(output_opt.empty()){
190 std::cerr <<
"No output dataset provided (use option -o). Use --help for help information";
194 imgReader.
open(image_opt[0]);
196 catch(std::string errorstring){
197 std::cout << errorstring << std::endl;
203 if(bstart_opt.size()){
204 if(bend_opt.size()!=bstart_opt.size()){
205 string errorstring=
"Error: options for start and end band indexes must be provided as pairs, missing end band";
209 for(
int ipair=0;ipair<bstart_opt.size();++ipair){
210 if(bend_opt[ipair]<=bstart_opt[ipair]){
211 string errorstring=
"Error: index for end band must be smaller then start band";
214 for(
int iband=bstart_opt[ipair];iband<=bend_opt[ipair];++iband)
215 band_opt.push_back(iband);
220 cerr << error << std::endl;
224 int nband=(band_opt.size()) ? band_opt.size() : imgReader.
nrOfBand();
226 if(fieldname_opt.size()<nband){
227 std::string bandString=fieldname_opt[0];
228 fieldname_opt.clear();
229 fieldname_opt.resize(nband);
230 for(
int iband=0;iband<nband;++iband){
231 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
233 fs << bandString << theBand;
234 fieldname_opt[iband]=fs.str();
239 std::cout << fieldname_opt << std::endl;
242 std::cout <<
"Number of bands in input image: " << imgReader.
nrOfBand() << std::endl;
244 OGRFieldType fieldType;
245 OGRFieldType labelType;
246 int ogr_typecount=11;
248 std::cout <<
"field and label types can be: ";
249 for(
int iType = 0; iType < ogr_typecount; ++iType){
251 std::cout <<
" " << OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType);
252 if( OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType) != NULL
253 && EQUAL(OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType),
254 ftype_opt[0].c_str()))
255 fieldType=(OGRFieldType) iType;
256 if( OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType) != NULL
257 && EQUAL(OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType),
258 ltype_opt[0].c_str()))
259 labelType=(OGRFieldType) iType;
267 std::cout << std::endl <<
"field type is: " << OGRFieldDefn::GetFieldTypeName(fieldType) << std::endl;
270 cerr <<
"field type " << OGRFieldDefn::GetFieldTypeName(fieldType) <<
" not supported" << std::endl;
280 std::cout << std::endl <<
"label type is: " << OGRFieldDefn::GetFieldTypeName(labelType) << std::endl;
283 cerr <<
"label type " << OGRFieldDefn::GetFieldTypeName(labelType) <<
" not supported" << std::endl;
288 const char* pszMessage;
289 void* pProgressArg=NULL;
290 GDALProgressFunc pfnProgress=GDALTermProgress;
294 bool sampleIsRaster=
true;
299 if(sample_opt.size()){
301 classReader.
open(sample_opt[0]);
303 catch(
string errorString){
305 sampleIsRaster=
false;
309 std::cerr <<
"No raster sample dataset provided (use option -s filename). Use --help for help information";
314 if(class_opt.empty()){
316 assert(sample_opt.size());
317 classReader.
open(sample_opt[0]);
319 vector<double> classBuffer(classReader.
nrOfCol());
322 vector<double> sample(2+nband);
324 vector<int> writeBufferClass;
325 vector<int> selectedClass;
331 std::cout <<
"extracting sample from image..." << std::endl;
333 pfnProgress(progress,pszMessage,pProgressArg);
334 for(irow=0;irow<classReader.
nrOfRow();++irow){
337 classReader.
readData(classBuffer,irow);
347 if(
static_cast<int>(jimg)<0||
static_cast<int>(jimg)>=imgReader.
nrOfRow())
349 for(
int iband=0;iband<nband;++iband){
350 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
351 imgReader.
readData(imgBuffer[iband],
static_cast<int>(jimg),theBand);
353 for(icol=0;icol<classReader.
nrOfCol();++icol){
356 int theClass=classBuffer[icol];
359 if(class_opt.empty()){
361 processClass=theClass;
364 for(
int iclass=0;iclass<class_opt.size();++iclass){
365 if(classBuffer[icol]==class_opt[iclass]){
367 theClass=class_opt[iclass];
376 if(verbose_opt[0]>1){
377 std::cout.precision(12);
378 std::cout << theClass <<
" " << x <<
" " << y << std::endl;
383 iimg=
static_cast<int>(iimg);
384 if(
static_cast<int>(iimg)<0||
static_cast<int>(iimg)>=imgReader.
nrOfCol())
387 for(
int iband=0;iband<nband&&valid;++iband){
388 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
389 if(srcnodata_opt.size()&&theBand==bndnodata_opt[0]){
391 for(
int inodata=0;inodata<srcnodata_opt.size()&&valid;++inodata){
392 if(imgBuffer[iband][iimg]==srcnodata_opt[inodata])
400 for(
int iband=0;iband<imgBuffer.size();++iband){
401 sample[iband+2]=imgBuffer[iband][iimg];
403 float theThreshold=(threshold_opt.size()>1)?threshold_opt[processClass]:threshold_opt[0];
405 double p=
static_cast<double>(rand())/(RAND_MAX);
414 writeBuffer.push_back(sample);
415 writeBufferClass.push_back(theClass);
417 if(nvalid.count(theClass))
424 if(ninvalid.count(theClass))
425 ninvalid[theClass]+=1;
427 ninvalid[theClass]=1;
430 progress=
static_cast<float>(irow+1.0)/classReader.
nrOfRow();
431 pfnProgress(progress,pszMessage,pProgressArg);
434 pfnProgress(progress,pszMessage,pProgressArg);
435 if(writeBuffer.size()>0){
436 assert(ntotalvalid==writeBuffer.size());
438 std::cout <<
"creating image sample writer " << output_opt[0] <<
" with " << writeBuffer.size() <<
" samples (" << ntotalinvalid <<
" invalid)" << std::endl;
439 ogrWriter.open(output_opt[0],ogrformat_opt[0]);
440 char **papszOptions=NULL;
441 ostringstream slayer;
442 slayer <<
"training data";
443 std::string layername=slayer.str();
444 ogrWriter.createLayer(layername, imgReader.
getProjection(), wkbPoint, papszOptions);
445 std::string fieldname=
"fid";
446 ogrWriter.createField(fieldname,OFTInteger);
447 map<std::string,double> pointAttributes;
448 ogrWriter.createField(label_opt[0],labelType);
449 for(
int iband=0;iband<nband;++iband){
450 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
451 ogrWriter.createField(fieldname_opt[iband],fieldType);
454 pfnProgress(progress,pszMessage,pProgressArg);
456 map<int,short> classDone;
458 vector<int> writeBufferClassTmp;
460 if(threshold_opt[0]<0){
461 map<int,unsigned long int>::iterator mapit;
462 map<int,unsigned long int> ncopied;
463 for(mapit=nvalid.begin();mapit!=nvalid.end();++mapit)
464 ncopied[mapit->first]=0;
466 cout <<
"ncopied.size(): " << ncopied.size() << endl;
467 while(classDone.size()<nvalid.size()){
468 int index=rand()%writeBufferClass.size();
469 int theClass=writeBufferClass[index];
470 float theThreshold=threshold_opt[0];
471 if(threshold_opt.size()>1&&class_opt.size())
472 theThreshold=threshold_opt[classmap[theClass]];
473 theThreshold=-theThreshold;
474 if(ncopied[theClass]<theThreshold){
475 writeBufferClassTmp.push_back(*(writeBufferClass.begin()+index));
476 writeBufferTmp.push_back(*(writeBuffer.begin()+index));
477 writeBufferClass.erase(writeBufferClass.begin()+index);
478 writeBuffer.erase(writeBuffer.begin()+index);
479 ++(ncopied[theClass]);
482 classDone[theClass]=1;
483 if(ncopied[theClass]>=nvalid[theClass]){
484 classDone[theClass]=1;
487 writeBuffer=writeBufferTmp;
488 writeBufferClass=writeBufferClassTmp;
506 for(
int isample=0;isample<writeBuffer.size();++isample){
508 std::cout <<
"writing sample " << isample << std::endl;
509 pointAttributes[label_opt[0]]=writeBufferClass[isample];
510 for(
int iband=0;iband<writeBuffer[0].size()-2;++iband){
511 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
512 pointAttributes[fieldname_opt[iband]]=writeBuffer[isample][iband+2];
515 std::cout <<
"all bands written" << std::endl;
516 ogrWriter.addPoint(writeBuffer[isample][0],writeBuffer[isample][1],pointAttributes,fieldname,isample);
517 progress=
static_cast<float>(isample+1.0)/writeBuffer.size();
518 pfnProgress(progress,pszMessage,pProgressArg);
523 std::cout <<
"No data found for any class " << std::endl;
526 nsample=writeBuffer.size();
528 std::cout <<
"total number of samples written: " << nsample << std::endl;
531 assert(class_opt[0]);
533 assert(threshold_opt.size()==1||threshold_opt.size()==class_opt.size());
536 if(verbose_opt[0]>1){
537 std::cout <<
"reading position from sample dataset " << std::endl;
538 std::cout <<
"class thresholds: " << std::endl;
539 for(
int iclass=0;iclass<class_opt.size();++iclass){
540 if(threshold_opt.size()>1)
541 std::cout << class_opt[iclass] <<
": " << threshold_opt[iclass] << std::endl;
543 std::cout << class_opt[iclass] <<
": " << threshold_opt[0] << std::endl;
546 classReader.
open(sample_opt[0]);
547 vector<int> classBuffer(classReader.
nrOfCol());
551 vector<double> sample(2+nband);
553 vector<int> writeBufferClass;
554 vector<int> selectedClass;
560 std::cout <<
"extracting sample from image..." << std::endl;
562 pfnProgress(progress,pszMessage,pProgressArg);
563 for(irow=0;irow<classReader.
nrOfRow();++irow){
566 classReader.
readData(classBuffer,irow);
576 if(
static_cast<int>(jimg)<0||
static_cast<int>(jimg)>=imgReader.
nrOfRow())
578 for(
int iband=0;iband<nband;++iband){
579 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
580 imgReader.
readData(imgBuffer[iband],
static_cast<int>(jimg),theBand);
583 for(icol=0;icol<classReader.
nrOfCol();++icol){
589 if(class_opt.empty()){
590 if(classBuffer[icol]){
592 theClass=classBuffer[icol];
596 for(
int iclass=0;iclass<class_opt.size();++iclass){
597 if(classBuffer[icol]==class_opt[iclass]){
599 theClass=class_opt[iclass];
608 if(verbose_opt[0]>1){
609 std::cout.precision(12);
610 std::cout << theClass <<
" " << x <<
" " << y << std::endl;
615 iimg=
static_cast<int>(iimg);
616 if(
static_cast<int>(iimg)<0||
static_cast<int>(iimg)>=imgReader.
nrOfCol())
620 for(
int iband=0;iband<nband;++iband){
621 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
622 if(srcnodata_opt.size()&&theBand==bndnodata_opt[0]){
624 for(
int inodata=0;inodata<srcnodata_opt.size()&&valid;++inodata){
625 if(imgBuffer[iband][iimg]==srcnodata_opt[inodata])
631 for(
int iband=0;iband<imgBuffer.size();++iband){
632 sample[iband+2]=imgBuffer[iband][iimg];
634 float theThreshold=(threshold_opt.size()>1)?threshold_opt[processClass]:threshold_opt[0];
636 double p=
static_cast<double>(rand())/(RAND_MAX);
645 writeBuffer.push_back(sample);
646 writeBufferClass.push_back(theClass);
648 if(nvalid.count(theClass))
655 if(ninvalid.count(theClass))
656 ninvalid[theClass]+=1;
658 ninvalid[theClass]=1;
662 progress=
static_cast<float>(irow+1.0)/classReader.
nrOfRow();
663 pfnProgress(progress,pszMessage,pProgressArg);
665 if(writeBuffer.size()>0){
666 assert(ntotalvalid==writeBuffer.size());
668 std::cout <<
"creating image sample writer " << output_opt[0] <<
" with " << writeBuffer.size() <<
" samples (" << ntotalinvalid <<
" invalid)" << std::endl;
669 ogrWriter.open(output_opt[0],ogrformat_opt[0]);
670 char **papszOptions=NULL;
671 ostringstream slayer;
672 slayer <<
"training data";
673 std::string layername=slayer.str();
674 ogrWriter.createLayer(layername, imgReader.
getProjection(), wkbPoint, papszOptions);
675 std::string fieldname=
"fid";
676 ogrWriter.createField(fieldname,OFTInteger);
677 map<std::string,double> pointAttributes;
678 ogrWriter.createField(label_opt[0],labelType);
679 for(
int iband=0;iband<nband;++iband){
680 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
681 ogrWriter.createField(fieldname_opt[iband],fieldType);
683 pfnProgress(progress,pszMessage,pProgressArg);
685 pfnProgress(progress,pszMessage,pProgressArg);
687 map<int,short> classDone;
689 vector<int> writeBufferClassTmp;
691 if(threshold_opt[0]<0){
692 map<int,unsigned long int>::iterator mapit;
693 map<int,unsigned long int> ncopied;
694 for(mapit=nvalid.begin();mapit!=nvalid.end();++mapit)
695 ncopied[mapit->first]=0;
697 while(classDone.size()<nvalid.size()){
698 int index=rand()%writeBufferClass.size();
699 int theClass=writeBufferClass[index];
700 float theThreshold=threshold_opt[0];
701 if(threshold_opt.size()>1&&class_opt.size())
702 theThreshold=threshold_opt[classmap[theClass]];
703 theThreshold=-theThreshold;
704 if(ncopied[theClass]<theThreshold){
705 writeBufferClassTmp.push_back(*(writeBufferClass.begin()+index));
706 writeBufferTmp.push_back(*(writeBuffer.begin()+index));
707 writeBufferClass.erase(writeBufferClass.begin()+index);
708 writeBuffer.erase(writeBuffer.begin()+index);
709 ++(ncopied[theClass]);
712 classDone[theClass]=1;
713 if(ncopied[theClass]>=nvalid[theClass]){
714 classDone[theClass]=1;
717 writeBuffer=writeBufferTmp;
718 writeBufferClass=writeBufferClassTmp;
736 for(
int isample=0;isample<writeBuffer.size();++isample){
737 pointAttributes[label_opt[0]]=writeBufferClass[isample];
738 for(
int iband=0;iband<writeBuffer[0].size()-2;++iband){
739 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
740 pointAttributes[fieldname_opt[iband]]=writeBuffer[isample][iband+2];
742 ogrWriter.addPoint(writeBuffer[isample][0],writeBuffer[isample][1],pointAttributes,fieldname,isample);
743 progress=
static_cast<float>(isample+1.0)/writeBuffer.size();
744 pfnProgress(progress,pszMessage,pProgressArg);
749 std::cout <<
"No data found for any class " << std::endl;
752 nsample=writeBuffer.size();
754 std::cout <<
"total number of samples written: " << nsample << std::endl;
755 if(nvalid.size()==class_opt.size()){
756 for(
int iclass=0;iclass<class_opt.size();++iclass)
757 std::cout <<
"class " << class_opt[iclass] <<
" has " << nvalid[iclass] <<
" samples" << std::endl;
763 cerr <<
"Error: vector sample not supported, consider using pkextractogr" << endl;
766 pfnProgress(progress,pszMessage,pProgressArg);
int nrOfRow(void) const
Get the number of rows of this dataset.
bool geo2image(double x, double y, double &i, double &j) const
Convert georeferenced coordinates (x and y) to image coordinates (column and row)
std::string getProjection(void) const
Get the projection string (deprecated, use getProjectionRef instead)
int nrOfBand(void) const
Get the number of bands of this dataset.
int nrOfCol(void) const
Get the number of columns of this dataset.
bool image2geo(double i, double j, double &x, double &y) const
Convert image coordinates (column and row) to georeferenced coordinates (x and y)
void readData(T &value, int col, int row, int band=0)
Read a single pixel cell value at a specific column and row for a specific band (all indices start co...
void close(void)
Set the memory (in MB) to cache a number of rows in memory.
void open(const std::string &filename, const GDALAccess &readMode=GA_ReadOnly)
Open an image.