pktools 2.6.7
Processing Kernel for geospatial data
pkextractimg.cc
1/**********************************************************************
2pkextractimg.cc: extract pixel values from raster image using a raster sample
3Copyright (C) 2008-2016 Pieter Kempeneers
4
5This file is part of pktools
6
7pktools is free software: you can redistribute it and/or modify
8it under the terms of the GNU General Public License as published by
9the Free Software Foundation, either version 3 of the License, or
10(at your option) any later version.
11
12pktools is distributed in the hope that it will be useful,
13but WITHOUT ANY WARRANTY; without even the implied warranty of
14MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15GNU General Public License for more details.
16
17You should have received a copy of the GNU General Public License
18along with pktools. If not, see <http://www.gnu.org/licenses/>.
19***********************************************************************/
20#include <assert.h>
21#include <math.h>
22#include <stdlib.h>
23#include <sstream>
24#include <string>
25#include <algorithm>
26#include <ctime>
27#include <vector>
28#include "imageclasses/ImgReaderGdal.h"
29#include "imageclasses/ImgWriterOgr.h"
30#include "base/Optionpk.h"
31#include "algorithms/StatFactory.h"
32
33#ifndef PI
34#define PI 3.1415926535897932384626433832795
35#endif
36
37/******************************************************************************/
88using namespace std;
89
90int main(int argc, char *argv[])
91{
92 Optionpk<string> image_opt("i", "input", "Raster input dataset containing band information");
93 Optionpk<string> sample_opt("s", "sample", "Raster dataset with features to be extracted from input data. Output will contain features with input band information included.");
94 Optionpk<string> output_opt("o", "output", "Output sample dataset");
95 Optionpk<int> class_opt("c", "class", "Class(es) to extract from input sample image. Leave empty to extract all valid data pixels from sample dataset");
96 Optionpk<float> threshold_opt("t", "threshold", "Probability threshold for selecting samples (randomly). Provide probability in percentage (>0) or absolute (<0). Use a single threshold per vector sample layer. If using raster land cover maps as a sample dataset, you can provide a threshold value for each class (e.g. -t 80 -t 60). Use value 100 to select all pixels for selected class(es)", 100);
97 Optionpk<string> ogrformat_opt("f", "f", "Output sample dataset format","SQLite");
98 Optionpk<string> ftype_opt("ft", "ftype", "Field type (only Real or Integer)", "Real");
99 Optionpk<string> ltype_opt("lt", "ltype", "Label type: In16 or String", "Integer");
100 Optionpk<int> band_opt("b", "band", "Band index(es) to extract (0 based). Leave empty to use all bands");
101 Optionpk<unsigned short> bstart_opt("sband", "startband", "Start band sequence number");
102 Optionpk<unsigned short> bend_opt("eband", "endband", "End band sequence number");
103 Optionpk<double> srcnodata_opt("srcnodata", "srcnodata", "Invalid value(s) for input image");
104 Optionpk<int> bndnodata_opt("bndnodata", "bndnodata", "Band in input image to check if pixel is valid (used for srcnodata)", 0);
105 Optionpk<string> fieldname_opt("bn", "bname", "For single band input data, this extra attribute name will correspond to the raster values. For multi-band input data, multiple attributes with this prefix will be added (e.g. b0, b1, b2, etc.)", "b");
106 Optionpk<string> label_opt("cn", "cname", "Name of the class label in the output vector dataset", "label");
107 Optionpk<short> down_opt("down", "down", "Down sampling factor", 1);
108 Optionpk<short> verbose_opt("v", "verbose", "Verbose mode if > 0", 0,2);
109
110 bstart_opt.setHide(1);
111 bend_opt.setHide(1);
112 bndnodata_opt.setHide(1);
113 srcnodata_opt.setHide(1);
114 fieldname_opt.setHide(1);
115 label_opt.setHide(1);
116 down_opt.setHide(1);
117
118 bool doProcess;//stop process when program was invoked with help option (-h --help)
119 try{
120 doProcess=image_opt.retrieveOption(argc,argv);
121 sample_opt.retrieveOption(argc,argv);
122 output_opt.retrieveOption(argc,argv);
123 class_opt.retrieveOption(argc,argv);
124 threshold_opt.retrieveOption(argc,argv);
125 ogrformat_opt.retrieveOption(argc,argv);
126 ftype_opt.retrieveOption(argc,argv);
127 ltype_opt.retrieveOption(argc,argv);
128 band_opt.retrieveOption(argc,argv);
129 bstart_opt.retrieveOption(argc,argv);
130 bend_opt.retrieveOption(argc,argv);
131 bndnodata_opt.retrieveOption(argc,argv);
132 srcnodata_opt.retrieveOption(argc,argv);
133 fieldname_opt.retrieveOption(argc,argv);
134 label_opt.retrieveOption(argc,argv);
135 down_opt.retrieveOption(argc,argv);
136 verbose_opt.retrieveOption(argc,argv);
137 }
138 catch(string predefinedString){
139 std::cout << predefinedString << std::endl;
140 exit(0);
141 }
142 if(!doProcess){
143 cout << endl;
144 cout << "Usage: pkextractimg -i input -s sample -o output" << endl;
145 cout << endl;
146 std::cout << "short option -h shows basic options only, use long option --help to show all options" << std::endl;
147 exit(0);//help was invoked, stop processing
148 }
149
150 // if(srcnodata_opt.size()){
151 // while(srcnodata_opt.size()<bndnodata_opt.size())
152 // srcnodata_opt.push_back(srcnodata_opt[0]);
153 // }
154
155 if(verbose_opt[0])
156 std::cout << class_opt << std::endl;
158 stat.setNoDataValues(srcnodata_opt);
160 unsigned long int nsample=0;
161 unsigned long int ntotalvalid=0;
162 unsigned long int ntotalinvalid=0;
163
164 map<int,unsigned long int> nvalid;
165 map<int,unsigned long int> ninvalid;
166 // vector<unsigned long int> nvalid(class_opt.size());
167 // vector<unsigned long int> ninvalid(class_opt.size());
168 // if(class_opt.empty()){
169 // nvalid.resize(256);
170 // ninvalid.resize(256);
171 // }
172 // for(int it=0;it<nvalid.size();++it){
173 // nvalid[it]=0;
174 // ninvalid[it]=0;
175 // }
176
177 map <int,short> classmap;//class->index
178 for(int iclass=0;iclass<class_opt.size();++iclass){
179 nvalid[class_opt[iclass]]=0;
180 ninvalid[class_opt[iclass]]=0;
181 classmap[class_opt[iclass]]=iclass;
182 }
183
184 ImgReaderGdal imgReader;
185 if(image_opt.empty()){
186 std::cerr << "No image dataset provided (use option -i). Use --help for help information";
187 exit(0);
188 }
189 if(output_opt.empty()){
190 std::cerr << "No output dataset provided (use option -o). Use --help for help information";
191 exit(0);
192 }
193 try{
194 imgReader.open(image_opt[0]);
195 }
196 catch(std::string errorstring){
197 std::cout << errorstring << std::endl;
198 exit(0);
199 }
200
201 //convert start and end band options to vector of band indexes
202 try{
203 if(bstart_opt.size()){
204 if(bend_opt.size()!=bstart_opt.size()){
205 string errorstring="Error: options for start and end band indexes must be provided as pairs, missing end band";
206 throw(errorstring);
207 }
208 band_opt.clear();
209 for(int ipair=0;ipair<bstart_opt.size();++ipair){
210 if(bend_opt[ipair]<=bstart_opt[ipair]){
211 string errorstring="Error: index for end band must be smaller then start band";
212 throw(errorstring);
213 }
214 for(int iband=bstart_opt[ipair];iband<=bend_opt[ipair];++iband)
215 band_opt.push_back(iband);
216 }
217 }
218 }
219 catch(string error){
220 cerr << error << std::endl;
221 exit(1);
222 }
223
224 int nband=(band_opt.size()) ? band_opt.size() : imgReader.nrOfBand();
225
226 if(fieldname_opt.size()<nband){
227 std::string bandString=fieldname_opt[0];
228 fieldname_opt.clear();
229 fieldname_opt.resize(nband);
230 for(int iband=0;iband<nband;++iband){
231 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
232 ostringstream fs;
233 fs << bandString << theBand;
234 fieldname_opt[iband]=fs.str();
235 }
236 }
237
238 if(verbose_opt[0])
239 std::cout << fieldname_opt << std::endl;
240
241 if(verbose_opt[0]>1)
242 std::cout << "Number of bands in input image: " << imgReader.nrOfBand() << std::endl;
243
244 OGRFieldType fieldType;
245 OGRFieldType labelType;
246 int ogr_typecount=11;//hard coded for now!
247 if(verbose_opt[0]>1)
248 std::cout << "field and label types can be: ";
249 for(int iType = 0; iType < ogr_typecount; ++iType){
250 if(verbose_opt[0]>1)
251 std::cout << " " << OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType);
252 if( OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType) != NULL
253 && EQUAL(OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType),
254 ftype_opt[0].c_str()))
255 fieldType=(OGRFieldType) iType;
256 if( OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType) != NULL
257 && EQUAL(OGRFieldDefn::GetFieldTypeName((OGRFieldType)iType),
258 ltype_opt[0].c_str()))
259 labelType=(OGRFieldType) iType;
260 }
261 switch( fieldType ){
262 case OFTInteger:
263 case OFTReal:
264 case OFTRealList:
265 case OFTString:
266 if(verbose_opt[0]>1)
267 std::cout << std::endl << "field type is: " << OGRFieldDefn::GetFieldTypeName(fieldType) << std::endl;
268 break;
269 default:
270 cerr << "field type " << OGRFieldDefn::GetFieldTypeName(fieldType) << " not supported" << std::endl;
271 exit(0);
272 break;
273 }
274 switch( labelType ){
275 case OFTInteger:
276 case OFTReal:
277 case OFTRealList:
278 case OFTString:
279 if(verbose_opt[0]>1)
280 std::cout << std::endl << "label type is: " << OGRFieldDefn::GetFieldTypeName(labelType) << std::endl;
281 break;
282 default:
283 cerr << "label type " << OGRFieldDefn::GetFieldTypeName(labelType) << " not supported" << std::endl;
284 exit(0);
285 break;
286 }
287
288 const char* pszMessage;
289 void* pProgressArg=NULL;
290 GDALProgressFunc pfnProgress=GDALTermProgress;
291 double progress=0;
292 srand(time(NULL));
293
294 bool sampleIsRaster=true;
295
296 ImgReaderGdal classReader;
297 ImgWriterOgr sampleWriterOgr;
298
299 if(sample_opt.size()){
300 try{
301 classReader.open(sample_opt[0]);
302 }
303 catch(string errorString){
304 //todo: sampleIsRaster will not work from GDAL 2.0!!?? (unification of driver for raster and vector datasets)
305 sampleIsRaster=false;
306 }
307 }
308 else{
309 std::cerr << "No raster sample dataset provided (use option -s filename). Use --help for help information";
310 exit(1);
311 }
312
313 if(sampleIsRaster){
314 if(class_opt.empty()){
315 ImgWriterOgr ogrWriter;
316 assert(sample_opt.size());
317 classReader.open(sample_opt[0]);
318 // vector<int> classBuffer(classReader.nrOfCol());
319 vector<double> classBuffer(classReader.nrOfCol());
320 Vector2d<double> imgBuffer(nband,imgReader.nrOfCol());//[band][col]
321 // vector<double> imgBuffer(nband);//[band]
322 vector<double> sample(2+nband);//x,y,band values
323 Vector2d<double> writeBuffer;
324 vector<int> writeBufferClass;
325 vector<int> selectedClass;
326 Vector2d<double> selectedBuffer;
327 double oldimgrow=-1;
328 int irow=0;
329 int icol=0;
330 if(verbose_opt[0]>1)
331 std::cout << "extracting sample from image..." << std::endl;
332 progress=0;
333 pfnProgress(progress,pszMessage,pProgressArg);
334 for(irow=0;irow<classReader.nrOfRow();++irow){
335 if(irow%down_opt[0])
336 continue;
337 classReader.readData(classBuffer,irow);
338 double x=0;//geo x coordinate
339 double y=0;//geo y coordinate
340 double iimg=0;//image x-coordinate in img image
341 double jimg=0;//image y-coordinate in img image
342
343 //find col in img
344 classReader.image2geo(icol,irow,x,y);
345 imgReader.geo2image(x,y,iimg,jimg);
346 //nearest neighbour
347 if(static_cast<int>(jimg)<0||static_cast<int>(jimg)>=imgReader.nrOfRow())
348 continue;
349 for(int iband=0;iband<nband;++iband){
350 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
351 imgReader.readData(imgBuffer[iband],static_cast<int>(jimg),theBand);
352 }
353 for(icol=0;icol<classReader.nrOfCol();++icol){
354 if(icol%down_opt[0])
355 continue;
356 int theClass=classBuffer[icol];
357 int processClass=0;
358 bool valid=false;
359 if(class_opt.empty()){
360 valid=true;//process every class
361 processClass=theClass;
362 }
363 else{
364 for(int iclass=0;iclass<class_opt.size();++iclass){
365 if(classBuffer[icol]==class_opt[iclass]){
366 processClass=iclass;
367 theClass=class_opt[iclass];
368 valid=true;//process this class
369 break;
370 }
371 }
372 }
373 classReader.image2geo(icol,irow,x,y);
374 sample[0]=x;
375 sample[1]=y;
376 if(verbose_opt[0]>1){
377 std::cout.precision(12);
378 std::cout << theClass << " " << x << " " << y << std::endl;
379 }
380 //find col in img
381 imgReader.geo2image(x,y,iimg,jimg);
382 //nearest neighbour
383 iimg=static_cast<int>(iimg);
384 if(static_cast<int>(iimg)<0||static_cast<int>(iimg)>=imgReader.nrOfCol())
385 continue;
386
387 for(int iband=0;iband<nband&&valid;++iband){
388 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
389 if(srcnodata_opt.size()&&theBand==bndnodata_opt[0]){
390 // vector<int>::const_iterator bndit=bndnodata_opt.begin();
391 for(int inodata=0;inodata<srcnodata_opt.size()&&valid;++inodata){
392 if(imgBuffer[iband][iimg]==srcnodata_opt[inodata])
393 valid=false;
394 }
395 }
396 }
397 // oldimgrow=jimg;
398
399 if(valid){
400 for(int iband=0;iband<imgBuffer.size();++iband){
401 sample[iband+2]=imgBuffer[iband][iimg];
402 }
403 float theThreshold=(threshold_opt.size()>1)?threshold_opt[processClass]:threshold_opt[0];
404 if(theThreshold>0){//percentual value
405 double p=static_cast<double>(rand())/(RAND_MAX);
406 p*=100.0;
407 if(p>theThreshold)
408 continue;//do not select for now, go to next column
409 }
410 // else if(nvalid.size()>processClass){//absolute value
411 // if(nvalid[processClass]>=-theThreshold)
412 // continue;//do not select any more pixels for this class, go to next column to search for other classes
413 // }
414 writeBuffer.push_back(sample);
415 writeBufferClass.push_back(theClass);
416 ++ntotalvalid;
417 if(nvalid.count(theClass))
418 nvalid[theClass]+=1;
419 else
420 nvalid[theClass]=1;
421 }
422 else{
423 ++ntotalinvalid;
424 if(ninvalid.count(theClass))
425 ninvalid[theClass]+=1;
426 else
427 ninvalid[theClass]=1;
428 }
429 }
430 progress=static_cast<float>(irow+1.0)/classReader.nrOfRow();
431 pfnProgress(progress,pszMessage,pProgressArg);
432 }//irow
433 progress=100;
434 pfnProgress(progress,pszMessage,pProgressArg);
435 if(writeBuffer.size()>0){
436 assert(ntotalvalid==writeBuffer.size());
437 if(verbose_opt[0]>0)
438 std::cout << "creating image sample writer " << output_opt[0] << " with " << writeBuffer.size() << " samples (" << ntotalinvalid << " invalid)" << std::endl;
439 ogrWriter.open(output_opt[0],ogrformat_opt[0]);
440 char **papszOptions=NULL;
441 ostringstream slayer;
442 slayer << "training data";
443 std::string layername=slayer.str();
444 ogrWriter.createLayer(layername, imgReader.getProjection(), wkbPoint, papszOptions);
445 std::string fieldname="fid";//number of the point
446 ogrWriter.createField(fieldname,OFTInteger);
447 map<std::string,double> pointAttributes;
448 ogrWriter.createField(label_opt[0],labelType);
449 for(int iband=0;iband<nband;++iband){
450 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
451 ogrWriter.createField(fieldname_opt[iband],fieldType);
452 }
453 progress=0;
454 pfnProgress(progress,pszMessage,pProgressArg);
455
456 map<int,short> classDone;
457 Vector2d<double> writeBufferTmp;
458 vector<int> writeBufferClassTmp;
459
460 if(threshold_opt[0]<0){//absolute threshold
461 map<int,unsigned long int>::iterator mapit;
462 map<int,unsigned long int> ncopied;
463 for(mapit=nvalid.begin();mapit!=nvalid.end();++mapit)
464 ncopied[mapit->first]=0;
465
466 cout << "ncopied.size(): " << ncopied.size() << endl;
467 while(classDone.size()<nvalid.size()){
468 int index=rand()%writeBufferClass.size();
469 int theClass=writeBufferClass[index];
470 float theThreshold=threshold_opt[0];
471 if(threshold_opt.size()>1&&class_opt.size())
472 theThreshold=threshold_opt[classmap[theClass]];
473 theThreshold=-theThreshold;
474 if(ncopied[theClass]<theThreshold){
475 writeBufferClassTmp.push_back(*(writeBufferClass.begin()+index));
476 writeBufferTmp.push_back(*(writeBuffer.begin()+index));
477 writeBufferClass.erase(writeBufferClass.begin()+index);
478 writeBuffer.erase(writeBuffer.begin()+index);
479 ++(ncopied[theClass]);
480 }
481 else
482 classDone[theClass]=1;
483 if(ncopied[theClass]>=nvalid[theClass]){
484 classDone[theClass]=1;
485 }
486 }
487 writeBuffer=writeBufferTmp;
488 writeBufferClass=writeBufferClassTmp;
489
490 // while(classDone.size()<nvalid.size()){
491 // int index=rand()%writeBufferClass.size();
492 // int theClass=writeBufferClass[index];
493 // float theThreshold=threshold_opt[0];
494 // if(threshold_opt.size()>1&&class_opt.size())
495 // theThreshold=threshold_opt[classmap[theClass]];
496 // theThreshold=-theThreshold;
497 // if(nvalid[theClass]>theThreshold){
498 // writeBufferClass.erase(writeBufferClass.begin()+index);
499 // writeBuffer.erase(writeBuffer.begin()+index);
500 // --(nvalid[theClass]);
501 // }
502 // else
503 // classDone[theClass]=1;
504 // }
505 }
506 for(int isample=0;isample<writeBuffer.size();++isample){
507 if(verbose_opt[0]>1)
508 std::cout << "writing sample " << isample << std::endl;
509 pointAttributes[label_opt[0]]=writeBufferClass[isample];
510 for(int iband=0;iband<writeBuffer[0].size()-2;++iband){
511 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
512 pointAttributes[fieldname_opt[iband]]=writeBuffer[isample][iband+2];
513 }
514 if(verbose_opt[0]>1)
515 std::cout << "all bands written" << std::endl;
516 ogrWriter.addPoint(writeBuffer[isample][0],writeBuffer[isample][1],pointAttributes,fieldname,isample);
517 progress=static_cast<float>(isample+1.0)/writeBuffer.size();
518 pfnProgress(progress,pszMessage,pProgressArg);
519 }
520 ogrWriter.close();
521 }
522 else{
523 std::cout << "No data found for any class " << std::endl;
524 }
525 classReader.close();
526 nsample=writeBuffer.size();
527 if(verbose_opt[0])
528 std::cout << "total number of samples written: " << nsample << std::endl;
529 }
530 else{//class_opt.size()!=0
531 assert(class_opt[0]);
532 // if(class_opt[0]){
533 assert(threshold_opt.size()==1||threshold_opt.size()==class_opt.size());
534 ImgReaderGdal classReader;
535 ImgWriterOgr ogrWriter;
536 if(verbose_opt[0]>1){
537 std::cout << "reading position from sample dataset " << std::endl;
538 std::cout << "class thresholds: " << std::endl;
539 for(int iclass=0;iclass<class_opt.size();++iclass){
540 if(threshold_opt.size()>1)
541 std::cout << class_opt[iclass] << ": " << threshold_opt[iclass] << std::endl;
542 else
543 std::cout << class_opt[iclass] << ": " << threshold_opt[0] << std::endl;
544 }
545 }
546 classReader.open(sample_opt[0]);
547 vector<int> classBuffer(classReader.nrOfCol());
548 // vector<double> classBuffer(classReader.nrOfCol());
549 Vector2d<double> imgBuffer(nband,imgReader.nrOfCol());//[band][col]
550 // vector<double> imgBuffer(nband);//[band]
551 vector<double> sample(2+nband);//x,y,band values
552 Vector2d<double> writeBuffer;
553 vector<int> writeBufferClass;
554 vector<int> selectedClass;
555 Vector2d<double> selectedBuffer;
556 double oldimgrow=-1;
557 int irow=0;
558 int icol=0;
559 if(verbose_opt[0]>1)
560 std::cout << "extracting sample from image..." << std::endl;
561 progress=0;
562 pfnProgress(progress,pszMessage,pProgressArg);
563 for(irow=0;irow<classReader.nrOfRow();++irow){
564 if(irow%down_opt[0])
565 continue;
566 classReader.readData(classBuffer,irow);
567 double x=0;//geo x coordinate
568 double y=0;//geo y coordinate
569 double iimg=0;//image x-coordinate in img image
570 double jimg=0;//image y-coordinate in img image
571
572 //find col in img
573 classReader.image2geo(icol,irow,x,y);
574 imgReader.geo2image(x,y,iimg,jimg);
575 //nearest neighbour
576 if(static_cast<int>(jimg)<0||static_cast<int>(jimg)>=imgReader.nrOfRow())
577 continue;
578 for(int iband=0;iband<nband;++iband){
579 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
580 imgReader.readData(imgBuffer[iband],static_cast<int>(jimg),theBand);
581 }
582
583 for(icol=0;icol<classReader.nrOfCol();++icol){
584 if(icol%down_opt[0])
585 continue;
586 int theClass=0;
587 // double theClass=0;
588 int processClass=-1;
589 if(class_opt.empty()){//process every class
590 if(classBuffer[icol]){
591 processClass=0;
592 theClass=classBuffer[icol];
593 }
594 }
595 else{
596 for(int iclass=0;iclass<class_opt.size();++iclass){
597 if(classBuffer[icol]==class_opt[iclass]){
598 processClass=iclass;
599 theClass=class_opt[iclass];
600 }
601 }
602 }
603 if(processClass>=0){
604 // if(classBuffer[icol]==class_opt[0]){
605 classReader.image2geo(icol,irow,x,y);
606 sample[0]=x;
607 sample[1]=y;
608 if(verbose_opt[0]>1){
609 std::cout.precision(12);
610 std::cout << theClass << " " << x << " " << y << std::endl;
611 }
612 //find col in img
613 imgReader.geo2image(x,y,iimg,jimg);
614 //nearest neighbour
615 iimg=static_cast<int>(iimg);
616 if(static_cast<int>(iimg)<0||static_cast<int>(iimg)>=imgReader.nrOfCol())
617 continue;
618 bool valid=true;
619
620 for(int iband=0;iband<nband;++iband){
621 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
622 if(srcnodata_opt.size()&&theBand==bndnodata_opt[0]){
623 // vector<int>::const_iterator bndit=bndnodata_opt.begin();
624 for(int inodata=0;inodata<srcnodata_opt.size()&&valid;++inodata){
625 if(imgBuffer[iband][iimg]==srcnodata_opt[inodata])
626 valid=false;
627 }
628 }
629 }
630 if(valid){
631 for(int iband=0;iband<imgBuffer.size();++iband){
632 sample[iband+2]=imgBuffer[iband][iimg];
633 }
634 float theThreshold=(threshold_opt.size()>1)?threshold_opt[processClass]:threshold_opt[0];
635 if(theThreshold>0){//percentual value
636 double p=static_cast<double>(rand())/(RAND_MAX);
637 p*=100.0;
638 if(p>theThreshold)
639 continue;//do not select for now, go to next column
640 }
641 // else if(nvalid.size()>processClass){//absolute value
642 // if(nvalid[processClass]>=-theThreshold)
643 // continue;//do not select any more pixels for this class, go to next column to search for other classes
644 // }
645 writeBuffer.push_back(sample);
646 writeBufferClass.push_back(theClass);
647 ++ntotalvalid;
648 if(nvalid.count(theClass))
649 nvalid[theClass]+=1;
650 else
651 nvalid[theClass]=1;
652 }
653 else{
654 ++ntotalinvalid;
655 if(ninvalid.count(theClass))
656 ninvalid[theClass]+=1;
657 else
658 ninvalid[theClass]=1;
659 }
660 }//processClass
661 }//icol
662 progress=static_cast<float>(irow+1.0)/classReader.nrOfRow();
663 pfnProgress(progress,pszMessage,pProgressArg);
664 }//irow
665 if(writeBuffer.size()>0){
666 assert(ntotalvalid==writeBuffer.size());
667 if(verbose_opt[0]>0)
668 std::cout << "creating image sample writer " << output_opt[0] << " with " << writeBuffer.size() << " samples (" << ntotalinvalid << " invalid)" << std::endl;
669 ogrWriter.open(output_opt[0],ogrformat_opt[0]);
670 char **papszOptions=NULL;
671 ostringstream slayer;
672 slayer << "training data";
673 std::string layername=slayer.str();
674 ogrWriter.createLayer(layername, imgReader.getProjection(), wkbPoint, papszOptions);
675 std::string fieldname="fid";//number of the point
676 ogrWriter.createField(fieldname,OFTInteger);
677 map<std::string,double> pointAttributes;
678 ogrWriter.createField(label_opt[0],labelType);
679 for(int iband=0;iband<nband;++iband){
680 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
681 ogrWriter.createField(fieldname_opt[iband],fieldType);
682 }
683 pfnProgress(progress,pszMessage,pProgressArg);
684 progress=0;
685 pfnProgress(progress,pszMessage,pProgressArg);
686
687 map<int,short> classDone;
688 Vector2d<double> writeBufferTmp;
689 vector<int> writeBufferClassTmp;
690
691 if(threshold_opt[0]<0){//absolute threshold
692 map<int,unsigned long int>::iterator mapit;
693 map<int,unsigned long int> ncopied;
694 for(mapit=nvalid.begin();mapit!=nvalid.end();++mapit)
695 ncopied[mapit->first]=0;
696
697 while(classDone.size()<nvalid.size()){
698 int index=rand()%writeBufferClass.size();
699 int theClass=writeBufferClass[index];
700 float theThreshold=threshold_opt[0];
701 if(threshold_opt.size()>1&&class_opt.size())
702 theThreshold=threshold_opt[classmap[theClass]];
703 theThreshold=-theThreshold;
704 if(ncopied[theClass]<theThreshold){
705 writeBufferClassTmp.push_back(*(writeBufferClass.begin()+index));
706 writeBufferTmp.push_back(*(writeBuffer.begin()+index));
707 writeBufferClass.erase(writeBufferClass.begin()+index);
708 writeBuffer.erase(writeBuffer.begin()+index);
709 ++(ncopied[theClass]);
710 }
711 else
712 classDone[theClass]=1;
713 if(ncopied[theClass]>=nvalid[theClass]){
714 classDone[theClass]=1;
715 }
716 }
717 writeBuffer=writeBufferTmp;
718 writeBufferClass=writeBufferClassTmp;
719 // while(classDone.size()<nvalid.size()){
720 // int index=rand()%writeBufferClass.size();
721 // int theClass=writeBufferClass[index];
722 // float theThreshold=threshold_opt[0];
723 // if(threshold_opt.size()>1&&class_opt.size())
724 // theThreshold=threshold_opt[classmap[theClass]];
725 // theThreshold=-theThreshold;
726 // if(nvalid[theClass]>theThreshold){
727 // writeBufferClass.erase(writeBufferClass.begin()+index);
728 // writeBuffer.erase(writeBuffer.begin()+index);
729 // --(nvalid[theClass]);
730 // }
731 // else
732 // classDone[theClass]=1;
733 // }
734 }
735
736 for(int isample=0;isample<writeBuffer.size();++isample){
737 pointAttributes[label_opt[0]]=writeBufferClass[isample];
738 for(int iband=0;iband<writeBuffer[0].size()-2;++iband){
739 int theBand=(band_opt.size()) ? band_opt[iband] : iband;
740 pointAttributes[fieldname_opt[iband]]=writeBuffer[isample][iband+2];
741 }
742 ogrWriter.addPoint(writeBuffer[isample][0],writeBuffer[isample][1],pointAttributes,fieldname,isample);
743 progress=static_cast<float>(isample+1.0)/writeBuffer.size();
744 pfnProgress(progress,pszMessage,pProgressArg);
745 }
746 ogrWriter.close();
747 }
748 else{
749 std::cout << "No data found for any class " << std::endl;
750 }
751 classReader.close();
752 nsample=writeBuffer.size();
753 if(verbose_opt[0]){
754 std::cout << "total number of samples written: " << nsample << std::endl;
755 if(nvalid.size()==class_opt.size()){
756 for(int iclass=0;iclass<class_opt.size();++iclass)
757 std::cout << "class " << class_opt[iclass] << " has " << nvalid[iclass] << " samples" << std::endl;
758 }
759 }
760 }
761 }
762 else{//vector dataset
763 cerr << "Error: vector sample not supported, consider using pkextractogr" << endl;
764 }//else (vector)
765 progress=1.0;
766 pfnProgress(progress,pszMessage,pProgressArg);
767 imgReader.close();
768}
769
int nrOfRow(void) const
Get the number of rows of this dataset.
bool geo2image(double x, double y, double &i, double &j) const
Convert georeferenced coordinates (x and y) to image coordinates (column and row)
std::string getProjection(void) const
Get the projection string (deprecated, use getProjectionRef instead)
int nrOfBand(void) const
Get the number of bands of this dataset.
int nrOfCol(void) const
Get the number of columns of this dataset.
Definition: ImgRasterGdal.h:98
bool image2geo(double i, double j, double &x, double &y) const
Convert image coordinates (column and row) to georeferenced coordinates (x and y)
void readData(T &value, int col, int row, int band=0)
Read a single pixel cell value at a specific column and row for a specific band (all indices start co...
Definition: ImgReaderGdal.h:95
void close(void)
Set the memory (in MB) to cache a number of rows in memory.
void open(const std::string &filename, const GDALAccess &readMode=GA_ReadOnly)
Open an image.