#!/usr/bin/python

import os
import sys
import gettext
PROGNAME="system-config-bind"
gettext.bindtextdomain(PROGNAME, "/usr/share/locale")
gettext.textdomain(PROGNAME)

try:
    gettext.install(PROGNAME, "/usr/share/locale", 1)
except IOError:
    import __builtin__
    __builtin__.__dict__['_'] = unicode    

from Conf import *
TRUE=1
FALSE=not TRUE

hname_re = re.compile('^[\*a-zA-Z0-9.\-]+$')

ip_re = re.compile('^([0-2]?[0-9]?[0-9])\\.([0-2]?[0-9]?[0-9])\\.([0-2]?[0-9]?[0-9])\\.([0-2]?[0-9]?[0-9])(?:\\.in-addr\\.arpa\\.?)?$')

revip_re = re.compile('^([0-2]?[0-9]?[0-9])(?:\\.([0-2]?[0-9]?[0-9]))?(?:\\.([0-2]?[0-9]?[0-9])\\.)?(?:\\.([0-2]?[0-9]?[0-9]))?(?:\\.in-addr\\.arpa)$')


def testName(value):
    if not hname_re.match(value):
        raise TestError, _("Zone name %s is not a valid domain name.") % value
		
class TestError(Exception):
	def __init__(self, args=None):
		#raise Error("Here with args" + args)
		self.args = args

	#def __str__(self):
		#return Exception.__str__(self)

def testServedBy(value):
    pass

def testFile(value):
    pass

def testHost(value):
    if not (hname_re.match(value) or value == '@'):
        raise TestError, _("%s is not a valid hostname.") % value
    
def testIp(value):
    if not checkIpNum(value):
        raise TestError, _("%s is not a valid IP address.") % value

def checkIpNum(value):
	try:
		m = ip_re.match(value)
		if not m:
			return FALSE
		nums = m.groups()
		if not nums or ( nums and ( len(nums) != 4 )) :
			return FALSE
		
		for i in xrange(0, 4):
			if nums[i] == None:
				return FALSE
			num = int(nums[i])
			if (num < 0) or (num > 255):
				return FALSE
	except TypeError:
		return FALSE
	
	return TRUE


def checkRevIpNum(value):
	try:
		m = revip_re.match(value)
		if not m:
			return FALSE
		nums = m.groups()
		
		if not nums: return FALSE
		
		#print len(nums), " matching groups!", nums
		for i in xrange(0, len(nums)):
			if nums[i] == None:
				continue
			num = int(nums[i])
			if (num < 0) or (num > 255):
				return FALSE
	except TypeError:
		return FALSE

	return TRUE

def checkTTL(ttl):
	if ttl < 0:
		return FALSE
	return TRUE
	# add more TTL checking code here

