001/*******************************************************************************
002 * This software is provided as a supplement to the authors' textbooks on digital
003 * image processing published by Springer-Verlag in various languages and editions.
004 * Permission to use and distribute this software is granted under the BSD 2-Clause
005 * "Simplified" License (see http://opensource.org/licenses/BSD-2-Clause).
006 * Copyright (c) 2006-2023 Wilhelm Burger, Mark J. Burge. All rights reserved.
007 * Visit https://imagingbook.com for additional details.
008 ******************************************************************************/
009
010package imagingbook.common.color.quantize;
011
012import imagingbook.common.color.RgbUtils;
013import imagingbook.common.color.statistics.ColorHistogram;
014
015import java.util.Arrays;
016import java.util.LinkedList;
017import java.util.List;
018import java.util.Locale;
019import java.util.Random;
020
021import static imagingbook.common.math.Arithmetic.sqr;
022
023/**
024 * <p>
025 * This class implements color quantization using k-means clustering of image pixels in RGB color space. It provides two
026 * modes for selecting initial color clusters: (a) random sampling of the input colors, (b) using the K most frequent
027 * colors. Note that this implementation is mainly to demonstrate the concept. Depending on the data size and number of
028 * quantization colors, this process may be excessively slow compared to other methods. It is usually impractical even
029 * for medium-sized images. During clustering all input pixels are used, i.e., no stochastic sub-sampling is applied
030 * (which could make the process a lot more efficient).
031 * </p>
032 *
033 * @author WB
034 * @version 2022/11/06
035 * @see MedianCutQuantizer
036 * @see OctreeQuantizer
037 */
038public class KMeansClusteringQuantizer implements ColorQuantizer {
039
040        /**
041         * Seed for random number generation (set to a nonzero value to obtain repeatable results for debugging and
042         * testing).
043         */
044        public static long RandomSeed = 0;      
045        private final Random random = (RandomSeed == 0) ? new Random() : new Random(RandomSeed);
046        
047        public static int DefaultIterations = 500;
048        
049        private final ColorCluster[] clusters;
050        private final double totalError;
051        private final float[][] colormap;
052        
053        
054        /** Method for choosing initial color clusters. */
055        public enum InitialClusterMethod {
056                /** Use K different random colors to initialize clusters. */
057                Random,
058                /** Use the K most frequent image colors to initialize clusters. */
059                MostFrequent
060        };
061        
062        // --------------------------------------------------------------
063
064        /**
065         * Constructor, creates a new {@link KMeansClusteringQuantizer} with up to K colors, using default parameters.
066         *
067         * @param pixels an image as a aRGB-encoded int array
068         * @param K the desired number of colors (1 or more)
069         */
070        public KMeansClusteringQuantizer(int[] pixels, int K) {
071                this(pixels, K, InitialClusterMethod.Random, DefaultIterations);
072        }
073
074        /**
075         * Constructor, creates a new {@link KMeansClusteringQuantizer} with up to K colors, but never more than the number
076         * of colors found in the supplied pixel data.
077         *
078         * @param pixels an image as a aRGB-encoded int array
079         * @param K the desired number of colors (1 or more)
080         * @param initMethod the method to initialize color clusters ({@link InitialClusterMethod})
081         * @param maxIterations the maximum number of clustering iterations
082         */
083        public KMeansClusteringQuantizer(int[] pixels, int K, InitialClusterMethod initMethod, int maxIterations) {
084                clusters = initClusters(pixels, K, initMethod);
085                totalError = doCluster(pixels, maxIterations);
086                colormap = makeColorMap();
087        }
088        
089        // --------------------------------------------------------------
090
091        private ColorCluster[] initClusters(int[] pixels, int K, InitialClusterMethod method) {
092                int[] samples = null;
093                switch (method) {
094                case Random:
095                        samples = getRandomColors(pixels, K);
096                case MostFrequent:
097                        samples = getMostFrequentColors(pixels, K);
098                }
099                int k = Math.min(samples.length, K);
100                ColorCluster[] clstrs = new ColorCluster[k];    // create an array of k clusters
101                for (int i = 0; i < k; i++) {
102                        clstrs[i] = new ColorCluster(samples[i]);       // initialize cluster center
103                }
104                return clstrs; 
105        }
106
107        /**
108         * Returns (up to) k colors randomly selected from the given pixel data.
109         */
110        private int[] getRandomColors(int[] pixels, int k) {
111                ColorHistogram colorHist = new ColorHistogram(pixels);
112                int[] colors = colorHist.getColors();
113                if (colors.length <= k) {
114                        return colors;
115                }
116                else {
117                        shuffle(colors, this.random);   // randomly permute colors
118                        return Arrays.copyOf(colors, k);
119                }
120        }
121        
122        /**
123         * Perform random permutation on the specified array.
124         * https://stackoverflow.com/questions/1519736/random-shuffling-of-an-array
125         */
126        private void shuffle(int[] arr, Random random) {
127                for (int i = arr.length - 1; i > 0; i--) {
128                        int idx = random.nextInt(i + 1);
129                        int tmp = arr[idx];
130                        arr[idx] = arr[i];
131                        arr[i] = tmp;
132                }
133        }
134
135        /**
136         * Returns the (maximally) k most frequent color values in the given pixel data. If fewer than k colors are
137         * available, these are returned, i.e., the resulting array may have less than k elements.
138         */
139        private int[] getMostFrequentColors(int[] pixels, int k) {
140                ColorHistogram colorHist = new ColorHistogram(pixels, true);    // sorts color bins by frequency
141                int[] colors = colorHist.getColors();
142                if (colors.length <= k) {
143                        return colors;
144                }
145                else {
146                        return Arrays.copyOf(colors, k);
147                }
148        }
149        
150        private double doCluster(int[] pixels, int maxIterations) {
151                int changed = Integer.MAX_VALUE;
152                double distSum = Double.POSITIVE_INFINITY;
153                int j = 0;
154                while (changed > 0 && j < maxIterations) {
155                        distSum = assignSamples(pixels);
156                        changed = updateClusters();
157                        j++;
158                }
159                return distSum;
160        }
161
162        
163        private double assignSamples(int[] pixels) {
164                double distSum = 0;
165                for (int p : pixels) {
166                        double dist = addToClosestCluster(p);
167                        distSum = distSum + dist;
168                }
169                return distSum;
170        }
171        
172        private int updateClusters() {
173                int changed = 0;
174                for (ColorCluster c : clusters) {
175                        changed = changed + c.update();
176                }
177                return changed;
178        }
179        
180        private double addToClosestCluster(int p) {
181                double minDist = Double.POSITIVE_INFINITY;
182                ColorCluster closest = null;
183                for (ColorCluster c : clusters) {
184                        double d = c.getSquaredDistance(p);
185                        if (d < minDist) {
186                                minDist = d;
187                                closest = c;
188                        }
189                }
190                closest.addPixel(p);
191                return minDist;
192        }
193
194        private float[][] makeColorMap() {
195                List<float[]> colList = new LinkedList<>();
196                for (ColorCluster c : clusters) {
197                        if (!c.isEmpty()) {
198                                colList.add(c.getCenterColor());
199                        }
200                }               
201                return colList.toArray(new float[0][]);
202        }
203        
204        // ------- methods required by abstract super class -----------------------
205        
206        @Override
207        public float[][] getColorMap() {
208                return colormap;
209        }
210        
211        
212        // ------------------------------------------------------------------------
213        /**
214         * Lists the color clusters to System.out (for debugging only).
215         */
216        public void listClusters() {
217                for (ColorCluster c : clusters) {
218                        System.out.println(c.toString());
219                }
220        }
221
222        /**
223         * Returns the final clustering error, calculated as the sum of the squared distances of the color samples to the
224         * associated cluster centers. This calculation is performed during the final iteration.
225         *
226         * @return the final clustering error
227         */
228        public double getTotalError() {
229                return totalError;
230        }
231
232        // ------------------------------------------------------------------------
233        
234        /**
235         * This inner class represents a color cluster. */
236        private static class ColorCluster {
237                private int sR, sG, sB;                         // RGB sums of contained pixels
238                private int pcount;                                     // pixel counter, used during pixel assignment
239                private int population = 0;                     // number of pixels contained in this cluster
240                private double cR, cG, cB;                      // center of this cluster
241
242                private ColorCluster(int p) {
243                        int[] rgb = RgbUtils.intToRgb(p);
244                        cR = rgb[0];
245                        cG = rgb[1];
246                        cB = rgb[2];
247                        reset();
248                }
249
250                private float[] getCenterColor() {
251                        return new float[] {(float)cR, (float)cG, (float)cB};
252                }
253
254                private boolean isEmpty() {
255                        return (population == 0);
256                }
257
258                private void reset() {  // reset sums, used at the start of the pixel assignment.
259                        sR = 0;
260                        sG = 0;
261                        sB = 0;
262                        pcount = 0;
263                }
264                
265                private void addPixel(int p) {
266                        int[] rgb = RgbUtils.intToRgb(p);
267                        sR += rgb[0];
268                        sG += rgb[1];
269                        sB += rgb[2];
270                        pcount = pcount + 1;
271                }
272
273                /**
274                 * This method is invoked after all samples have been assigned to clusters. It updates the cluster's center and
275                 * returns by how much its population changed from the previous clustering (absolute count).
276                 *
277                 * @return the change in cluster population from the previous clustering
278                 */
279                private int update() {
280                        if (pcount > 0) {
281                                double scale = 1.0 / pcount;
282                                cR = sR * scale;
283                                cG = sG * scale;
284                                cB = sB * scale;
285                        }
286                        int changeCount = Math.abs(pcount - population);        // change in cluster population
287                        population = pcount;
288                        reset();
289                        return changeCount;     
290                }
291
292                /**
293                 * Calculates and returns the squared Euclidean distance between the color p and this cluster's center in RGB
294                 * space.
295                 *
296                 * @param p Color sample
297                 * @return squared distance to the cluster center
298                 */
299                private double getSquaredDistance(int p) {
300                        int[] rgb = RgbUtils.intToRgb(p);
301                        return sqr(rgb[0] - cR) + sqr(rgb[1] - cG) + sqr(rgb[2] - cB);
302                }
303                
304                @Override
305                public String toString() {
306                        return String.format(Locale.US, this.getClass().getSimpleName() +
307                                        ": ctr=(%.1f,%.1f,%.1f), pop=%d", cR, cG, cB, population);
308                }
309        }
310
311} 
312