Warning

This document is for an in-development version of Galaxy. You can alternatively view this page in the latest release if it exists or view the top of the latest release's documentation.

Source code for galaxy.openid.providers

"""
Contains OpenID provider functionality
"""
import logging
import os

import six

from galaxy.util import parse_xml, string_as_bool
from galaxy.util.odict import odict


log = logging.getLogger(__name__)

NO_PROVIDER_ID = 'None'
RESERVED_PROVIDER_IDS = [NO_PROVIDER_ID]


[docs]class OpenIDProvider(object): '''An OpenID Provider object.'''
[docs] @classmethod def from_file(cls, filename): return cls.from_elem(parse_xml(filename).getroot())
[docs] @classmethod def from_elem(cls, xml_root): provider_elem = xml_root provider_id = provider_elem.get('id', None) provider_name = provider_elem.get('name', provider_id) op_endpoint_url = provider_elem.find('op_endpoint_url') if op_endpoint_url is not None: op_endpoint_url = op_endpoint_url.text never_associate_with_user = string_as_bool(provider_elem.get('never_associate_with_user', 'False')) assert (provider_id and provider_name and op_endpoint_url), Exception("OpenID Provider improperly configured") assert provider_id not in RESERVED_PROVIDER_IDS, Exception('Specified OpenID Provider uses a reserved id: %s' % (provider_id)) sreg_required = [] sreg_optional = [] use_for = {} store_user_preference = {} use_default_sreg = True for elem in provider_elem.findall('sreg'): use_default_sreg = False for field_elem in elem.findall('field'): sreg_name = field_elem.get('name') assert sreg_name, Exception('A name is required for a sreg element') if string_as_bool(field_elem.get('required')): sreg_required.append(sreg_name) else: sreg_optional.append(sreg_name) for use_elem in field_elem.findall('use_for'): use_for[use_elem.get('name')] = sreg_name for store_user_preference_elem in field_elem.findall('store_user_preference'): store_user_preference[store_user_preference_elem.get('name')] = sreg_name if use_default_sreg: sreg_required = None sreg_optional = None use_for = None return cls(provider_id, provider_name, op_endpoint_url, sreg_required=sreg_required, sreg_optional=sreg_optional, use_for=use_for, store_user_preference=store_user_preference, never_associate_with_user=never_associate_with_user)
[docs] def __init__(self, id, name, op_endpoint_url, sreg_required=None, sreg_optional=None, use_for=None, store_user_preference=None, never_associate_with_user=None): '''When sreg options are not specified, defaults are used.''' self.id = id self.name = name self.op_endpoint_url = op_endpoint_url if sreg_optional is None: self.sreg_optional = ['nickname', 'email'] else: self.sreg_optional = sreg_optional if sreg_required: self.sreg_required = sreg_required else: self.sreg_required = [] if use_for is not None: self.use_for = use_for else: self.use_for = {} if 'nickname' in (self.sreg_optional + self.sreg_required): self.use_for['username'] = 'nickname' if 'email' in (self.sreg_optional + self.sreg_required): self.use_for['email'] = 'email' if store_user_preference: self.store_user_preference = store_user_preference else: self.store_user_preference = {} if never_associate_with_user: self.never_associate_with_user = True else: self.never_associate_with_user = False
[docs] def post_authentication(self, trans, openid_manager, info): sreg_attributes = openid_manager.get_sreg(info) for store_pref_name, store_pref_value_name in self.store_user_preference.items(): if store_pref_value_name in (self.sreg_optional + self.sreg_required): trans.user.preferences[store_pref_name] = sreg_attributes.get(store_pref_value_name) else: raise Exception('Only sreg is currently supported.') trans.sa_session.add(trans.user) trans.sa_session.flush()
[docs] def has_post_authentication_actions(self): return bool(self.store_user_preference)
[docs]class OpenIDProviders(object): '''Collection of OpenID Providers''' NO_PROVIDER_ID = NO_PROVIDER_ID
[docs] @classmethod def from_file(cls, filename): try: return cls.from_elem(parse_xml(filename).getroot()) except Exception as e: log.error('Failed to load OpenID Providers: %s' % (e)) return cls()
[docs] @classmethod def from_elem(cls, xml_root): oid_elem = xml_root providers = odict() for elem in oid_elem.findall('provider'): try: provider = OpenIDProvider.from_file(os.path.join('openid', elem.get('file'))) providers[provider.id] = provider log.debug('Loaded OpenID provider: %s (%s)' % (provider.name, provider.id)) except Exception as e: log.error('Failed to add OpenID provider: %s' % (e)) return cls(providers)
[docs] def __init__(self, providers=None): if providers: self.providers = providers else: self.providers = odict() self._banned_identifiers = [provider.op_endpoint_url for provider in self.providers.values() if provider.never_associate_with_user]
def __iter__(self): for provider in six.itervalues(self.providers): yield provider
[docs] def get(self, name, default=None): if name in self.providers: return self.providers[name] else: return default
[docs] def new_provider_from_identifier(self, identifier): return OpenIDProvider(None, identifier, identifier, never_associate_with_user=identifier in self._banned_identifiers)