/* XmpSession.java

   COPYRIGHT 2008 KRUPCZAK.ORG, LLC.

   This program is free software; you can redistribute it and/or
   modify it under the terms of the GNU General Public License as
   published by the Free Software Foundation; either version 2 of the
   License, or (at your option) any later version.

   This program is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
   General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
   USA
 
   For more information, visit:
   http://www.krupczak.org/
*/

package org.krupczak.xmp;

import java.net.*;
import java.io.*;
import java.util.Date;
import java.io.UnsupportedEncodingException;
import javax.net.*;
import java.security.*;
import java.security.cert.*;
import javax.net.ssl.*;
import java.util.Scanner;

/**
 *  XmpSession is used to communicate with a Cartographer agent or
 *  another manager using the XMP protocol.  Instantiating an object
 *  causes an XMP session to be established with an XMP entity.  A
 *  session consists of an SSL/TCP connection over which multiple XMP
 *  messages can be sent and received.  Each XMP message is prepended
 *  by an all-ASCII wire header which contains the protocol version
 *  number (4-chars), length of XMP message (12-chars), and the
 *  user/profile (16-chars) in which the message should be evaluated.
 *  User/profiles are configured a priori in each agent.
 *  @author Bobby Krupczak, rdk@krupczak.org
 *  @version $Id: XmpSession.java 15 2008-07-17 14:20:37Z rdk $
 *  @see Xmp
 *  @see XmpVar
 *  @see XmpMessage
 **/

public class XmpSession {

  /* class variables and methods *********************** */
  public static final int STATE_CLOSED = 0;
  public static final int STATE_OPEN = 1;

  /* instance variables ****************************** */
  public String targetAddr;
  public int targetPort;
  Socket s;
  int state;
  int msgsIn, msgsOut, bytesIn, bytesOut;
  OutputStream sockout;
  InputStream sockin;
  Date lastUsage;
  SocketOpts sockopts;
  String authenUser;
  public boolean dumpPDUs;
  public int errorStatus; // error status of last PDU received

  /* constructors  ***************************** */

 /** default constructor called by other constructors **/
 public XmpSession(SocketOpts sockopts) {
     this.sockopts = sockopts;
     targetPort = 0;
     targetAddr = null;
     s = null;
     state = STATE_CLOSED;

     msgsIn = 0;
     msgsOut = 0;
     bytesIn = 0;
     bytesOut = 0;
     lastUsage = null;
     authenUser = null;
     dumpPDUs = false;
     errorStatus = Xmp.ERROR_NOERROR;
  }

  /** Create session; provide hostname or IP address in string format **/
  public XmpSession(SocketOpts sockopts, String target, int port)
  {
     this(sockopts);

     /* open socket with target/port */
     targetAddr = target;
     targetPort = port;

     try {
       s = sockopts.sslSocketFactory.createSocket();
       s.connect(new InetSocketAddress(targetAddr,targetPort),
                 sockopts.getConnectTimeout());
       s.setReuseAddress(true);
       state = STATE_OPEN;
       sockin = s.getInputStream();
       sockout = s.getOutputStream();
     } catch (IOException e) { 
         state = STATE_CLOSED; 

         // System.out.println("XmpSession: connect failed "+e.getMessage());

         // try to get regular non-ssl socket
         try {
           s = sockopts.socketFactory.createSocket();
           s.connect(new InetSocketAddress(targetAddr,targetPort),
                     sockopts.getConnectTimeout());
           s.setReuseAddress(true);
           state = STATE_OPEN;
           sockin = s.getInputStream();
           sockout = s.getOutputStream();
         } catch (IOException e1) {
           //System.out.println("XmpSession: connect failed "+e1.getMessage());
            state = STATE_CLOSED; 
            return; 
         }
     }

  } /* XmpSession(String target,port) */

  /** create session; provide everything plus authenUser **/
  public XmpSession(SocketOpts sockopts, String target, int port, 
                    String authenUser)
  {
      this(sockopts,target,port);
      this.authenUser = authenUser;
  }

  /** Create session; provide IPv4 address; no IPv6 supported at this
   *  time
    **/
  public XmpSession(SocketOpts sockopts, InetAddress target, int port) {
      /* open socket with target/port */
      this(sockopts,target.getHostAddress(),port);
  }

  public XmpSession(SocketOpts sockopts, InetAddress target, int port, 
                    String authenUser) 
  {
      /* open socket with target/port */
      this(sockopts,target.getHostAddress(),port);
      this.authenUser = authenUser;
  }

