1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.math.optimization;
19
20 import java.util.Arrays;
21 import java.util.Comparator;
22
23 import org.apache.commons.math.ConvergenceException;
24 import org.apache.commons.math.DimensionMismatchException;
25 import org.apache.commons.math.linear.RealMatrix;
26 import org.apache.commons.math.random.CorrelatedRandomVectorGenerator;
27 import org.apache.commons.math.random.JDKRandomGenerator;
28 import org.apache.commons.math.random.NotPositiveDefiniteMatrixException;
29 import org.apache.commons.math.random.RandomGenerator;
30 import org.apache.commons.math.random.RandomVectorGenerator;
31 import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
32 import org.apache.commons.math.random.UniformRandomGenerator;
33 import org.apache.commons.math.stat.descriptive.moment.VectorialCovariance;
34 import org.apache.commons.math.stat.descriptive.moment.VectorialMean;
35
36 /**
37 * This class implements simplex-based direct search optimization
38 * algorithms.
39 *
40 * <p>Direct search methods only use cost function values, they don't
41 * need derivatives and don't either try to compute approximation of
42 * the derivatives. According to a 1996 paper by Margaret H. Wright
43 * (<a href="http://cm.bell-labs.com/cm/cs/doc/96/4-02.ps.gz">Direct
44 * Search Methods: Once Scorned, Now Respectable</a>), they are used
45 * when either the computation of the derivative is impossible (noisy
46 * functions, unpredictable dicontinuities) or difficult (complexity,
47 * computation cost). In the first cases, rather than an optimum, a
48 * <em>not too bad</em> point is desired. In the latter cases, an
49 * optimum is desired but cannot be reasonably found. In all cases
50 * direct search methods can be useful.</p>
51 *
52 * <p>Simplex-based direct search methods are based on comparison of
53 * the cost function values at the vertices of a simplex (which is a
54 * set of n+1 points in dimension n) that is updated by the algorithms
55 * steps.</p>
56 *
57 * <p>Minimization can be attempted either in single-start or in
58 * multi-start mode. Multi-start is a traditional way to try to avoid
59 * being trapped in a local minimum and miss the global minimum of a
60 * function. It can also be used to verify the convergence of an
61 * algorithm. The various multi-start-enabled <code>minimize</code>
62 * methods return the best minimum found after all starts, and the
63 * {@link #getMinima getMinima} method can be used to retrieve all
64 * minima from all starts (including the one already provided by the
65 * {@link #minimize(CostFunction, int, ConvergenceChecker, double[],
66 * double[]) minimize} method).</p>
67 *
68 * <p>This class is the base class performing the boilerplate simplex
69 * initialization and handling. The simplex update by itself is
70 * performed by the derived classes according to the implemented
71 * algorithms.</p>
72 *
73 * @see CostFunction
74 * @see NelderMead
75 * @see MultiDirectional
76 * @version $Revision: 628000 $ $Date: 2008-02-15 03:31:48 -0700 (Fri, 15 Feb 2008) $
77 * @since 1.2
78 */
79 public abstract class DirectSearchOptimizer {
80
81 /** Simple constructor.
82 */
83 protected DirectSearchOptimizer() {
84 }
85
86 /** Minimizes a cost function.
87 * <p>The initial simplex is built from two vertices that are
88 * considered to represent two opposite vertices of a box parallel
89 * to the canonical axes of the space. The simplex is the subset of
90 * vertices encountered while going from vertexA to vertexB
91 * traveling along the box edges only. This can be seen as a scaled
92 * regular simplex using the projected separation between the given
93 * points as the scaling factor along each coordinate axis.</p>
94 * <p>The optimization is performed in single-start mode.</p>
95 * @param f cost function
96 * @param maxEvaluations maximal number of function calls for each
97 * start (note that the number will be checked <em>after</em>
98 * complete simplices have been evaluated, this means that in some
99 * cases this number will be exceeded by a few units, depending on
100 * the dimension of the problem)
101 * @param checker object to use to check for convergence
102 * @param vertexA first vertex
103 * @param vertexB last vertex
104 * @return the point/cost pairs giving the minimal cost
105 * @exception CostException if the cost function throws one during
106 * the search
107 * @exception ConvergenceException if none of the starts did
108 * converge (it is not thrown if at least one start did converge)
109 */
110 public PointCostPair minimize(CostFunction f, int maxEvaluations,
111 ConvergenceChecker checker,
112 double[] vertexA, double[] vertexB)
113 throws CostException, ConvergenceException {
114
115 // set up optimizer
116 buildSimplex(vertexA, vertexB);
117 setSingleStart();
118
119 // compute minimum
120 return minimize(f, maxEvaluations, checker);
121
122 }
123
124 /** Minimizes a cost function.
125 * <p>The initial simplex is built from two vertices that are
126 * considered to represent two opposite vertices of a box parallel
127 * to the canonical axes of the space. The simplex is the subset of
128 * vertices encountered while going from vertexA to vertexB
129 * traveling along the box edges only. This can be seen as a scaled
130 * regular simplex using the projected separation between the given
131 * points as the scaling factor along each coordinate axis.</p>
132 * <p>The optimization is performed in multi-start mode.</p>
133 * @param f cost function
134 * @param maxEvaluations maximal number of function calls for each
135 * start (note that the number will be checked <em>after</em>
136 * complete simplices have been evaluated, this means that in some
137 * cases this number will be exceeded by a few units, depending on
138 * the dimension of the problem)
139 * @param checker object to use to check for convergence
140 * @param vertexA first vertex
141 * @param vertexB last vertex
142 * @param starts number of starts to perform (including the
143 * first one), multi-start is disabled if value is less than or
144 * equal to 1
145 * @param seed seed for the random vector generator
146 * @return the point/cost pairs giving the minimal cost
147 * @exception CostException if the cost function throws one during
148 * the search
149 * @exception ConvergenceException if none of the starts did
150 * converge (it is not thrown if at least one start did converge)
151 */
152 public PointCostPair minimize(CostFunction f, int maxEvaluations,
153 ConvergenceChecker checker,
154 double[] vertexA, double[] vertexB,
155 int starts, long seed)
156 throws CostException, ConvergenceException {
157
158 // set up the simplex traveling around the box
159 buildSimplex(vertexA, vertexB);
160
161 // we consider the simplex could have been produced by a generator
162 // having its mean value at the center of the box, the standard
163 // deviation along each axe being the corresponding half size
164 double[] mean = new double[vertexA.length];
165 double[] standardDeviation = new double[vertexA.length];
166 for (int i = 0; i < vertexA.length; ++i) {
167 mean[i] = 0.5 * (vertexA[i] + vertexB[i]);
168 standardDeviation[i] = 0.5 * Math.abs(vertexA[i] - vertexB[i]);
169 }
170
171 RandomGenerator rg = new JDKRandomGenerator();
172 rg.setSeed(seed);
173 UniformRandomGenerator urg = new UniformRandomGenerator(rg);
174 RandomVectorGenerator rvg =
175 new UncorrelatedRandomVectorGenerator(mean, standardDeviation, urg);
176 setMultiStart(starts, rvg);
177
178 // compute minimum
179 return minimize(f, maxEvaluations, checker);
180
181 }
182
183 /** Minimizes a cost function.
184 * <p>The simplex is built from all its vertices.</p>
185 * <p>The optimization is performed in single-start mode.</p>
186 * @param f cost function
187 * @param maxEvaluations maximal number of function calls for each
188 * start (note that the number will be checked <em>after</em>
189 * complete simplices have been evaluated, this means that in some
190 * cases this number will be exceeded by a few units, depending on
191 * the dimension of the problem)
192 * @param checker object to use to check for convergence
193 * @param vertices array containing all vertices of the simplex
194 * @return the point/cost pairs giving the minimal cost
195 * @exception CostException if the cost function throws one during
196 * the search
197 * @exception ConvergenceException if none of the starts did
198 * converge (it is not thrown if at least one start did converge)
199 */
200 public PointCostPair minimize(CostFunction f, int maxEvaluations,
201 ConvergenceChecker checker,
202 double[][] vertices)
203 throws CostException, ConvergenceException {
204
205 // set up optimizer
206 buildSimplex(vertices);
207 setSingleStart();
208
209 // compute minimum
210 return minimize(f, maxEvaluations, checker);
211
212 }
213
214 /** Minimizes a cost function.
215 * <p>The simplex is built from all its vertices.</p>
216 * <p>The optimization is performed in multi-start mode.</p>
217 * @param f cost function
218 * @param maxEvaluations maximal number of function calls for each
219 * start (note that the number will be checked <em>after</em>
220 * complete simplices have been evaluated, this means that in some
221 * cases this number will be exceeded by a few units, depending on
222 * the dimension of the problem)
223 * @param checker object to use to check for convergence
224 * @param vertices array containing all vertices of the simplex
225 * @param starts number of starts to perform (including the
226 * first one), multi-start is disabled if value is less than or
227 * equal to 1
228 * @param seed seed for the random vector generator
229 * @return the point/cost pairs giving the minimal cost
230 * @exception NotPositiveDefiniteMatrixException if the vertices
231 * array is degenerated
232 * @exception CostException if the cost function throws one during
233 * the search
234 * @exception ConvergenceException if none of the starts did
235 * converge (it is not thrown if at least one start did converge)
236 */
237 public PointCostPair minimize(CostFunction f, int maxEvaluations,
238 ConvergenceChecker checker,
239 double[][] vertices,
240 int starts, long seed)
241 throws NotPositiveDefiniteMatrixException,
242 CostException, ConvergenceException {
243
244 try {
245 // store the points into the simplex
246 buildSimplex(vertices);
247
248 // compute the statistical properties of the simplex points
249 VectorialMean meanStat = new VectorialMean(vertices[0].length);
250 VectorialCovariance covStat = new VectorialCovariance(vertices[0].length, true);
251 for (int i = 0; i < vertices.length; ++i) {
252 meanStat.increment(vertices[i]);
253 covStat.increment(vertices[i]);
254 }
255 double[] mean = meanStat.getResult();
256 RealMatrix covariance = covStat.getResult();
257
258
259 RandomGenerator rg = new JDKRandomGenerator();
260 rg.setSeed(seed);
261 RandomVectorGenerator rvg =
262 new CorrelatedRandomVectorGenerator(mean,
263 covariance, 1.0e-12 * covariance.getNorm(),
264 new UniformRandomGenerator(rg));
265 setMultiStart(starts, rvg);
266
267 // compute minimum
268 return minimize(f, maxEvaluations, checker);
269
270 } catch (DimensionMismatchException dme) {
271 // this should not happen
272 throw new RuntimeException("internal error");
273 }
274
275 }
276
277 /** Minimizes a cost function.
278 * <p>The simplex is built randomly.</p>
279 * <p>The optimization is performed in single-start mode.</p>
280 * @param f cost function
281 * @param maxEvaluations maximal number of function calls for each
282 * start (note that the number will be checked <em>after</em>
283 * complete simplices have been evaluated, this means that in some
284 * cases this number will be exceeded by a few units, depending on
285 * the dimension of the problem)
286 * @param checker object to use to check for convergence
287 * @param generator random vector generator
288 * @return the point/cost pairs giving the minimal cost
289 * @exception CostException if the cost function throws one during
290 * the search
291 * @exception ConvergenceException if none of the starts did
292 * converge (it is not thrown if at least one start did converge)
293 */
294 public PointCostPair minimize(CostFunction f, int maxEvaluations,
295 ConvergenceChecker checker,
296 RandomVectorGenerator generator)
297 throws CostException, ConvergenceException {
298
299 // set up optimizer
300 buildSimplex(generator);
301 setSingleStart();
302
303 // compute minimum
304 return minimize(f, maxEvaluations, checker);
305
306 }
307
308 /** Minimizes a cost function.
309 * <p>The simplex is built randomly.</p>
310 * <p>The optimization is performed in multi-start mode.</p>
311 * @param f cost function
312 * @param maxEvaluations maximal number of function calls for each
313 * start (note that the number will be checked <em>after</em>
314 * complete simplices have been evaluated, this means that in some
315 * cases this number will be exceeded by a few units, depending on
316 * the dimension of the problem)
317 * @param checker object to use to check for convergence
318 * @param generator random vector generator
319 * @param starts number of starts to perform (including the
320 * first one), multi-start is disabled if value is less than or
321 * equal to 1
322 * @return the point/cost pairs giving the minimal cost
323 * @exception CostException if the cost function throws one during
324 * the search
325 * @exception ConvergenceException if none of the starts did
326 * converge (it is not thrown if at least one start did converge)
327 */
328 public PointCostPair minimize(CostFunction f, int maxEvaluations,
329 ConvergenceChecker checker,
330 RandomVectorGenerator generator,
331 int starts)
332 throws CostException, ConvergenceException {
333
334 // set up optimizer
335 buildSimplex(generator);
336 setMultiStart(starts, generator);
337
338 // compute minimum
339 return minimize(f, maxEvaluations, checker);
340
341 }
342
343 /** Build a simplex from two extreme vertices.
344 * <p>The two vertices are considered to represent two opposite
345 * vertices of a box parallel to the canonical axes of the
346 * space. The simplex is the subset of vertices encountered while
347 * going from vertexA to vertexB traveling along the box edges
348 * only. This can be seen as a scaled regular simplex using the
349 * projected separation between the given points as the scaling
350 * factor along each coordinate axis.</p>
351 * @param vertexA first vertex
352 * @param vertexB last vertex
353 */
354 private void buildSimplex(double[] vertexA, double[] vertexB) {
355
356 int n = vertexA.length;
357 simplex = new PointCostPair[n + 1];
358
359 // set up the simplex traveling around the box
360 for (int i = 0; i <= n; ++i) {
361 double[] vertex = new double[n];
362 if (i > 0) {
363 System.arraycopy(vertexB, 0, vertex, 0, i);
364 }
365 if (i < n) {
366 System.arraycopy(vertexA, i, vertex, i, n - i);
367 }
368 simplex[i] = new PointCostPair(vertex, Double.NaN);
369 }
370
371 }
372
373 /** Build a simplex from all its points.
374 * @param vertices array containing all vertices of the simplex
375 */
376 private void buildSimplex(double[][] vertices) {
377 int n = vertices.length - 1;
378 simplex = new PointCostPair[n + 1];
379 for (int i = 0; i <= n; ++i) {
380 simplex[i] = new PointCostPair(vertices[i], Double.NaN);
381 }
382 }
383
384 /** Build a simplex randomly.
385 * @param generator random vector generator
386 */
387 private void buildSimplex(RandomVectorGenerator generator) {
388
389 // use first vector size to compute the number of points
390 double[] vertex = generator.nextVector();
391 int n = vertex.length;
392 simplex = new PointCostPair[n + 1];
393 simplex[0] = new PointCostPair(vertex, Double.NaN);
394
395 // fill up the vertex
396 for (int i = 1; i <= n; ++i) {
397 simplex[i] = new PointCostPair(generator.nextVector(), Double.NaN);
398 }
399
400 }
401
402 /** Set up single-start mode.
403 */
404 private void setSingleStart() {
405 starts = 1;
406 generator = null;
407 minima = null;
408 }
409
410 /** Set up multi-start mode.
411 * @param starts number of starts to perform (including the
412 * first one), multi-start is disabled if value is less than or
413 * equal to 1
414 * @param generator random vector generator to use for restarts
415 */
416 private void setMultiStart(int starts, RandomVectorGenerator generator) {
417 if (starts < 2) {
418 this.starts = 1;
419 this.generator = null;
420 minima = null;
421 } else {
422 this.starts = starts;
423 this.generator = generator;
424 minima = null;
425 }
426 }
427
428 /** Get all the minima found during the last call to {@link
429 * #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
430 * minimize}.
431 * <p>The optimizer stores all the minima found during a set of
432 * restarts when multi-start mode is enabled. The {@link
433 * #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
434 * minimize} method returns the best point only. This method
435 * returns all the points found at the end of each starts, including
436 * the best one already returned by the {@link #minimize(CostFunction,
437 * int, ConvergenceChecker, double[], double[]) minimize} method.
438 * The array as one element for each start as specified in the constructor
439 * (it has one element only if optimizer has been set up for single-start).</p>
440 * <p>The array containing the minima is ordered with the results
441 * from the runs that did converge first, sorted from lowest to
442 * highest minimum cost, and null elements corresponding to the runs
443 * that did not converge (all elements will be null if the {@link
444 * #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
445 * minimize} method did throw a {@link ConvergenceException
446 * ConvergenceException}).</p>
447 * @return array containing the minima, or null if {@link
448 * #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
449 * minimize} has not been called
450 */
451 public PointCostPair[] getMinima() {
452 return (PointCostPair[]) minima.clone();
453 }
454
455 /** Minimizes a cost function.
456 * @param f cost function
457 * @param maxEvaluations maximal number of function calls for each
458 * start (note that the number will be checked <em>after</em>
459 * complete simplices have been evaluated, this means that in some
460 * cases this number will be exceeded by a few units, depending on
461 * the dimension of the problem)
462 * @param checker object to use to check for convergence
463 * @return the point/cost pairs giving the minimal cost
464 * @exception CostException if the cost function throws one during
465 * the search
466 * @exception ConvergenceException if none of the starts did
467 * converge (it is not thrown if at least one start did converge)
468 */
469 private PointCostPair minimize(CostFunction f, int maxEvaluations,
470 ConvergenceChecker checker)
471 throws CostException, ConvergenceException {
472
473 this.f = f;
474 minima = new PointCostPair[starts];
475
476 // multi-start loop
477 for (int i = 0; i < starts; ++i) {
478
479 evaluations = 0;
480 evaluateSimplex();
481
482 for (boolean loop = true; loop;) {
483 if (checker.converged(simplex)) {
484 // we have found a minimum
485 minima[i] = simplex[0];
486 loop = false;
487 } else if (evaluations >= maxEvaluations) {
488 // this start did not converge, try a new one
489 minima[i] = null;
490 loop = false;
491 } else {
492 iterateSimplex();
493 }
494 }
495
496 if (i < (starts - 1)) {
497 // restart
498 buildSimplex(generator);
499 }
500
501 }
502
503 // sort the minima from lowest cost to highest cost, followed by
504 // null elements
505 Arrays.sort(minima, pointCostPairComparator);
506
507 // return the found point given the lowest cost
508 if (minima[0] == null) {
509 throw new ConvergenceException("none of the {0} start points" +
510 " lead to convergence",
511 new Object[] {
512 Integer.toString(starts)
513 });
514 }
515 return minima[0];
516
517 }
518
519 /** Compute the next simplex of the algorithm.
520 * @exception CostException if the function cannot be evaluated at
521 * some point
522 */
523 protected abstract void iterateSimplex()
524 throws CostException;
525
526 /** Evaluate the cost on one point.
527 * <p>A side effect of this method is to count the number of
528 * function evaluations</p>
529 * @param x point on which the cost function should be evaluated
530 * @return cost at the given point
531 * @exception CostException if no cost can be computed for the parameters
532 */
533 protected double evaluateCost(double[] x)
534 throws CostException {
535 evaluations++;
536 return f.cost(x);
537 }
538
539 /** Evaluate all the non-evaluated points of the simplex.
540 * @exception CostException if no cost can be computed for the parameters
541 */
542 protected void evaluateSimplex()
543 throws CostException {
544
545 // evaluate the cost at all non-evaluated simplex points
546 for (int i = 0; i < simplex.length; ++i) {
547 PointCostPair pair = simplex[i];
548 if (Double.isNaN(pair.getCost())) {
549 simplex[i] = new PointCostPair(pair.getPoint(), evaluateCost(pair.getPoint()));
550 }
551 }
552
553 // sort the simplex from lowest cost to highest cost
554 Arrays.sort(simplex, pointCostPairComparator);
555
556 }
557
558 /** Replace the worst point of the simplex by a new point.
559 * @param pointCostPair point to insert
560 */
561 protected void replaceWorstPoint(PointCostPair pointCostPair) {
562 int n = simplex.length - 1;
563 for (int i = 0; i < n; ++i) {
564 if (simplex[i].getCost() > pointCostPair.getCost()) {
565 PointCostPair tmp = simplex[i];
566 simplex[i] = pointCostPair;
567 pointCostPair = tmp;
568 }
569 }
570 simplex[n] = pointCostPair;
571 }
572
573 /** Comparator for {@link PointCostPair PointCostPair} objects. */
574 private static Comparator pointCostPairComparator = new Comparator() {
575 public int compare(Object o1, Object o2) {
576 if (o1 == null) {
577 return (o2 == null) ? 0 : +1;
578 } else if (o2 == null) {
579 return -1;
580 }
581 double cost1 = ((PointCostPair) o1).getCost();
582 double cost2 = ((PointCostPair) o2).getCost();
583 return (cost1 < cost2) ? -1 : ((o1 == o2) ? 0 : +1);
584 }
585 };
586
587 /** Simplex. */
588 protected PointCostPair[] simplex;
589
590 /** Cost function. */
591 private CostFunction f;
592
593 /** Number of evaluations already performed. */
594 private int evaluations;
595
596 /** Number of starts to go. */
597 private int starts;
598
599 /** Random generator for multi-start. */
600 private RandomVectorGenerator generator;
601
602 /** Found minima. */
603 private PointCostPair[] minima;
604
605 }