/*
 * ISMEMBC2.CPP
 *
 * Helper function for ISMEMBER.M.
 *
 * This MEX-file handles the work for the ISMEMBER(A,S) syntax.
 * ISMEMBER must make sure that A and S are of the same class and that S is sorted 
 * by real part before calling this function.
 * This function returns the location of the found members.
 *
 * MATLAB Usage:  B = ISMEMBC2(A,S)
 *
 * Copyright 1984-2004 The MathWorks, Inc. 
 * $Revision: 1.1.6.1 $  $Date: 2004/12/06 16:35:37 $
 */

static char rcsid[] = "$Id: ismembc2.cpp,v 1.1.6.1 2004/12/06 16:35:37 batserve Exp $";

#include "mex.h"

void ValidateInputs(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[],
                    const mxArray **A, const mxArray **S)
{
    const int NumInputs = 2;
    const int NumOutputs = 1;
    if (nrhs > NumInputs) mexErrMsgIdAndTxt("MATLAB:ismembc2:TooManyInputs",
	"Too many input arguments.");  

    if (nrhs < NumInputs) mexErrMsgIdAndTxt("MATLAB:ismembc2:TooFewInputs",
	"Too few input arguments.");

    if (nlhs > NumOutputs) mexErrMsgIdAndTxt("MATLAB:ismembc2:TooManyOutputs",
	"Too many output arguments.");
    
    *A = prhs[0];
    if (!(mxIsNumeric(*A) || mxIsChar(*A) || mxIsLogical(*A)))
    {
        mexErrMsgIdAndTxt("MATLAB:ismembc2:InvalidA",
        "A must be a numeric,logical or char array.");
    }
    
    *S = prhs[1];
    if (mxGetClassID(*A) != mxGetClassID(*S))
    {
        mexErrMsgIdAndTxt("MATLAB:ismembc2:InvalidInputClass",
        "The set A and set S must have the same class.");
    }
}

/*
 * Where is the value (realPart + i*imagPart) found in the list of values in the
 * real-part and imaginary-part arrays prSet and piSet?
 */
template <class T> double IsInSet(T* realPart, T* imagPart, T* prSet, T* piSet, int numelSet)
{
    double found = 0.0;
    int lower;
    int upper;
    int midpoint;
    int k;

	
	if (realPart == NULL || prSet == NULL)
	{/*This should never happen*/
		return found;
	}

    if (numelSet > 0)
    {
        if ((*realPart >= prSet[0]) && (*realPart <= prSet[numelSet-1]))
        {
            /* Initialize bounds */
            lower = 0;
            upper = numelSet - 1;
            while ((upper - lower) > 1)
            {
                /* Find middle of the current region */
                midpoint = (lower + upper) >> 1;
                
                /* 
                 * How we shrink the region depends on whether realPart is in
                 * the upper half or the lower half.
                 */
                if (*realPart >= prSet[midpoint])
                {
                    lower = midpoint;
                }
                else
                {
                    upper = midpoint;
                }
            }
            
            /*
             * There may be more than one value in the Set that has the same real part,
             * so we have to loop over values in the Set until we've reached a Set
             * value whose real part is higher than realPart.
             */
            k = upper;
            while ((k >= 0) && (prSet[k] >= *realPart))
            {
                if (*realPart == prSet[k])
                {
                    if (piSet != NULL)
                    {
                        /* 
                         * The set is complex, so we have to check the imaginary
                         * part explicitly.
                         */
                        if ((imagPart != NULL && *imagPart == piSet[k]) || 
							(imagPart == NULL && piSet[k] == 0))
                        {
                            found = k+1;
                            break;
                        }
                    }
                    else
                    {
                        /*
                         * The set is real, so we just have to check see if imagPart is 0.
                         */
                        if (imagPart == NULL || *imagPart == 0.0)
                        {
                            found = k+1;
                            break;
                        }
                    }
                }
                k--;
            }
        }
    }
    return(found);
}

template <class T> void checkSet(double* B, void *prA, void *piA, T *prS, T *piS, int numelA, int numelS)
{
	T *realPart = (T*)prA;
	T *imagPart = (T*)piA;
	if (realPart == NULL)
	{/*Should not happen*/
		return;
	}
	for (int k = 0; k < numelA; k++)
	{
		B[k] = IsInSet(realPart, imagPart, prS, piS, numelS);
		/*Now iterate through by incrementing pointer*/
		realPart++;
		if (imagPart != NULL)
		{
			imagPart++;
		}
	}
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    int numDimsA;
    const int *dimsA;
    int numelA;
    int numelS;
    mxArray *B = NULL;
    const mxArray *A = NULL;
    const mxArray *S = NULL;
    void *prS;
    void *piS;
    double *plB;
    void *prA;
    void *piA;
    mxClassID classA;
    
    ValidateInputs(nlhs, plhs, nrhs, prhs, &A, &S);

    numDimsA = mxGetNumberOfDimensions(A);
    dimsA = mxGetDimensions(A);

    B = mxCreateNumericArray(numDimsA, dimsA, mxDOUBLE_CLASS, mxREAL);
    prS = mxGetData(S);
    piS = mxGetImagData(S);
    numelS = mxGetNumberOfElements(S);
    numelA = mxGetNumberOfElements(A);
    plB = (double*) mxGetData(B);
    prA = mxGetData(A);
    piA = mxGetImagData(A);
    classA = mxGetClassID(A);
    
    switch (classA) {
		case mxUINT8_CLASS:
		case mxLOGICAL_CLASS:
		{
			checkSet(plB, prA, piA, (uint8_T*)prS, (uint8_T*)piS, numelA, numelS);
			break;
		}
		case mxINT8_CLASS: 
		{
			checkSet(plB, prA, piA, (int8_T*)prS, (int8_T*)piS, numelA, numelS);
			break;
		}
		case mxCHAR_CLASS:
		case mxUINT16_CLASS: 
		{
			checkSet(plB, prA, piA, (uint16_T*)prS, (uint16_T*)piS, numelA, numelS);
			break;
		}
		case mxINT16_CLASS: 
		{
			checkSet(plB, prA, piA, (int16_T*)prS, (int16_T*)piS, numelA, numelS);
			break;
		}
		case mxUINT32_CLASS: 
		{
			checkSet(plB, prA, piA, (uint32_T*)prS, (uint32_T*)piS, numelA, numelS);
			break;
		}
		case mxINT32_CLASS: 
		{
			checkSet(plB, prA, piA, (int32_T*)prS, (int32_T*)piS, numelA, numelS);
			break;
		}
		case mxUINT64_CLASS: 
		{
			checkSet(plB, prA, piA, (uint64_T*)prS, (uint64_T*)piS, numelA, numelS);
			break;
		}
		case mxINT64_CLASS: 
		{
			checkSet(plB, prA, piA, (int64_T*)prS, (int64_T*)piS, numelA, numelS);
			break;
		}
		case mxDOUBLE_CLASS: 
		{
			checkSet(plB, prA, piA, (real64_T*)prS, (real64_T*)piS, numelA, numelS);
			break;
		}
		case mxSINGLE_CLASS: 
		{
			checkSet(plB, prA, piA, (real32_T*)prS, (real32_T*)piS, numelA, numelS);
			break;
		}
	}
    plhs[0] = B;
}
