Fern-ベースの点分類器,平面物体検出器
OpenCV2.*で実装されたと噂の「Fern-ベースの点分類器,平面物体検出器」のサンプルコードを試してみました。
はじめにがんばる分、SURFに比べてかなり早い印象。
おためし
上:探した画像 下:特徴点を対応付けたもの
元の論文
概念は研究室のミーティングで聞いたことがあるような気がしてきた。
論文これからよみます。読めるかわからないけど。
- Mustafa Özuysal, Michael Calonder, Vincent Lepetit, Pascal Fua,"Fast KeyPoint Recognition Using Random Ferns,"IEEE Transactions on Pattern Analysis and Machine Intelligence, 15 Jan. 2009.(PDF)
- Vincent Lepetit, Pascal Fua,“Towards Recognizing Feature Points Using Classification Trees,” Technical Report IC/2004/74, EPFL, 2004. (PDF)
サンプルコードの動画対応版(C++)
「OpenCV2.1/samples/c/find_obj_ferns.cpp」をもとに以下の機能を追加しています。
後でつけられたら解説付けたいです。
- カメラからの入力
- 見つかった四角形が凸包かどうかを調べる
- 半回転のdetectにしか対応していなかったものを、1回転対応に(コメント欄にてご指摘いただいた後、修正。とおりすがりさんありがとうございました。)
#include <cv.h> #include <cvaux.h> #include "highgui.h" #include <algorithm> #include <iostream> #include <vector> using namespace cv; bool checkConvexHull(vector<Point2f>* p){ Mat points = Mat(*p, true); // vectorをMatに変換 return isContourConvex(points); } int main(int argc, char** argv) { //video VideoCapture cap(0); if(!cap.isOpened()){ return -1; } //window namedWindow("Object", 1); namedWindow("Image", 1); namedWindow("Object Correspondence", 1); // load object image const char* object_filename = "../data/tv_500_t.png"; Mat object = imread( object_filename, CV_LOAD_IMAGE_GRAYSCALE ); if( !object.data){ fprintf( stderr, "Can not load %s \n" "Usage: find_obj [<object_filename> ]\n", object_filename); exit(-1); } // find keypoints Size patchSize(32, 32); LDetector ldetector(7, 20, 2, 2000, patchSize.width, 2); ldetector.setVerbose(true); PlanarObjectDetector detector; vector<Mat> objpyr; int blurKSize = 3; double sigma = 0; GaussianBlur(object, object, Size(blurKSize, blurKSize), sigma, sigma); buildPyramid(object, objpyr, ldetector.nOctaves-1); vector<KeyPoint> objKeypoints; //PatchGenerator gen(0,256,5,true,0.8,1.2,-CV_PI/2,CV_PI/2,-CV_PI/2,CV_PI/2); // 半回転しか対応しないそうです。 PatchGenerator gen(0,256,5,true,0.8,1.2,-CV_PI,CV_PI,-CV_PI,CV_PI); string model_filename = format("%s_model.xml.gz", object_filename); printf("Trying to load %s ...\n", model_filename.c_str()); FileStorage fs(model_filename, FileStorage::READ); if( fs.isOpened() ){ detector.read(fs.getFirstTopLevelNode()); printf("Successfully loaded %s.\n", model_filename.c_str()); }else{ printf("The file not found and can not be read. Let's train the model.\n"); printf("Step 1. Finding the robust keypoints ...\n"); ldetector.setVerbose(true); ldetector.getMostStable2D(object, objKeypoints, 100, gen); printf("Done.\nStep 2. Training ferns-based planar object detector ...\n"); detector.setVerbose(true); detector.train(objpyr, objKeypoints, patchSize.width, 100, 11, 10000, ldetector, gen); printf("Done.\nStep 3. Saving the model to %s ...\n", model_filename.c_str()); if( fs.open(model_filename, FileStorage::WRITE) ) detector.write(fs, "ferns_model"); } printf("Now find the keypoints in the image, try recognize them and compute the homography matrix\n"); fs.release(); objKeypoints = detector.getModelPoints(); std::cout << "Object keypoints: " << objKeypoints.size() << "\n"; Mat objectColor; cvtColor(object, objectColor, CV_GRAY2BGR); for(int i = 0; i < (int)objKeypoints.size(); i++ ){ circle( objectColor, objKeypoints[i].pt, 2, Scalar(0,0,255), -1 ); circle( objectColor, objKeypoints[i].pt, (1 << objKeypoints[i].octave)*15, Scalar(0,255,0), 1 ); } imshow( "Object", objectColor ); Mat image; while(true){ // each frame Mat frame,frame_gray; cap >> frame; cvtColor(frame, frame_gray, CV_BGR2GRAY); double imgscale = 1; resize(frame_gray, image, Size(), 1./imgscale, 1./imgscale, INTER_CUBIC); GaussianBlur(image, image, Size(blurKSize, blurKSize), sigma, sigma); vector<Mat> imgpyr; buildPyramid(image, imgpyr, ldetector.nOctaves-1); Mat correspond( object.rows + image.rows, std::max(object.cols, image.cols), CV_8UC3); correspond = Scalar(0.); Mat part(correspond, Rect(0, 0, object.cols, object.rows)); cvtColor(object, part, CV_GRAY2BGR); part = Mat(correspond, Rect(0, object.rows, image.cols, image.rows)); cvtColor(image, part, CV_GRAY2BGR); double t = (double)getTickCount(); vector<KeyPoint> imgKeypoints; ldetector(imgpyr, imgKeypoints, 300); std::cout << "Image keypoints: " << imgKeypoints.size() << "\n"; vector<Point2f> dst_corners; vector<int> pairs; Mat H; bool found = detector(imgpyr, imgKeypoints, H, dst_corners, &pairs); t = (double)getTickCount() - t; printf("%gms\n", t*1000/getTickFrequency()); if( found // 凸包チェック && checkConvexHull(&dst_corners) ){ for(int i = 0; i < 4; i++ ){ Point r1 = dst_corners[i%4]; Point r2 = dst_corners[(i+1)%4]; line( correspond, Point(r1.x, r1.y+object.rows), Point(r2.x, r2.y+object.rows), Scalar(0,0,255) ); } } for(int i = 0; i < (int)pairs.size(); i += 2 ){ line( correspond, objKeypoints[pairs[i]].pt, imgKeypoints[pairs[i+1]].pt + Point2f(0,object.rows), Scalar(0,255,0) ); } imshow( "Object Correspondence", correspond ); Mat imageColor; cvtColor(image, imageColor, CV_GRAY2BGR); for(int i = 0; i < (int)imgKeypoints.size(); i++ ){ circle( imageColor, imgKeypoints[i].pt, 2, Scalar(0,0,255), -1 ); circle( imageColor, imgKeypoints[i].pt, (1 << imgKeypoints[i].octave)*15, Scalar(0,255,0), 1 ); } imshow( "Image", imageColor ); int c = waitKey(20); if(c== 's'){ imwrite("correspond.png", correspond ); }else if(c==27){ break; } } return 0; }