##
# Copyright (c) 2008 Guido Guenther <agx@sigxcpu.org>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##

"""
NSS Directory service interfaces.

Uses libc's Name Service Switch for user and groups (/etc/nsswitch.conf).
"""

__all__ = [
    "NssDirectoryService",
]

from twistedcaldav.directory.cachingdirectory import CachingDirectoryService,\
    CachingDirectoryRecord
from twisted.cred.credentials import UsernamePassword
from twistedcaldav.scheduling.cuaddress import normalizeCUAddr
from twisted.python import log
import pwd, grp, socket
import PAM

class NsSwitch(object):
    """Simple interface to the nsswitch calls"""

    def get_user(self, username):
        try:
            return pwd.getpwnam(username)
        except KeyError:
            return None

    def get_group(self, groupname):
        try:
            return grp.getgrnam(groupname)
        except KeyError:
            return None

    def get_users(self):
        return pwd.getpwall()

    def get_groups(self):
        return grp.getgrall()


class NssDirectoryService(CachingDirectoryService):
    """
    Nss based Directory Service of L{IDirectoryService}
    """

    baseGUID = "8EFFFAF1-5221-4813-B971-58506B963573"

    def __repr__(self):
        return "<%s %r>" % (self.__class__.__name__, self.realmName)

    def __init__(self, params):
        """
        @param params: a dictionary containing the following keys:
            cacheTimeout, realmName, groupPrefix, mailDomain, firstValidUid,
            lastValidUid, firstValidGid, lastValidGid
        """

        defaults = {
            "realmName": "Test Realm",
            # we only consider groups starting with:
            "groupPrefix": "caldavd-",
            # dont set calendarUserAdresses by default
            "mailDomain": None,
            # exclude system users and nobody by "default":
            "firstValidUid": 1000,
            "lastValidUid": 65533,
            "firstValidGid": 1000,
            "lastValidGid": 65533,
            "cacheTimeout": 1, # Minutes
            "augmentService" : None,
            "groupMembershipCache" : None,
        }
        ignored = None
        params = self.getParams(params, defaults, ignored)

        super(NssDirectoryService, self).__init__(params['cacheTimeout'])

        self.nsswitch = NsSwitch()
        self.realmName = params["realmName"]
        self.mailDomain = params["mailDomain"]
        self.groupPrefix = params["groupPrefix"]
        self.first_valid_uid = params["firstValidUid"]
        self.first_valid_gid = params["firstValidGid"]
        self.last_valid_uid = params["lastValidUid"]
        self.last_valid_gid = params["lastValidGid"]
        self.augmentService = params["augmentService"]
        self.groupMembershipCache = params["groupMembershipCache"]

    def recordTypes(self):
        recordTypes = (
            self.recordType_users,
            self.recordType_groups,
        )
        return recordTypes

    def _isValidUid(self, uid):
        if uid >= self.first_valid_uid and uid <= self.last_valid_uid:
            return True

    def _isValidGid(self, gid):
        if gid >= self.first_valid_gid and gid <= self.last_valid_gid:
            return True

    def queryDirectory(self, recordTypes, indexType, indexKey):
        self.log_debug("Querying directory for recordTypes %s, "
                       "indexType %s and indexKey %s" %
                       (recordTypes, indexType, indexKey),
                       system="NssDirectoryService")

        def _recordWithGUID(recordType, guid):
            # Code has to be written to query on GUID
            pass

        def _recordWithShortName(recordType, shortName):
            record = None
            if recordType == self.recordType_users:
                result = self.nsswitch.get_user(shortName)
                if result and self._isValidUid(result[2]):
                    record = NssUserRecord(
                                service = self,
                                userName = result[0],
                                gecos = result[4],
                             )
            elif recordType == self.recordType_groups:
                result = self.nsswitch.get_group(self.groupPrefix + shortName)
                if result and self._isValidGid(result[2]):
                    record = NssGroupRecord(
                                service = self,
                                groupName = result[0],
                                members = result[3]
                             )
            return record

        for recordType in recordTypes:
            record = None
            if indexType == self.INDEX_TYPE_GUID:
                record = _recordWithGUID(recordType, indexKey)
            elif indexType == self.INDEX_TYPE_SHORTNAME:
                record = _recordWithShortName(recordType, indexKey)
            elif indexType == self.INDEX_TYPE_CUA:
                address = normalizeCUAddr(indexKey)
                if address.startswith("urn:uuid:"):
                    guid = address[9:]
                    record = _recordWithGUID(recordType, guid)
                elif address.startswith("mailto:") and \
                     address.endswith("@"+self.mailDomain):
                    shortName = address[7:].partition("@")[0]
                    record = _recordWithShortName(recordType, shortName)
            elif indexType == self.INDEX_TYPE_AUTHID:
                pass

            if record:
                self.recordCacheForType(recordType).addRecord(
                        record, indexType, indexKey
                )

                # We got a match, so don't bother checking other types
                break


