?

Log in

No account? Create an account

K-Nearest Neighbor in D language - leonardo
View:Recent Entries.
View:Archive.
View:Friends.
View:Profile.
View:Website (My Website).

Security:
Subject:K-Nearest Neighbor in D language
Time:12:26 am
By leonardo maffi,
version 1.1.

Warning: currently nearest4/nearest5 don't work correctly, I'll try to fix them.

(Later this article will be moved to my site.)

These blog posts present a performance comparison between F# and OCaml implementations of a simple K-Nearest Neighbor (with K=1) classifier algorithm to recognize handwritten digits:

http://philtomson.github.io/blog/2014/05/29/comparing-a-machine-learning-algorithm-implemented-in-f-number-and-ocaml/

http://philtomson.github.io/blog/2014/05/30/stop-the-presses-ocaml-wins/

One of the Reddit discussions:

http://www.reddit.com/r/fsharp/comments/26vl3w/stop_the_presses_ocaml_wins_in_terms_of_speed_vs/

The original F# and OCaml code:
https://github.com/philtomson/ClassifyDigits

So I've studied and compared versions of this code in D language.

First I've written a functional-style version, 27 CLOC lines long, that is similar to the array-based OcaML version:
import std.stdio, std.algorithm, std.range, std.array, std.conv,
       std.string, std.math, std.typecons;

struct LabelPixels { ubyte label; ubyte[] pixels; }

immutable readData = (in string fileName) =>
    fileName
    .File
    .byLine
    .dropOne
    .map!(r => r.chomp.split(",").to!(ubyte[]))
    .map!(a => LabelPixels(a[0], a.dropOne))
    .array;

immutable distance = (in ubyte[] p1, in ubyte[] p2)
pure /*nothrow*/ @safe /*@nogc*/ =>
    double(p1.zip(p2).map!(ab => (ab[0] - ab[1]) ^^ 2).sum).sqrt;

immutable classify = (in LabelPixels[] trainingSet, in ubyte[] pixels)
/*pure nothrow*/ @safe =>
    trainingSet
    .map!(s => tuple(pixels.distance(s.pixels), ubyte(s.label)))
    .reduce!min[1];

void main() {
    const trainingSet = "training_sample.csv".readData;
    const validationSample = "validation_sample.csv".readData;

    immutable nCorrect = validationSample
                         .map!(s => trainingSet.classify(s.pixels) == s.label)
                         .sum;
    writeln("Percentage correct: ", double(nCorrect) / validationSample.length * 100);
}

