001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataOutputStream;
021import java.io.EOFException;
022import java.io.IOException;
023import java.net.Socket;
024import java.net.URI;
025import java.net.UnknownHostException;
026import java.nio.ByteBuffer;
027import java.util.concurrent.atomic.AtomicInteger;
028
029import javax.net.SocketFactory;
030import javax.net.ssl.SSLContext;
031import javax.net.ssl.SSLEngine;
032import javax.net.ssl.SSLEngineResult;
033
034import org.apache.activemq.thread.TaskRunnerFactory;
035import org.apache.activemq.util.IOExceptionSupport;
036import org.apache.activemq.util.ServiceStopper;
037import org.apache.activemq.wireformat.WireFormat;
038
039/**
040 * This transport initializes the SSLEngine and reads the first command before
041 * handing off to the detected transport.
042 *
043 */
044public class AutoInitNioSSLTransport extends NIOSSLTransport {
045
046    public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
047        super(wireFormat, socketFactory, remoteLocation, localLocation);
048    }
049
050    public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
051        super(wireFormat, socket, null, null, null);
052    }
053
054    @Override
055    public void setSslContext(SSLContext sslContext) {
056        this.sslContext = sslContext;
057    }
058
059    public ByteBuffer getInputBuffer() {
060        return this.inputBuffer;
061    }
062
063    @Override
064    protected void initializeStreams() throws IOException {
065        NIOOutputStream outputStream = null;
066        try {
067            channel = socket.getChannel();
068            channel.configureBlocking(false);
069
070            if (sslContext == null) {
071                sslContext = SSLContext.getDefault();
072            }
073
074            String remoteHost = null;
075            int remotePort = -1;
076
077            try {
078                URI remoteAddress = new URI(this.getRemoteAddress());
079                remoteHost = remoteAddress.getHost();
080                remotePort = remoteAddress.getPort();
081            } catch (Exception e) {
082            }
083
084            // initialize engine, the initial sslSession we get will need to be
085            // updated once the ssl handshake process is completed.
086            if (remoteHost != null && remotePort != -1) {
087                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
088            } else {
089                sslEngine = sslContext.createSSLEngine();
090            }
091
092            sslEngine.setUseClientMode(false);
093            if (enabledCipherSuites != null) {
094                sslEngine.setEnabledCipherSuites(enabledCipherSuites);
095            }
096
097            if (enabledProtocols != null) {
098                sslEngine.setEnabledProtocols(enabledProtocols);
099            }
100
101            if (wantClientAuth) {
102                sslEngine.setWantClientAuth(wantClientAuth);
103            }
104
105            if (needClientAuth) {
106                sslEngine.setNeedClientAuth(needClientAuth);
107            }
108
109            sslSession = sslEngine.getSession();
110
111            inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
112            inputBuffer.clear();
113
114            outputStream = new NIOOutputStream(channel);
115            outputStream.setEngine(sslEngine);
116            this.dataOut = new DataOutputStream(outputStream);
117            this.buffOut = outputStream;
118            sslEngine.beginHandshake();
119            handshakeStatus = sslEngine.getHandshakeStatus();
120            doHandshake();
121
122        } catch (Exception e) {
123            try {
124                if(outputStream != null) {
125                    outputStream.close();
126                }
127                super.closeStreams();
128            } catch (Exception ex) {}
129            throw new IOException(e);
130        }
131    }
132
133    @Override
134    protected void doOpenWireInit() throws Exception {
135
136    }
137
138    public SSLEngine getSslSession() {
139        return this.sslEngine;
140    }
141
142    private volatile byte[] readData;
143
144    private final AtomicInteger readSize = new AtomicInteger();
145
146    public byte[] getReadData() {
147        return readData != null ? readData : new byte[0];
148    }
149
150    public AtomicInteger getReadSize() {
151        return readSize;
152    }
153
154    @Override
155    public void serviceRead() {
156        try {
157            if (handshakeInProgress) {
158                doHandshake();
159            }
160
161            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
162            plain.position(plain.limit());
163
164            while (true) {
165                if (!plain.hasRemaining()) {
166                    int readCount = secureRead(plain);
167
168                    if (readCount == 0) {
169                        break;
170                    }
171
172                    // channel is closed, cleanup
173                    if (readCount == -1) {
174                        onException(new EOFException());
175                        break;
176                    }
177
178                    receiveCounter += readCount;
179                    readSize.addAndGet(readCount);
180                }
181
182                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
183                    processCommand(plain);
184                    //we have received enough bytes to detect the protocol
185                    if (receiveCounter >= 8) {
186                        break;
187                    }
188                }
189            }
190        } catch (IOException e) {
191            onException(e);
192        } catch (Throwable e) {
193            onException(IOExceptionSupport.create(e));
194        }
195    }
196
197    @Override
198    protected void processCommand(ByteBuffer plain) throws Exception {
199        ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter);
200        if (readData != null) {
201            newBuffer.put(readData);
202        }
203        newBuffer.put(plain);
204        newBuffer.flip();
205        readData = newBuffer.array();
206    }
207
208
209    @Override
210    public void doStart() throws Exception {
211        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
212        // no need to init as we can delay that until demand (eg in doHandshake)
213        connect();
214    }
215
216
217    @Override
218    protected void doStop(ServiceStopper stopper) throws Exception {
219        if (taskRunnerFactory != null) {
220            taskRunnerFactory.shutdownNow();
221            taskRunnerFactory = null;
222        }
223    }
224
225
226}