class NssDirectoryRecord(CachingDirectoryRecord):
    """
    Nss Directory Record
    """
    def __init__(self, service, recordType, shortNames,
                 fullName=None, emailAddresses=set(),
                 enabledForCalendaring=None,
                 enabledForAddressBooks=None,
                 enabledForLogin=True
             ):
        super(NssDirectoryRecord, self).__init__(
            service               = service,
            recordType            = recordType,
            guid                  = None,
            shortNames            = shortNames,
            fullName              = fullName,
            emailAddresses        = emailAddresses,
            enabledForCalendaring = enabledForCalendaring,
            enabledForAddressBooks= enabledForAddressBooks,
            enabledForLogin       = enabledForLogin,
        )


class NssUserRecord(NssDirectoryRecord):
    """
    NSS Users implementation of L{IDirectoryRecord}.
    """
    def __init__(self, service, userName, gecos):
        recordType = service.recordType_users
        shortNames = (userName,)
        fullName = gecos.split(",",1)[0]
        emailAddresses = set()
        if service.mailDomain:
            emailAddresses.add("%s@%s" % (userName, service.mailDomain))
        super(NssUserRecord, self).__init__(service, recordType, shortNames,
                                            fullName=fullName,
                                            emailAddresses=emailAddresses,
                                            enabledForCalendaring=True,
                                            enabledForAddressBooks=True)

    def groups(self):
        for result in self.service.nsswitch.get_groups():
            if self.service._isValidGid(result[2]) and \
               result[0].startswith(self.service.groupPrefix) and \
               self.shortNames[0] in result[3]:
                yield self.service.recordWithShortName(
                        self.service.recordType_groups,
                        result[0].replace(self.service.groupPrefix,'',1)
                      )

    def verifyCredentials(self, credentials):
        if isinstance(credentials, UsernamePassword):
            # Check that the username supplied matches the shortName
            # (The DCS might already enforce this constraint, not sure)
            if credentials.username not in self.shortNames:
                return False

            # Check cached password
            try:
                if credentials.password == self.password:
                    return True
            except AttributeError:
                pass

            # Authenticate against PAM
            def pam_conv(auth, query_list, userData):
                return [(credentials.password, 0)]

            auth = PAM.pam()
            auth.start("caldav")
            auth.set_item(PAM.PAM_USER, credentials.username)
            auth.set_item(PAM.PAM_CONV, pam_conv)
            try:
                auth.authenticate()
            except PAM.error, resp:
                return False
            else:
                # Cache the password to avoid future DS queries
                self.password = credentials.password
                return True

        return super(NssUserRecord, self).verifyCredentials(credentials)


class NssGroupRecord(NssDirectoryRecord):
    """
    NSS Groups implementation of L{IDirectoryRecord}.
    """
    def __init__(self, service, groupName, members=()):
        recordType = service.recordType_groups
        shortNames  = (groupName.replace(service.groupPrefix,'',1),)
        super(NssGroupRecord, self).__init__(service, recordType, shortNames,
                                             enabledForCalendaring=False,
                                             enabledForAddressBooks=False,
                                             enabledForLogin=False)
        self._members = members

    def members(self):
        for shortName in self._members:
            yield self.service.recordWithShortName(
                    self.service.recordType_users,
                    shortName
                  )

