View Javadoc

1   /*
2    *  Licensed to the Apache Software Foundation (ASF) under one
3    *  or more contributor license agreements.  See the NOTICE file
4    *  distributed with this work for additional information
5    *  regarding copyright ownership.  The ASF licenses this file
6    *  to you under the Apache License, Version 2.0 (the
7    *  "License"); you may not use this file except in compliance
8    *  with the License.  You may obtain a copy of the License at
9    *
10   *    http://www.apache.org/licenses/LICENSE-2.0
11   *
12   *  Unless required by applicable law or agreed to in writing,
13   *  software distributed under the License is distributed on an
14   *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   *  KIND, either express or implied.  See the License for the
16   *  specific language governing permissions and limitations
17   *  under the License.
18   *
19   */
20  package org.apache.mina.filter.ssl;
21  
22  import java.net.InetSocketAddress;
23  import java.nio.ByteBuffer;
24  import java.util.Queue;
25  import java.util.concurrent.ConcurrentLinkedQueue;
26  
27  import javax.net.ssl.SSLContext;
28  import javax.net.ssl.SSLEngine;
29  import javax.net.ssl.SSLEngineResult;
30  import javax.net.ssl.SSLException;
31  import javax.net.ssl.SSLHandshakeException;
32  
33  import org.apache.mina.core.buffer.IoBuffer;
34  import org.apache.mina.core.filterchain.IoFilterEvent;
35  import org.apache.mina.core.filterchain.IoFilter.NextFilter;
36  import org.apache.mina.core.future.DefaultWriteFuture;
37  import org.apache.mina.core.future.WriteFuture;
38  import org.apache.mina.core.session.IoEventType;
39  import org.apache.mina.core.session.IoSession;
40  import org.apache.mina.core.write.DefaultWriteRequest;
41  import org.apache.mina.core.write.WriteRequest;
42  import org.apache.mina.util.CircularQueue;
43  import org.slf4j.Logger;
44  import org.slf4j.LoggerFactory;
45  
46  /**
47   * A helper class using the SSLEngine API to decrypt/encrypt data.
48   * <p/>
49   * Each connection has a SSLEngine that is used through the lifetime of the connection.
50   * We allocate buffers for use as the outbound and inbound network buffers.
51   * These buffers handle all of the intermediary data for the SSL connection. To make things easy,
52   * we'll require outNetBuffer be completely flushed before trying to wrap any more data.
53   *
54   * @author The Apache MINA Project (dev@mina.apache.org)
55   * @version $Rev: 713364 $, $Date: 2008-11-12 14:35:51 +0100 (Wed, 12 Nov 2008) $
56   */
57  class SslHandler {
58  
59      private final Logger logger = LoggerFactory.getLogger(getClass());
60      private final SslFilter parent;
61      private final SSLContext sslContext;
62      private final IoSession session;
63      private final Queue<IoFilterEvent> preHandshakeEventQueue = new CircularQueue<IoFilterEvent>();
64      private final Queue<IoFilterEvent> filterWriteEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
65      private final Queue<IoFilterEvent> messageReceivedEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
66      private SSLEngine sslEngine;
67  
68      /**
69       * Encrypted data from the net
70       */
71      private IoBuffer inNetBuffer;
72  
73      /**
74       * Encrypted data to be written to the net
75       */
76      private IoBuffer outNetBuffer;
77  
78      /**
79       * Applicaton cleartext data to be read by application
80       */
81      private IoBuffer appBuffer;
82  
83      /**
84       * Empty buffer used during initial handshake and close operations
85       */
86      private final IoBuffer emptyBuffer = IoBuffer.allocate(0);
87  
88      private SSLEngineResult.HandshakeStatus handshakeStatus;
89      private boolean initialHandshakeComplete;
90      private boolean handshakeComplete;
91      private boolean writingEncryptedData;
92  
93      /**
94       * Constuctor.
95       *
96       * @param sslc
97       * @throws SSLException
98       */
99      public SslHandler(SslFilter parent, SSLContext sslContext, IoSession session)
100             throws SSLException {
101         this.parent = parent;
102         this.session = session;
103         this.sslContext = sslContext;
104         init();
105     }
106 
107     /**
108      * Initialize the SSL handshake.
109      *
110      * @throws SSLException
111      */
112     public void init() throws SSLException {
113         if (sslEngine != null) {
114             // We already have a SSL engine created, no need to create a new one
115             return;
116         }
117 
118         InetSocketAddress peer = (InetSocketAddress) session
119                 .getAttribute(SslFilter.PEER_ADDRESS);
120         
121         // Create the SSL engine here
122         if (peer == null) {
123             sslEngine = sslContext.createSSLEngine();
124         } else {
125             sslEngine = sslContext.createSSLEngine(peer.getHostName(), peer.getPort());
126         }
127         
128         // Initialize the engine in client mode if necessary
129         sslEngine.setUseClientMode(parent.isUseClientMode());
130 
131         // Initialize the different SslEngine modes
132         if (parent.isWantClientAuth()) {
133             sslEngine.setWantClientAuth(true);
134         }
135 
136         if (parent.isNeedClientAuth()) {
137             sslEngine.setNeedClientAuth(true);
138         }
139 
140         if (parent.getEnabledCipherSuites() != null) {
141             sslEngine.setEnabledCipherSuites(parent.getEnabledCipherSuites());
142         }
143 
144         if (parent.getEnabledProtocols() != null) {
145             sslEngine.setEnabledProtocols(parent.getEnabledProtocols());
146         }
147 
148         // TODO : we may not need to call this method...
149         sslEngine.beginHandshake();
150         
151         
152         handshakeStatus = sslEngine.getHandshakeStatus();
153 
154         handshakeComplete = false;
155         initialHandshakeComplete = false;
156         writingEncryptedData = false;
157     }
158 
159     /**
160      * Release allocated buffers.
161      */
162     public void destroy() {
163         if (sslEngine == null) {
164             return;
165         }
166 
167         // Close inbound and flush all remaining data if available.
168         try {
169             sslEngine.closeInbound();
170         } catch (SSLException e) {
171             logger.debug(
172                     "Unexpected exception from SSLEngine.closeInbound().", e);
173         }
174 
175 
176         if (outNetBuffer != null) {
177             outNetBuffer.capacity(sslEngine.getSession().getPacketBufferSize());
178         } else {
179             createOutNetBuffer(0);
180         }
181         try {
182             do {
183                 outNetBuffer.clear();
184             } while (sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf()).bytesProduced() > 0);
185         } catch (SSLException e) {
186             // Ignore.
187         } finally {
188             destroyOutNetBuffer();
189         }
190 
191         sslEngine.closeOutbound();
192         sslEngine = null;
193 
194         preHandshakeEventQueue.clear();
195     }
196 
197     private void destroyOutNetBuffer() {
198         outNetBuffer.free();
199         outNetBuffer = null;
200     }
201 
202     public SslFilter getParent() {
203         return parent;
204     }
205 
206     public IoSession getSession() {
207         return session;
208     }
209 
210     /**
211      * Check we are writing encrypted data.
212      */
213     public boolean isWritingEncryptedData() {
214         return writingEncryptedData;
215     }
216 
217     /**
218      * Check if handshake is completed.
219      */
220     public boolean isHandshakeComplete() {
221         return handshakeComplete;
222     }
223 
224     public boolean isInboundDone() {
225         return sslEngine == null || sslEngine.isInboundDone();
226     }
227 
228     public boolean isOutboundDone() {
229         return sslEngine == null || sslEngine.isOutboundDone();
230     }
231 
232     /**
233      * Check if there is any need to complete handshake.
234      */
235     public boolean needToCompleteHandshake() {
236         return handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP && !isInboundDone();
237     }
238 
239     public void schedulePreHandshakeWriteRequest(NextFilter nextFilter,
240                                                  WriteRequest writeRequest) {
241         preHandshakeEventQueue.add(new IoFilterEvent(nextFilter,
242                 IoEventType.WRITE, session, writeRequest));
243     }
244 
245     public void flushPreHandshakeEvents() throws SSLException {
246         IoFilterEvent scheduledWrite;
247 
248         while ((scheduledWrite = preHandshakeEventQueue.poll()) != null) {
249             parent.filterWrite(scheduledWrite.getNextFilter(), session,
250                     (WriteRequest) scheduledWrite.getParameter());
251         }
252     }
253 
254     public void scheduleFilterWrite(NextFilter nextFilter, WriteRequest writeRequest) {
255         filterWriteEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.WRITE, session, writeRequest));
256     }
257 
258     public void scheduleMessageReceived(NextFilter nextFilter, Object message) {
259         messageReceivedEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.MESSAGE_RECEIVED, session, message));
260     }
261 
262     public void flushScheduledEvents() {
263         // Fire events only when no lock is hold for this handler.
264         if (Thread.holdsLock(this)) {
265             return;
266         }
267 
268         IoFilterEvent e;
269 
270         // We need synchronization here inevitably because filterWrite can be
271         // called simultaneously and cause 'bad record MAC' integrity error.
272         synchronized (this) {
273             while ((e = filterWriteEventQueue.poll()) != null) {
274                 e.getNextFilter().filterWrite(session, (WriteRequest) e.getParameter());
275             }
276         }
277 
278         while ((e = messageReceivedEventQueue.poll()) != null) {
279             e.getNextFilter().messageReceived(session, e.getParameter());
280         }
281     }
282 
283     /**
284      * Call when data read from net. Will perform inial hanshake or decrypt provided
285      * Buffer.
286      * Decrytpted data reurned by getAppBuffer(), if any.
287      *
288      * @param buf        buffer to decrypt
289      * @param nextFilter Next filter in chain
290      * @throws SSLException on errors
291      */
292     public void messageReceived(NextFilter nextFilter, ByteBuffer buf) throws SSLException {
293         // append buf to inNetBuffer
294         if (inNetBuffer == null) {
295             inNetBuffer = IoBuffer.allocate(buf.remaining()).setAutoExpand(true);
296         }
297 
298         inNetBuffer.put(buf);
299         if (!handshakeComplete) {
300             handshake(nextFilter);
301         } else {
302             decrypt(nextFilter);
303         }
304 
305         if (isInboundDone()) {
306             // Rewind the MINA buffer if not all data is processed and inbound is finished.
307             int inNetBufferPosition = inNetBuffer == null? 0 : inNetBuffer.position();
308             buf.position(buf.position() - inNetBufferPosition);
309             inNetBuffer = null;
310         }
311     }
312 
313     /**
314      * Get decrypted application data.
315      *
316      * @return buffer with data
317      */
318     public IoBuffer fetchAppBuffer() {
319         IoBuffer appBuffer = this.appBuffer.flip();
320         this.appBuffer = null;
321         return appBuffer;
322     }
323 
324     /**
325      * Get encrypted data to be sent.
326      *
327      * @return buffer with data
328      */
329     public IoBuffer fetchOutNetBuffer() {
330         IoBuffer answer = outNetBuffer;
331         if (answer == null) {
332             return emptyBuffer;
333         }
334 
335         outNetBuffer = null;
336         return answer.shrink();
337     }
338 
339     /**
340      * Encrypt provided buffer. Encrypted data returned by getOutNetBuffer().
341      *
342      * @param src data to encrypt
343      * @throws SSLException on errors
344      */
345     public void encrypt(ByteBuffer src) throws SSLException {
346         if (!handshakeComplete) {
347             throw new IllegalStateException();
348         }
349 
350         if (!src.hasRemaining()) {
351             if (outNetBuffer == null) {
352                 outNetBuffer = emptyBuffer;
353             }
354             return;
355         }
356 
357         createOutNetBuffer(src.remaining());
358 
359         // Loop until there is no more data in src
360         while (src.hasRemaining()) {
361 
362             SSLEngineResult result = sslEngine.wrap(src, outNetBuffer.buf());
363             if (result.getStatus() == SSLEngineResult.Status.OK) {
364                 if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
365                     doTasks();
366                 }
367             } else if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
368                 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
369                 outNetBuffer.limit(outNetBuffer.capacity());
370             } else {
371                 throw new SSLException("SSLEngine error during encrypt: "
372                         + result.getStatus() + " src: " + src
373                         + "outNetBuffer: " + outNetBuffer);
374             }
375         }
376 
377         outNetBuffer.flip();
378     }
379 
380     /**
381      * Start SSL shutdown process.
382      *
383      * @return <tt>true</tt> if shutdown process is started.
384      *         <tt>false</tt> if shutdown process is already finished.
385      * @throws SSLException on errors
386      */
387     public boolean closeOutbound() throws SSLException {
388         if (sslEngine == null || sslEngine.isOutboundDone()) {
389             return false;
390         }
391 
392         sslEngine.closeOutbound();
393 
394         createOutNetBuffer(0);
395         SSLEngineResult result;
396         for (;;) {
397             result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
398             if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
399                 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
400                 outNetBuffer.limit(outNetBuffer.capacity());
401             } else {
402                 break;
403             }
404         }
405 
406         if (result.getStatus() != SSLEngineResult.Status.CLOSED) {
407             throw new SSLException("Improper close state: " + result);
408         }
409         outNetBuffer.flip();
410         return true;
411     }
412 
413     /**
414      * Decrypt in net buffer. Result is stored in app buffer.
415      *
416      * @throws SSLException
417      */
418     private void decrypt(NextFilter nextFilter) throws SSLException {
419 
420         if (!handshakeComplete) {
421             throw new IllegalStateException();
422         }
423 
424         unwrap(nextFilter);
425     }
426 
427     /**
428      * @param res
429      * @throws SSLException
430      */
431     private void checkStatus(SSLEngineResult res)
432             throws SSLException {
433 
434         SSLEngineResult.Status status = res.getStatus();
435 
436         /*
437         * The status may be:
438         * OK - Normal operation
439         * OVERFLOW - Should never happen since the application buffer is
440         *      sized to hold the maximum packet size.
441         * UNDERFLOW - Need to read more data from the socket. It's normal.
442         * CLOSED - The other peer closed the socket. Also normal.
443         */
444         if (status != SSLEngineResult.Status.OK
445                 && status != SSLEngineResult.Status.CLOSED
446                 && status != SSLEngineResult.Status.BUFFER_UNDERFLOW) {
447             throw new SSLException("SSLEngine error during decrypt: " + status
448                     + " inNetBuffer: " + inNetBuffer + "appBuffer: "
449                     + appBuffer);
450         }
451     }
452 
453     /**
454      * Perform any handshaking processing.
455      */
456     public void handshake(NextFilter nextFilter) throws SSLException {
457         for (;;) {
458             switch (handshakeStatus) {
459                 case FINISHED :
460                     session.setAttribute(
461                             SslFilter.SSL_SESSION, sslEngine.getSession());
462                     handshakeComplete = true;
463                     
464                     if (!initialHandshakeComplete
465                             && session.containsAttribute(SslFilter.USE_NOTIFICATION)) {
466                         // SESSION_SECURED is fired only when it's the first handshake.
467                         // (i.e. renegotiation shouldn't trigger SESSION_SECURED.)
468                         initialHandshakeComplete = true;
469                         scheduleMessageReceived(nextFilter,
470                                 SslFilter.SESSION_SECURED);
471                     }
472                     
473                     return;
474                     
475                 case NEED_TASK :
476                     handshakeStatus = doTasks();
477                     break;
478                     
479                 case NEED_UNWRAP :
480                     // we need more data read
481                     SSLEngineResult.Status status = unwrapHandshake(nextFilter);
482                     
483                     if (status == SSLEngineResult.Status.BUFFER_UNDERFLOW &&
484                             handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED ||
485                             isInboundDone()) {
486                         // We need more data or the session is closed
487                         return;
488                     }
489                     
490                     break;
491 
492                 case NEED_WRAP :
493                     // First make sure that the out buffer is completely empty. Since we
494                     // cannot call wrap with data left on the buffer
495                     if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
496                         return;
497                     }
498 
499                     SSLEngineResult result;
500                     createOutNetBuffer(0);
501                     
502                     for (;;) {
503                         result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
504                         if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
505                             outNetBuffer.capacity(outNetBuffer.capacity() << 1);
506                             outNetBuffer.limit(outNetBuffer.capacity());
507                         } else {
508                             break;
509                         }
510                     }
511 
512                     outNetBuffer.flip();
513                     handshakeStatus = result.getHandshakeStatus();
514                     writeNetBuffer(nextFilter);
515                     break;
516             
517                 default :
518                     throw new IllegalStateException("Invalid Handshaking State"
519                             + handshakeStatus);
520             }
521         }
522     }
523 
524     private void createOutNetBuffer(int expectedRemaining) {
525         // SSLEngine requires us to allocate unnecessarily big buffer
526         // even for small data.  *Shrug*
527         int capacity = Math.max(
528                 expectedRemaining,
529                 sslEngine.getSession().getPacketBufferSize());
530 
531         if (outNetBuffer != null) {
532             outNetBuffer.capacity(capacity);
533         } else {
534             outNetBuffer = IoBuffer.allocate(capacity).minimumCapacity(0);
535         }
536     }
537 
538     public WriteFuture writeNetBuffer(NextFilter nextFilter)
539             throws SSLException {
540         // Check if any net data needed to be writen
541         if (outNetBuffer == null || !outNetBuffer.hasRemaining()) {
542             // no; bail out
543             return null;
544         }
545 
546         // set flag that we are writing encrypted data
547         // (used in SSLFilter.filterWrite())
548         writingEncryptedData = true;
549 
550         // write net data
551         WriteFuture writeFuture = null;
552 
553         try {
554             IoBuffer writeBuffer = fetchOutNetBuffer();
555             writeFuture = new DefaultWriteFuture(session);
556             parent.filterWrite(nextFilter, session, new DefaultWriteRequest(
557                     writeBuffer, writeFuture));
558 
559             // loop while more writes required to complete handshake
560             while (needToCompleteHandshake()) {
561                 try {
562                     handshake(nextFilter);
563                 } catch (SSLException ssle) {
564                     SSLException newSsle = new SSLHandshakeException(
565                             "SSL handshake failed.");
566                     newSsle.initCause(ssle);
567                     throw newSsle;
568                 }
569 
570                 IoBuffer outNetBuffer = fetchOutNetBuffer();
571                 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
572                     writeFuture = new DefaultWriteFuture(session);
573                     parent.filterWrite(nextFilter, session,
574                             new DefaultWriteRequest(outNetBuffer, writeFuture));
575                 }
576             }
577         } finally {
578             writingEncryptedData = false;
579         }
580 
581         return writeFuture;
582     }
583 
584     private void unwrap(NextFilter nextFilter) throws SSLException {
585         // Prepare the net data for reading.
586         if (inNetBuffer != null) {
587             inNetBuffer.flip();
588         }
589 
590         if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
591             return;
592         }
593 
594         SSLEngineResult res = unwrap0();
595 
596         // prepare to be written again
597         if (inNetBuffer.hasRemaining()) {
598             inNetBuffer.compact();
599         } else {
600             inNetBuffer = null;
601         }
602 
603         checkStatus(res);
604 
605         renegotiateIfNeeded(nextFilter, res);
606     }
607 
608     private SSLEngineResult.Status unwrapHandshake(NextFilter nextFilter) throws SSLException {
609         // Prepare the net data for reading.
610         if (inNetBuffer != null) {
611             inNetBuffer.flip();
612         }
613 
614         if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
615             // Need more data.
616             return SSLEngineResult.Status.BUFFER_UNDERFLOW;
617         }
618 
619         SSLEngineResult res = unwrap0();
620         handshakeStatus = res.getHandshakeStatus();
621 
622         checkStatus(res);
623 
624         // If handshake finished, no data was produced, and the status is still ok,
625         // try to unwrap more
626         if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED
627                 && res.getStatus() == SSLEngineResult.Status.OK
628                 && inNetBuffer.hasRemaining()) {
629             res = unwrap0();
630 
631             // prepare to be written again
632             if (inNetBuffer.hasRemaining()) {
633                 inNetBuffer.compact();
634             } else {
635                 inNetBuffer = null;
636             }
637 
638             renegotiateIfNeeded(nextFilter, res);
639         } else {
640             // prepare to be written again
641             if (inNetBuffer.hasRemaining()) {
642                 inNetBuffer.compact();
643             } else {
644                 inNetBuffer = null;
645             }
646         }
647 
648         return res.getStatus();
649     }
650 
651     private void renegotiateIfNeeded(NextFilter nextFilter, SSLEngineResult res)
652             throws SSLException {
653         if (res.getStatus() != SSLEngineResult.Status.CLOSED
654                 && res.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW
655                 && res.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
656             // Renegotiation required.
657             handshakeComplete = false;
658             handshakeStatus = res.getHandshakeStatus();
659             handshake(nextFilter);
660         }
661     }
662 
663     private SSLEngineResult unwrap0() throws SSLException {
664         if (appBuffer == null) {
665             appBuffer = IoBuffer.allocate(inNetBuffer.remaining());
666         } else {
667             appBuffer.expand(inNetBuffer.remaining());
668         }
669 
670         SSLEngineResult res;
671         do {
672             res = sslEngine.unwrap(inNetBuffer.buf(), appBuffer.buf());
673             if (res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
674                 appBuffer.capacity(appBuffer.capacity() << 1);
675                 appBuffer.limit(appBuffer.capacity());
676                 continue;
677             }
678         } while ((res.getStatus() == SSLEngineResult.Status.OK || res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) &&
679                  (handshakeComplete && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING ||
680                   res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP));
681 
682         return res;
683     }
684 
685     /**
686      * Do all the outstanding handshake tasks in the current Thread.
687      */
688     private SSLEngineResult.HandshakeStatus doTasks() {
689         /*
690          * We could run this in a separate thread, but I don't see the need
691          * for this when used from SSLFilter. Use thread filters in MINA instead?
692          */
693         Runnable runnable;
694         while ((runnable = sslEngine.getDelegatedTask()) != null) {
695             // TODO : we may have to use a thread pool here to improve the performances
696             runnable.run();
697         }
698         return sslEngine.getHandshakeStatus();
699     }
700 
701     /**
702      * Creates a new MINA buffer that is a deep copy of the remaining bytes
703      * in the given buffer (between index buf.position() and buf.limit())
704      *
705      * @param src the buffer to copy
706      * @return the new buffer, ready to read from
707      */
708     public static IoBuffer copy(ByteBuffer src) {
709         IoBuffer copy = IoBuffer.allocate(src.remaining());
710         copy.put(src);
711         copy.flip();
712         return copy;
713     }
714 }