Source code for galaxy.openid.providers

"""
Contains OpenID provider functionality
"""
import logging
import os

import six

from galaxy.util import parse_xml, string_as_bool
from galaxy.util.odict import odict


log = logging.getLogger(__name__)

NO_PROVIDER_ID = 'None'
RESERVED_PROVIDER_IDS = [NO_PROVIDER_ID]


[docs]class OpenIDProvider(object): '''An OpenID Provider object.'''
[docs] @classmethod def from_file(cls, filename): return cls.from_elem(parse_xml(filename).getroot())
[docs] @classmethod def from_elem(cls, xml_root): provider_elem = xml_root provider_id = provider_elem.get('id', None) provider_name = provider_elem.get('name', provider_id) op_endpoint_url = provider_elem.find('op_endpoint_url') if op_endpoint_url is not None: op_endpoint_url = op_endpoint_url.text never_associate_with_user = string_as_bool(provider_elem.get('never_associate_with_user', 'False')) assert (provider_id and provider_name and op_endpoint_url), Exception("OpenID Provider improperly configured") assert provider_id not in RESERVED_PROVIDER_IDS, Exception('Specified OpenID Provider uses a reserved id: %s' % (provider_id)) sreg_required = [] sreg_optional = [] use_for = {} store_user_preference = {} use_default_sreg = True for elem in provider_elem.findall('sreg'): use_default_sreg = False for field_elem in elem.findall('field'): sreg_name = field_elem.get('name') assert sreg_name, Exception('A name is required for a sreg element') if string_as_bool(field_elem.get('required')): sreg_required.append(sreg_name) else: sreg_optional.append(sreg_name) for use_elem in field_elem.findall('use_for'): use_for[use_elem.get('name')] = sreg_name for store_user_preference_elem in field_elem.findall('store_user_preference'): store_user_preference[store_user_preference_elem.get('name')] = sreg_name if use_default_sreg: sreg_required = None sreg_optional = None use_for = None return cls(provider_id, provider_name, op_endpoint_url, sreg_required=sreg_required, sreg_optional=sreg_optional, use_for=use_for, store_user_preference=store_user_preference, never_associate_with_user=never_associate_with_user)
[docs] def __init__(self, id, name, op_endpoint_url, sreg_required=None, sreg_optional=None, use_for=None, store_user_preference=None, never_associate_with_user=None): '''When sreg options are not specified, defaults are used.''' self.id = id self.name = name self.op_endpoint_url = op_endpoint_url if sreg_optional is None: self.sreg_optional = ['nickname', 'email'] else: self.sreg_optional = sreg_optional if sreg_required: self.sreg_required = sreg_required else: self.sreg_required = [] if use_for is not None: self.use_for = use_for else: self.use_for = {} if 'nickname' in (self.sreg_optional + self.sreg_required): self.use_for['username'] = 'nickname' if 'email' in (self.sreg_optional + self.sreg_required): self.use_for['email'] = 'email' if store_user_preference: self.store_user_preference = store_user_preference else: self.store_user_preference = {} if never_associate_with_user: self.never_associate_with_user = True else: self.never_associate_with_user = False
[docs] def post_authentication(self, trans, openid_manager, info): sreg_attributes = openid_manager.get_sreg(info) for store_pref_name, store_pref_value_name in self.store_user_preference.items(): if store_pref_value_name in (self.sreg_optional + self.sreg_required): trans.user.preferences[store_pref_name] = sreg_attributes.get(store_pref_value_name) else: raise Exception('Only sreg is currently supported.') trans.sa_session.add(trans.user) trans.sa_session.flush()
[docs] def has_post_authentication_actions(self): return bool(self.store_user_preference)
[docs]class OpenIDProviders(object): '''Collection of OpenID Providers''' NO_PROVIDER_ID = NO_PROVIDER_ID
[docs] @classmethod def from_file(cls, filename): try: return cls.from_elem(parse_xml(filename).getroot()) except Exception as e: log.error('Failed to load OpenID Providers: %s' % (e)) return cls()
[docs] @classmethod def from_elem(cls, xml_root): oid_elem = xml_root providers = odict() for elem in oid_elem.findall('provider'): try: provider = OpenIDProvider.from_file(os.path.join('openid', elem.get('file'))) providers[provider.id] = provider log.debug('Loaded OpenID provider: %s (%s)' % (provider.name, provider.id)) except Exception as e: log.error('Failed to add OpenID provider: %s' % (e)) return cls(providers)
[docs] def __init__(self, providers=None): if providers: self.providers = providers else: self.providers = odict() self._banned_identifiers = [provider.op_endpoint_url for provider in self.providers.values() if provider.never_associate_with_user]
def __iter__(self): for provider in six.itervalues(self.providers): yield provider
[docs] def get(self, name, default=None): if name in self.providers: return self.providers[name] else: return default
[docs] def new_provider_from_identifier(self, identifier): return OpenIDProvider(None, identifier, identifier, never_associate_with_user=identifier in self._banned_identifiers)