function name = createClassFromWsdl(wsdl)
%createClassFromWsdl Create a MATLAB object based on a WSDL-file.
%   createClassFromWsdl('source') creates MATLAB classes based on a WSDL 
%   application programming interface (API). The source argument specifies a URL
%   or file path to a WSDL API, which defines web service methods, arguments, 
%   and transactions. It returns the name of the new class.
%  
%   Based on the WSDL API, the createClassFromWSDL function creates a new folder
%   in the current directory. The folder contains an M-file for each web service
%   method. In addition, two default M-files are created, the object's
%   display method (display.m) and its construtor (servicename.m).
%
%   Example
%  
%   cd(tempdir)
%   % Create an class for the web service provided by xmethods.net.
%   createClassFromWsdl('http://www.xmethods.net/sd/2001/BNQuoteService.wsdl');
%   % Instantiate the object.
%   bq = BNQuoteService;
%   % getPrice returns the price of a bok based on its ISBN.
%   getPrice(bq, '0735712719')
%
%   See also createSoapMessage, callSoapService, parseSoapResponse.

% Matthew J. Simoneau, June 2003
% $Revision: 1.1.6.4 $  $Date: 2004/12/27 23:33:02 $
% Copyright 1984-2004 The MathWorks, Inc.

% Parse the WSDL-file.
wsdlUrl = xmlstringinput(wsdl,true,false);
R = parseWsdl(wsdlUrl);

% Create the constructor and methods
for i = 1:length(R)
    makeconstructor(R(i))
    makemethods(R(i))
end
rehash path

% Return the name of the class.
if (length(R) == 1)
    name = R.name;
else
    name = {R.name};
end

%===============================================================================
function struct = parseWsdl(wsdl)

% Parse and process WSDL file. 
parser = org.apache.axis.wsdl.gen.Parser;
try
    parser.run(wsdl);
catch
    exception = regexp(lasterr, ...
        'Java exception occurred: \n(.*?)\s*\n','tokens','once');
    if ~isempty(exception)
        exception = exception{1};
        if strcmp(exception, ...
                'java.net.ConnectException: Connection refused: connect')
            error('Connection refused.');
        end
        if strcmp(exception, ...
                'ice.net.URLNotFoundException: Document not found on server')
            error('The requested URL was not found on this server.')
        end
        host = regexp(exception,'java.net.UnknownHostException: (.*)','tokens','once');
        if ~isempty(host)
            error('Unknown host: %s',host{1})
        end
        sax = regexp(exception,'org.xml.sax.SAXException: (.*)','tokens','once');
        if ~isempty(sax)
            error('Could not parse XML: %s',sax{1})
        end
        error(exception)
    else
        rethrow(lasterror)
    end
end
    
definition = parser.getCurrentDefinition;
symbolTable = parser.getSymbolTable;
typeData = getTypeData(symbolTable);

% This is the structure to return.
struct = [];