  /* private methods **************************** */

  private void setErrorStatus(int e) { errorStatus = e; }

  private boolean isAlphaNumeric(byte aCh)
  {
      if (aCh >= 'A' && aCh <= 'Z')
	 return true;
      if (aCh >= 'a' && aCh <= 'z')
	 return true;
      if (aCh >= '0' && aCh <= '9')
	 return true;
      return false;
  }

  /** Each message is preceded by a wirehdr containing 4-character xmp
   *  version, a 12 character ascii length field followed by a 16-char
   *  ascii authenUser field the length field specifies the length of
   *  the following message; we do this so that our XML parsing is
   *  simpler Read 12-char int, convert, and return;
  **/
  synchronized private int readWireHdr() {
     byte[] wireHdr = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                       0,0,0,0,0,0,0,0,0,0,0,0}; /* 32 bytes */
     int len,ret,version;
     String str,lengthStr, versionStr, userStr,authenUser;
     int i,startIndex,endIndex;
     Scanner sc;
     int messageLength;
     
     if (sockin == null) {
	return -1;
     }

     try {
       ret = sockin.read(wireHdr);
       if (ret != 32) {
	  System.out.println("XmpSession: read "+ret+" bytes instead of 32");
          return -1;

       }
     } catch (IOException e) { return -1; }

     bytesIn += wireHdr.length;

     // parseInt(), valueOf(), Scanner, etc., all choke if they
     // find a non-decimal digit; our integer could
     // be left justified, right justified, padded
     // on left or right with spaces or zeros 
     // in order to prevent the int parser from choking
     // on these and in order to avoid having to write our
     // own int parser, we go through wireheader and convert
     // any non-ascii char to a space ' '; then scanner
     // parser will work

     for (i=0; i<wireHdr.length; i++) {
         if (isAlphaNumeric(wireHdr[i]) == false)
	    wireHdr[i] = ' ';
     }

     // dig out the fields 4, 12, 16 bytes respectively in
     // version 1 of XMP
     versionStr = new String(wireHdr,0,4);
     lengthStr = new String(wireHdr,4,12);
     userStr = new String(wireHdr,16,16);

     // parse version; 4 chars
     sc = new Scanner(versionStr);
     try {
       version = sc.nextInt();
     } catch (Exception e) { version = -1; }
     if (version != Xmp.XMP_VERSION) {
	System.out.println("XmpSession: readWireHdr found version "+
                           version+" instead of "+Xmp.XMP_VERSION);
     }

     // parse length; 12 chars
     sc = new Scanner(lengthStr);
     try {
       messageLength = sc.nextInt();
     } catch (Exception e) { messageLength = -1; }

     // authenUser field; 16 chars
     sc = new Scanner(userStr);
     try {
	 authenUser = sc.next();
     } catch (Exception e) { authenUser = null; }

     //System.out.println("readWireHdr: version: "+version+" len: "+
     //                    messageLength+" user:" +authenUser);

     // save authenUser in our object instance
     this.authenUser = authenUser;