# Zone
#  This class presents a data-oriented class for making changes
#  to the /var/named/ Zone files
class SOA:
    def __init__(self,pns="@",server="localhost",contact="root"):
        self.pns=pns
        self.server=server
        self.contact=contact
        self.SOA=[1,28800,14400,3600000,86400]
        
    def load(self,SOAList):
        self.SOA=SOAList
        
    def getPNS(self):
        return self.pns
    def setPNS(self,pns):
        self.pns=pns

    def getTTL(self):
        return self.SOA[4]
    def setTTL(self,ttl):
        self.SOA[4]=ttl

    def getRefresh(self):
        return self.SOA[1]
    def setRefresh(self,refresh):
        self.SOA[1]=refresh

    def getRetry(self):
        return self.SOA[2]
    def setRetry(self,retry):
        self.SOA[2]=retry

    def getExpire(self):
        return self.SOA[3]
    def setExpire(self,expire):
        self.SOA[3]=expire

    def getSerial(self):
        return self.SOA[0]
    def setSerial(self,i):
        self.SOA[0]=i
    def incSerial(self):
        self.SOA[0]=self.SOA[0]+1

    def setContact(self,contact):
        self.contact=contact
    def getContact(self):
        return self.contact
    def setServer(self,server):
        self.server=server
    def getServer(self):
        return self.server

    def testPNS(self, value):
        if not value:
            raise TestError, _("Primary Nameserver (SOA) not defined.")
		
        if ((len(value) == 1) and value == "@"):
            return

        if not ((len(value) > 1) and value[-1] == '.'):
            raise TestError, _("Primary Nameserver (SOA) '%s' has no . at the end. You must use a full hostname.") %  value
        if not hname_re.match(value):
            raise TestError, _("Primary Nameserver (SOA) '%s' is not a valid hostname or IP address.") % value
		
    def testSerial(self, value):
        if value <= 0:
            raise TestError, _("Serial number must be >= 1.")
        
    def testRefresh(self, value):
        if not checkTTL(value):
            raise TestError, _("Refresh value must be >= 0.")

    def testRetry(self, value):
        if not checkTTL(value):
            raise TestError, _("Retry value must be >= 0.")

    def testExpire(self, value):
        if not checkTTL(value):
            raise TestError, _("Expire value must be >= 0.")

    def testTTL(self, value):
        if not checkTTL(value):
            raise TestError, _("TTL value must be >= 0.")

    def out(self):
        ret="%s\tIN\tSOA\t%s\t%s\t(\n" % (self.getPNS(),self.getServer(),self.getContact())
        ret="%s\t\t\t\t%d ; serial\n" % (ret,self.getSerial())
        ret="%s\t\t\t\t%d ; refresh\n" % (ret,self.getRefresh())
        ret="%s\t\t\t\t%d ; retry\n" % (ret,self.getRetry())
        ret="%s\t\t\t\t%d ; expire\n" % (ret,self.getExpire())
        ret="%s\t\t\t\t%d ; ttl\n" % (ret,self.getTTL())
        ret="%s\t\t\t\t)\n\n\n" % (ret)
        return ret

    
class ZoneList:
    def __init__(self):
        self.list=[]

    def append(self,x):
        self.list.append(x)

    def getList(self,type=""):
        l=[]
        for i in self.list:
            if i.getType()==type or type=="":
                l.append(i)
        return l

    def remove(self,rec):
        self.list.remove(rec)
    
class ZoneRec:
    def __init__(self,f0,f1,f2,f3,f4,parent):
        self.f0=f0
        self.f1=f1
        self.f2=f2
        self.f3=f3
        self.f4=f4
        self.parent=parent
    def getType(self):
        return self.f2
    def unlink(self):
        self.parent.unlink(self)
    def getTtl(self):
        return self.f1
    def setTtl(self,ttl):
        self.f1=ttl
    def out(self):
        return "%s\t%s\tIN\t%s\t%s\t%s" % (self.f0,self.f1,self.f2,self.f3, self.f4)
    def unlink(self):
        self.parent.unlink(self)


class A(ZoneRec):
    def __init__(self,f1,f2,f3,f4,f5,parent=None):
        ZoneRec.__init__(self,f1,f2,f3,f4,f5,parent);
    def getHost(self):
        return self.f0
    def setHost(self,host):
        self.f0=host
    def getName(self):
        return self.getHost()
    def getIp(self):
        return self.f3
    def setIp(self,ip):
        self.f3=ip
    def testHost(self, value):
        testHost(value)

    def testIp(self, value):
        testIp(value)

    def get_str(self):
        return _("host '%s'") % self.getHost()
		
class PTR(ZoneRec):
    def __init__(self,f1,f2,f3,f4,f5,parent=None):
        ZoneRec.__init__(self,f1,f2,f3,f4,f5,parent);
    def getHost(self):
        return self.f3
    def setHost(self,host):
        self.f3=host
    def getName(self):
        return self.getHost()
    def getIp(self):
        return self.f0
    def setIp(self,ip):
        self.f0=ip
    def testHost(self, value):
        testHost(value)
    def testIp(self, value):
        testIp(value)

    def get_str(self):
        return _("host '%s'") % self.getHost()
		