% Extracting information about each binding to create a MATLAB class.
it = symbolTable.getHashMap().values().iterator();
while it.hasNext
    v = it.next();
    for i = 1:v.size
        entry = v.elementAt(i-1);
        if isa(entry,'org.apache.axis.wsdl.symbolTable.BindingEntry')
            bindingEntry = entry;
        else
            continue
        end
        
        % For each binding...
        binding = bindingEntry.getBinding();

        % Find the service and port for this binding.
        serviceIterator = definition.getServices.values.iterator;
        port = [];
        while (isempty(port) && serviceIterator.hasNext)
            service = serviceIterator.next;
            portIterator = service.getPorts.values.iterator;
            while portIterator.hasNext;
                testPort = portIterator.next;
                if binding.getQName.equals(testPort.getBinding.getQName)
                    % Found it.  Keep variables port and service.
                    port = testPort;
                    break
                end
            end
        end

        % Deterimine the MATLAB object name from the service name.
        name = char(service.getQName.getLocalPart);
        name = genvarname(name(max([0 find(name == '.')])+1:end));

        % Construct the operations for this binding's portType.
        ptEntry = symbolTable.getPortTypeEntry(binding.getPortType().getQName());
        portType = ptEntry.getPortType;
        operations = portType.getOperations().iterator();
        ops = [];
        while (operations.hasNext)
            % For each operation...
            operation = operations.next;
            parameters = bindingEntry.getParameters(operation);
            bindingOperation = binding.getBindingOperation( ...
                operation.getName, ...
                operation.getInput.getName, ...
                operation.getOutput.getName);
            extension = bindingOperation.getExtensibilityElements.elementAt(0);
            if isa(extension,'javax.wsdl.extensions.soap.SOAPOperation')
                soapOperation = extension;
            else
                % Not a SOAP operation.  Skip.
                continue
            end
            op = makeOperation(operation,parameters,typeData);
            soapBody = bindingOperation.getBindingInput.getExtensibilityElements.elementAt(0);
            op.targetNamespaceURI = char(soapBody.getNamespaceURI);
            if isempty(op.targetNamespaceURI)
                op.targetNamespaceURI = char(definition.getTargetNamespace);
            end
            op.soapAction = char(soapOperation.getSoapActionURI);
            ops = [ops op];
        end
        % If there are SOAP operations defined, add it to the list.
        if ~isempty(ops)
            struct(end+1).name = name;
            struct(end).wsdlLocation = wsdl;
            struct(end).endpoint = char(port.getExtensibilityElements.elementAt(0).getLocationURI);
            struct(end).style = char(bindingEntry.getBindingStyle.toString);
            struct(end).methods = ops;
        end
    end
end

%===============================================================================
function typeData = getTypeData(symbolTable)
types = symbolTable.getTypes;
typeData = java.util.Hashtable;
for i = 1:types.size
    type = types.elementAt(i-1);
    name = char(type.getQName.toString);
    % TODO: preserve namespaces.
    name = getLocalName(name);    
    isType = isa(type,'org.apache.axis.wsdl.symbolTable.Type') || ...
        isa(type,'org.apache.axis.wsdl.symbolTable.CollectionElement');
    if (~isempty(type.getNode) && isType && isempty(type.getBaseType))
        v = org.apache.axis.wsdl.symbolTable.SchemaUtils.getContainedElementDeclarations(type.getNode,symbolTable);
        store = java.util.Vector;
        if ~isempty(v)
            for i = 1:v.size
                x = v.elementAt(i-1);
                store.add(char(x.getName.toString));
                store.add(char(x.getType.getQName.toString));
            end
        end
        if (store.size == 0)
            children = type.getNode.getChildNodes;
            for j = 1:children.getLength
                child = children.item(j-1);
                grandChild = getFirstChildNode(child);
                if strcmp('restriction',char(child.getLocalName))
                    % This is a restriction type, so use the base class.
                    store.add(char(child.getAttribute('base')));
                elseif strcmp('complexContent',char(child.getLocalName)) && ...
                        strcmp('restriction',char(grandChild.getLocalName))
                    % This is an array, so extract the arrayType.
                    base = grandChild.getAttribute('base');
                    % See if there is more information on the attribute node.
                    attributeNode = getFirstChildNode(grandChild);
                    if ~isempty(attributeNode)
                        n = char(attributeNode.getAttribute('wsdl:arrayType'));
                        if ~isempty(n)
                            base = n;
                        end
                    end
                    store.add(base);
                end
            end
            if (store.size > 1)
                error('Error parsing WSDL-file');
            end
            % TODO: There are other types we're not parsing completely.
        end
        typeData.put(name,store);
    end
end


%===============================================================================
function op = makeOperation(operation,parameters,typeData)

op = [];
if ~isequal(operation.getStyle,javax.wsdl.OperationType.REQUEST_RESPONSE)
    return
end

name = char(operation.getName);
op.methodName = name;

% Create documentation for M-file help.
if isempty(operation.getDocumentationElement)
    doc = '';
else
    doc = sprintf('%s\n\n', ...
        char(operation.getDocumentationElement.getTextContent));
end