     return messageLength;

  } /* readWireHdr */

  synchronized private int writeWireHdr(int len) 
  {
     byte[] wireHdr = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
                       0,0,0,0,0,0,0,0,0,0,0,0}; /* 32 bytes */
     byte[] strBytes;
     int ret;
     String str;

     if (len < 1)
        return -1;

     if (sockout == null) {
        System.out.println("writeWireHdr: sockout is null"); 
        return -1;
     }

     if (authenUser == null) {
	System.out.println("wireWireHdr: null authenUser");
        return -1;
     }

     // xmp version is 4 chars
     str = Integer.toString(Xmp.XMP_VERSION);
     try {
	 strBytes = str.getBytes("US-ASCII");
     }
     catch (UnsupportedEncodingException e) {
       strBytes = str.getBytes();
     }
     if (strBytes.length > 4) {
        System.arraycopy(strBytes,0,wireHdr,0,4);
     }
     else {
        System.arraycopy(strBytes,0,wireHdr,0,strBytes.length);
     }

     // length field is 12 chars 
     str = Integer.toString(len);
     try {
       strBytes = str.getBytes("US-ASCII");
     }
     catch (UnsupportedEncodingException e) {
       strBytes = str.getBytes();
     }
     if (strBytes.length > 12)
        System.arraycopy(strBytes,0,wireHdr,4,12);
     else
        System.arraycopy(strBytes,0,wireHdr,4,strBytes.length);
     //System.out.println("writeWireHdr: len is "+str);  

     // authenUser is 16 chars
     try {
       strBytes = authenUser.getBytes("US-ASCII");
     }
     catch (UnsupportedEncodingException e) {
       strBytes = authenUser.getBytes();
     }
     if (strBytes.length > 16)
        System.arraycopy(strBytes,0,wireHdr,16,16);
     else
        System.arraycopy(strBytes,0,wireHdr,16,strBytes.length);

     // write out the 32-char wire header
     try {

       sockout.write(wireHdr); 

       bytesOut += 32;

       //System.out.println("writeWireHdr: "+wireHdr.toString());

     } catch (IOException e) { return -1; }

     return wireHdr.length;
  } /* writeWireHdr() */

  /* public methods ***************************** */

  public int getConnectTimeout() { return sockopts.getConnectTimeout(); }

  public int getErrorStatus() { return errorStatus; }

  /** set the connection timeout; a timeout of 0 effectively means
    * block for a long time (the default imeout); when calling
    * socket() w/o a timeout value set, the timeout is unknown 
  **/
  public void setConnectTimeout(int newTimeout) 
  {  

      System.out.println("XmpSession: setting timeout to "+newTimeout);

      sockopts.setConnectTimeout(newTimeout);         
  }

  /** set the authenUser to send in messages **/
  public void setAuthenUser(String user) { authenUser = user; }

  /** return the authenUser that is configured for sending or that was
   * recently received 
  **/
  public String getAuthenUser() { return authenUser; }

  /** obtain which SSL version/protocol was used for this session **/
  public String getSessionProto() 
  {
      SSLSession se;

      se = ((SSLSocket)s).getSession();

      if (se != null) {
	 return se.getProtocol();
      }
      else {
	  return new String("No proto");
      }
  }

  /** open session with target/port **/
  synchronized public int openSession(InetAddress target, int port) {

     if (state != STATE_CLOSED)
        return -1;

     targetAddr = target.toString();
     targetPort = port;

     try {

       s = sockopts.sslSocketFactory.createSocket();
       s.connect(new InetSocketAddress(targetAddr,targetPort),
                 sockopts.getConnectTimeout());
       s.setReuseAddress(true);
       state = STATE_OPEN;
       sockin = s.getInputStream();
       sockout = s.getOutputStream();

       msgsIn = 0;
       msgsOut = 0;
       bytesIn = 0;
       bytesOut = 0;
       errorStatus = Xmp.ERROR_NOERROR;

     } catch (IOException e) { 
         state = STATE_CLOSED;

         // try to get regular non-ssl socket

         try {
           s = sockopts.socketFactory.createSocket();
           s.connect(new InetSocketAddress(targetAddr,targetPort),
                     sockopts.getConnectTimeout());
           s.setReuseAddress(true);
           state = STATE_OPEN;
           sockin = s.getInputStream();
           sockout = s.getOutputStream();
 
         } catch (IOException e1) {
	     //System.out.println("Caught IO exception "+e1.getMessage()); 
            state = STATE_CLOSED; 
            return -1; 
         }
     }

     return 1;
  }

  synchronized public int openSession(String target, int port) 
  {
     if (state != STATE_CLOSED)
        return -1;

     try {

       return openSession(InetAddress.getByName(target),port);
  
     } catch (Exception e) {
	return -1;
     }

  }

  synchronized public int closeSession() 
  {

      if (state != STATE_OPEN)
         return -1;
      try {
        s.close();
      } catch (IOException e) { }

      state = STATE_CLOSED;
      return 1;
  }

  synchronized public boolean isClosed() 
  {
      if (state == STATE_CLOSED)
	 return true;
      return false;
  }

  synchronized public boolean isOpen()
  {
      if (state == STATE_OPEN)
	 return true;
      return false;
  }

  /** given session that has been closed, re TCP/SSL connect **/
  synchronized public boolean reConnect() 
  {

     if (isClosed() == false)
        return false;

     try {

       s = sockopts.sslSocketFactory.createSocket();
       s.connect(new InetSocketAddress(targetAddr,targetPort),
                 sockopts.getConnectTimeout());
       s.setReuseAddress(true);
       state = STATE_OPEN;
       sockin = s.getInputStream();
       sockout = s.getOutputStream();
 
       msgsIn = 0;
       msgsOut = 0;
       bytesIn = 0;
       bytesOut = 0;
       errorStatus = Xmp.ERROR_NOERROR;

       return true;

     } catch (IOException e) { 

	 //System.out.println("XmpSession: reConnect failed "+e.getMessage());

       try {

         s = sockopts.socketFactory.createSocket();
         s.connect(new InetSocketAddress(targetAddr,targetPort),
                   sockopts.getConnectTimeout());
         s.setReuseAddress(true);
         state = STATE_OPEN;
         sockin = s.getInputStream();
         sockout = s.getOutputStream();

         msgsIn = 0;
         msgsOut = 0;
         bytesIn = 0;
         bytesOut = 0;
         errorStatus = Xmp.ERROR_NOERROR;

         return true; 

       } catch (IOException e1) {
         state = STATE_CLOSED; 
         return false;
       }

     }

  } /* reConnect() */

  /** send an XMP message via this established session **/
  synchronized public int sendMessage(String message) 
  {

      // write the wire header 
      if (message == null)
         return -1;

      // convert String to array of bytes
      // assume 1-byte ascii chars
      byte[] buf = message.getBytes();

      if (writeWireHdr(message.length()) < 32) {
	 closeSession();
         return -1;
      }

      if (buf.length != message.length()) {
         System.out.println("sendMessage: lengths dont match");
      }

      try {
        sockout.write(buf,0,buf.length);
      } catch (IOException e) { 
        closeSession();
        return -2; 
      }
 
      msgsOut++;
      bytesOut += message.length();
      lastUsage = new Date();

      return 1;
  }

  synchronized public int sendMessage(XmpMessage msg) {

      // encode it if it is not already encoded
      if (msg.isEncoded() == false) {
	  if (msg.xmlEncodeMessage() < 1) {
	     return -3;
          }
      }

      if (dumpPDUs == true) {
	 msg.dump();
      }

      return sendMessage(msg.getXmlEncoding());
  }

  // receive message into XmpMessage object and update
  // the last errorStatus value
  synchronized public XmpMessage recvMessage() 
  { 
      XmpMessage msg = new XmpMessage();
      int msgLen,count,ret;

      // get wire header first to see how much space we need
      if ((msgLen = readWireHdr()) < 4) {
	 closeSession();
         setErrorStatus(Xmp.ERROR_INVALID);
         return null;
      }

      byte[] buf = new byte[msgLen];

      try {
        count = 0;
        while (count < msgLen) {
          ret = sockin.read(buf,count,msgLen-count);
          count+= ret;
          bytesIn += ret;
        }
      } catch (IOException e) { 
        setErrorStatus(Xmp.ERROR_INVALID); 
        return null; 
      }

      if (count != msgLen) {
	  System.out.println(count+"does not equal "+msgLen);
      }

      // XML-parse the message into internal format
      ret = msg.xmlDecodeMessage(buf);

      if (msg.isDecoded() == false) {
         System.out.println("recvMessage: decode returned "+ret);
         setErrorStatus(Xmp.ERROR_PARSEERROR);
	 return null;
      }

      msgsIn++;
      lastUsage = new Date();

      if (dumpPDUs)
	 msg.dump();

      setErrorStatus(msg.getErrorStatus());

      return msg;
  }

  // receive message into a string and return to user
  synchronized public String recvMessage(String ignore) {
      int msgLen,count,ret;

      // get wire header first to see how much space we need
      if ((msgLen = readWireHdr()) < 4) {
	 closeSession();
         return null;
      }

      byte[] buf = new byte[msgLen+1];

      try {
        count = 0;
        while (count < msgLen) {
          ret = sockin.read(buf,count,msgLen-count);
          count+= ret;
          bytesIn += ret;
        }
      } catch (IOException e) { 
          closeSession();
          return null; 
      }

      lastUsage = new Date();
      msgsIn++;
      lastUsage = new Date();

      return new String(buf,0,msgLen);
  } /* recvMessage(String) */

  synchronized public boolean messageAvailable() { 
      try {
	// check to see if wire-header plus message
        if (sockin.available() > 4)
	   return true;
      } catch (IOException e) { return false; }

      return false;
  }

  // send a message already encoded as XML
  synchronized public XmpMessage sendMessageReceiveReply(XmpMessage msg) {
     XmpMessage aReply;

     if (state == STATE_CLOSED) {
         if (reConnect() == false)
   	    return null;
     }

     // encode the message
     if (msg.xmlEncodeMessage() < 1) {
	return null;
     }
     if (sendMessage(msg) < 0)
	return null;
 
     aReply = recvMessage();

     return aReply;

  }  /* sendMessageReceiveReply(XmpMessage) */

  // send a message already encoded as XML, receive
  // the XML-encoded reply
  synchronized public String sendMessageReceiveReply(String msg) {
     String aReply;
     int ret;

     if (state == STATE_CLOSED) {
         if (reConnect() == false)
   	    return null;
     }

     if (sendMessage(msg) < 0) {
        return null;
     }
     
     // wait for wireHdr for reply
     aReply = recvMessage(null);

     return aReply;

  } /* sendMessageReceiveReply(String) */

  public int getMessagesIn() { return msgsIn; }
  public int getBytesIn() { return bytesIn; }
  public int getMessagesOut() { return msgsOut; }
  public int getBytesOut() { return bytesOut; }

  public int getSessionIdleTime()
  {
      Date theDate = new Date();
      long diffTime;

      if (lastUsage == null)
	 return -1;
      
      diffTime = theDate.getTime() - lastUsage.getTime();
      diffTime = diffTime / 1000;

      return (int)diffTime; /* in seconds or -1 if never used */
  }

  public XmpMessage queryVars(XmpVar[] vars)
  {
      XmpMessage query, reply;
      int ret;

      query = new XmpMessage(Xmp.MSGTYPE_GETREQUEST);
      query.setDecoded();
      query.setMIBVars(vars);
      if ((ret = sendMessage(query)) < 0) {
	 //System.out.println("XmpSession: failed to send GetRequest "+ret);
	 return null;
      }
      reply = recvMessage();

      if (reply == null) {
         System.out.println("session.queryVars received null response");
	 return null;
      }

      if ((reply != null) && (reply.getMsgType() != Xmp.MSGTYPE_RESPONSE)) {
	  System.out.println("session.queryVars returned "+
			       Xmp.messageTypeToString(reply.getMsgType())+
                             " instead of reponse");
          return null;
      }
       
      if (reply.getErrorStatus() != Xmp.ERROR_NOERROR) {
	  System.out.println("session.queryVars returned error "+
	     "Reply Error: "+Xmp.errorStatusToString(reply.getErrorStatus()));
	  return null;
      }
      if (reply.getMessageID() != query.getMessageID()) {
	  System.out.println("session.queryVars message ids dont match "+
                             query.getMessageID()+" and "+
                             reply.getMessageID());
          return null;
      }

      return reply;
  }

  public XmpMessage queryTableVars(String[] table, int maxRows, 
                                   XmpVar[] vars) 
  { 
      int ret;
      XmpMessage query, reply;

      if (table.length != 3) {
         return null;
      }

      query = new XmpMessage(Xmp.MSGTYPE_SELECTTABLEREQUEST);
      query.setDecoded();

      // mibname, tablename, keyvalue, maxRows
      query.setTableOperands(table[0],table[1],table[2],0);

      // set the MIB vars (e.g. columns) from table
      query.setMIBVars(vars);

      if ((ret = sendMessage(query)) < 0) {
          //System.out.println("XmpSession: failed to send SelectTableRequest "+ret);
          return null;
      }

      reply = recvMessage();

      if (reply == null) {
         System.out.println("session.queryTableVars received null response");
	 return null;
      }

      if ((reply != null) && (reply.getMsgType() != Xmp.MSGTYPE_RESPONSE)) {
	  System.out.println("session.queryTableVars returned "+
			       Xmp.messageTypeToString(reply.getMsgType())+
                             " instead of reponse");
          return null;
      }
       
      if (reply.getErrorStatus() != Xmp.ERROR_NOERROR) {
	  System.out.println("session.queryTableVars returned error "+
	     "Reply Error: "+Xmp.errorStatusToString(reply.getErrorStatus()));
	  return null;
      }
      if (reply.getMessageID() != query.getMessageID()) {
	  System.out.println("session.queryTableVars message ids dont match "+
                             query.getMessageID()+" and "+
                             reply.getMessageID());
          return null;
      }

      return reply;

  } /* queryTableVars() */

  public boolean dumpPDUs() { return dumpPDUs; }
  public void setDumpPDUs(boolean val) { dumpPDUs = val; }

} /* class XmpSession */

