1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.ssl;
21
22 import java.net.InetSocketAddress;
23 import java.nio.ByteBuffer;
24 import java.util.Queue;
25 import java.util.concurrent.ConcurrentLinkedQueue;
26
27 import javax.net.ssl.SSLContext;
28 import javax.net.ssl.SSLEngine;
29 import javax.net.ssl.SSLEngineResult;
30 import javax.net.ssl.SSLException;
31 import javax.net.ssl.SSLHandshakeException;
32
33 import org.apache.mina.core.buffer.IoBuffer;
34 import org.apache.mina.core.filterchain.IoFilterEvent;
35 import org.apache.mina.core.filterchain.IoFilter.NextFilter;
36 import org.apache.mina.core.future.DefaultWriteFuture;
37 import org.apache.mina.core.future.WriteFuture;
38 import org.apache.mina.core.session.IoEventType;
39 import org.apache.mina.core.session.IoSession;
40 import org.apache.mina.core.write.DefaultWriteRequest;
41 import org.apache.mina.core.write.WriteRequest;
42 import org.apache.mina.util.CircularQueue;
43 import org.slf4j.Logger;
44 import org.slf4j.LoggerFactory;
45
46
47
48
49
50
51
52
53
54
55
56
57 class SslHandler {
58
59 private final Logger logger = LoggerFactory.getLogger(getClass());
60 private final SslFilter parent;
61 private final SSLContext sslContext;
62 private final IoSession session;
63 private final Queue<IoFilterEvent> preHandshakeEventQueue = new CircularQueue<IoFilterEvent>();
64 private final Queue<IoFilterEvent> filterWriteEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
65 private final Queue<IoFilterEvent> messageReceivedEventQueue = new ConcurrentLinkedQueue<IoFilterEvent>();
66 private SSLEngine sslEngine;
67
68
69
70
71 private IoBuffer inNetBuffer;
72
73
74
75
76 private IoBuffer outNetBuffer;
77
78
79
80
81 private IoBuffer appBuffer;
82
83
84
85
86 private final IoBuffer emptyBuffer = IoBuffer.allocate(0);
87
88 private SSLEngineResult.HandshakeStatus handshakeStatus;
89 private boolean initialHandshakeComplete;
90 private boolean handshakeComplete;
91 private boolean writingEncryptedData;
92
93
94
95
96
97
98
99 public SslHandler(SslFilter parent, SSLContext sslContext, IoSession session)
100 throws SSLException {
101 this.parent = parent;
102 this.session = session;
103 this.sslContext = sslContext;
104 init();
105 }
106
107
108
109
110
111
112 public void init() throws SSLException {
113 if (sslEngine != null) {
114
115 return;
116 }
117
118 InetSocketAddress peer = (InetSocketAddress) session
119 .getAttribute(SslFilter.PEER_ADDRESS);
120
121
122 if (peer == null) {
123 sslEngine = sslContext.createSSLEngine();
124 } else {
125 sslEngine = sslContext.createSSLEngine(peer.getHostName(), peer.getPort());
126 }
127
128
129 sslEngine.setUseClientMode(parent.isUseClientMode());
130
131
132 if (parent.isWantClientAuth()) {
133 sslEngine.setWantClientAuth(true);
134 }
135
136 if (parent.isNeedClientAuth()) {
137 sslEngine.setNeedClientAuth(true);
138 }
139
140 if (parent.getEnabledCipherSuites() != null) {
141 sslEngine.setEnabledCipherSuites(parent.getEnabledCipherSuites());
142 }
143
144 if (parent.getEnabledProtocols() != null) {
145 sslEngine.setEnabledProtocols(parent.getEnabledProtocols());
146 }
147
148
149 sslEngine.beginHandshake();
150
151
152 handshakeStatus = sslEngine.getHandshakeStatus();
153
154 handshakeComplete = false;
155 initialHandshakeComplete = false;
156 writingEncryptedData = false;
157 }
158
159
160
161
162 public void destroy() {
163 if (sslEngine == null) {
164 return;
165 }
166
167
168 try {
169 sslEngine.closeInbound();
170 } catch (SSLException e) {
171 logger.debug(
172 "Unexpected exception from SSLEngine.closeInbound().", e);
173 }
174
175
176 if (outNetBuffer != null) {
177 outNetBuffer.capacity(sslEngine.getSession().getPacketBufferSize());
178 } else {
179 createOutNetBuffer(0);
180 }
181 try {
182 do {
183 outNetBuffer.clear();
184 } while (sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf()).bytesProduced() > 0);
185 } catch (SSLException e) {
186
187 } finally {
188 destroyOutNetBuffer();
189 }
190
191 sslEngine.closeOutbound();
192 sslEngine = null;
193
194 preHandshakeEventQueue.clear();
195 }
196
197 private void destroyOutNetBuffer() {
198 outNetBuffer.free();
199 outNetBuffer = null;
200 }
201
202 public SslFilter getParent() {
203 return parent;
204 }
205
206 public IoSession getSession() {
207 return session;
208 }
209
210
211
212
213 public boolean isWritingEncryptedData() {
214 return writingEncryptedData;
215 }
216
217
218
219
220 public boolean isHandshakeComplete() {
221 return handshakeComplete;
222 }
223
224 public boolean isInboundDone() {
225 return sslEngine == null || sslEngine.isInboundDone();
226 }
227
228 public boolean isOutboundDone() {
229 return sslEngine == null || sslEngine.isOutboundDone();
230 }
231
232
233
234
235 public boolean needToCompleteHandshake() {
236 return handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP && !isInboundDone();
237 }
238
239 public void schedulePreHandshakeWriteRequest(NextFilter nextFilter,
240 WriteRequest writeRequest) {
241 preHandshakeEventQueue.add(new IoFilterEvent(nextFilter,
242 IoEventType.WRITE, session, writeRequest));
243 }
244
245 public void flushPreHandshakeEvents() throws SSLException {
246 IoFilterEvent scheduledWrite;
247
248 while ((scheduledWrite = preHandshakeEventQueue.poll()) != null) {
249 parent.filterWrite(scheduledWrite.getNextFilter(), session,
250 (WriteRequest) scheduledWrite.getParameter());
251 }
252 }
253
254 public void scheduleFilterWrite(NextFilter nextFilter, WriteRequest writeRequest) {
255 filterWriteEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.WRITE, session, writeRequest));
256 }
257
258 public void scheduleMessageReceived(NextFilter nextFilter, Object message) {
259 messageReceivedEventQueue.add(new IoFilterEvent(nextFilter, IoEventType.MESSAGE_RECEIVED, session, message));
260 }
261
262 public void flushScheduledEvents() {
263
264 if (Thread.holdsLock(this)) {
265 return;
266 }
267
268 IoFilterEvent e;
269
270
271
272 synchronized (this) {
273 while ((e = filterWriteEventQueue.poll()) != null) {
274 e.getNextFilter().filterWrite(session, (WriteRequest) e.getParameter());
275 }
276 }
277
278 while ((e = messageReceivedEventQueue.poll()) != null) {
279 e.getNextFilter().messageReceived(session, e.getParameter());
280 }
281 }
282
283
284
285
286
287
288
289
290
291
292 public void messageReceived(NextFilter nextFilter, ByteBuffer buf) throws SSLException {
293
294 if (inNetBuffer == null) {
295 inNetBuffer = IoBuffer.allocate(buf.remaining()).setAutoExpand(true);
296 }
297
298 inNetBuffer.put(buf);
299 if (!handshakeComplete) {
300 handshake(nextFilter);
301 } else {
302 decrypt(nextFilter);
303 }
304
305 if (isInboundDone()) {
306
307 int inNetBufferPosition = inNetBuffer == null? 0 : inNetBuffer.position();
308 buf.position(buf.position() - inNetBufferPosition);
309 inNetBuffer = null;
310 }
311 }
312
313
314
315
316
317
318 public IoBuffer fetchAppBuffer() {
319 IoBuffer appBuffer = this.appBuffer.flip();
320 this.appBuffer = null;
321 return appBuffer;
322 }
323
324
325
326
327
328
329 public IoBuffer fetchOutNetBuffer() {
330 IoBuffer answer = outNetBuffer;
331 if (answer == null) {
332 return emptyBuffer;
333 }
334
335 outNetBuffer = null;
336 return answer.shrink();
337 }
338
339
340
341
342
343
344
345 public void encrypt(ByteBuffer src) throws SSLException {
346 if (!handshakeComplete) {
347 throw new IllegalStateException();
348 }
349
350 if (!src.hasRemaining()) {
351 if (outNetBuffer == null) {
352 outNetBuffer = emptyBuffer;
353 }
354 return;
355 }
356
357 createOutNetBuffer(src.remaining());
358
359
360 while (src.hasRemaining()) {
361
362 SSLEngineResult result = sslEngine.wrap(src, outNetBuffer.buf());
363 if (result.getStatus() == SSLEngineResult.Status.OK) {
364 if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
365 doTasks();
366 }
367 } else if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
368 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
369 outNetBuffer.limit(outNetBuffer.capacity());
370 } else {
371 throw new SSLException("SSLEngine error during encrypt: "
372 + result.getStatus() + " src: " + src
373 + "outNetBuffer: " + outNetBuffer);
374 }
375 }
376
377 outNetBuffer.flip();
378 }
379
380
381
382
383
384
385
386
387 public boolean closeOutbound() throws SSLException {
388 if (sslEngine == null || sslEngine.isOutboundDone()) {
389 return false;
390 }
391
392 sslEngine.closeOutbound();
393
394 createOutNetBuffer(0);
395 SSLEngineResult result;
396 for (;;) {
397 result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
398 if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
399 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
400 outNetBuffer.limit(outNetBuffer.capacity());
401 } else {
402 break;
403 }
404 }
405
406 if (result.getStatus() != SSLEngineResult.Status.CLOSED) {
407 throw new SSLException("Improper close state: " + result);
408 }
409 outNetBuffer.flip();
410 return true;
411 }
412
413
414
415
416
417
418 private void decrypt(NextFilter nextFilter) throws SSLException {
419
420 if (!handshakeComplete) {
421 throw new IllegalStateException();
422 }
423
424 unwrap(nextFilter);
425 }
426
427
428
429
430
431 private void checkStatus(SSLEngineResult res)
432 throws SSLException {
433
434 SSLEngineResult.Status status = res.getStatus();
435
436
437
438
439
440
441
442
443
444 if (status != SSLEngineResult.Status.OK
445 && status != SSLEngineResult.Status.CLOSED
446 && status != SSLEngineResult.Status.BUFFER_UNDERFLOW) {
447 throw new SSLException("SSLEngine error during decrypt: " + status
448 + " inNetBuffer: " + inNetBuffer + "appBuffer: "
449 + appBuffer);
450 }
451 }
452
453
454
455
456 public void handshake(NextFilter nextFilter) throws SSLException {
457 for (;;) {
458 switch (handshakeStatus) {
459 case FINISHED :
460 session.setAttribute(
461 SslFilter.SSL_SESSION, sslEngine.getSession());
462 handshakeComplete = true;
463
464 if (!initialHandshakeComplete
465 && session.containsAttribute(SslFilter.USE_NOTIFICATION)) {
466
467
468 initialHandshakeComplete = true;
469 scheduleMessageReceived(nextFilter,
470 SslFilter.SESSION_SECURED);
471 }
472
473 return;
474
475 case NEED_TASK :
476 handshakeStatus = doTasks();
477 break;
478
479 case NEED_UNWRAP :
480
481 SSLEngineResult.Status status = unwrapHandshake(nextFilter);
482
483 if (status == SSLEngineResult.Status.BUFFER_UNDERFLOW &&
484 handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED ||
485 isInboundDone()) {
486
487 return;
488 }
489
490 break;
491
492 case NEED_WRAP :
493
494
495 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
496 return;
497 }
498
499 SSLEngineResult result;
500 createOutNetBuffer(0);
501
502 for (;;) {
503 result = sslEngine.wrap(emptyBuffer.buf(), outNetBuffer.buf());
504 if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
505 outNetBuffer.capacity(outNetBuffer.capacity() << 1);
506 outNetBuffer.limit(outNetBuffer.capacity());
507 } else {
508 break;
509 }
510 }
511
512 outNetBuffer.flip();
513 handshakeStatus = result.getHandshakeStatus();
514 writeNetBuffer(nextFilter);
515 break;
516
517 default :
518 throw new IllegalStateException("Invalid Handshaking State"
519 + handshakeStatus);
520 }
521 }
522 }
523
524 private void createOutNetBuffer(int expectedRemaining) {
525
526
527 int capacity = Math.max(
528 expectedRemaining,
529 sslEngine.getSession().getPacketBufferSize());
530
531 if (outNetBuffer != null) {
532 outNetBuffer.capacity(capacity);
533 } else {
534 outNetBuffer = IoBuffer.allocate(capacity).minimumCapacity(0);
535 }
536 }
537
538 public WriteFuture writeNetBuffer(NextFilter nextFilter)
539 throws SSLException {
540
541 if (outNetBuffer == null || !outNetBuffer.hasRemaining()) {
542
543 return null;
544 }
545
546
547
548 writingEncryptedData = true;
549
550
551 WriteFuture writeFuture = null;
552
553 try {
554 IoBuffer writeBuffer = fetchOutNetBuffer();
555 writeFuture = new DefaultWriteFuture(session);
556 parent.filterWrite(nextFilter, session, new DefaultWriteRequest(
557 writeBuffer, writeFuture));
558
559
560 while (needToCompleteHandshake()) {
561 try {
562 handshake(nextFilter);
563 } catch (SSLException ssle) {
564 SSLException newSsle = new SSLHandshakeException(
565 "SSL handshake failed.");
566 newSsle.initCause(ssle);
567 throw newSsle;
568 }
569
570 IoBuffer outNetBuffer = fetchOutNetBuffer();
571 if (outNetBuffer != null && outNetBuffer.hasRemaining()) {
572 writeFuture = new DefaultWriteFuture(session);
573 parent.filterWrite(nextFilter, session,
574 new DefaultWriteRequest(outNetBuffer, writeFuture));
575 }
576 }
577 } finally {
578 writingEncryptedData = false;
579 }
580
581 return writeFuture;
582 }
583
584 private void unwrap(NextFilter nextFilter) throws SSLException {
585
586 if (inNetBuffer != null) {
587 inNetBuffer.flip();
588 }
589
590 if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
591 return;
592 }
593
594 SSLEngineResult res = unwrap0();
595
596
597 if (inNetBuffer.hasRemaining()) {
598 inNetBuffer.compact();
599 } else {
600 inNetBuffer = null;
601 }
602
603 checkStatus(res);
604
605 renegotiateIfNeeded(nextFilter, res);
606 }
607
608 private SSLEngineResult.Status unwrapHandshake(NextFilter nextFilter) throws SSLException {
609
610 if (inNetBuffer != null) {
611 inNetBuffer.flip();
612 }
613
614 if (inNetBuffer == null || !inNetBuffer.hasRemaining()) {
615
616 return SSLEngineResult.Status.BUFFER_UNDERFLOW;
617 }
618
619 SSLEngineResult res = unwrap0();
620 handshakeStatus = res.getHandshakeStatus();
621
622 checkStatus(res);
623
624
625
626 if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED
627 && res.getStatus() == SSLEngineResult.Status.OK
628 && inNetBuffer.hasRemaining()) {
629 res = unwrap0();
630
631
632 if (inNetBuffer.hasRemaining()) {
633 inNetBuffer.compact();
634 } else {
635 inNetBuffer = null;
636 }
637
638 renegotiateIfNeeded(nextFilter, res);
639 } else {
640
641 if (inNetBuffer.hasRemaining()) {
642 inNetBuffer.compact();
643 } else {
644 inNetBuffer = null;
645 }
646 }
647
648 return res.getStatus();
649 }
650
651 private void renegotiateIfNeeded(NextFilter nextFilter, SSLEngineResult res)
652 throws SSLException {
653 if (res.getStatus() != SSLEngineResult.Status.CLOSED
654 && res.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW
655 && res.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
656
657 handshakeComplete = false;
658 handshakeStatus = res.getHandshakeStatus();
659 handshake(nextFilter);
660 }
661 }
662
663 private SSLEngineResult unwrap0() throws SSLException {
664 if (appBuffer == null) {
665 appBuffer = IoBuffer.allocate(inNetBuffer.remaining());
666 } else {
667 appBuffer.expand(inNetBuffer.remaining());
668 }
669
670 SSLEngineResult res;
671 do {
672 res = sslEngine.unwrap(inNetBuffer.buf(), appBuffer.buf());
673 if (res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) {
674 appBuffer.capacity(appBuffer.capacity() << 1);
675 appBuffer.limit(appBuffer.capacity());
676 continue;
677 }
678 } while ((res.getStatus() == SSLEngineResult.Status.OK || res.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) &&
679 (handshakeComplete && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING ||
680 res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP));
681
682 return res;
683 }
684
685
686
687
688 private SSLEngineResult.HandshakeStatus doTasks() {
689
690
691
692
693 Runnable runnable;
694 while ((runnable = sslEngine.getDelegatedTask()) != null) {
695
696 runnable.run();
697 }
698 return sslEngine.getHandshakeStatus();
699 }
700
701
702
703
704
705
706
707
708 public static IoBuffer copy(ByteBuffer src) {
709 IoBuffer copy = IoBuffer.allocate(src.remaining());
710 copy.put(src);
711 copy.flip();
712 return copy;
713 }
714 }