Notes to this first version:
- Inside LabelPixels the pixes are represented with an unsigned byte each. This reduces the RAM used to store the data and allows to use the CPU cache more efficiently.
- "readData" is a lambda assigned to a global immutable value. This is not fully idiomatic D style, but I've used this because it looks more similar to the F# code. In practice I usually don't use such global lambdas because the output type gives me useful information to understand the purpose of the function, and they are less easy to find in the assembly (you can't just search for "readData" in the asm listing to find this lambda.
- The D standard library contains a module to read CSV files (std.csv), but I always find its API bad, so I've used File.byLine and I have processed the lines lazily.
- In readData I've used the eager function split instead of the lazy splitter (that uses less RAM) because for the data sets used by this program split is faster (readData loads both CSV files in about 0.8-0.9 seconds on my PC).
- In the "distance" lambda the nothrow and @nogc attributes are commented out because zip() doesn't yet respect such attributes (zip throws an exception in certain cases and the code to throw the exception allocates it on the heap). The "classify" lambda calls "distance" so it can't be nothrow and @nogc (and because reduce doesn't have a default value in case its input range is empty, so it can throw, and because classify contains a heap-allocating lambda).
- The distance lambda contains code like "map!(ab => (int(ab[0]) - ab[1])" because D is not yet able to de-structure tuples, like the 2-tuples generated lazily by zip.
- The classify lambda contains a map of tuples because the max/min functions of Phobos still lack an optional "key" function like the max/min functions of Python. So I've created tuples of the mapped value plus the label and I've used reduce to find the min. I have used "ubyte(s.label)" to produce a mutable tuple, that reduce doesn't accept (unless you use a mutable seed).

Most run-time of this program is used inside the distance function. So if you want to optimize this program that's the function to improve.

In the zip you can find a "nearest1b.d" file that is very similar to this first function, but it's adapted to the current version of the LDC2 compiler, that is a little older (it doesn't support the @nogc attribute, T(x) syntax for implicit casts, and its Phobos doesn't have a sum function).

The D compilers (even the well optimizing LDC2 complier) is not yet very good at optimizing that functional distance function. So in the "nearest2.d" version of the program (that you can find in the zip archive) I've written distance in imperative way, this produces a program that compiled with ldc is more than 3 times faster:
uint distance(in ubyte[] p1, in ubyte[] p2) pure nothrow @safe {
    uint tot = 0;
    foreach (immutable i, immutable p1i; p1)
        tot += (p1i - p2[i]) ^^ 2;
    return tot;
}

Now this distance function can be pure nothrow @safe. If you look this function also cheats a little not computing the square root, because it's not useful if you want to search the closest. This function also computes the total using integers.

I have not shown the asm generated by LDC2 for the nearest lambda in the "nearest1.d" program because it's too much complex and messy, but I can show the asm for this second simpler distance function:
__D8nearest38distanceFNaNbNfyAhyAhZk:
	pushl	%ebx
	pushl	%edi
	pushl	%esi
	movl	24(%esp), %ecx
	xorl	%eax, %eax
	testl	%ecx, %ecx
	je	LBB8_3
	movl	28(%esp), %edx
	movl	20(%esp), %esi
	.align	16, 0x90
LBB8_2:
	movzbl	(%edx), %edi
	movzbl	(%esi), %ebx
	subl	%ebx, %edi
	imull	%edi, %edi
	addl	%edi, %eax
	incl	%edx
	incl	%esi
	decl	%ecx
	jne	LBB8_2
LBB8_3:
	popl	%esi
	popl	%edi
	popl	%ebx
	ret	$16

As you see the loop gets compiled to quite simple asm with 9 X86 (32 bit) instructions.

To improve performance I have changed the data layout a little. With the "binary_generator1.d" program you can find in the zip archive I have generated binary files of the training set and validation sample. So now the loading of the data is very simple and very fast (see in the zip archive for the full source of this "nearest3.d" program:
enum nCols = 785;
enum size_t labelIndex = 0;
alias TData = ubyte;
alias TLabel = TData;

immutable(TData[nCols])[] readData(in string fileName) {
    return cast(typeof(return))std.file.read(fileName);
}

uint distance(immutable ref TData[nCols - 1] p1,
              immutable ref TData[nCols - 1] p2)
pure nothrow @safe /*@nogc*/ {
    uint tot = 0;
    foreach (immutable i, immutable p1i; p1)
        tot += (p1i - p2[i]) ^^ 2;
    return tot;
}

TLabel classify(immutable TData[nCols][] trainingSet,
                immutable ref TData[nCols - 1] pixels)
pure nothrow @safe /*@nogc*/ {
    auto closestDistance = uint.max;
    auto closestLabel = TLabel.max;

    foreach (immutable ref s; trainingSet) {
        immutable dist = pixels.distance(s[1 .. $]);
        if (dist < closestDistance) {
            closestDistance = dist;
            closestLabel = s[labelIndex];
        }
    }

    return closestLabel;
}

Notes:
- The classify function is now imperative, but this improves the performance just a little.
- @nogc is still commented out because the current version of the LDC2 compiler doesn't support it yet.
- The main performance difference of this third program comes the data layout. Now I have hard-coded (defined statically) the number of columns nCols. So now the data sets are essentially a ubyte[N][]. In D this is represented as a dense array, that minimize cache misses and reduce by one the number of levels of indirection. Thanks to the compile-time knowledge of the loop bounds inside the distance function, the LLVM back-end of the LDC2 compiler performs some loop unrolling (in theory it can be performed even before, but in practice the the version 3.4.1 of the LLVM is not performing loop unrolling on dynamic bounds).

So the asm of the distance function of the "nearest3.d" program is longer because of a partial loop unwinding:
__D8nearest38distanceFNaNbNfKyG784hKyG784hZk:
	pushl	%ebx
	pushl	%edi
	pushl	%esi
	xorl	%ecx, %ecx
	movl	$7, %edx
	movl	16(%esp), %esi
	.align	16, 0x90
LBB1_1:
	movzbl	-7(%esi,%edx), %edi
	movzbl	-7(%eax,%edx), %ebx
	subl	%ebx, %edi
	imull	%edi, %edi
	addl	%ecx, %edi
	movzbl	-6(%esi,%edx), %ecx
	movzbl	-6(%eax,%edx), %ebx
	subl	%ebx, %ecx
	imull	%ecx, %ecx
	addl	%edi, %ecx
	movzbl	-5(%esi,%edx), %edi
	movzbl	-5(%eax,%edx), %ebx
	subl	%ebx, %edi
	imull	%edi, %edi
	addl	%ecx, %edi
	movzbl	-4(%esi,%edx), %ecx
	movzbl	-4(%eax,%edx), %ebx
	subl	%ebx, %ecx
	imull	%ecx, %ecx
	addl	%edi, %ecx
	movzbl	-3(%esi,%edx), %edi
	movzbl	-3(%eax,%edx), %ebx
	subl	%ebx, %edi
	imull	%edi, %edi
	addl	%ecx, %edi
	movzbl	-2(%esi,%edx), %ecx
	movzbl	-2(%eax,%edx), %ebx
	subl	%ebx, %ecx
	imull	%ecx, %ecx
	addl	%edi, %ecx
	movzbl	-1(%esi,%edx), %edi
	movzbl	-1(%eax,%edx), %ebx
	subl	%ebx, %edi
	imull	%edi, %edi
	addl	%ecx, %edi
	movzbl	(%esi,%edx), %ecx
	movzbl	(%eax,%edx), %ebx
	subl	%ebx, %ecx
	imull	%ecx, %ecx
	addl	%edi, %ecx
	addl	$8, %edx
	cmpl	$791, %edx
	jne	LBB1_1
	movl	%ecx, %eax
	popl	%esi
	popl	%edi
	popl	%ebx
	ret	$4

Thanks to the data layout changes and other smaller improvements, the "nearest3.d" function is almost three times faster than "nearest2.d".

To reduce the run time and better utilize one CPU core, we can use the SIMD registers of the CPU.

To perform a subtraction of the arrays I change the data layout again, this time I represent the pixels with short (in D signed 16 bit integers). So I use a short8 to perform several subtractions in parallel. A problem comes from the label, if I slice it away as in the "nearest3.d" program, the array pointers are not aligned to 16 bytes and this is not accepted by my CPU. To solve this problem I add a padding of 7 short in every line of the matrix after the label. This is performed by the "binary_generator2.d" program.

So the 4th version of the program contains (see in the zip archive for the full code of "nearest4.d"):
enum nCols = 785;
enum size_t labelIndex = 0;
alias TData = short;
alias TLabel = TData;

immutable(TData[nCols + 7])[] readData(in string fileName) {
    return cast(typeof(return))std.file.read(fileName);
}

uint distance(immutable ref TData[nCols - 1] p1,
              immutable ref TData[nCols - 1] p2) pure nothrow /*@nogc*/ {
    alias TV = short8;
    enum size_t Vlen = TV.init.array.length;
    assert(p1.length % Vlen == 0);
    immutable v1 = cast(immutable TV*)p1.ptr;
    immutable v2 = cast(immutable TV*)p2.ptr;

    TV totV = 0;
    foreach (immutable i; 0 .. p1.length / Vlen) {
        TV d = v1[i] - v2[i];
        totV += d * d;
    }

    uint tot = 0;
    foreach (immutable t; totV.array)
        tot += t;
    return tot;
}

TLabel classify(immutable TData[nCols + 7][] trainingSet,
                immutable ref TData[nCols - 1] pixels) pure nothrow /*@nogc*/ {
    auto closestDistance = uint.max;
    auto closestLabel = short.max;

    foreach (immutable ref s; trainingSet) {
        immutable dist = pixels.distance(s[8 .. $]);
        if (dist < closestDistance) {
            closestDistance = dist;
            closestLabel = s[labelIndex];
        }
    }

    return closestLabel;
}

Notes:
- The classify function is not changed much, it performs the slicing to throw away the first eight (short label + 7 padding short) items.
- The distance function is a little more complex, but not too much. It uses the basic SIMD operations like subtraction, sum, product, plus the ".array" attribute that allows me to manage a short8 as a fixed-size value to sum all its contents.
- The code of distance is designed to be adaptable to other sizes of SIMD registers (but it's not able to adapt automatically).
- This "nearest4.d" program is only 1.6 times faster than "nearest4.d" probably because the SIMD code that manages shorts is not very clean.

The asm for the distance function of the "nearest4.d" program is rather long, despite being fast:
__D8nearest48distanceFNaNbKyG784sKyG784sZk:
	pushl	%edi
	pushl	%esi
	pxor	%xmm0, %xmm0
	movl	$208, %ecx
	movl	12(%esp), %edx
	.align	16, 0x90
LBB1_1:
	movdqa	-208(%edx,%ecx), %xmm1
	movdqa	-192(%edx,%ecx), %xmm2
	psubw	-208(%eax,%ecx), %xmm1
	pmullw	%xmm1, %xmm1
	paddw	%xmm0, %xmm1
	psubw	-192(%eax,%ecx), %xmm2
	pmullw	%xmm2, %xmm2
	paddw	%xmm1, %xmm2
	movdqa	-176(%edx,%ecx), %xmm0
	psubw	-176(%eax,%ecx), %xmm0
	pmullw	%xmm0, %xmm0
	paddw	%xmm2, %xmm0
	movdqa	-160(%edx,%ecx), %xmm1
	psubw	-160(%eax,%ecx), %xmm1
	pmullw	%xmm1, %xmm1
	paddw	%xmm0, %xmm1
	movdqa	-144(%edx,%ecx), %xmm0
	psubw	-144(%eax,%ecx), %xmm0
	pmullw	%xmm0, %xmm0
	paddw	%xmm1, %xmm0
	movdqa	-128(%edx,%ecx), %xmm1
	psubw	-128(%eax,%ecx), %xmm1
	pmullw	%xmm1, %xmm1
	paddw	%xmm0, %xmm1
	movdqa	-112(%edx,%ecx), %xmm0
	psubw	-112(%eax,%ecx), %xmm0
	pmullw	%xmm0, %xmm0
	paddw	%xmm1, %xmm0
	movdqa	-96(%edx,%ecx), %xmm1
	psubw	-96(%eax,%ecx), %xmm1
	pmullw	%xmm1, %xmm1
	paddw	%xmm0, %xmm1
	movdqa	-80(%edx,%ecx), %xmm0
	psubw	-80(%eax,%ecx), %xmm0
	pmullw	%xmm0, %xmm0
	paddw	%xmm1, %xmm0
	movdqa	-64(%edx,%ecx), %xmm1
	psubw	-64(%eax,%ecx), %xmm1
	pmullw	%xmm1, %xmm1
	paddw	%xmm0, %xmm1
	movdqa	-48(%edx,%ecx), %xmm0
	psubw	-48(%eax,%ecx), %xmm0
	pmullw	%xmm0, %xmm0
	paddw	%xmm1, %xmm0
	movdqa	-32(%edx,%ecx), %xmm1
	psubw	-32(%eax,%ecx), %xmm1
	pmullw	%xmm1, %xmm1
	paddw	%xmm0, %xmm1
	movdqa	-16(%edx,%ecx), %xmm2
	psubw	-16(%eax,%ecx), %xmm2
	pmullw	%xmm2, %xmm2
	paddw	%xmm1, %xmm2
	movdqa	(%edx,%ecx), %xmm0
	psubw	(%eax,%ecx), %xmm0
	pmullw	%xmm0, %xmm0
	paddw	%xmm2, %xmm0
	addl	$224, %ecx
	cmpl	$1776, %ecx
	jne	LBB1_1
	pshufd	$3, %xmm0, %xmm1
	movd	%xmm1, %eax
	movdqa	%xmm0, %xmm1
	movhlps	%xmm1, %xmm1
	movd	%xmm1, %ecx
	pshufd	$1, %xmm0, %xmm1
	movd	%xmm1, %edx
	movd	%xmm0, %esi
	movswl	%si, %edi
	sarl	$16, %esi
	addl	%edi, %esi
	movswl	%dx, %edi
	addl	%esi, %edi
	sarl	$16, %edx
	addl	%edi, %edx
	movswl	%cx, %esi
	addl	%edx, %esi
	sarl	$16, %ecx
	addl	%esi, %ecx
	movswl	%ax, %edx
	addl	%ecx, %edx
	sarl	$16, %eax
	addl	%edx, %eax
	popl	%esi
	popl	%edi
	ret	$4

The run time of the varions versions, in seconds, best of 3:
  nearest1:   91     (dmd)
  nearest1b:  19.7   (ldc2)
  nearest2:    5.91  (ldc2)
  nearest3:    2.05  (ldc2)
  nearest4:    1.28  (ldc2)

I have compiled the programs with:
dmd -wi -d -O -release -inline -boundscheck=off nearest1.d
ldmd2 -wi -unroll-allow-partial -O -release -inline -noboundscheck nearest1b.d
ldmd2 -wi -unroll-allow-partial -O -release -inline -noboundscheck nearest2.d
ldmd2 -wi -unroll-allow-partial -O -release -inline -noboundscheck nearest3.d
ldmd2 -wi -unroll-allow-partial -O -release -inline -noboundscheck nearest4.d
strip nearest1b.exe nearest2.exe nearest3.exe nearest4.exe

I have used the compilers:

DMD32 D Compiler v2.066

And ldc2 V.0.13.0-beta1, based on DMD v2.064 and LLVM 3.4.1, Default target: i686-pc-mingw32, Host CPU: core2.

On my 32 bit Windows the binaries generated are:
    1.692 nearest1.d
  221.212 nearest1.exe
    1.757 nearest1b.d
1.079.822 nearest1b.ex
    1.716 nearest2.d
1.070.606 nearest2.exe
    2.059 nearest2b.d
1.071.118 nearest2b.exe
    2.037 nearest3.d
1.353.230 nearest3.exe
    2.365 nearest4.d
1.353.742 nearest4.exe
    2.513 nearest5.d

In the zip archive I have also added a "nearest5.d" program that shows a recent improvement of the D type system (not yet available in ldc2), to support fixed-size array length inference for slices passed to templated functions.

With a run-time 1.28 seconds for the "nearest4.d" program there are still ways to reduce the run-time. I am using an old 2.3 GHz CPU, so if you use a modern Intel 4 GHz CPU like the Core i7-4790K and with faster memory, you can see a significant speedup. If you use a modern CPU you can also use YMM registers with short16 SIMD, that should offer some speedup (LDC2 is already able to target such YMM registers, you need to replace short8 with short16 in the nearest4.d program and change the padding). And then this program is very easy to parallelize for two or four cores, you can use the std.parallelism module for the computation of nCorrect with map+reduce in the main function, that should give some speedup.

You can find all the code and data here:
http://www.fantascienza.net/leonardo/js/nearest.7z

(This 7z archive contains the data sets too, I hope this is OK. Otherwise I'll remove them from the 7z archive.)
comments: Leave a comment Previous Entry Share Next Entry


thedeemon
Link:(Link)
Time:2014-06-10 04:25 am (UTC)
Can you run nearest4 built with dmd? That would show apart speed up from manual optimization and from compiler change.
(Reply) (Thread)

leonardo_m
Link:(Link)
Time:2014-06-10 07:25 am (UTC)
I can't because I am using a 32 bit system and currently DMD supports SIMD only on 64 bit systems.
(Reply) (Parent) (Thread)


thedeemon
Link:(Link)
Time:2014-06-10 07:39 am (UTC)
Ouch, I didn't know about this limitation. core.simd docs are silent about it.
(Reply) (Parent) (Thread)


thedeemon
Link:(Link)
Time:2014-06-11 07:55 am (UTC)
>The "classify" lambda calls "distance" so it can't be nothrow and @nogc (and because reduce doesn't have a default value in case its input range is empty, so it can throw).

At least in DMD 2.065 the first version of "classify" actually allocates 12 bytes on each call to store things captured by lambda, so it cannot be @nogc even were it not using "zip". So during processing stage there are 500 allocations overall. I wonder if 2.066 will do better.
(Reply) (Thread)

leonardo_m
Link:(Link)
Time:2014-06-11 11:47 am (UTC)
>At least in DMD 2.065 the first version of "classify" actually allocates 12 bytes on each call to store things captured by lambda, so it cannot be @nogc even were it not using "zip". So during processing stage there are 500 allocations overall.

Right. And they don't influence much the overall run-time of that version of the program.


>I wonder if 2.066 will do better.

I think lambda closures have not changed in the meantime, and there are no immediate plans in changing them. It's a price to pay to avoid the very large complexity and unsafety of the C++11 lambdas, or the complexity and strictness of the Rust type system.

I think the profiler allows you to spot the performance problems and @nogc and the -vgc switch to spot the hidden heap allocations. If you want to remove the allocations you can often find workarounds (not in this case, it seems), and when you can't, you can rewrite the function in a more imperative and simpler style to remove the heap allocation.

Edited at 2014-06-11 11:50 am (UTC)
(Reply) (Parent) (Thread)

K-Nearest Neighbor in D language - leonardo
View:Recent Entries.
View:Archive.
View:Friends.
View:Profile.
View:Website (My Website).