模板元编程实例---如何设计通用的几何库

模板元编程实例—如何设计通用的几何库

设计原理

假设你需要使用c++程序来计算两点间的距离.你可能会这样做:

  • 先定义一个struct:

    1
    2
    3
    4
    struct mypoint
    {
    double x, y;
    };
  • 然后定义一个包含计算算法的函数:

    1
    2
    3
    4
    5
    6
    double distance(mypoint const& a, mypoint const& b)
    {
    double dx = a.x - b.x;
    double dy = a.y - b.y;
    return sqrt(dx * dx + dy * dy);
    }

相当简单而实用,但是不够通用.一个库的设计需要考虑未来可能的变化.
上面的设计只能用于笛卡尔坐标系中的2D点.
通用的库需要能够计算如下距离:

  • 适用于任何point struct或者point class,而不是只适用于mypoint.
  • 不只是二维
  • 适用于其它坐标系统,如地球或球体上
  • 能够计算点与线或者其它几何图形之间的距离
  • double更高的精度
  • 尽可能避免使用sqrt:通常我们不希望调用它,因为它的开销比较大.而且对于比较距离时没有必要.

接下来,我们将一步一步给出一个更通用的实现.

使用模板

我们可以将距离函数改为模板函数.这样就可以计算除mypoint之外的其他点类型之间的距离.
我们添加两个模板参数,允许输入两种不同的点类型.

1
2
3
4
5
6
7
template <typename P1, typename P2>
double distance(P1 const& a, P2 const& b)
{
double dx = a.x - b.x;
double dy = a.y - b.y;
return std::sqrt(dx * dx + dy * dy);
}

模板版本比之前的实现好一些,但是还不够.
考虑c++类的成员变量为protected或者不能直接访问x,y.

使用Traits

我们需要使用一种更通用的方法来允许任意的点类型都能够作为距离函数的输入.
除了直接访问xy,我们将添加一层间接层,使用traits系统.
距离函数可以变为:

1
2
3
4
5
6
7
template <typename P1, typename P2>
double distance(P1 const& a, P2 const& b)
{
double dx = get<0>(a) - get<0>(b);
double dy = get<1>(a) - get<1>(b);
return std::sqrt(dx * dx + dy * dy);
}

上面的距离函数使用了get函数来访问一个点的坐标系统,使用点的维度作为模板参数.
get可以这样实现:

1
2
3
4
5
namespace traits
{
template <typename P, int D>
struct access {};
}

定义mypoint的模板特例:

1
2
3
4
5
6
7
8
9
10
11
12
13
namespace traits
{
template <>
struct access<mypoint, 0>
{
static double get(mypoint const& p)
{
return p.x;
}
};
// same for 1: p.y
...
}

现在通过调用traits::access<mypoint, 0>::get(a)就可以返回坐标系中的x.我们可以通过定义get来进一步简化调用方式:

1
2
3
4
5
template <int D, typename P>
inline double get(P const& p)
{
return traits::access<P, D>::get(p);
}

通过上面的实现,我们就可以对任何特化了traits::accesspoint a调用get<0>(a).
同样的原理,我们也可以实现对于坐标yget<1>(a).

任意维度

为了实现对任意维度的计算,我们可以通过循环来遍历所有维度.但是循环调用相对于直接计算会有性能开销.因此我们可以通过使用模板实现这样的算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
template <typename P1, typename P2, int D>
struct pythagoras
{
static double apply(P1 const& a, P2 const& b)
{
double d = get<D-1>(a) - get<D-1>(b);
return d * d + pythagoras<P1, P2, D-1>::apply(a, b);
}
};

template <typename P1, typename P2 >
struct pythagoras<P1, P2, 0>
{
static double apply(P1 const&, P2 const&)
{
return 0;
}
};

然后距离函数可以调用pythagoras并指定维度:

1
2
3
4
5
6
7
template <typename P1, typename P2>
double distance(P1 const& a, P2 const& b)
{
BOOST_STATIC_ASSERT(( dimension<P1>::value == dimension<P2>::value ));

return sqrt(pythagoras<P1, P2, dimension<P1>::value>::apply(a, b));
}

维度可以通过定义另外一个traits类来实现:

1
2
3
4
5
namespace traits
{
template <typename P>
struct dimension {};
}

