1   package org.apache.mina.example.echoserver.ssl;
2   
3   import java.net.InetSocketAddress;
4   import java.net.Socket;
5   import java.net.SocketAddress;
6   import java.nio.charset.Charset;
7   import java.security.cert.CertificateException;
8   import java.util.ArrayList;
9   import java.util.List;
10  
11  import javax.net.ssl.SSLContext;
12  import javax.net.ssl.SSLSocket;
13  import javax.net.ssl.TrustManager;
14  import javax.net.ssl.X509TrustManager;
15  
16  import junit.framework.TestCase;
17  
18  import org.apache.mina.common.IoAcceptor;
19  import org.apache.mina.common.IoHandlerAdapter;
20  import org.apache.mina.common.IoSession;
21  import org.apache.mina.filter.SSLFilter;
22  import org.apache.mina.filter.codec.ProtocolCodecFilter;
23  import org.apache.mina.filter.codec.textline.TextLineCodecFactory;
24  import org.apache.mina.transport.socket.nio.SocketAcceptor;
25  
26  public class SSLFilterTest extends TestCase {
27  
28      private static final int PORT = 17887;
29  
30      private IoAcceptor acceptor;
31  
32      SocketAddress socketAddress = new InetSocketAddress(PORT);
33  
34      protected void setUp() throws Exception {
35          super.setUp();
36          acceptor = new SocketAcceptor();
37      }
38  
39      protected void tearDown() throws Exception {
40          acceptor.unbindAll();
41          super.tearDown();
42      }
43  
44      public void testMessageSentIsCalled() throws Exception {
45          testMessageSentIsCalled(false);
46      }
47  
48      public void testMessageSentIsCalled_With_SSL() throws Exception {
49          testMessageSentIsCalled(true);
50      }
51  
52      private void testMessageSentIsCalled(boolean useSSL) throws Exception {
53          SSLFilter sslFilter = null;
54          if (useSSL) {
55              sslFilter = new SSLFilter(BogusSSLContextFactory.getInstance(true));
56              acceptor.getFilterChain().addLast("sslFilter", sslFilter);
57          }
58          acceptor.getFilterChain().addLast(
59                  "codec",
60                  new ProtocolCodecFilter(new TextLineCodecFactory(Charset
61                          .forName("UTF-8"))));
62  
63          EchoHandler handler = new EchoHandler();
64          acceptor.bind(socketAddress, handler);
65          System.out.println("MINA server started.");
66  
67          Socket socket = getClientSocket(useSSL);
68          int bytesSent = 0;
69          bytesSent += writeMessage(socket, "test-1\n");
70  
71          if (useSSL) {
72              // Test renegotiation
73              SSLSocket ss = (SSLSocket) socket;
74              //ss.getSession().invalidate();
75              ss.startHandshake();
76          }
77  
78          bytesSent += writeMessage(socket, "test-2\n");
79          byte[] response = new byte[bytesSent];
80          for (int i = 0; i < response.length; i++) {
81              response[i] = (byte) socket.getInputStream().read();
82          }
83          
84          if (useSSL) {
85              // Read SSL close notify.
86              while (socket.getInputStream().read() >= 0) {
87                  continue;
88              }
89          }
90          
91          socket.close();
92          
93          long millis = System.currentTimeMillis();
94          while (handler.sentMessages.size() < 2
95                  && System.currentTimeMillis() < millis + 5000) {
96              Thread.sleep(200);
97          }
98          assertEquals("received what we sent", "test-1\ntest-2\n", new String(
99                  response, "UTF-8"));
100 
101         System.out.println("handler: " + handler.sentMessages);
102         assertEquals("handler should have sent 2 messages:", 2,
103                 handler.sentMessages.size());
104         assertTrue(handler.sentMessages.contains("test-1"));
105         assertTrue(handler.sentMessages.contains("test-2"));
106     }
107 
108     private int writeMessage(Socket socket, String message) throws Exception {
109         byte request[] = message.getBytes("UTF-8");
110         socket.getOutputStream().write(request);
111         return request.length;
112     }
113 
114     private Socket getClientSocket(boolean ssl) throws Exception {
115         if (ssl) {
116             SSLContext ctx = SSLContext.getInstance("TLS");
117             ctx.init(null, trustManagers, null);
118             return ctx.getSocketFactory().createSocket("localhost", PORT);
119         }
120         return new Socket("localhost", PORT);
121     }
122 
123     private static class EchoHandler extends IoHandlerAdapter {
124 
125         List<String> sentMessages = new ArrayList<String>();
126 
127         public void exceptionCaught(IoSession session, Throwable cause)
128                 throws Exception {
129             cause.printStackTrace();
130         }
131 
132         public void messageReceived(IoSession session, Object message)
133                 throws Exception {
134             session.write(message);
135         }
136 
137         public void messageSent(IoSession session, Object message)
138                 throws Exception {
139             sentMessages.add(message.toString());
140             System.out.println(message);
141             if (sentMessages.size() >= 2) {
142                 session.close();
143             }
144         }
145 
146     }
147 
148     TrustManager[] trustManagers = new TrustManager[] { new TrustAnyone() };
149 
150     private static class TrustAnyone implements X509TrustManager {
151         public void checkClientTrusted(
152                 java.security.cert.X509Certificate[] x509Certificates, String s)
153                 throws CertificateException {
154         }
155 
156         public void checkServerTrusted(
157                 java.security.cert.X509Certificate[] x509Certificates, String s)
158                 throws CertificateException {
159         }
160 
161         public java.security.cert.X509Certificate[] getAcceptedIssuers() {
162             return new java.security.cert.X509Certificate[0];
163         }
164     }
165 
166 }