Efficient Graph-Based Image Segmentation

ちょっとEfficient Graph-Based Image Segmentation*1のセグメント化で3色だけでなくもっと次元(例えば深度マップとか)を増やしたいんで練習用に書いてみた。
前処理の平滑化の方法はめんどくさいんでガウシアンフィルタを使った。sigmaを0にしても、本家の奴とは微妙に結果が違うけど気にしない。
手法自体の解説は論文を参照してください。
コードは続きから。

#include <stdlib.h>
#include <map>
#include <vector>
#include <math.h>
#include <stdio.h>
#include <iostream>

#include <cv.h>
#include <highgui.h>

using namespace std;

template<class T>
inline T square(T x) { return x * x; }

class UnionFind {
private:
  int *parent_;
  double *weights_;

  void Clear() {
    parent_ = NULL;
    weights_ = NULL;
  }
  UnionFind(const UnionFind &);
  UnionFind &operator=(const UnionFind &);
public:
  UnionFind() { Clear(); }
  UnionFind(int size) { Clear(); Create(size); }
  ~UnionFind() { Release(); }
  void Create(int size) {
    Release();
    parent_ = new int[size];
    weights_ = new double[size];
    memset(parent_, -1, sizeof(int) * size);
    memset(weights_, 0, sizeof(double) * size);
  }
  void Release() {
    delete [] parent_;
    delete [] weights_;
    Clear();
  }

  bool UnionSet(int x, int y, double w) {
    x = Root(x); y = Root(y);
    if (x == y) { return false; }
    if (parent_[x] > parent_[y]) { swap(x, y); }
    weights_[x] = max(w, max(weights_[x], weights_[y]));
    parent_[x] += parent_[y];
    parent_[y] = x;
    return true;
  }
  bool FindSet(int x, int y) {
    return Root(x) == Root(y);
  }
  int Root(int x) { return parent_[x] < 0 ? x : parent_[x] = Root(parent_[x]); }
  int Size(int x) { return -parent_[Root(x)]; }
  double Weight(int x) { return weights_[Root(x)]; }
  double Int(int x) { return Weight(x); }
};

struct Edge {
  int from;
  int to;
  double weight;
  Edge() {;}
  Edge(int from, int to, double weight) : from(from), to(to), weight(weight) {;}
  bool operator<(const Edge &rhs) const {
    return weight < rhs.weight;
  }
};



inline double calcWeight(const IplImage *src, int fx, int fy, int tx, int ty) {
  double ret = 0;
  const unsigned char *from = (unsigned char*)(src->imageData + fy * src->widthStep + fx * src->nChannels);
  const unsigned char *to = (unsigned char*)(src->imageData + ty * src->widthStep + tx * src->nChannels);
  for (int i = 0; i < src->nChannels; i++, from++, to++) {
    ret += square((int)*from - (int)*to);
  }
  ret = sqrt(ret);
  return ret;
}

inline void setColor(IplImage *dest, int fx, int fy, int color) {
  unsigned char *p = (unsigned char*)(dest->imageData + fy * dest->widthStep + fx * dest->nChannels);
  for (int i = 0; i < dest->nChannels; i++, p++) {
    *p = (color >> (i * 8)) & 255;
  }
}

inline double Tau(double k, double R) {
  return k / R;
}


void usage() {
  fprintf(stderr, "usage: ./segment sigma k min input output\n");
  exit(1);
}

int main(int argc, char *argv[])
{
  if (6 != argc) { usage(); }
  double sigma = atof(argv[1]);
  double k = atof(argv[2]);
  int minRegionSize = atoi(argv[3]);
  const char *input = argv[4];
  const char *output = argv[5];


  // load image
  IplImage *src = cvLoadImage(input);
  IplImage *dest = cvCreateImage(cvSize(src->width, src->height), IPL_DEPTH_8U, 3);
  const int width = src->width;
  const int height = src->height;
  if (sigma != 0) {
    cvSmooth(src, src, CV_GAUSSIAN, 0, 0, sigma);
  }


  // create edges
  vector<Edge> edges(width * height * 8);
  int m = 0;
  for (int y = 0; y < height; y++) {
    for (int x = 0; x < width; x++) {
      const int dx[4] = { 1, 1, 0, -1 };
      const int dy[4] = { 0, 1, 1, 1 };
      for (int dir = 0; dir < 4; dir++) {
        int nx = x + dx[dir];
        int ny = y + dy[dir];
        if (nx < 0 || nx >= width || ny < 0 || ny >= height) { continue; }
        double w = calcWeight(src, x, y, nx, ny);
        edges[m++] = Edge(y * width + x, ny * width + nx, w);
      }
    }
  }
  edges.resize(m);
  sort(edges.begin(), edges.end());


  // region merge
  UnionFind uFind(width * height);
  int regionCount = width * height;
  for (int i = 0; i < m; i++) {
    const Edge &edge = edges[i];
    if (uFind.FindSet(edge.from, edge.to)) { continue; }
    double MInt = min(uFind.Int(edge.from) + Tau(k, uFind.Size(edge.from)),
                      uFind.Int(edge.to)   + Tau(k, uFind.Size(edge.to)));
    if (edge.weight <= MInt) {
      uFind.UnionSet(edge.from, edge.to, edge.weight);
      regionCount--;
    }
  }

  // small region merge
  for (int i = 0; i < m; i++) {
    const Edge &edge = edges[i];
    if (uFind.FindSet(edge.from, edge.to)) { continue; }
    if (uFind.Size(edge.from) < minRegionSize || uFind.Size(edge.to) < minRegionSize) {
      uFind.UnionSet(edge.from, edge.to, edge.weight);
      regionCount--;
    }
  }

  // set color
  vector<int> color(regionCount);
  map<int, int> label;
  for (int i = 0; i < regionCount; i++) {
    color[i] = rand() % (1 << 24);
  }
  int n = 0;
  for (int y = 0; y < height; y++) {
    for (int x = 0; x < width; x++) {
      int root = uFind.Root(y * width + x);
      if (!label.count(root)) { label[root] = n++; }
      setColor(dest, x, y, color[label[root]]);
    }
  }
  assert(regionCount == n);

  printf("got %d components\n", regionCount);
  cvSaveImage(output, dest);

  //cvNamedWindow("segment", CV_WINDOW_AUTOSIZE);
  //cvShowImage("segment", dest);
  //cvWaitKey(0);

  // closing
  cvReleaseImage(&dest);
  cvReleaseImage(&src);
}