然后针对相应的类(如mypoint)进行特例化,因为这个traits只是发布一个值,因此为了简便我们可以继承Boost.MPL中的class boost::mpl::int_:

1
2
3
4
5
6
namespace traits
{
template <>
struct dimension<mypoint> : boost::mpl::int_<2>
{};
}

现在我们就实现了对任意维度点进行计算距离的算法.我们还使用编译期断言来防止对两个不同维度的点进行计算.

坐标类型

在上面的实现中,我们假设了double类型,如果点是integer呢?

1
2
3
4
5
6
7
8
9
10
11
12
namespace traits
{
template <typename P>
struct coordinate_type{};

// specialization for our mypoint
template <>
struct coordinate_type<mypoint>
{
typedef double type;
};
}

access函数类似,我们同样添加一个代理:

1
2
template <typename P>
struct coordinate_type : traits::coordinate_type<P> {};

然后我们可以修改我们的距离计算函数.因为计算的两个point类型可能有不同的类型,我们必须处理这种情况.我们需要选择其中一种具有更高精度的类型作为结果类型,我们假设有一个select_most_precise元函数用于选择最佳类型.

这样我们的计算函数可以改为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
template <typename P1, typename P2, int D>
struct pythagoras
{
typedef typename select_most_precise
<
typename coordinate_type<P1>::type,
typename coordinate_type<P2>::type
>::type computation_type;

static computation_type apply(P1 const& a, P2 const& b)
{
computation_type d = get<D-1>(a) - get<D-1>(b);
return d * d + pythagoras <P1, P2, D-1> ::apply(a, b);
}
};

不同的形状

我们已经设计了一个支持任意维度和任意坐标系统中的点的实现.
现在我们需要看看如何支持计算点与多边形或者点与线之间的距离.
支持这些形式对之前的设计会有较大的影响,我们不想添加另外一个名称的函数,如:

1
2
template <typename P, typename S>
double distance_point_segment(P const& p, S const& s)

我们想更加通用,距离函数的调用者最好不用关心形状的类型,我们也无法通过重载类实现,因为模板的签名相同,会有二义性.
有两种解决方法:

  • tag dispatching
  • SFINAE

在这里,我们选择tag dispatching因为它适合于`traits`系统.

使用tag dispatching,距离计算算法检查输入的几何形状类型.
我们的距离函数将变成:

1
2
3
4
5
6
7
8
9
10
template <typename G1, typename G2>
double distance(G1 const& g1, G2 const& g2)
{
return dispatch::distance
<
typename tag<G1>::type,
typename tag<G2>::type,
G1, G2
>::apply(g1, g2);
}

使用tag元函数获取类型然后将调用转交给dispatch::distanceapply方法.
tag元函数是另一个traits类,需要被point类特例化:

1
2
3
4
5
6
7
8
9
10
11
12
namespace traits
{
template <typename G>
struct tag {};

// specialization
template <>
struct tag<mypoint>
{
typedef point_tag type;
};
}

Tags (point_tag, segment_tag, etc)是用于特例化dispatch struct的空结构.
distancedispatch struct和其特例化都定义于另外一个单独的命名空间中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
namespace dispatch {
template < typename Tag1, typename Tag2, typename G1, typename G2 >
struct distance
{};

template <typename P1, typename P2>
struct distance < point_tag, point_tag, P1, P2 >
{
static double apply(P1 const& a, P2 const& b)
{
// here we call pythagoras
// exactly like we did before
...
}
};

template <typename P, typename S>
struct distance
<
point_tag, segment_tag, P, S
>
{
static double apply(P const& p, S const& s)
{
// here we refer to another function
// implementing point-segment
// calculations in 2 or 3
// dimensions...
...
}
};

// here we might have many more
// specializations,
// for point-polygon, box-circle, etc.

} // namespace

现在,距离算法对所有不同的几何形状都是通用的.
还有一个缺点是:我们必须为point,segment特例化2个dispatch.

1
2
3
4
5
6
7
8
point a(1,1);
point b(2,2);
std::cout << distance(a,b) << std::endl;
segment s1(0,0,5,3);
std::cout << distance(a, s1) << std::endl;
rgb red(255, 0, 0);
rbc orange(255, 128, 0);
std::cout << "color distance: " << distance(red, orange) << std::endl;