class CNAME(ZoneRec):
    def __init__(self,f1,f2,f3,f4,f5,parent=None):
        ZoneRec.__init__(self,f1,f2,f3,f4,f5,parent);
    def getHost(self):
        return self.f3
    def setHost(self,host):
        self.f3=host
    def getName(self):
        return self.getHost()
    def getAlias(self):
        return self.f0
    def setAlias(self,alias):
        self.f0=alias
    def testAlias(self, value):
        testHost(value)

class HINFO(ZoneRec):
    def __init__(self,f1,f2,f3,f4,f5,parent=None):
        ZoneRec.__init__(self,f1,f2,f3,f4,f5,parent);
    def getHost(self):
        return self.f0
    def setHost(self,host):
        self.f0=host
    def getName(self):
        return self.getHost()
    def getAlias(self):
        return self.f3
    def setAlias(self,alias):
        self.f3=alias
    def testAlias(self, value):
        testHost(value)

class SRV(ZoneRec):
    def __init__(self,f1,f2,f3,f4,f5,parent=None):
        ZoneRec.__init__(self,f1,f2,f3,f4,f5,parent);

class MX(ZoneRec):
    def __init__(self,f1,f2,f3,f4,f5,parent=None):
        ZoneRec.__init__(self,f1,f2,f3,f4,f5,parent);
    def getServer(self):
        return self.f4
    def setServer(self,server):
        self.f4=server
    def setName(self,name):
        self.f0=name
    def getName(self):
        return self.f0
    def getPriority(self):
        return int(self.f3)
    def setPriority(self,priority):
        self.f3=str(priority)

class NS(ZoneRec):
    def __init__(self,f0,f1,f2,f3,f4,parent=None):
        ZoneRec.__init__(self,f0,f1,f2,f3,f4,parent);
    def getShortHost(self):
        return self.getHost().split(".")[0]
    def setHost(self,host):
        self.f3=host
    def getName(self):
        return self.getShortHost()
    def getHost(self):
        if self.f3 =="":
            return "@"
        else:
            return self.f3
    def setServedBy(self,served_by):
        self.f0=served_by
    def getServedBy(self):
        return self.f0
        