% Calling parameters:
op.input = [];
if ~isempty(parameters)
    for j = 1:parameters.list.size
        p = parameters.list.get(j-1);
        parameterName = char(p.getQName.toString);
        nextName = getLocalName(parameterName);
        typeName = char(p.getType.getQName.toString);
        [nextType,isArray] = extractType(typeName,typeData);
        nextOp = struct('name',nextName,'type',nextType,'isArray',isArray);
        op.input = [op.input nextOp];
    end
end
doc = buildDoc(sprintf('%s  Input:\n',doc),op.input,'    ');

% Return parameter:
op.output = [];
if ~isempty(parameters)
    returnParam = parameters.returnParam;
    if ~isempty(returnParam)
        parameterName = char(returnParam.getQName.toString);
        nextName = getLocalName(parameterName);
        typeName = char(returnParam.getType.getQName.toString);
        [nextType,isArray] = extractType(typeName,typeData);
        op.output = struct('name',nextName,'type',nextType,'isArray',isArray);
    end
end
doc = buildDoc(sprintf('%s\n  Output:\n',doc),op.output,'    ');
    
% Save the documentation for the M-file help.
doc = regexprep(doc,'\s*$','');
doc = ['%   ' regexprep(doc,'\n','\n%   ')];
op.documentation = doc;


%===============================================================================
function doc = buildDoc(doc,x,prefix)
for i = 1:length(x)
    if isstruct(x(i).type)
        if x(i).isArray
            array = '(:)';
        else
            array = '';
        end
        doc = buildDoc(doc,x(i).type,[prefix x(i).name array '.']);
    else
        if x(i).isArray
            array = '{:}';
        else
            array = '';
        end
        doc = sprintf('%s%s%s%s = (%s)\n', ...
            doc,prefix,x(i).name,array,getLocalName(x(i).type));
    end
end

%===============================================================================
function [nextType,isArray] = extractType(typeName,typeData,nesting)

% Initialize variables.
nextType = [];
isArray = false;
typeName = getLocalName(typeName);
typeInfo = typeData.get(typeName);

% Keep track of the tree to detect recursion.
if (nargin < 3)
    nesting = {};
end
if ~isempty(strmatch(typeName,nesting,'exact'))
    warning('MATLAB:createClassFromWsdl:Unsupported', ...
        '"%s" is defined recursively.',typeName);
    nextType = typeName;
    return
end
nesting{end+1} = typeName;

if isempty(typeInfo) || (typeInfo.size == 0)
    if isempty(regexp(typeName,'\]$'))
        % Simpe type.
        %    op.input(end+1).name =
        %    op.input(end+1).type =
        nextType = restoreNamespace(typeName);
    else
        % It is an array.
        nextType = extractType(typeName(1:find(typeName == '[')-1),typeData,nesting);
        isArray = true;
    end    
elseif (typeInfo.size == 1)
    % Restriction.
    %    op.input(end+1).name =
    %    op.input(end+1).type = (where type is the base class)
    [nextType,isArray] = extractType(typeInfo.elementAt(0),typeData,nesting);
else
    % Nested type.
    %    op.input(end+1).name =
    %    op.input(end+1).type(1).name =
    %    op.input(end+1).type(1).type =
    %    op.input(end+1).type(2).name =
    %    op.input(end+1).type(2).type =
    for ii = 1:2:typeInfo.size
        nextType((ii+1)/2).name = getLocalName(char(typeInfo.elementAt(ii-1)));
        [nextType((ii+1)/2).type,nextType((ii+1)/2).isArray] = ...
            extractType(typeInfo.elementAt(ii),typeData,nesting);
    end
end


%===============================================================================
function s = getLocalName(s)
s = regexprep(s,'.*[}:]','');

%===============================================================================
function s = restoreNamespace(s)
xsd = '{http://www.w3.org/2001/XMLSchema}';
switch s
    case 'string'
        ns = xsd;
    otherwise
        ns = '';
end
s = [ns s];

%===============================================================================
function child = getFirstChildNode(node)
child = node.getFirstChild;
while ~isempty(child) && (child.getNodeType == child.TEXT_NODE)
    child = child.getNextSibling;
end