class Zone(Conf):
    def __init__(self,name,filename):
	Conf.__init__(self, filename, commenttype=';' )
        self.SOA=SOA()
        self.name=name
        self.read()
        self.rewind()
        self.zones=[]
        self.origins=[]
        current=[]
        prev_val="@"
        self.zoneList=ZoneList()
        while self.findnextcodeline():
            line=self.getline().split(";")
            val=line[0].split()
            self.nextline()
            if val[0]=="$ORIGIN":
                self.origins.append(val)
                continue
            if val[0]=="$TTL":
                self.TTL=val[1]
                continue
            if "SOA" in val:
                i=val.index("SOA")
                if val[0]!="IN":
                    pns=val[0]
                else:
                    pns=""
                self.SOA=SOA(pns,val[i+1],val[i+2])
                done=FALSE
                val=val[i+3:]
                SOAList=[]
                while not done:
                    if len(val)==0 or val[0]==";":
                        val=self.getline().split()
                        self.findnextcodeline()
                        self.nextline()
                    if val[0]==")":
                        done=TRUE
                        break
                    if val[0]=="(":
                        val=val[1:]
                        continue
                    SOAList.append(self.translate(val[0]))
                    if val[0][-1]==")":
                        done=TRUE
                        break
                    val=val[1:]

                self.SOA.load(SOAList)
                continue
            self.add(val)
            continue

    def translate(self,val):
        if val[-1]=="D":
            return atoi(val[:-1])*60*60*24
        if val[-1]=="H":
            return atoi(val[:-1])*60*60
        if val[-1]=="W":
            return atoi(val[:-1])*60*60*24*7
        if val[-1]=="M":
            return atoi(val[:-1])*60
        if val[-1]==")":
            return atoi(val[:-1])
        return atoi(val)
    def unlink(self,rec):
        if isinstance(rec,A):
            self.zoneList.remove(rec)
        if isinstance(rec,NS):
            self.zoneList.remove(rec)
        if isinstance(rec,MX):
            self.zoneList.remove(rec)
        if isinstance(rec,PTR):
            self.zoneList.remove(rec)
        if isinstance(rec,CNAME):
            self.zoneList.remove(rec)
            
    def getName(self):
        return self.name
    def isReverse(self):
        x=self.filename.split(".in-addr.arpa")
        return self.filename!=x[0]

    def addPTR(self,name,ip):
        if self.isReverse():
            for i in self.getPTRList():
                if i.getIp()==ip:
                    raise TestError, _("%s already exists") % ip
            else:
                self.zoneList.append(PTR(ip,"","PTR",name,"",self))
    def addA(self,name,ip):
            for a in self.getAList():
                if a.getName()==name:
                    raise TestError, _("%s already exists") % name
            else:
                self.zoneList.append(A(name,"","A",ip,"",self))
        
    def addMX(self,name,priority,server):
            for a in self.getMXList():
                if a.getName()==name:
                    raise TestError, _("%s already exists") % name
            else:
                self.zoneList.append(MX(name,"","MX",priority,server,self))
        
    def addCNAME(self,alias,host):
            for a in self.getCNAMEList():
                if a.getAlias()==alias:
                    raise TestError, _("%s already exists") % alias
            else:
                self.zoneList.append(CNAME(alias,"","CNAME",host,"",self))
        
    def modify(self,name,ip):
        if self.isReverse():
            for i in self.getPtrList():
                if i.getIp()==ip:
                    self.addPTR(ip,name)
                    return
            raise TestError, _("%s does not exist") % ip
        else:
            for n in self.getAList():
                if n.getName()==name:
                    val=[name,ip]
                    self.add(val)
                    return
            raise TestError, _("%s does not exist") % name
        
    def setServer(self,server):
        self.SOA.setServer(server)
    def getServer(self):
        return self.SOA.getServer()

    def setContact(self,contact):
        self.SOA.setContact(contact)
    def getContact(self):
        return self.SOA.getContact()

    def getSOA(self):
        return self.SOA

    def getTTL(self):
        return self.SOA.getTTL()
    def setTTL(self,ttl):
        self.setTTL(ttl)

    def getRefresh(self):
        return self.SOA.getRefresh()
    def setRefresh(self,refresh):
        self.SOA.setRefresh(refresh)

    def getRetry(self):
        return self.SOA.getRetry()
    def setRetry(self,retry):
        self.SOA.setRetry(retry)

    def getExpire(self):
        return self.SOA.getExpire()
    def setExpire(self,expire):
        self.SOA.setExpire(expire)

    def getSerial(self):
        return self.SOA.getSerial()
    def setSerial(self,i):
        self.SOA.setSerial(i)
    def incSerial(self):
        self.SOA.incSerial()

    def getNSList(self):
        return self.zoneList.getList("NS")

    def addNS(self,host,served_by):
        self.zoneList.append(NS(served_by,"","NS",host,"",self))

    def getNS(self,val):
        return self.NS

    def getAList(self):
        return self.zoneList.getList("A")

    def add(self,val):
        ttl=""
        name=""
        ctr=0

        if val[0] != "IN":
            name=val[0]
            if val[1] != "IN":
                ttl=val[1]

        for i in val:
            ctr=ctr+1
            if i == "IN":
                val=val[ctr:]
                break
        if len(val)!=3:
            val.append("")
            
        if "A" in val:
            a=A(name,ttl,val[0], val[1],val[2],self)
        elif "MX" in val:
            a=MX(name,ttl,val[0], val[1],val[2],self)
        elif "PTR" in val:
            a=PTR(name,ttl,val[0], val[1],val[2],self)
        elif "NS" in val:
            a=NS(name,ttl,val[0], val[1],val[2],self)
        elif "CNAME" in val:
            a=CNAME(name,ttl,val[0], val[1],val[2],self)
        elif "HINFO" in val:
            a=HINFO(name,ttl,val[0], val[1],val[2],self)
        elif "SRV" in val:
            a=SRV(name,ttl,val[0], val[1],val[2],self)
        else:
            a=ZoneRec(name,ttl,val[0], val[1],val[2],self)
        self.zoneList.append(a)
                      
    def getA(self,host):
        for i in self.getAList():
            if i.getHost()==host:
                return i
        return None

    def getPTRList(self):
        return self.zoneList.getList("PTR")

    def getMXList(self):
        return self.zoneList.getList("MX")

    def getMX(self,name):
        for i in self.getMXList():
            if i.getName()==name:
                return i
        return None

    def getCNAMEList(self):
        return self.zoneList.getList("CNAME")

    def getCNAME(self,val):
        for c in self.zoneList.getList("CNAME"):
            if c.getHost()==val:
                return c
        return None

    def getSRVList(self):
        return self.zoneList.getList("SRV")

    def getSRV(self,val):
        for c in self.SRV:
            if c.getService()==val:
                return c
        return None

    def save(self,uid,gid):
        self.incSerial()
        tmpFile="%s.%d" % (self.filename,os.getpid())
        fd = open(tmpFile,"w")
        fd.write(self.out())
        fd.close()
        os.rename(tmpFile,self.filename)
        os.chown(self.filename,uid,gid)

    def findnextcommentline(self):
        # optional whitespace followed by non-comment character
        # defines a codeline.  blank lines, lines with only whitespace,
        # and comment lines do not count.
        return self.findnextline('^[' + self.commenttype + ']+')
    
    def outComments(self):
        self.read()
        self.rewind()
        ret=""
        while self.findnextcommentline():
            val=self.getline()
            ret="%s%s\n" % (ret,val)
            self.nextline()
        return ret

        
    def out(self):
        ret=self.outComments()
        for origin in self.origins:
            ret="%s%s\n" % (ret,join(origin,' '))
        ret="%s$TTL %s\n" % (ret,self.getTTL())
        ret="%s%s\n" % (ret,self.SOA.out())
        for i in self.zoneList.getList():
            ret="%s%s\n" % (ret,i.out())
        return ret

def Usage():
    print _("Zone - Python named zone commandline tool\n\nUsage: zone [-a|-m|-d|-l] [-h] [-i <IP> ] [-n <Name> ] name")

if __name__ == "__main__":
    import signal
    import getopt

    if os.getuid() != 0:
        print _("Please restart %s with root permissions!") % (sys.argv[0])
        sys.exit(10)
        
    if len(sys.argv) == 1:
        Usage()
        sys.exit(1)
        
    cmdline = sys.argv[1:]
    sys.argv = sys.argv[:1]

    signal.signal (signal.SIGINT, signal.SIG_DFL)
    class BadUsage: pass
    
    zoneName = cmdline[-1]
    ip=None
    name=None
    type="A"
    add=FALSE
    modify=FALSE
    delete=FALSE
    progname = os.path.basename(sys.argv[0])

    try:
        opts, args = getopt.getopt(cmdline, "amdlhi:n:t:", ["add","modify","delete", "help","ip=","name=","list","type"])
        for opt, val in opts:
            
            if opt == '-a' or opt == '--add':
                add=TRUE

            if opt == '-m' or opt == '--modify':
                modify=TRUE

            if opt == '-d' or opt == '--delete':
                delete=TRUE

            if opt == '-i' or opt == '--ip':
                ip=val

            if opt == '-p' or opt == '--protocol':
                protocol=val

            if opt == '-n' or opt == '--name':
                name=val

            if opt == '-t' or opt == '--type':
                type=val

            if opt == '-l' or opt == '--list':
                z=Zone(zoneName,zoneName)
                print z.out()
                sys.exit(0)

            if opt == '-h' or opt == '--help':
                Usage()
                sys.exit(0)

    except (getopt.error, BadUsage):
        Usage()
        sys.exit(1)

    if (add and delete) or (add and modify) or (modify and delete):
        Usage()
        sys.exit(1)

    try:
        z=Zone(zoneName, zoneName)
        if delete:
            z.delete(name)
            sys.exit(0)

        if add or modify:
            if add:
                z.add(name,ip)

            if modify:
                z.modify(name,ip)

            z.save()
            sys.exit(0)

        found=FALSE
        print z.getOpt(name)
                
    except ValueError,e:
        print _("Error: %s") % e
        sys.exit(1)
        