%===============================================================================
%===============================================================================
function makeconstructor(R)
% Create a constructure from a structure derived from a WSDL
tf = fullfile(fileparts(mfilename('fullpath')),'private','constructor.mtl');
template = textread(tf,'%s','delimiter','\n','whitespace','');

replacements = {'$CLASSNAME$',R.name,'$ENDPOINT$', ...
    R.endpoint,'$WSDLLOCATION$',R.wsdlLocation};
for i = 1:2:length(replacements)
    template = strrep(template,replacements{i},replacements{i+1});
end

[succ mess mid] = rmdir(['@' R.name],'s');
mkdir(['@' R.name])

writemfile(['@' R.name filesep R.name '.m'],template);

% Also create a display method
C = {'function display(obj)','disp(struct(obj))'};
writemfile(['@' R.name filesep 'display.m'],C);


%===============================================================================
function makemethods(R)
% Creates the methods for the WSDL described by R.

% Read in the template.
tf = fullfile(fileparts(mfilename('fullpath')),'private','genericmethod.mtl');
originalTemplate = textread(tf,'%s','delimiter','\n','whitespace','');

methodNames = genvarname({R.methods.methodName},R.name);
for iMethod = 1:length(R.methods)
    method = R.methods(iMethod);
    
    if isempty(method.output)
        outputNames = {};
    else
        outputNames = {method.output.name};
    end
    legalOutputNames = genvarname(outputNames);
    switch length(legalOutputNames)
        case 0
            outputString = '';
        case 1
            outputString = sprintf('%s = ',legalOutputNames{1});
    end

    if isempty(method.input)
        inputNames = {};
    else
        inputNames = {method.input.name};
    end
    legalInputNames = genvarname(inputNames);
    switch length(legalInputNames)
        case 0
            inputString = '(obj)';
        case 1
            inputString = sprintf('(obj,%s)',legalInputNames{1});
        otherwise
            inputString = sprintf('%s,',legalInputNames{:});
            inputString(end) = [];
            inputString = sprintf('(obj,%s)',inputString);
    end

    % Write out the parameter name, input name, and type mapping.
    s = sprintf('values = { ...\n');
    for i = 1:length(legalInputNames)
        s = sprintf('%s   %s, ...\n', ...
            s, ...
            legalInputNames{i});
    end
    s = sprintf('%s   };\n',s);
    s = sprintf('%snames = { ...\n',s);
    for i = 1:length(inputNames)
        s = sprintf('%s   ''%s'', ...\n', ...
            s, ...
            inputNames{i});
    end
    s = sprintf('%s   };\n',s);
    s = sprintf('%stypes = { ...\n',s);
    if isempty(method.input)
        inputTypes = {};
    elseif isstruct(method.input(1).type)
        inputTypes = {method.input.name};
    else
        inputTypes = {method.input.type};
    end
    for i = 1:length(inputNames)
        if isstruct(inputTypes{i})
            t = '';
        else
            t = inputTypes{i};
        end
        s = sprintf('%s   ''%s'', ...\n', ...
            s, ...
            t);
    end
    parameterDefinition = sprintf('%s   };',s);

    replacements = {'$METHODNAME$',method.methodName,...
        '$TARGETNAMESPACEURI$',method.targetNamespaceURI,...
        '$SOAPACTION$',method.soapAction,...
        '$OUTPUT$',outputString,...
        '$INPUT$',inputString, ...
        '$PARAMETERDEFINITION$',parameterDefinition, ...
        '$STYLE$',R.style, ...
        '$DOCUMENTATION$',method.documentation};
    template = originalTemplate;
    for i = 1:2:length(replacements)
        template = strrep(template,replacements{i},char(replacements{i+1}));
    end
    writemfile(['@' R.name filesep methodNames{iMethod} '.m'],template);
end


%===============================================================================
function status = writemfile(fname,C)
% Write a cell to file.

C = cellstr(C);
count = 0;
fid = fopen([pwd filesep fname],'w');
for i = 1:length(C);
    count = count + fprintf(fid,'%s\n',C{i});
end
status = fclose(fid);
if (count~=(sum(cellfun('length',C))+length(C))) || (status==-1)
    error(['Error writing file ' fname])
end
