Initial commit

This commit is contained in:
2020-05-08 14:39:22 +01:00
commit 57828567af
1662 changed files with 248701 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python driver for MongoDB."""
ASCENDING = 1
"""Ascending sort order."""
DESCENDING = -1
"""Descending sort order."""
GEO2D = "2d"
"""Index specifier for a 2-dimensional `geospatial index`_.
.. _geospatial index: http://docs.mongodb.org/manual/core/2d/
"""
GEOHAYSTACK = "geoHaystack"
"""Index specifier for a 2-dimensional `haystack index`_.
.. versionadded:: 2.1
.. _haystack index: http://docs.mongodb.org/manual/core/geohaystack/
"""
GEOSPHERE = "2dsphere"
"""Index specifier for a `spherical geospatial index`_.
.. versionadded:: 2.5
.. _spherical geospatial index: http://docs.mongodb.org/manual/core/2dsphere/
"""
HASHED = "hashed"
"""Index specifier for a `hashed index`_.
.. versionadded:: 2.5
.. _hashed index: http://docs.mongodb.org/manual/core/index-hashed/
"""
TEXT = "text"
"""Index specifier for a `text index`_.
.. versionadded:: 2.7.1
.. _text index: http://docs.mongodb.org/manual/core/index-text/
"""
OFF = 0
"""No database profiling."""
SLOW_ONLY = 1
"""Only profile slow operations."""
ALL = 2
"""Profile all operations."""
version_tuple = (3, 9, 0)
def get_version_string():
if isinstance(version_tuple[-1], str):
return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1]
return '.'.join(map(str, version_tuple))
__version__ = version = get_version_string()
"""Current version of PyMongo."""
from pymongo.collection import ReturnDocument
from pymongo.common import (MIN_SUPPORTED_WIRE_VERSION,
MAX_SUPPORTED_WIRE_VERSION)
from pymongo.cursor import CursorType
from pymongo.mongo_client import MongoClient
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
from pymongo.operations import (IndexModel,
InsertOne,
DeleteOne,
DeleteMany,
UpdateOne,
UpdateMany,
ReplaceOne)
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
def has_c():
"""Is the C extension installed?"""
try:
from pymongo import _cmessage
return True
except ImportError:
return False

View File

@@ -0,0 +1,235 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Perform aggregation operations on a collection or database."""
from bson.son import SON
from pymongo import common
from pymongo.collation import validate_collation_or_none
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import ReadPreference
class _AggregationCommand(object):
"""The internal abstract base class for aggregation cursors.
Should not be called directly by application developers. Use
:meth:`pymongo.collection.Collection.aggregate`, or
:meth:`pymongo.database.Database.aggregate` instead.
"""
def __init__(self, target, cursor_class, pipeline, options,
explicit_session, user_fields=None, result_processor=None):
if "explain" in options:
raise ConfigurationError("The explain option is not supported. "
"Use Database.command instead.")
self._target = target
common.validate_list('pipeline', pipeline)
self._pipeline = pipeline
self._performs_write = False
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
self._performs_write = True
common.validate_is_mapping('options', options)
self._options = options
# This is the batchSize that will be used for setting the initial
# batchSize for the cursor, as well as the subsequent getMores.
self._batch_size = common.validate_non_negative_integer_or_none(
"batchSize", self._options.pop("batchSize", None))
# If the cursor option is already specified, avoid overriding it.
self._options.setdefault("cursor", {})
# If the pipeline performs a write, we ignore the initial batchSize
# since the server doesn't return results in this case.
if self._batch_size is not None and not self._performs_write:
self._options["cursor"]["batchSize"] = self._batch_size
self._cursor_class = cursor_class
self._explicit_session = explicit_session
self._user_fields = user_fields
self._result_processor = result_processor
self._collation = validate_collation_or_none(
options.pop('collation', None))
self._max_await_time_ms = options.pop('maxAwaitTimeMS', None)
@property
def _aggregation_target(self):
"""The argument to pass to the aggregate command."""
raise NotImplementedError
@property
def _cursor_namespace(self):
"""The namespace in which the aggregate command is run."""
raise NotImplementedError
@property
def _cursor_collection(self, cursor_doc):
"""The Collection used for the aggregate command cursor."""
raise NotImplementedError
@property
def _database(self):
"""The database against which the aggregation command is run."""
raise NotImplementedError
@staticmethod
def _check_compat(sock_info):
"""Check whether the server version in-use supports aggregation."""
pass
def _process_result(self, result, session, server, sock_info, slave_ok):
if self._result_processor:
self._result_processor(
result, session, server, sock_info, slave_ok)
def get_read_preference(self, session):
if self._performs_write:
return ReadPreference.PRIMARY
return self._target._read_preference_for(session)
def get_cursor(self, session, server, sock_info, slave_ok):
# Ensure command compatibility.
self._check_compat(sock_info)
# Serialize command.
cmd = SON([("aggregate", self._aggregation_target),
("pipeline", self._pipeline)])
cmd.update(self._options)
# Apply this target's read concern if:
# readConcern has not been specified as a kwarg and either
# - server version is >= 4.2 or
# - server version is >= 3.2 and pipeline doesn't use $out
if (('readConcern' not in cmd) and
((sock_info.max_wire_version >= 4 and
not self._performs_write) or
(sock_info.max_wire_version >= 8))):
read_concern = self._target.read_concern
else:
read_concern = None
# Apply this target's write concern if:
# writeConcern has not been specified as a kwarg and pipeline doesn't
# perform a write operation
if 'writeConcern' not in cmd and self._performs_write:
write_concern = self._target._write_concern_for(session)
else:
write_concern = None
# Run command.
result = sock_info.command(
self._database.name,
cmd,
slave_ok,
self.get_read_preference(session),
self._target.codec_options,
parse_write_concern_error=True,
read_concern=read_concern,
write_concern=write_concern,
collation=self._collation,
session=session,
client=self._database.client,
user_fields=self._user_fields)
self._process_result(result, session, server, sock_info, slave_ok)
# Extract cursor from result or mock/fake one if necessary.
if 'cursor' in result:
cursor = result['cursor']
else:
# Pre-MongoDB 2.6 or unacknowledged write. Fake a cursor.
cursor = {
"id": 0,
"firstBatch": result.get("result", []),
"ns": self._cursor_namespace,
}
# Create and return cursor instance.
return self._cursor_class(
self._cursor_collection(cursor), cursor, sock_info.address,
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session, explicit_session=self._explicit_session)
class _CollectionAggregationCommand(_AggregationCommand):
def __init__(self, *args, **kwargs):
# Pop additional option and initialize parent class.
use_cursor = kwargs.pop("use_cursor", True)
super(_CollectionAggregationCommand, self).__init__(*args, **kwargs)
# Remove the cursor document if the user has set use_cursor to False.
self._use_cursor = use_cursor
if not self._use_cursor:
self._options.pop("cursor", None)
@property
def _aggregation_target(self):
return self._target.name
@property
def _cursor_namespace(self):
return self._target.full_name
def _cursor_collection(self, cursor):
"""The Collection used for the aggregate command cursor."""
return self._target
@property
def _database(self):
return self._target.database
class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
def __init__(self, *args, **kwargs):
super(_CollectionRawAggregationCommand, self).__init__(*args, **kwargs)
# For raw-batches, we set the initial batchSize for the cursor to 0.
if self._use_cursor and not self._performs_write:
self._options["cursor"]["batchSize"] = 0
class _DatabaseAggregationCommand(_AggregationCommand):
@property
def _aggregation_target(self):
return 1
@property
def _cursor_namespace(self):
return "%s.$cmd.aggregate" % (self._target.name,)
@property
def _database(self):
return self._target
def _cursor_collection(self, cursor):
"""The Collection used for the aggregate command cursor."""
# Collection level aggregate may not always return the "ns" field
# according to our MockupDB tests. Let's handle that case for db level
# aggregate too by defaulting to the <db>.$cmd.aggregate namespace.
_, collname = cursor.get("ns", self._cursor_namespace).split(".", 1)
return self._database[collname]
@staticmethod
def _check_compat(sock_info):
# Older server version don't raise a descriptive error, so we raise
# one instead.
if not sock_info.max_wire_version >= 6:
err_msg = "Database.aggregate() is only supported on MongoDB 3.6+."
raise ConfigurationError(err_msg)

View File

@@ -0,0 +1,569 @@
# Copyright 2013-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Authentication helpers."""
import functools
import hashlib
import hmac
import os
import socket
try:
from urllib import quote
except ImportError:
from urllib.parse import quote
HAVE_KERBEROS = True
_USE_PRINCIPAL = False
try:
import winkerberos as kerberos
if tuple(map(int, kerberos.__version__.split('.')[:2])) >= (0, 5):
_USE_PRINCIPAL = True
except ImportError:
try:
import kerberos
except ImportError:
HAVE_KERBEROS = False
from base64 import standard_b64decode, standard_b64encode
from collections import namedtuple
from bson.binary import Binary
from bson.py3compat import string_type, _unicode, PY3
from bson.son import SON
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep
MECHANISMS = frozenset(
['GSSAPI',
'MONGODB-CR',
'MONGODB-X509',
'PLAIN',
'SCRAM-SHA-1',
'SCRAM-SHA-256',
'DEFAULT'])
"""The authentication mechanisms supported by PyMongo."""
class _Cache(object):
__slots__ = ("data",)
_hash_val = hash('_Cache')
def __init__(self):
self.data = None
def __eq__(self, other):
# Two instances must always compare equal.
if isinstance(other, _Cache):
return True
return NotImplemented
def __ne__(self, other):
if isinstance(other, _Cache):
return False
return NotImplemented
def __hash__(self):
return self._hash_val
MongoCredential = namedtuple(
'MongoCredential',
['mechanism',
'source',
'username',
'password',
'mechanism_properties',
'cache'])
"""A hashable namedtuple of values used for authentication."""
GSSAPIProperties = namedtuple('GSSAPIProperties',
['service_name',
'canonicalize_host_name',
'service_realm'])
"""Mechanism properties for GSSAPI authentication."""
def _build_credentials_tuple(mech, source, user, passwd, extra, database):
"""Build and return a mechanism specific credentials tuple.
"""
if mech != 'MONGODB-X509' and user is None:
raise ConfigurationError("%s requires a username." % (mech,))
if mech == 'GSSAPI':
if source is not None and source != '$external':
raise ValueError(
"authentication source must be $external or None for GSSAPI")
properties = extra.get('authmechanismproperties', {})
service_name = properties.get('SERVICE_NAME', 'mongodb')
canonicalize = properties.get('CANONICALIZE_HOST_NAME', False)
service_realm = properties.get('SERVICE_REALM')
props = GSSAPIProperties(service_name=service_name,
canonicalize_host_name=canonicalize,
service_realm=service_realm)
# Source is always $external.
return MongoCredential(mech, '$external', user, passwd, props, None)
elif mech == 'MONGODB-X509':
if passwd is not None:
raise ConfigurationError(
"Passwords are not supported by MONGODB-X509")
if source is not None and source != '$external':
raise ValueError(
"authentication source must be "
"$external or None for MONGODB-X509")
# user can be None.
return MongoCredential(mech, '$external', user, None, None, None)
elif mech == 'PLAIN':
source_database = source or database or '$external'
return MongoCredential(mech, source_database, user, passwd, None, None)
else:
source_database = source or database or 'admin'
if passwd is None:
raise ConfigurationError("A password is required.")
return MongoCredential(
mech, source_database, user, passwd, None, _Cache())
if PY3:
def _xor(fir, sec):
"""XOR two byte strings together (python 3.x)."""
return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)])
_from_bytes = int.from_bytes
_to_bytes = int.to_bytes
else:
from binascii import (hexlify as _hexlify,
unhexlify as _unhexlify)
def _xor(fir, sec):
"""XOR two byte strings together (python 2.x)."""
return b"".join([chr(ord(x) ^ ord(y)) for x, y in zip(fir, sec)])
def _from_bytes(value, dummy, _int=int, _hexlify=_hexlify):
"""An implementation of int.from_bytes for python 2.x."""
return _int(_hexlify(value), 16)
def _to_bytes(value, length, dummy, _unhexlify=_unhexlify):
"""An implementation of int.to_bytes for python 2.x."""
fmt = '%%0%dx' % (2 * length,)
return _unhexlify(fmt % value)
try:
# The fastest option, if it's been compiled to use OpenSSL's HMAC.
from backports.pbkdf2 import pbkdf2_hmac as _hi
except ImportError:
try:
# Python 2.7.8+, or Python 3.4+.
from hashlib import pbkdf2_hmac as _hi
except ImportError:
def _hi(hash_name, data, salt, iterations):
"""A simple implementation of PBKDF2-HMAC."""
mac = hmac.HMAC(data, None, getattr(hashlib, hash_name))
def _digest(msg, mac=mac):
"""Get a digest for msg."""
_mac = mac.copy()
_mac.update(msg)
return _mac.digest()
from_bytes = _from_bytes
to_bytes = _to_bytes
_u1 = _digest(salt + b'\x00\x00\x00\x01')
_ui = from_bytes(_u1, 'big')
for _ in range(iterations - 1):
_u1 = _digest(_u1)
_ui ^= from_bytes(_u1, 'big')
return to_bytes(_ui, mac.digest_size, 'big')
try:
from hmac import compare_digest
except ImportError:
if PY3:
def _xor_bytes(a, b):
return a ^ b
else:
def _xor_bytes(a, b, _ord=ord):
return _ord(a) ^ _ord(b)
# Python 2.x < 2.7.7
# Note: This method is intentionally obtuse to prevent timing attacks. Do
# not refactor it!
# References:
# - http://bugs.python.org/issue14532
# - http://bugs.python.org/issue14955
# - http://bugs.python.org/issue15061
def compare_digest(a, b, _xor_bytes=_xor_bytes):
left = None
right = b
if len(a) == len(b):
left = a
result = 0
if len(a) != len(b):
left = b
result = 1
for x, y in zip(left, right):
result |= _xor_bytes(x, y)
return result == 0
def _parse_scram_response(response):
"""Split a scram response into key, value pairs."""
return dict(item.split(b"=", 1) for item in response.split(b","))
def _authenticate_scram(credentials, sock_info, mechanism):
"""Authenticate using SCRAM."""
username = credentials.username
if mechanism == 'SCRAM-SHA-256':
digest = "sha256"
digestmod = hashlib.sha256
data = saslprep(credentials.password).encode("utf-8")
else:
digest = "sha1"
digestmod = hashlib.sha1
data = _password_digest(username, credentials.password).encode("utf-8")
source = credentials.source
cache = credentials.cache
# Make local
_hmac = hmac.HMAC
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
nonce = standard_b64encode(os.urandom(32))
first_bare = b"n=" + user + b",r=" + nonce
cmd = SON([('saslStart', 1),
('mechanism', mechanism),
('payload', Binary(b"n,," + first_bare)),
('autoAuthorize', 1)])
res = sock_info.command(source, cmd)
server_first = res['payload']
parsed = _parse_scram_response(server_first)
iterations = int(parsed[b'i'])
if iterations < 4096:
raise OperationFailure("Server returned an invalid iteration count.")
salt = parsed[b's']
rnonce = parsed[b'r']
if not rnonce.startswith(nonce):
raise OperationFailure("Server returned an invalid nonce.")
without_proof = b"c=biws,r=" + rnonce
if cache.data:
client_key, server_key, csalt, citerations = cache.data
else:
client_key, server_key, csalt, citerations = None, None, None, None
# Salt and / or iterations could change for a number of different
# reasons. Either changing invalidates the cache.
if not client_key or salt != csalt or iterations != citerations:
salted_pass = _hi(
digest, data, standard_b64decode(salt), iterations)
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
cache.data = (client_key, server_key, salt, iterations)
stored_key = digestmod(client_key).digest()
auth_msg = b",".join((first_bare, server_first, without_proof))
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
client_final = b",".join((without_proof, client_proof))
server_sig = standard_b64encode(
_hmac(server_key, auth_msg, digestmod).digest())
cmd = SON([('saslContinue', 1),
('conversationId', res['conversationId']),
('payload', Binary(client_final))])
res = sock_info.command(source, cmd)
parsed = _parse_scram_response(res['payload'])
if not compare_digest(parsed[b'v'], server_sig):
raise OperationFailure("Server returned an invalid signature.")
# Depending on how it's configured, Cyrus SASL (which the server uses)
# requires a third empty challenge.
if not res['done']:
cmd = SON([('saslContinue', 1),
('conversationId', res['conversationId']),
('payload', Binary(b''))])
res = sock_info.command(source, cmd)
if not res['done']:
raise OperationFailure('SASL conversation failed to complete.')
def _password_digest(username, password):
"""Get a password digest to use for authentication.
"""
if not isinstance(password, string_type):
raise TypeError("password must be an "
"instance of %s" % (string_type.__name__,))
if len(password) == 0:
raise ValueError("password can't be empty")
if not isinstance(username, string_type):
raise TypeError("password must be an "
"instance of %s" % (string_type.__name__,))
md5hash = hashlib.md5()
data = "%s:mongo:%s" % (username, password)
md5hash.update(data.encode('utf-8'))
return _unicode(md5hash.hexdigest())
def _auth_key(nonce, username, password):
"""Get an auth key to use for authentication.
"""
digest = _password_digest(username, password)
md5hash = hashlib.md5()
data = "%s%s%s" % (nonce, username, digest)
md5hash.update(data.encode('utf-8'))
return _unicode(md5hash.hexdigest())
def _authenticate_gssapi(credentials, sock_info):
"""Authenticate using GSSAPI.
"""
if not HAVE_KERBEROS:
raise ConfigurationError('The "kerberos" module must be '
'installed to use GSSAPI authentication.')
try:
username = credentials.username
password = credentials.password
props = credentials.mechanism_properties
# Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph.
host = sock_info.address[0]
if props.canonicalize_host_name:
host = socket.getfqdn(host)
service = props.service_name + '@' + host
if props.service_realm is not None:
service = service + '@' + props.service_realm
if password is not None:
if _USE_PRINCIPAL:
# Note that, though we use unquote_plus for unquoting URI
# options, we use quote here. Microsoft's UrlUnescape (used
# by WinKerberos) doesn't support +.
principal = ":".join((quote(username), quote(password)))
result, ctx = kerberos.authGSSClientInit(
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
else:
if '@' in username:
user, domain = username.split('@', 1)
else:
user, domain = username, None
result, ctx = kerberos.authGSSClientInit(
service, gssflags=kerberos.GSS_C_MUTUAL_FLAG,
user=user, domain=domain, password=password)
else:
result, ctx = kerberos.authGSSClientInit(
service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
if result != kerberos.AUTH_GSS_COMPLETE:
raise OperationFailure('Kerberos context failed to initialize.')
try:
# pykerberos uses a weird mix of exceptions and return values
# to indicate errors.
# 0 == continue, 1 == complete, -1 == error
# Only authGSSClientStep can return 0.
if kerberos.authGSSClientStep(ctx, '') != 0:
raise OperationFailure('Unknown kerberos '
'failure in step function.')
# Start a SASL conversation with mongod/s
# Note: pykerberos deals with base64 encoded byte strings.
# Since mongo accepts base64 strings as the payload we don't
# have to use bson.binary.Binary.
payload = kerberos.authGSSClientResponse(ctx)
cmd = SON([('saslStart', 1),
('mechanism', 'GSSAPI'),
('payload', payload),
('autoAuthorize', 1)])
response = sock_info.command('$external', cmd)
# Limit how many times we loop to catch protocol / library issues
for _ in range(10):
result = kerberos.authGSSClientStep(ctx,
str(response['payload']))
if result == -1:
raise OperationFailure('Unknown kerberos '
'failure in step function.')
payload = kerberos.authGSSClientResponse(ctx) or ''
cmd = SON([('saslContinue', 1),
('conversationId', response['conversationId']),
('payload', payload)])
response = sock_info.command('$external', cmd)
if result == kerberos.AUTH_GSS_COMPLETE:
break
else:
raise OperationFailure('Kerberos '
'authentication failed to complete.')
# Once the security context is established actually authenticate.
# See RFC 4752, Section 3.1, last two paragraphs.
if kerberos.authGSSClientUnwrap(ctx,
str(response['payload'])) != 1:
raise OperationFailure('Unknown kerberos '
'failure during GSS_Unwrap step.')
if kerberos.authGSSClientWrap(ctx,
kerberos.authGSSClientResponse(ctx),
username) != 1:
raise OperationFailure('Unknown kerberos '
'failure during GSS_Wrap step.')
payload = kerberos.authGSSClientResponse(ctx)
cmd = SON([('saslContinue', 1),
('conversationId', response['conversationId']),
('payload', payload)])
sock_info.command('$external', cmd)
finally:
kerberos.authGSSClientClean(ctx)
except kerberos.KrbError as exc:
raise OperationFailure(str(exc))
def _authenticate_plain(credentials, sock_info):
"""Authenticate using SASL PLAIN (RFC 4616)
"""
source = credentials.source
username = credentials.username
password = credentials.password
payload = ('\x00%s\x00%s' % (username, password)).encode('utf-8')
cmd = SON([('saslStart', 1),
('mechanism', 'PLAIN'),
('payload', Binary(payload)),
('autoAuthorize', 1)])
sock_info.command(source, cmd)
def _authenticate_cram_md5(credentials, sock_info):
"""Authenticate using CRAM-MD5 (RFC 2195)
"""
source = credentials.source
username = credentials.username
password = credentials.password
# The password used as the mac key is the
# same as what we use for MONGODB-CR
passwd = _password_digest(username, password)
cmd = SON([('saslStart', 1),
('mechanism', 'CRAM-MD5'),
('payload', Binary(b'')),
('autoAuthorize', 1)])
response = sock_info.command(source, cmd)
# MD5 as implicit default digest for digestmod is deprecated
# in python 3.4
mac = hmac.HMAC(key=passwd.encode('utf-8'), digestmod=hashlib.md5)
mac.update(response['payload'])
challenge = username.encode('utf-8') + b' ' + mac.hexdigest().encode('utf-8')
cmd = SON([('saslContinue', 1),
('conversationId', response['conversationId']),
('payload', Binary(challenge))])
sock_info.command(source, cmd)
def _authenticate_x509(credentials, sock_info):
"""Authenticate using MONGODB-X509.
"""
query = SON([('authenticate', 1),
('mechanism', 'MONGODB-X509')])
if credentials.username is not None:
query['user'] = credentials.username
elif sock_info.max_wire_version < 5:
raise ConfigurationError(
"A username is required for MONGODB-X509 authentication "
"when connected to MongoDB versions older than 3.4.")
sock_info.command('$external', query)
def _authenticate_mongo_cr(credentials, sock_info):
"""Authenticate using MONGODB-CR.
"""
source = credentials.source
username = credentials.username
password = credentials.password
# Get a nonce
response = sock_info.command(source, {'getnonce': 1})
nonce = response['nonce']
key = _auth_key(nonce, username, password)
# Actually authenticate
query = SON([('authenticate', 1),
('user', username),
('nonce', nonce),
('key', key)])
sock_info.command(source, query)
def _authenticate_default(credentials, sock_info):
if sock_info.max_wire_version >= 7:
source = credentials.source
cmd = SON([
('ismaster', 1),
('saslSupportedMechs', source + '.' + credentials.username)])
mechs = sock_info.command(
source, cmd, publish_events=False).get('saslSupportedMechs', [])
if 'SCRAM-SHA-256' in mechs:
return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-256')
else:
return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1')
elif sock_info.max_wire_version >= 3:
return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1')
else:
return _authenticate_mongo_cr(credentials, sock_info)
_AUTH_MAP = {
'CRAM-MD5': _authenticate_cram_md5,
'GSSAPI': _authenticate_gssapi,
'MONGODB-CR': _authenticate_mongo_cr,
'MONGODB-X509': _authenticate_x509,
'PLAIN': _authenticate_plain,
'SCRAM-SHA-1': functools.partial(
_authenticate_scram, mechanism='SCRAM-SHA-1'),
'SCRAM-SHA-256': functools.partial(
_authenticate_scram, mechanism='SCRAM-SHA-256'),
'DEFAULT': _authenticate_default,
}
def authenticate(credentials, sock_info):
"""Authenticate sock_info."""
mechanism = credentials.mechanism
auth_func = _AUTH_MAP.get(mechanism)
auth_func(credentials, sock_info)
def logout(source, sock_info):
"""Log out from a database."""
sock_info.command(source, {'logout': 1})

View File

@@ -0,0 +1,691 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The bulk write operations interface.
.. versionadded:: 2.7
"""
import copy
from itertools import islice
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo.client_session import _validate_session_write_concern
from pymongo.common import (validate_is_mapping,
validate_is_document_type,
validate_ok_for_replace,
validate_ok_for_update)
from pymongo.helpers import _RETRYABLE_ERROR_CODES
from pymongo.collation import validate_collation_or_none
from pymongo.errors import (BulkWriteError,
ConfigurationError,
InvalidOperation,
OperationFailure)
from pymongo.message import (_INSERT, _UPDATE, _DELETE,
_do_batched_insert,
_randint,
_BulkWriteContext,
_EncryptedBulkWriteContext)
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
_DELETE_ALL = 0
_DELETE_ONE = 1
# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err
_BAD_VALUE = 2
_UNKNOWN_ERROR = 8
_WRITE_CONCERN_ERROR = 64
_COMMANDS = ('insert', 'update', 'delete')
# These string literals are used when we create fake server return
# documents client side. We use unicode literals in python 2.x to
# match the actual return values from the server.
_UOP = u"op"
class _Run(object):
"""Represents a batch of write operations.
"""
def __init__(self, op_type):
"""Initialize a new Run object.
"""
self.op_type = op_type
self.index_map = []
self.ops = []
self.idx_offset = 0
def index(self, idx):
"""Get the original index of an operation in this run.
:Parameters:
- `idx`: The Run index that maps to the original index.
"""
return self.index_map[idx]
def add(self, original_index, operation):
"""Add an operation to this Run instance.
:Parameters:
- `original_index`: The original index of this operation
within a larger bulk operation.
- `operation`: The operation document.
"""
self.index_map.append(original_index)
self.ops.append(operation)
def _merge_command(run, full_result, offset, result):
"""Merge a write command result into the full bulk result.
"""
affected = result.get("n", 0)
if run.op_type == _INSERT:
full_result["nInserted"] += affected
elif run.op_type == _DELETE:
full_result["nRemoved"] += affected
elif run.op_type == _UPDATE:
upserted = result.get("upserted")
if upserted:
n_upserted = len(upserted)
for doc in upserted:
doc["index"] = run.index(doc["index"] + offset)
full_result["upserted"].extend(upserted)
full_result["nUpserted"] += n_upserted
full_result["nMatched"] += (affected - n_upserted)
else:
full_result["nMatched"] += affected
full_result["nModified"] += result["nModified"]
write_errors = result.get("writeErrors")
if write_errors:
for doc in write_errors:
# Leave the server response intact for APM.
replacement = doc.copy()
idx = doc["index"] + offset
replacement["index"] = run.index(idx)
# Add the failed operation to the error document.
replacement[_UOP] = run.ops[idx]
full_result["writeErrors"].append(replacement)
wc_error = result.get("writeConcernError")
if wc_error:
full_result["writeConcernErrors"].append(wc_error)
def _raise_bulk_write_error(full_result):
"""Raise a BulkWriteError from the full bulk api result.
"""
if full_result["writeErrors"]:
full_result["writeErrors"].sort(
key=lambda error: error["index"])
raise BulkWriteError(full_result)
class _Bulk(object):
"""The private guts of the bulk write API.
"""
def __init__(self, collection, ordered, bypass_document_validation):
"""Initialize a _Bulk instance.
"""
self.collection = collection.with_options(
codec_options=collection.codec_options._replace(
unicode_decode_error_handler='replace',
document_class=dict))
self.ordered = ordered
self.ops = []
self.executed = False
self.bypass_doc_val = bypass_document_validation
self.uses_collation = False
self.uses_array_filters = False
self.is_retryable = True
self.retrying = False
self.started_retryable_write = False
# Extra state so that we know where to pick up on a retry attempt.
self.current_run = None
@property
def bulk_ctx_class(self):
encrypter = self.collection.database.client._encrypter
if encrypter and not encrypter._bypass_auto_encryption:
return _EncryptedBulkWriteContext
else:
return _BulkWriteContext
def add_insert(self, document):
"""Add an insert document to the list of ops.
"""
validate_is_document_type("document", document)
# Generate ObjectId client side.
if not (isinstance(document, RawBSONDocument) or '_id' in document):
document['_id'] = ObjectId()
self.ops.append((_INSERT, document))
def add_update(self, selector, update, multi=False, upsert=False,
collation=None, array_filters=None):
"""Create an update document and add it to the list of ops.
"""
validate_ok_for_update(update)
cmd = SON([('q', selector), ('u', update),
('multi', multi), ('upsert', upsert)])
collation = validate_collation_or_none(collation)
if collation is not None:
self.uses_collation = True
cmd['collation'] = collation
if array_filters is not None:
self.uses_array_filters = True
cmd['arrayFilters'] = array_filters
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append((_UPDATE, cmd))
def add_replace(self, selector, replacement, upsert=False,
collation=None):
"""Create a replace document and add it to the list of ops.
"""
validate_ok_for_replace(replacement)
cmd = SON([('q', selector), ('u', replacement),
('multi', False), ('upsert', upsert)])
collation = validate_collation_or_none(collation)
if collation is not None:
self.uses_collation = True
cmd['collation'] = collation
self.ops.append((_UPDATE, cmd))
def add_delete(self, selector, limit, collation=None):
"""Create a delete document and add it to the list of ops.
"""
cmd = SON([('q', selector), ('limit', limit)])
collation = validate_collation_or_none(collation)
if collation is not None:
self.uses_collation = True
cmd['collation'] = collation
if limit == _DELETE_ALL:
# A bulk_write containing a delete_many is not retryable.
self.is_retryable = False
self.ops.append((_DELETE, cmd))
def gen_ordered(self):
"""Generate batches of operations, batched by type of
operation, in the order **provided**.
"""
run = None
for idx, (op_type, operation) in enumerate(self.ops):
if run is None:
run = _Run(op_type)
elif run.op_type != op_type:
yield run
run = _Run(op_type)
run.add(idx, operation)
yield run
def gen_unordered(self):
"""Generate batches of operations, batched by type of
operation, in arbitrary order.
"""
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
for idx, (op_type, operation) in enumerate(self.ops):
operations[op_type].add(idx, operation)
for run in operations:
if run.ops:
yield run
def _execute_command(self, generator, write_concern, session,
sock_info, op_id, retryable, full_result):
if sock_info.max_wire_version < 5 and self.uses_collation:
raise ConfigurationError(
'Must be connected to MongoDB 3.4+ to use a collation.')
if sock_info.max_wire_version < 6 and self.uses_array_filters:
raise ConfigurationError(
'Must be connected to MongoDB 3.6+ to use arrayFilters.')
db_name = self.collection.database.name
client = self.collection.database.client
listeners = client._event_listeners
if not self.current_run:
self.current_run = next(generator)
run = self.current_run
# sock_info.command validates the session, but we use
# sock_info.write_command.
sock_info.validate_session(client, session)
while run:
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
('ordered', self.ordered)])
if not write_concern.is_server_default:
cmd['writeConcern'] = write_concern.document
if self.bypass_doc_val and sock_info.max_wire_version >= 4:
cmd['bypassDocumentValidation'] = True
bwc = self.bulk_ctx_class(
db_name, cmd, sock_info, op_id, listeners, session,
run.op_type, self.collection.codec_options)
while run.idx_offset < len(run.ops):
if session:
# Start a new retryable write unless one was already
# started for this command.
if retryable and not self.started_retryable_write:
session._start_retryable_write()
self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY)
sock_info.send_cluster_time(cmd, session, client)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible in one command.
result, to_send = bwc.execute(ops, client)
# Retryable writeConcernErrors halt the execution of this run.
wce = result.get('writeConcernError', {})
if wce.get('code', 0) in _RETRYABLE_ERROR_CODES:
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
full = copy.deepcopy(full_result)
_merge_command(run, full, run.idx_offset, result)
_raise_bulk_write_error(full)
_merge_command(run, full_result, run.idx_offset, result)
# We're no longer in a retry once a command succeeds.
self.retrying = False
self.started_retryable_write = False
if self.ordered and "writeErrors" in result:
break
run.idx_offset += len(to_send)
# We're supposed to continue if errors are
# at the write concern level (e.g. wtimeout)
if self.ordered and full_result['writeErrors']:
break
# Reset our state
self.current_run = run = next(generator, None)
def execute_command(self, generator, write_concern, session):
"""Execute using write commands.
"""
# nModified is only reported for write commands, not legacy ops.
full_result = {
"writeErrors": [],
"writeConcernErrors": [],
"nInserted": 0,
"nUpserted": 0,
"nMatched": 0,
"nModified": 0,
"nRemoved": 0,
"upserted": [],
}
op_id = _randint()
def retryable_bulk(session, sock_info, retryable):
self._execute_command(
generator, write_concern, session, sock_info, op_id,
retryable, full_result)
client = self.collection.database.client
with client._tmp_session(session) as s:
client._retry_with_session(
self.is_retryable, retryable_bulk, s, self)
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
_raise_bulk_write_error(full_result)
return full_result
def execute_insert_no_results(self, sock_info, run, op_id, acknowledged):
"""Execute insert, returning no results.
"""
command = SON([('insert', self.collection.name),
('ordered', self.ordered)])
concern = {'w': int(self.ordered)}
command['writeConcern'] = concern
if self.bypass_doc_val and sock_info.max_wire_version >= 4:
command['bypassDocumentValidation'] = True
db = self.collection.database
bwc = _BulkWriteContext(
db.name, command, sock_info, op_id, db.client._event_listeners,
None, _INSERT, self.collection.codec_options)
# Legacy batched OP_INSERT.
_do_batched_insert(
self.collection.full_name, run.ops, True, acknowledged, concern,
not self.ordered, self.collection.codec_options, bwc)
def execute_op_msg_no_results(self, sock_info, generator):
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered.
"""
db_name = self.collection.database.name
client = self.collection.database.client
listeners = client._event_listeners
op_id = _randint()
if not self.current_run:
self.current_run = next(generator)
run = self.current_run
while run:
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
('ordered', False),
('writeConcern', {'w': 0})])
bwc = self.bulk_ctx_class(
db_name, cmd, sock_info, op_id, listeners, None,
run.op_type, self.collection.codec_options)
while run.idx_offset < len(run.ops):
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
to_send = bwc.execute_unack(ops, client)
run.idx_offset += len(to_send)
self.current_run = run = next(generator, None)
def execute_command_no_results(self, sock_info, generator):
"""Execute write commands with OP_MSG and w=0 WriteConcern, ordered.
"""
full_result = {
"writeErrors": [],
"writeConcernErrors": [],
"nInserted": 0,
"nUpserted": 0,
"nMatched": 0,
"nModified": 0,
"nRemoved": 0,
"upserted": [],
}
# Ordered bulk writes have to be acknowledged so that we stop
# processing at the first error, even when the application
# specified unacknowledged writeConcern.
write_concern = WriteConcern()
op_id = _randint()
try:
self._execute_command(
generator, write_concern, None,
sock_info, op_id, False, full_result)
except OperationFailure:
pass
def execute_no_results(self, sock_info, generator):
"""Execute all operations, returning no results (w=0).
"""
if self.uses_collation:
raise ConfigurationError(
'Collation is unsupported for unacknowledged writes.')
if self.uses_array_filters:
raise ConfigurationError(
'arrayFilters is unsupported for unacknowledged writes.')
# Cannot have both unacknowledged writes and bypass document validation.
if self.bypass_doc_val and sock_info.max_wire_version >= 4:
raise OperationFailure("Cannot set bypass_document_validation with"
" unacknowledged write concern")
# OP_MSG
if sock_info.max_wire_version > 5:
if self.ordered:
return self.execute_command_no_results(sock_info, generator)
return self.execute_op_msg_no_results(sock_info, generator)
coll = self.collection
# If ordered is True we have to send GLE or use write
# commands so we can abort on the first error.
write_concern = WriteConcern(w=int(self.ordered))
op_id = _randint()
next_run = next(generator)
while next_run:
# An ordered bulk write needs to send acknowledged writes to short
# circuit the next run. However, the final message on the final
# run can be unacknowledged.
run = next_run
next_run = next(generator, None)
needs_ack = self.ordered and next_run is not None
try:
if run.op_type == _INSERT:
self.execute_insert_no_results(
sock_info, run, op_id, needs_ack)
elif run.op_type == _UPDATE:
for operation in run.ops:
doc = operation['u']
check_keys = True
if doc and next(iter(doc)).startswith('$'):
check_keys = False
coll._update(
sock_info,
operation['q'],
doc,
operation['upsert'],
check_keys,
operation['multi'],
write_concern=write_concern,
op_id=op_id,
ordered=self.ordered,
bypass_doc_val=self.bypass_doc_val)
else:
for operation in run.ops:
coll._delete(sock_info,
operation['q'],
not operation['limit'],
write_concern,
op_id,
self.ordered)
except OperationFailure:
if self.ordered:
break
def execute(self, write_concern, session):
"""Execute operations.
"""
if not self.ops:
raise InvalidOperation('No operations to execute')
if self.executed:
raise InvalidOperation('Bulk operations can '
'only be executed once.')
self.executed = True
write_concern = write_concern or self.collection.write_concern
session = _validate_session_write_concern(session, write_concern)
if self.ordered:
generator = self.gen_ordered()
else:
generator = self.gen_unordered()
client = self.collection.database.client
if not write_concern.acknowledged:
with client._socket_for_writes(session) as sock_info:
self.execute_no_results(sock_info, generator)
else:
return self.execute_command(generator, write_concern, session)
class BulkUpsertOperation(object):
"""An interface for adding upsert operations.
"""
__slots__ = ('__selector', '__bulk', '__collation')
def __init__(self, selector, bulk, collation):
self.__selector = selector
self.__bulk = bulk
self.__collation = collation
def update_one(self, update):
"""Update one document matching the selector.
:Parameters:
- `update` (dict): the update operations to apply
"""
self.__bulk.add_update(self.__selector,
update, multi=False, upsert=True,
collation=self.__collation)
def update(self, update):
"""Update all documents matching the selector.
:Parameters:
- `update` (dict): the update operations to apply
"""
self.__bulk.add_update(self.__selector,
update, multi=True, upsert=True,
collation=self.__collation)
def replace_one(self, replacement):
"""Replace one entire document matching the selector criteria.
:Parameters:
- `replacement` (dict): the replacement document
"""
self.__bulk.add_replace(self.__selector, replacement, upsert=True,
collation=self.__collation)
class BulkWriteOperation(object):
"""An interface for adding update or remove operations.
"""
__slots__ = ('__selector', '__bulk', '__collation')
def __init__(self, selector, bulk, collation):
self.__selector = selector
self.__bulk = bulk
self.__collation = collation
def update_one(self, update):
"""Update one document matching the selector criteria.
:Parameters:
- `update` (dict): the update operations to apply
"""
self.__bulk.add_update(self.__selector, update, multi=False,
collation=self.__collation)
def update(self, update):
"""Update all documents matching the selector criteria.
:Parameters:
- `update` (dict): the update operations to apply
"""
self.__bulk.add_update(self.__selector, update, multi=True,
collation=self.__collation)
def replace_one(self, replacement):
"""Replace one entire document matching the selector criteria.
:Parameters:
- `replacement` (dict): the replacement document
"""
self.__bulk.add_replace(self.__selector, replacement,
collation=self.__collation)
def remove_one(self):
"""Remove a single document matching the selector criteria.
"""
self.__bulk.add_delete(self.__selector, _DELETE_ONE,
collation=self.__collation)
def remove(self):
"""Remove all documents matching the selector criteria.
"""
self.__bulk.add_delete(self.__selector, _DELETE_ALL,
collation=self.__collation)
def upsert(self):
"""Specify that all chained update operations should be
upserts.
:Returns:
- A :class:`BulkUpsertOperation` instance, used to add
update operations to this bulk operation.
"""
return BulkUpsertOperation(self.__selector, self.__bulk,
self.__collation)
class BulkOperationBuilder(object):
"""**DEPRECATED**: An interface for executing a batch of write operations.
"""
__slots__ = '__bulk'
def __init__(self, collection, ordered=True,
bypass_document_validation=False):
"""**DEPRECATED**: Initialize a new BulkOperationBuilder instance.
:Parameters:
- `collection`: A :class:`~pymongo.collection.Collection` instance.
- `ordered` (optional): If ``True`` all operations will be executed
serially, in the order provided, and the entire execution will
abort on the first error. If ``False`` operations will be executed
in arbitrary order (possibly in parallel on the server), reporting
any errors that occurred after attempting all operations. Defaults
to ``True``.
- `bypass_document_validation`: (optional) If ``True``, allows the
write to opt-out of document level validation. Default is
``False``.
.. note:: `bypass_document_validation` requires server version
**>= 3.2**
.. versionchanged:: 3.5
Deprecated. Use :meth:`~pymongo.collection.Collection.bulk_write`
instead.
.. versionchanged:: 3.2
Added bypass_document_validation support
"""
self.__bulk = _Bulk(collection, ordered, bypass_document_validation)
def find(self, selector, collation=None):
"""Specify selection criteria for bulk operations.
:Parameters:
- `selector` (dict): the selection criteria for update
and remove operations.
- `collation` (optional): An instance of
:class:`~pymongo.collation.Collation`. This option is only
supported on MongoDB 3.4 and above.
:Returns:
- A :class:`BulkWriteOperation` instance, used to add
update and remove operations to this bulk operation.
.. versionchanged:: 3.4
Added the `collation` option.
"""
validate_is_mapping("selector", selector)
return BulkWriteOperation(selector, self.__bulk, collation)
def insert(self, document):
"""Insert a single document.
:Parameters:
- `document` (dict): the document to insert
.. seealso:: :ref:`writes-and-ids`
"""
self.__bulk.add_insert(document)
def execute(self, write_concern=None):
"""Execute all provided operations.
:Parameters:
- write_concern (optional): the write concern for this bulk
execution.
"""
if write_concern is not None:
write_concern = WriteConcern(**write_concern)
return self.__bulk.execute(write_concern, session=None)

View File

@@ -0,0 +1,381 @@
# Copyright 2017 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Watch changes on a collection, a database, or the entire cluster."""
import copy
from bson import _bson_to_dict
from bson.raw_bson import RawBSONDocument
from pymongo import common
from pymongo.aggregation import (_CollectionAggregationCommand,
_DatabaseAggregationCommand)
from pymongo.collation import validate_collation_or_none
from pymongo.command_cursor import CommandCursor
from pymongo.errors import (ConnectionFailure,
InvalidOperation,
OperationFailure,
PyMongoError)
# The change streams spec considers the following server errors from the
# getMore command non-resumable. All other getMore errors are resumable.
_NON_RESUMABLE_GETMORE_ERRORS = frozenset([
11601, # Interrupted
136, # CappedPositionLost
237, # CursorKilled
None, # No error code was returned.
])
class ChangeStream(object):
"""The internal abstract base class for change stream cursors.
Should not be called directly by application developers. Use
:meth:`pymongo.collection.Collection.watch`,
:meth:`pymongo.database.Database.watch`, or
:meth:`pymongo.mongo_client.MongoClient.watch` instead.
.. versionadded:: 3.6
.. mongodoc:: changeStreams
"""
def __init__(self, target, pipeline, full_document, resume_after,
max_await_time_ms, batch_size, collation,
start_at_operation_time, session, start_after):
if pipeline is None:
pipeline = []
elif not isinstance(pipeline, list):
raise TypeError("pipeline must be a list")
common.validate_string_or_none('full_document', full_document)
validate_collation_or_none(collation)
common.validate_non_negative_integer_or_none("batchSize", batch_size)
self._decode_custom = False
self._orig_codec_options = target.codec_options
if target.codec_options.type_registry._decoder_map:
self._decode_custom = True
# Keep the type registry so that we support encoding custom types
# in the pipeline.
self._target = target.with_options(
codec_options=target.codec_options.with_options(
document_class=RawBSONDocument))
else:
self._target = target
self._pipeline = copy.deepcopy(pipeline)
self._full_document = full_document
self._uses_start_after = start_after is not None
self._uses_resume_after = resume_after is not None
self._resume_token = copy.deepcopy(start_after or resume_after)
self._max_await_time_ms = max_await_time_ms
self._batch_size = batch_size
self._collation = collation
self._start_at_operation_time = start_at_operation_time
self._session = session
# Initialize cursor.
self._cursor = self._create_cursor()
@property
def _aggregation_command_class(self):
"""The aggregation command class to be used."""
raise NotImplementedError
@property
def _client(self):
"""The client against which the aggregation commands for
this ChangeStream will be run. """
raise NotImplementedError
def _change_stream_options(self):
"""Return the options dict for the $changeStream pipeline stage."""
options = {}
if self._full_document is not None:
options['fullDocument'] = self._full_document
resume_token = self.resume_token
if resume_token is not None:
if self._uses_start_after:
options['startAfter'] = resume_token
if self._uses_resume_after:
options['resumeAfter'] = resume_token
if self._start_at_operation_time is not None:
options['startAtOperationTime'] = self._start_at_operation_time
return options
def _command_options(self):
"""Return the options dict for the aggregation command."""
options = {}
if self._max_await_time_ms is not None:
options["maxAwaitTimeMS"] = self._max_await_time_ms
if self._batch_size is not None:
options["batchSize"] = self._batch_size
return options
def _aggregation_pipeline(self):
"""Return the full aggregation pipeline for this ChangeStream."""
options = self._change_stream_options()
full_pipeline = [{'$changeStream': options}]
full_pipeline.extend(self._pipeline)
return full_pipeline
def _process_result(self, result, session, server, sock_info, slave_ok):
"""Callback that caches the startAtOperationTime from a changeStream
aggregate command response containing an empty batch of change
documents.
This is implemented as a callback because we need access to the wire
version in order to determine whether to cache this value.
"""
if not result['cursor']['firstBatch']:
if (self._start_at_operation_time is None and
self.resume_token is None and
sock_info.max_wire_version >= 7):
self._start_at_operation_time = result["operationTime"]
def _run_aggregation_cmd(self, session, explicit_session):
"""Run the full aggregation pipeline for this ChangeStream and return
the corresponding CommandCursor.
"""
cmd = self._aggregation_command_class(
self._target, CommandCursor, self._aggregation_pipeline(),
self._command_options(), explicit_session,
result_processor=self._process_result)
return self._client._retryable_read(
cmd.get_cursor, self._target._read_preference_for(session),
session)
def _create_cursor(self):
with self._client._tmp_session(self._session, close=False) as s:
return self._run_aggregation_cmd(
session=s,
explicit_session=self._session is not None)
def _resume(self):
"""Reestablish this change stream after a resumable error."""
try:
self._cursor.close()
except PyMongoError:
pass
self._cursor = self._create_cursor()
def close(self):
"""Close this ChangeStream."""
self._cursor.close()
def __iter__(self):
return self
@property
def resume_token(self):
"""The cached resume token that will be used to resume after the most
recently returned change.
.. versionadded:: 3.9
"""
return copy.deepcopy(self._resume_token)
def next(self):
"""Advance the cursor.
This method blocks until the next change document is returned or an
unrecoverable error is raised. This method is used when iterating over
all changes in the cursor. For example::
try:
resume_token = None
pipeline = [{'$match': {'operationType': 'insert'}}]
with db.collection.watch(pipeline) as stream:
for insert_change in stream:
print(insert_change)
resume_token = stream.resume_token
except pymongo.errors.PyMongoError:
# The ChangeStream encountered an unrecoverable error or the
# resume attempt failed to recreate the cursor.
if resume_token is None:
# There is no usable resume token because there was a
# failure during ChangeStream initialization.
logging.error('...')
else:
# Use the interrupted ChangeStream's resume token to create
# a new ChangeStream. The new stream will continue from the
# last seen insert change without missing any events.
with db.collection.watch(
pipeline, resume_after=resume_token) as stream:
for insert_change in stream:
print(insert_change)
Raises :exc:`StopIteration` if this ChangeStream is closed.
"""
while self.alive:
doc = self.try_next()
if doc is not None:
return doc
raise StopIteration
__next__ = next
@property
def alive(self):
"""Does this cursor have the potential to return more data?
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
.. versionadded:: 3.8
"""
return self._cursor.alive
def try_next(self):
"""Advance the cursor without blocking indefinitely.
This method returns the next change document without waiting
indefinitely for the next change. For example::
with db.collection.watch() as stream:
while stream.alive:
change = stream.try_next()
# Note that the ChangeStream's resume token may be updated
# even when no changes are returned.
print("Current resume token: %r" % (stream.resume_token,))
if change is not None:
print("Change document: %r" % (change,))
continue
# We end up here when there are no recent changes.
# Sleep for a while before trying again to avoid flooding
# the server with getMore requests when no changes are
# available.
time.sleep(10)
If no change document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there have been no changes) then ``None`` is returned.
:Returns:
The next change document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 3.8
"""
# Attempt to get the next change with at most one getMore and at most
# one resume attempt.
try:
change = self._cursor._try_next(True)
except ConnectionFailure:
self._resume()
change = self._cursor._try_next(False)
except OperationFailure as exc:
if (exc.code in _NON_RESUMABLE_GETMORE_ERRORS or
exc.has_error_label("NonResumableChangeStreamError")):
raise
self._resume()
change = self._cursor._try_next(False)
# If no changes are available.
if change is None:
# We have either iterated over all documents in the cursor,
# OR the most-recently returned batch is empty. In either case,
# update the cached resume token with the postBatchResumeToken if
# one was returned. We also clear the startAtOperationTime.
if self._cursor._post_batch_resume_token is not None:
self._resume_token = self._cursor._post_batch_resume_token
self._start_at_operation_time = None
return change
# Else, changes are available.
try:
resume_token = change['_id']
except KeyError:
self.close()
raise InvalidOperation(
"Cannot provide resume functionality when the resume "
"token is missing.")
# If this is the last change document from the current batch, cache the
# postBatchResumeToken.
if (not self._cursor._has_next() and
self._cursor._post_batch_resume_token):
resume_token = self._cursor._post_batch_resume_token
# Hereafter, don't use startAfter; instead use resumeAfter.
self._uses_start_after = False
self._uses_resume_after = True
# Cache the resume token and clear startAtOperationTime.
self._resume_token = resume_token
self._start_at_operation_time = None
if self._decode_custom:
return _bson_to_dict(change.raw, self._orig_codec_options)
return change
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class CollectionChangeStream(ChangeStream):
"""A change stream that watches changes on a single collection.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.collection.Collection.watch` instead.
.. versionadded:: 3.7
"""
@property
def _aggregation_command_class(self):
return _CollectionAggregationCommand
@property
def _client(self):
return self._target.database.client
class DatabaseChangeStream(ChangeStream):
"""A change stream that watches changes on all collections in a database.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.database.Database.watch` instead.
.. versionadded:: 3.7
"""
@property
def _aggregation_command_class(self):
return _DatabaseAggregationCommand
@property
def _client(self):
return self._target.client
class ClusterChangeStream(DatabaseChangeStream):
"""A change stream that watches changes on all collections in the cluster.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead.
.. versionadded:: 3.7
"""
def _change_stream_options(self):
options = super(ClusterChangeStream, self)._change_stream_options()
options["allChangesForCluster"] = True
return options

View File

@@ -0,0 +1,249 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse mongo client options."""
from bson.codec_options import _parse_codec_options
from pymongo.auth import _build_credentials_tuple
from pymongo.common import validate_boolean
from pymongo import common
from pymongo.compression_support import CompressionSettings
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _EventListeners
from pymongo.pool import PoolOptions
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import (make_read_preference,
read_pref_mode_from_name)
from pymongo.server_selectors import any_server_selector
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern
def _parse_credentials(username, password, database, options):
"""Parse authentication credentials."""
mechanism = options.get('authmechanism', 'DEFAULT' if username else None)
source = options.get('authsource')
if username or mechanism:
return _build_credentials_tuple(
mechanism, source, username, password, options, database)
return None
def _parse_read_preference(options):
"""Parse read preference options."""
if 'read_preference' in options:
return options['read_preference']
name = options.get('readpreference', 'primary')
mode = read_pref_mode_from_name(name)
tags = options.get('readpreferencetags')
max_staleness = options.get('maxstalenessseconds', -1)
return make_read_preference(mode, tags, max_staleness)
def _parse_write_concern(options):
"""Parse write concern options."""
concern = options.get('w')
wtimeout = options.get('wtimeoutms')
j = options.get('journal')
fsync = options.get('fsync')
return WriteConcern(concern, wtimeout, j, fsync)
def _parse_read_concern(options):
"""Parse read concern options."""
concern = options.get('readconcernlevel')
return ReadConcern(concern)
def _parse_ssl_options(options):
"""Parse ssl options."""
use_ssl = options.get('ssl')
if use_ssl is not None:
validate_boolean('ssl', use_ssl)
certfile = options.get('ssl_certfile')
keyfile = options.get('ssl_keyfile')
passphrase = options.get('ssl_pem_passphrase')
ca_certs = options.get('ssl_ca_certs')
cert_reqs = options.get('ssl_cert_reqs')
match_hostname = options.get('ssl_match_hostname', True)
crlfile = options.get('ssl_crlfile')
ssl_kwarg_keys = [k for k in options
if k.startswith('ssl_') and options[k]]
if use_ssl == False and ssl_kwarg_keys:
raise ConfigurationError("ssl has not been enabled but the "
"following ssl parameters have been set: "
"%s. Please set `ssl=True` or remove."
% ', '.join(ssl_kwarg_keys))
if ssl_kwarg_keys and use_ssl is None:
# ssl options imply ssl = True
use_ssl = True
if use_ssl is True:
ctx = get_ssl_context(
certfile,
keyfile,
passphrase,
ca_certs,
cert_reqs,
crlfile,
match_hostname)
return ctx, match_hostname
return None, match_hostname
def _parse_pool_options(options):
"""Parse connection pool options."""
max_pool_size = options.get('maxpoolsize', common.MAX_POOL_SIZE)
min_pool_size = options.get('minpoolsize', common.MIN_POOL_SIZE)
max_idle_time_seconds = options.get(
'maxidletimems', common.MAX_IDLE_TIME_SEC)
if max_pool_size is not None and min_pool_size > max_pool_size:
raise ValueError("minPoolSize must be smaller or equal to maxPoolSize")
connect_timeout = options.get('connecttimeoutms', common.CONNECT_TIMEOUT)
socket_keepalive = options.get('socketkeepalive', True)
socket_timeout = options.get('sockettimeoutms')
wait_queue_timeout = options.get(
'waitqueuetimeoutms', common.WAIT_QUEUE_TIMEOUT)
wait_queue_multiple = options.get('waitqueuemultiple')
event_listeners = options.get('event_listeners')
appname = options.get('appname')
driver = options.get('driver')
compression_settings = CompressionSettings(
options.get('compressors', []),
options.get('zlibcompressionlevel', -1))
ssl_context, ssl_match_hostname = _parse_ssl_options(options)
return PoolOptions(max_pool_size,
min_pool_size,
max_idle_time_seconds,
connect_timeout, socket_timeout,
wait_queue_timeout, wait_queue_multiple,
ssl_context, ssl_match_hostname, socket_keepalive,
_EventListeners(event_listeners),
appname,
driver,
compression_settings)
class ClientOptions(object):
"""ClientOptions"""
def __init__(self, username, password, database, options):
self.__options = options
self.__codec_options = _parse_codec_options(options)
self.__credentials = _parse_credentials(
username, password, database, options)
self.__local_threshold_ms = options.get(
'localthresholdms', common.LOCAL_THRESHOLD_MS)
# self.__server_selection_timeout is in seconds. Must use full name for
# common.SERVER_SELECTION_TIMEOUT because it is set directly by tests.
self.__server_selection_timeout = options.get(
'serverselectiontimeoutms', common.SERVER_SELECTION_TIMEOUT)
self.__pool_options = _parse_pool_options(options)
self.__read_preference = _parse_read_preference(options)
self.__replica_set_name = options.get('replicaset')
self.__write_concern = _parse_write_concern(options)
self.__read_concern = _parse_read_concern(options)
self.__connect = options.get('connect')
self.__heartbeat_frequency = options.get(
'heartbeatfrequencyms', common.HEARTBEAT_FREQUENCY)
self.__retry_writes = options.get('retrywrites', common.RETRY_WRITES)
self.__retry_reads = options.get('retryreads', common.RETRY_READS)
self.__server_selector = options.get(
'server_selector', any_server_selector)
self.__auto_encryption_opts = options.get('auto_encryption_opts')
@property
def _options(self):
"""The original options used to create this ClientOptions."""
return self.__options
@property
def connect(self):
"""Whether to begin discovering a MongoDB topology automatically."""
return self.__connect
@property
def codec_options(self):
"""A :class:`~bson.codec_options.CodecOptions` instance."""
return self.__codec_options
@property
def credentials(self):
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
return self.__credentials
@property
def local_threshold_ms(self):
"""The local threshold for this instance."""
return self.__local_threshold_ms
@property
def server_selection_timeout(self):
"""The server selection timeout for this instance in seconds."""
return self.__server_selection_timeout
@property
def server_selector(self):
return self.__server_selector
@property
def heartbeat_frequency(self):
"""The monitoring frequency in seconds."""
return self.__heartbeat_frequency
@property
def pool_options(self):
"""A :class:`~pymongo.pool.PoolOptions` instance."""
return self.__pool_options
@property
def read_preference(self):
"""A read preference instance."""
return self.__read_preference
@property
def replica_set_name(self):
"""Replica set name or None."""
return self.__replica_set_name
@property
def write_concern(self):
"""A :class:`~pymongo.write_concern.WriteConcern` instance."""
return self.__write_concern
@property
def read_concern(self):
"""A :class:`~pymongo.read_concern.ReadConcern` instance."""
return self.__read_concern
@property
def retry_writes(self):
"""If this instance should retry supported write operations."""
return self.__retry_writes
@property
def retry_reads(self):
"""If this instance should retry supported read operations."""
return self.__retry_reads
@property
def auto_encryption_opts(self):
"""A :class:`~pymongo.encryption.AutoEncryptionOpts` or None."""
return self.__auto_encryption_opts

View File

@@ -0,0 +1,908 @@
# Copyright 2017 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Logical sessions for ordering sequential operations.
Requires MongoDB 3.6.
.. versionadded:: 3.6
Causally Consistent Reads
=========================
.. code-block:: python
with client.start_session(causal_consistency=True) as session:
collection = client.db.collection
collection.update_one({'_id': 1}, {'$set': {'x': 10}}, session=session)
secondary_c = collection.with_options(
read_preference=ReadPreference.SECONDARY)
# A secondary read waits for replication of the write.
secondary_c.find_one({'_id': 1}, session=session)
If `causal_consistency` is True (the default), read operations that use
the session are causally after previous read and write operations. Using a
causally consistent session, an application can read its own writes and is
guaranteed monotonic reads, even when reading from replica set secondaries.
.. mongodoc:: causal-consistency
.. _transactions-ref:
Transactions
============
MongoDB 4.0 adds support for transactions on replica set primaries. A
transaction is associated with a :class:`ClientSession`. To start a transaction
on a session, use :meth:`ClientSession.start_transaction` in a with-statement.
Then, execute an operation within the transaction by passing the session to the
operation:
.. code-block:: python
orders = client.db.orders
inventory = client.db.inventory
with client.start_session() as session:
with session.start_transaction():
orders.insert_one({"sku": "abc123", "qty": 100}, session=session)
inventory.update_one({"sku": "abc123", "qty": {"$gte": 100}},
{"$inc": {"qty": -100}}, session=session)
Upon normal completion of ``with session.start_transaction()`` block, the
transaction automatically calls :meth:`ClientSession.commit_transaction`.
If the block exits with an exception, the transaction automatically calls
:meth:`ClientSession.abort_transaction`.
For multi-document transactions, you can only specify read/write (CRUD)
operations on existing collections. For example, a multi-document transaction
cannot include a create or drop collection/index operations, including an
insert operation that would result in the creation of a new collection.
A session may only have a single active transaction at a time, multiple
transactions on the same session can be executed in sequence.
.. versionadded:: 3.7
Sharded Transactions
^^^^^^^^^^^^^^^^^^^^
PyMongo 3.9 adds support for transactions on sharded clusters running MongoDB
4.2. Sharded transactions have the same API as replica set transactions.
When running a transaction against a sharded cluster, the session is
pinned to the mongos server selected for the first operation in the
transaction. All subsequent operations that are part of the same transaction
are routed to the same mongos server. When the transaction is completed, by
running either commitTransaction or abortTransaction, the session is unpinned.
.. versionadded:: 3.9
.. mongodoc:: transactions
Classes
=======
"""
import collections
import sys
import uuid
from bson.binary import Binary
from bson.int64 import Int64
from bson.py3compat import abc, integer_types, reraise_instance
from bson.son import SON
from bson.timestamp import Timestamp
from pymongo import monotonic
from pymongo.errors import (ConfigurationError,
ConnectionFailure,
InvalidOperation,
OperationFailure,
PyMongoError,
ServerSelectionTimeoutError,
WTimeoutError)
from pymongo.helpers import _RETRYABLE_ERROR_CODES
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.write_concern import WriteConcern
class SessionOptions(object):
"""Options for a new :class:`ClientSession`.
:Parameters:
- `causal_consistency` (optional): If True (the default), read
operations are causally ordered within the session.
- `default_transaction_options` (optional): The default
TransactionOptions to use for transactions started on this session.
"""
def __init__(self,
causal_consistency=True,
default_transaction_options=None):
self._causal_consistency = causal_consistency
if default_transaction_options is not None:
if not isinstance(default_transaction_options, TransactionOptions):
raise TypeError(
"default_transaction_options must be an instance of "
"pymongo.client_session.TransactionOptions, not: %r" %
(default_transaction_options,))
self._default_transaction_options = default_transaction_options
@property
def causal_consistency(self):
"""Whether causal consistency is configured."""
return self._causal_consistency
@property
def default_transaction_options(self):
"""The default TransactionOptions to use for transactions started on
this session.
.. versionadded:: 3.7
"""
return self._default_transaction_options
class TransactionOptions(object):
"""Options for :meth:`ClientSession.start_transaction`.
:Parameters:
- `read_concern` (optional): The
:class:`~pymongo.read_concern.ReadConcern` to use for this transaction.
If ``None`` (the default) the :attr:`read_preference` of
the :class:`MongoClient` is used.
- `write_concern` (optional): The
:class:`~pymongo.write_concern.WriteConcern` to use for this
transaction. If ``None`` (the default) the :attr:`read_preference` of
the :class:`MongoClient` is used.
- `read_preference` (optional): The read preference to use. If
``None`` (the default) the :attr:`read_preference` of this
:class:`MongoClient` is used. See :mod:`~pymongo.read_preferences`
for options. Transactions which read must use
:attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`.
- `max_commit_time_ms` (optional): The maximum amount of time to allow a
single commitTransaction command to run. This option is an alias for
maxTimeMS option on the commitTransaction command. If ``None`` (the
default) maxTimeMS is not used.
.. versionchanged:: 3.9
Added the ``max_commit_time_ms`` option.
.. versionadded:: 3.7
"""
def __init__(self, read_concern=None, write_concern=None,
read_preference=None, max_commit_time_ms=None):
self._read_concern = read_concern
self._write_concern = write_concern
self._read_preference = read_preference
self._max_commit_time_ms = max_commit_time_ms
if read_concern is not None:
if not isinstance(read_concern, ReadConcern):
raise TypeError("read_concern must be an instance of "
"pymongo.read_concern.ReadConcern, not: %r" %
(read_concern,))
if write_concern is not None:
if not isinstance(write_concern, WriteConcern):
raise TypeError("write_concern must be an instance of "
"pymongo.write_concern.WriteConcern, not: %r" %
(write_concern,))
if not write_concern.acknowledged:
raise ConfigurationError(
"transactions do not support unacknowledged write concern"
": %r" % (write_concern,))
if read_preference is not None:
if not isinstance(read_preference, _ServerMode):
raise TypeError("%r is not valid for read_preference. See "
"pymongo.read_preferences for valid "
"options." % (read_preference,))
if max_commit_time_ms is not None:
if not isinstance(max_commit_time_ms, integer_types):
raise TypeError(
"max_commit_time_ms must be an integer or None")
@property
def read_concern(self):
"""This transaction's :class:`~pymongo.read_concern.ReadConcern`."""
return self._read_concern
@property
def write_concern(self):
"""This transaction's :class:`~pymongo.write_concern.WriteConcern`."""
return self._write_concern
@property
def read_preference(self):
"""This transaction's :class:`~pymongo.read_preferences.ReadPreference`.
"""
return self._read_preference
@property
def max_commit_time_ms(self):
"""The maxTimeMS to use when running a commitTransaction command.
.. versionadded:: 3.9
"""
return self._max_commit_time_ms
def _validate_session_write_concern(session, write_concern):
"""Validate that an explicit session is not used with an unack'ed write.
Returns the session to use for the next operation.
"""
if session:
if write_concern is not None and not write_concern.acknowledged:
# For unacknowledged writes without an explicit session,
# drivers SHOULD NOT use an implicit session. If a driver
# creates an implicit session for unacknowledged writes
# without an explicit session, the driver MUST NOT send the
# session ID.
if session._implicit:
return None
else:
raise ConfigurationError(
'Explicit sessions are incompatible with '
'unacknowledged write concern: %r' % (
write_concern,))
return session
class _TransactionContext(object):
"""Internal transaction context manager for start_transaction."""
def __init__(self, session):
self.__session = session
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.__session._in_transaction:
if exc_val is None:
self.__session.commit_transaction()
else:
self.__session.abort_transaction()
class _TxnState(object):
NONE = 1
STARTING = 2
IN_PROGRESS = 3
COMMITTED = 4
COMMITTED_EMPTY = 5
ABORTED = 6
class _Transaction(object):
"""Internal class to hold transaction information in a ClientSession."""
def __init__(self, opts):
self.opts = opts
self.state = _TxnState.NONE
self.sharded = False
self.pinned_address = None
self.recovery_token = None
def active(self):
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
def reset(self):
self.state = _TxnState.NONE
self.sharded = False
self.pinned_address = None
self.recovery_token = None
def _reraise_with_unknown_commit(exc):
"""Re-raise an exception with the UnknownTransactionCommitResult label."""
exc._add_error_label("UnknownTransactionCommitResult")
reraise_instance(exc, trace=sys.exc_info()[2])
def _max_time_expired_error(exc):
"""Return true if exc is a MaxTimeMSExpired error."""
return isinstance(exc, OperationFailure) and exc.code == 50
# From the transactions spec, all the retryable writes errors plus
# WriteConcernFailed.
_UNKNOWN_COMMIT_ERROR_CODES = _RETRYABLE_ERROR_CODES | frozenset([
64, # WriteConcernFailed
50, # MaxTimeMSExpired
])
# From the Convenient API for Transactions spec, with_transaction must
# halt retries after 120 seconds.
# This limit is non-configurable and was chosen to be twice the 60 second
# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter.
_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120
def _within_time_limit(start_time):
"""Are we within the with_transaction retry limit?"""
return monotonic.time() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
class ClientSession(object):
"""A session for ordering sequential operations."""
def __init__(self, client, server_session, options, authset, implicit):
# A MongoClient, a _ServerSession, a SessionOptions, and a set.
self._client = client
self._server_session = server_session
self._options = options
self._authset = authset
self._cluster_time = None
self._operation_time = None
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None)
def end_session(self):
"""Finish this session. If a transaction has started, abort it.
It is an error to use the session after the session has ended.
"""
self._end_session(lock=True)
def _end_session(self, lock):
if self._server_session is not None:
try:
if self._in_transaction:
self.abort_transaction()
finally:
self._client._return_server_session(self._server_session, lock)
self._server_session = None
def _check_ended(self):
if self._server_session is None:
raise InvalidOperation("Cannot use ended session")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._end_session(lock=True)
@property
def client(self):
"""The :class:`~pymongo.mongo_client.MongoClient` this session was
created from.
"""
return self._client
@property
def options(self):
"""The :class:`SessionOptions` this session was created with."""
return self._options
@property
def session_id(self):
"""A BSON document, the opaque server session identifier."""
self._check_ended()
return self._server_session.session_id
@property
def cluster_time(self):
"""The cluster time returned by the last operation executed
in this session.
"""
return self._cluster_time
@property
def operation_time(self):
"""The operation time returned by the last operation executed
in this session.
"""
return self._operation_time
def _inherit_option(self, name, val):
"""Return the inherited TransactionOption value."""
if val:
return val
txn_opts = self.options.default_transaction_options
val = txn_opts and getattr(txn_opts, name)
if val:
return val
return getattr(self.client, name)
def with_transaction(self, callback, read_concern=None, write_concern=None,
read_preference=None, max_commit_time_ms=None):
"""Execute a callback in a transaction.
This method starts a transaction on this session, executes ``callback``
once, and then commits the transaction. For example::
def callback(session):
orders = session.client.db.orders
inventory = session.client.db.inventory
orders.insert_one({"sku": "abc123", "qty": 100}, session=session)
inventory.update_one({"sku": "abc123", "qty": {"$gte": 100}},
{"$inc": {"qty": -100}}, session=session)
with client.start_session() as session:
session.with_transaction(callback)
To pass arbitrary arguments to the ``callback``, wrap your callable
with a ``lambda`` like this::
def callback(session, custom_arg, custom_kwarg=None):
# Transaction operations...
with client.start_session() as session:
session.with_transaction(
lambda s: callback(s, "custom_arg", custom_kwarg=1))
In the event of an exception, ``with_transaction`` may retry the commit
or the entire transaction, therefore ``callback`` may be invoked
multiple times by a single call to ``with_transaction``. Developers
should be mindful of this possiblity when writing a ``callback`` that
modifies application state or has any other side-effects.
Note that even when the ``callback`` is invoked multiple times,
``with_transaction`` ensures that the transaction will be committed
at-most-once on the server.
The ``callback`` should not attempt to start new transactions, but
should simply run operations meant to be contained within a
transaction. The ``callback`` should also not commit the transaction;
this is handled automatically by ``with_transaction``. If the
``callback`` does commit or abort the transaction without error,
however, ``with_transaction`` will return without taking further
action.
When ``callback`` raises an exception, ``with_transaction``
automatically aborts the current transaction. When ``callback`` or
:meth:`~ClientSession.commit_transaction` raises an exception that
includes the ``"TransientTransactionError"`` error label,
``with_transaction`` starts a new transaction and re-executes
the ``callback``.
When :meth:`~ClientSession.commit_transaction` raises an exception with
the ``"UnknownTransactionCommitResult"`` error label,
``with_transaction`` retries the commit until the result of the
transaction is known.
This method will cease retrying after 120 seconds has elapsed. This
timeout is not configurable and any exception raised by the
``callback`` or by :meth:`ClientSession.commit_transaction` after the
timeout is reached will be re-raised. Applications that desire a
different timeout duration should not use this method.
:Parameters:
- `callback`: The callable ``callback`` to run inside a transaction.
The callable must accept a single argument, this session. Note,
under certain error conditions the callback may be run multiple
times.
- `read_concern` (optional): The
:class:`~pymongo.read_concern.ReadConcern` to use for this
transaction.
- `write_concern` (optional): The
:class:`~pymongo.write_concern.WriteConcern` to use for this
transaction.
- `read_preference` (optional): The read preference to use for this
transaction. If ``None`` (the default) the :attr:`read_preference`
of this :class:`Database` is used. See
:mod:`~pymongo.read_preferences` for options.
:Returns:
The return value of the ``callback``.
.. versionadded:: 3.9
"""
start_time = monotonic.time()
while True:
self.start_transaction(
read_concern, write_concern, read_preference,
max_commit_time_ms)
try:
ret = callback(self)
except Exception as exc:
if self._in_transaction:
self.abort_transaction()
if (isinstance(exc, PyMongoError) and
exc.has_error_label("TransientTransactionError") and
_within_time_limit(start_time)):
# Retry the entire transaction.
continue
raise
if self._transaction.state in (
_TxnState.NONE, _TxnState.COMMITTED, _TxnState.ABORTED):
# Assume callback intentionally ended the transaction.
return ret
while True:
try:
self.commit_transaction()
except PyMongoError as exc:
if (exc.has_error_label("UnknownTransactionCommitResult")
and _within_time_limit(start_time)
and not _max_time_expired_error(exc)):
# Retry the commit.
continue
if (exc.has_error_label("TransientTransactionError") and
_within_time_limit(start_time)):
# Retry the entire transaction.
break
raise
# Commit succeeded.
return ret
def start_transaction(self, read_concern=None, write_concern=None,
read_preference=None, max_commit_time_ms=None):
"""Start a multi-statement transaction.
Takes the same arguments as :class:`TransactionOptions`.
.. versionchanged:: 3.9
Added the ``max_commit_time_ms`` option.
.. versionadded:: 3.7
"""
self._check_ended()
if self._in_transaction:
raise InvalidOperation("Transaction already in progress")
read_concern = self._inherit_option("read_concern", read_concern)
write_concern = self._inherit_option("write_concern", write_concern)
read_preference = self._inherit_option(
"read_preference", read_preference)
if max_commit_time_ms is None:
opts = self.options.default_transaction_options
if opts:
max_commit_time_ms = opts.max_commit_time_ms
self._transaction.opts = TransactionOptions(
read_concern, write_concern, read_preference, max_commit_time_ms)
self._transaction.reset()
self._transaction.state = _TxnState.STARTING
self._start_retryable_write()
return _TransactionContext(self)
def commit_transaction(self):
"""Commit a multi-statement transaction.
.. versionadded:: 3.7
"""
self._check_ended()
retry = False
state = self._transaction.state
if state is _TxnState.NONE:
raise InvalidOperation("No transaction started")
elif state in (_TxnState.STARTING, _TxnState.COMMITTED_EMPTY):
# Server transaction was never started, no need to send a command.
self._transaction.state = _TxnState.COMMITTED_EMPTY
return
elif state is _TxnState.ABORTED:
raise InvalidOperation(
"Cannot call commitTransaction after calling abortTransaction")
elif state is _TxnState.COMMITTED:
# We're explicitly retrying the commit, move the state back to
# "in progress" so that _in_transaction returns true.
self._transaction.state = _TxnState.IN_PROGRESS
retry = True
try:
self._finish_transaction_with_retry("commitTransaction", retry)
except ConnectionFailure as exc:
# We do not know if the commit was successfully applied on the
# server or if it satisfied the provided write concern, set the
# unknown commit error label.
exc._remove_error_label("TransientTransactionError")
_reraise_with_unknown_commit(exc)
except WTimeoutError as exc:
# We do not know if the commit has satisfied the provided write
# concern, add the unknown commit error label.
_reraise_with_unknown_commit(exc)
except OperationFailure as exc:
if exc.code not in _UNKNOWN_COMMIT_ERROR_CODES:
# The server reports errorLabels in the case.
raise
# We do not know if the commit was successfully applied on the
# server or if it satisfied the provided write concern, set the
# unknown commit error label.
_reraise_with_unknown_commit(exc)
finally:
self._transaction.state = _TxnState.COMMITTED
def abort_transaction(self):
"""Abort a multi-statement transaction.
.. versionadded:: 3.7
"""
self._check_ended()
state = self._transaction.state
if state is _TxnState.NONE:
raise InvalidOperation("No transaction started")
elif state is _TxnState.STARTING:
# Server transaction was never started, no need to send a command.
self._transaction.state = _TxnState.ABORTED
return
elif state is _TxnState.ABORTED:
raise InvalidOperation("Cannot call abortTransaction twice")
elif state in (_TxnState.COMMITTED, _TxnState.COMMITTED_EMPTY):
raise InvalidOperation(
"Cannot call abortTransaction after calling commitTransaction")
try:
self._finish_transaction_with_retry("abortTransaction", False)
except (OperationFailure, ConnectionFailure):
# The transactions spec says to ignore abortTransaction errors.
pass
finally:
self._transaction.state = _TxnState.ABORTED
def _finish_transaction_with_retry(self, command_name, explict_retry):
"""Run commit or abort with one retry after any retryable error.
:Parameters:
- `command_name`: Either "commitTransaction" or "abortTransaction".
- `explict_retry`: True when this is an explict commit retry attempt,
ie the application called session.commit_transaction() twice.
"""
# This can be refactored with MongoClient._retry_with_session.
try:
return self._finish_transaction(command_name, explict_retry)
except ServerSelectionTimeoutError:
raise
except ConnectionFailure as exc:
try:
return self._finish_transaction(command_name, True)
except ServerSelectionTimeoutError:
# Raise the original error so the application can infer that
# an attempt was made.
raise exc
except OperationFailure as exc:
if exc.code not in _RETRYABLE_ERROR_CODES:
raise
try:
return self._finish_transaction(command_name, True)
except ServerSelectionTimeoutError:
# Raise the original error so the application can infer that
# an attempt was made.
raise exc
def _finish_transaction(self, command_name, retrying):
opts = self._transaction.opts
wc = opts.write_concern
cmd = SON([(command_name, 1)])
if command_name == "commitTransaction":
if opts.max_commit_time_ms:
cmd['maxTimeMS'] = opts.max_commit_time_ms
# Transaction spec says that after the initial commit attempt,
# subsequent commitTransaction commands should be upgraded to use
# w:"majority" and set a default value of 10 seconds for wtimeout.
if retrying:
wc_doc = wc.document
wc_doc["w"] = "majority"
wc_doc.setdefault("wtimeout", 10000)
wc = WriteConcern(**wc_doc)
if self._transaction.recovery_token:
cmd['recoveryToken'] = self._transaction.recovery_token
with self._client._socket_for_writes(self) as sock_info:
return self._client.admin._command(
sock_info,
cmd,
session=self,
write_concern=wc,
parse_write_concern_error=True)
def _advance_cluster_time(self, cluster_time):
"""Internal cluster time helper."""
if self._cluster_time is None:
self._cluster_time = cluster_time
elif cluster_time is not None:
if cluster_time["clusterTime"] > self._cluster_time["clusterTime"]:
self._cluster_time = cluster_time
def advance_cluster_time(self, cluster_time):
"""Update the cluster time for this session.
:Parameters:
- `cluster_time`: The
:data:`~pymongo.client_session.ClientSession.cluster_time` from
another `ClientSession` instance.
"""
if not isinstance(cluster_time, abc.Mapping):
raise TypeError(
"cluster_time must be a subclass of collections.Mapping")
if not isinstance(cluster_time.get("clusterTime"), Timestamp):
raise ValueError("Invalid cluster_time")
self._advance_cluster_time(cluster_time)
def _advance_operation_time(self, operation_time):
"""Internal operation time helper."""
if self._operation_time is None:
self._operation_time = operation_time
elif operation_time is not None:
if operation_time > self._operation_time:
self._operation_time = operation_time
def advance_operation_time(self, operation_time):
"""Update the operation time for this session.
:Parameters:
- `operation_time`: The
:data:`~pymongo.client_session.ClientSession.operation_time` from
another `ClientSession` instance.
"""
if not isinstance(operation_time, Timestamp):
raise TypeError("operation_time must be an instance "
"of bson.timestamp.Timestamp")
self._advance_operation_time(operation_time)
def _process_response(self, reply):
"""Process a response to a command that was run with this session."""
self._advance_cluster_time(reply.get('$clusterTime'))
self._advance_operation_time(reply.get('operationTime'))
if self._in_transaction and self._transaction.sharded:
recovery_token = reply.get('recoveryToken')
if recovery_token:
self._transaction.recovery_token = recovery_token
@property
def has_ended(self):
"""True if this session is finished."""
return self._server_session is None
@property
def _in_transaction(self):
"""True if this session has an active multi-statement transaction."""
return self._transaction.active()
@property
def _pinned_address(self):
"""The mongos address this transaction was created on."""
if self._transaction.active():
return self._transaction.pinned_address
return None
def _pin_mongos(self, server):
"""Pin this session to the given mongos Server."""
self._transaction.sharded = True
self._transaction.pinned_address = server.description.address
def _unpin_mongos(self):
"""Unpin this session from any pinned mongos address."""
self._transaction.pinned_address = None
def _txn_read_preference(self):
"""Return read preference of this transaction or None."""
if self._in_transaction:
return self._transaction.opts.read_preference
return None
def _apply_to(self, command, is_retryable, read_preference):
self._check_ended()
self._server_session.last_use = monotonic.time()
command['lsid'] = self._server_session.session_id
if not self._in_transaction:
self._transaction.reset()
if is_retryable:
command['txnNumber'] = self._server_session.transaction_id
return
if self._in_transaction:
if read_preference != ReadPreference.PRIMARY:
raise InvalidOperation(
'read preference in a transaction must be primary, not: '
'%r' % (read_preference,))
if self._transaction.state == _TxnState.STARTING:
# First command begins a new transaction.
self._transaction.state = _TxnState.IN_PROGRESS
command['startTransaction'] = True
if self._transaction.opts.read_concern:
rc = self._transaction.opts.read_concern.document
else:
rc = {}
if (self.options.causal_consistency
and self.operation_time is not None):
rc['afterClusterTime'] = self.operation_time
if rc:
command['readConcern'] = rc
command['txnNumber'] = self._server_session.transaction_id
command['autocommit'] = False
def _start_retryable_write(self):
self._check_ended()
self._server_session.inc_transaction_id()
class _ServerSession(object):
def __init__(self):
# Ensure id is type 4, regardless of CodecOptions.uuid_representation.
self.session_id = {'id': Binary(uuid.uuid4().bytes, 4)}
self.last_use = monotonic.time()
self._transaction_id = 0
self.dirty = False
def mark_dirty(self):
"""Mark this session as dirty.
A server session is marked dirty when a command fails with a network
error. Dirty sessions are later discarded from the server session pool.
"""
self.dirty = True
def timed_out(self, session_timeout_minutes):
idle_seconds = monotonic.time() - self.last_use
# Timed out if we have less than a minute to live.
return idle_seconds > (session_timeout_minutes - 1) * 60
@property
def transaction_id(self):
"""Positive 64-bit integer."""
return Int64(self._transaction_id)
def inc_transaction_id(self):
self._transaction_id += 1
class _ServerSessionPool(collections.deque):
"""Pool of _ServerSession objects.
This class is not thread-safe, access it while holding the Topology lock.
"""
def pop_all(self):
ids = []
while self:
ids.append(self.pop().session_id)
return ids
def get_server_session(self, session_timeout_minutes):
# Although the Driver Sessions Spec says we only clear stale sessions
# in return_server_session, PyMongo can't take a lock when returning
# sessions from a __del__ method (like in Cursor.__die), so it can't
# clear stale sessions there. In case many sessions were returned via
# __del__, check for stale sessions here too.
self._clear_stale(session_timeout_minutes)
# The most recently used sessions are on the left.
while self:
s = self.popleft()
if not s.timed_out(session_timeout_minutes):
return s
return _ServerSession()
def return_server_session(self, server_session, session_timeout_minutes):
self._clear_stale(session_timeout_minutes)
if not server_session.timed_out(session_timeout_minutes):
self.return_server_session_no_lock(server_session)
def return_server_session_no_lock(self, server_session):
if not server_session.dirty:
self.appendleft(server_session)
def _clear_stale(self, session_timeout_minutes):
# Clear stale sessions. The least recently used are on the right.
while self:
if self[-1].timed_out(session_timeout_minutes):
self.pop()
else:
# The remaining sessions also haven't timed out.
break

View File

@@ -0,0 +1,225 @@
# Copyright 2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for working with `collations`_.
.. _collations: http://userguide.icu-project.org/collation/concepts
"""
from pymongo import common
class CollationStrength(object):
"""
An enum that defines values for `strength` on a
:class:`~pymongo.collation.Collation`.
"""
PRIMARY = 1
"""Differentiate base (unadorned) characters."""
SECONDARY = 2
"""Differentiate character accents."""
TERTIARY = 3
"""Differentiate character case."""
QUATERNARY = 4
"""Differentiate words with and without punctuation."""
IDENTICAL = 5
"""Differentiate unicode code point (characters are exactly identical)."""
class CollationAlternate(object):
"""
An enum that defines values for `alternate` on a
:class:`~pymongo.collation.Collation`.
"""
NON_IGNORABLE = 'non-ignorable'
"""Spaces and punctuation are treated as base characters."""
SHIFTED = 'shifted'
"""Spaces and punctuation are *not* considered base characters.
Spaces and punctuation are distinguished regardless when the
:class:`~pymongo.collation.Collation` strength is at least
:data:`~pymongo.collation.CollationStrength.QUATERNARY`.
"""
class CollationMaxVariable(object):
"""
An enum that defines values for `max_variable` on a
:class:`~pymongo.collation.Collation`.
"""
PUNCT = 'punct'
"""Both punctuation and spaces are ignored."""
SPACE = 'space'
"""Spaces alone are ignored."""
class CollationCaseFirst(object):
"""
An enum that defines values for `case_first` on a
:class:`~pymongo.collation.Collation`.
"""
UPPER = 'upper'
"""Sort uppercase characters first."""
LOWER = 'lower'
"""Sort lowercase characters first."""
OFF = 'off'
"""Default for locale or collation strength."""
class Collation(object):
"""Collation
:Parameters:
- `locale`: (string) The locale of the collation. This should be a string
that identifies an `ICU locale ID` exactly. For example, ``en_US`` is
valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB
documentation for a list of supported locales.
- `caseLevel`: (optional) If ``True``, turn on case sensitivity if
`strength` is 1 or 2 (case sensitivity is implied if `strength` is
greater than 2). Defaults to ``False``.
- `caseFirst`: (optional) Specify that either uppercase or lowercase
characters take precedence. Must be one of the following values:
* :data:`~CollationCaseFirst.UPPER`
* :data:`~CollationCaseFirst.LOWER`
* :data:`~CollationCaseFirst.OFF` (the default)
- `strength`: (optional) Specify the comparison strength. This is also
known as the ICU comparison level. This must be one of the following
values:
* :data:`~CollationStrength.PRIMARY`
* :data:`~CollationStrength.SECONDARY`
* :data:`~CollationStrength.TERTIARY` (the default)
* :data:`~CollationStrength.QUATERNARY`
* :data:`~CollationStrength.IDENTICAL`
Each successive level builds upon the previous. For example, a
`strength` of :data:`~CollationStrength.SECONDARY` differentiates
characters based both on the unadorned base character and its accents.
- `numericOrdering`: (optional) If ``True``, order numbers numerically
instead of in collation order (defaults to ``False``).
- `alternate`: (optional) Specify whether spaces and punctuation are
considered base characters. This must be one of the following values:
* :data:`~CollationAlternate.NON_IGNORABLE` (the default)
* :data:`~CollationAlternate.SHIFTED`
- `maxVariable`: (optional) When `alternate` is
:data:`~CollationAlternate.SHIFTED`, this option specifies what
characters may be ignored. This must be one of the following values:
* :data:`~CollationMaxVariable.PUNCT` (the default)
* :data:`~CollationMaxVariable.SPACE`
- `normalization`: (optional) If ``True``, normalizes text into Unicode
NFD. Defaults to ``False``.
- `backwards`: (optional) If ``True``, accents on characters are
considered from the back of the word to the front, as it is done in some
French dictionary ordering traditions. Defaults to ``False``.
- `kwargs`: (optional) Keyword arguments supplying any additional options
to be sent with this Collation object.
.. versionadded: 3.4
"""
__slots__ = ("__document",)
def __init__(self, locale,
caseLevel=None,
caseFirst=None,
strength=None,
numericOrdering=None,
alternate=None,
maxVariable=None,
normalization=None,
backwards=None,
**kwargs):
locale = common.validate_string('locale', locale)
self.__document = {'locale': locale}
if caseLevel is not None:
self.__document['caseLevel'] = common.validate_boolean(
'caseLevel', caseLevel)
if caseFirst is not None:
self.__document['caseFirst'] = common.validate_string(
'caseFirst', caseFirst)
if strength is not None:
self.__document['strength'] = common.validate_integer(
'strength', strength)
if numericOrdering is not None:
self.__document['numericOrdering'] = common.validate_boolean(
'numericOrdering', numericOrdering)
if alternate is not None:
self.__document['alternate'] = common.validate_string(
'alternate', alternate)
if maxVariable is not None:
self.__document['maxVariable'] = common.validate_string(
'maxVariable', maxVariable)
if normalization is not None:
self.__document['normalization'] = common.validate_boolean(
'normalization', normalization)
if backwards is not None:
self.__document['backwards'] = common.validate_boolean(
'backwards', backwards)
self.__document.update(kwargs)
@property
def document(self):
"""The document representation of this collation.
.. note::
:class:`Collation` is immutable. Mutating the value of
:attr:`document` does not mutate this :class:`Collation`.
"""
return self.__document.copy()
def __repr__(self):
document = self.document
return 'Collation(%s)' % (
', '.join('%s=%r' % (key, document[key]) for key in document),)
def __eq__(self, other):
if isinstance(other, Collation):
return self.document == other.document
return NotImplemented
def __ne__(self, other):
return not self == other
def validate_collation_or_none(value):
if value is None:
return None
if isinstance(value, Collation):
return value.document
if isinstance(value, dict):
return value
raise TypeError(
'collation must be a dict, an instance of collation.Collation, '
'or None.')

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,308 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CommandCursor class to iterate over command results."""
from collections import deque
from bson.py3compat import integer_types
from pymongo.errors import (ConnectionFailure,
InvalidOperation,
NotMasterError,
OperationFailure)
from pymongo.message import (_CursorAddress,
_GetMore,
_RawBatchGetMore)
class CommandCursor(object):
"""A cursor / iterator over command cursors."""
_getmore_class = _GetMore
def __init__(self, collection, cursor_info, address, retrieved=0,
batch_size=0, max_await_time_ms=None, session=None,
explicit_session=False):
"""Create a new command cursor.
The parameter 'retrieved' is unused.
"""
self.__collection = collection
self.__id = cursor_info['id']
self.__data = deque(cursor_info['firstBatch'])
self.__postbatchresumetoken = cursor_info.get('postBatchResumeToken')
self.__address = address
self.__batch_size = batch_size
self.__max_await_time_ms = max_await_time_ms
self.__session = session
self.__explicit_session = explicit_session
self.__killed = (self.__id == 0)
if self.__killed:
self.__end_session(True)
if "ns" in cursor_info:
self.__ns = cursor_info["ns"]
else:
self.__ns = collection.full_name
self.batch_size(batch_size)
if (not isinstance(max_await_time_ms, integer_types)
and max_await_time_ms is not None):
raise TypeError("max_await_time_ms must be an integer or None")
def __del__(self):
if self.__id and not self.__killed:
self.__die()
def __die(self, synchronous=False):
"""Closes this cursor.
"""
already_killed = self.__killed
self.__killed = True
if self.__id and not already_killed:
address = _CursorAddress(
self.__address, self.__collection.full_name)
if synchronous:
self.__collection.database.client._close_cursor_now(
self.__id, address, session=self.__session)
else:
# The cursor will be closed later in a different session.
self.__collection.database.client._close_cursor(
self.__id, address)
self.__end_session(synchronous)
def __end_session(self, synchronous):
if self.__session and not self.__explicit_session:
self.__session._end_session(lock=synchronous)
self.__session = None
def close(self):
"""Explicitly close / kill this cursor.
"""
self.__die(True)
def batch_size(self, batch_size):
"""Limits the number of documents returned in one batch. Each batch
requires a round trip to the server. It can be adjusted to optimize
performance and limit data transfer.
.. note:: batch_size can not override MongoDB's internal limits on the
amount of data it will return to the client in a single batch (i.e
if you set batch size to 1,000,000,000, MongoDB will currently only
return 4-16MB of results per batch).
Raises :exc:`TypeError` if `batch_size` is not an integer.
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
:Parameters:
- `batch_size`: The size of each batch of results requested.
"""
if not isinstance(batch_size, integer_types):
raise TypeError("batch_size must be an integer")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
self.__batch_size = batch_size == 1 and 2 or batch_size
return self
def _has_next(self):
"""Returns `True` if the cursor has documents remaining from the
previous batch."""
return len(self.__data) > 0
@property
def _post_batch_resume_token(self):
"""Retrieve the postBatchResumeToken from the response to a
changeStream aggregate or getMore."""
return self.__postbatchresumetoken
def __send_message(self, operation):
"""Send a getmore message and handle the response.
"""
def kill():
self.__killed = True
self.__end_session(True)
client = self.__collection.database.client
try:
response = client._run_operation_with_response(
operation, self._unpack_response, address=self.__address)
except OperationFailure:
kill()
raise
except NotMasterError:
# Don't send kill cursors to another server after a "not master"
# error. It's completely pointless.
kill()
raise
except ConnectionFailure:
# Don't try to send kill cursors on another socket
# or to another server. It can cause a _pinValue
# assertion on some server releases if we get here
# due to a socket timeout.
kill()
raise
except Exception:
# Close the cursor
self.__die()
raise
from_command = response.from_command
reply = response.data
docs = response.docs
if from_command:
cursor = docs[0]['cursor']
documents = cursor['nextBatch']
self.__postbatchresumetoken = cursor.get('postBatchResumeToken')
self.__id = cursor['id']
else:
documents = docs
self.__id = reply.cursor_id
if self.__id == 0:
kill()
self.__data = deque(documents)
def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):
return response.unpack_response(cursor_id, codec_options, user_fields,
legacy_response)
def _refresh(self):
"""Refreshes the cursor with more data from the server.
Returns the length of self.__data after refresh. Will exit early if
self.__data is already non-empty. Raises OperationFailure when the
cursor cannot be refreshed due to an error on the query.
"""
if len(self.__data) or self.__killed:
return len(self.__data)
if self.__id: # Get More
dbname, collname = self.__ns.split('.', 1)
read_pref = self.__collection._read_preference_for(self.session)
self.__send_message(
self._getmore_class(dbname,
collname,
self.__batch_size,
self.__id,
self.__collection.codec_options,
read_pref,
self.__session,
self.__collection.database.client,
self.__max_await_time_ms,
False))
else: # Cursor id is zero nothing else to return
self.__killed = True
self.__end_session(True)
return len(self.__data)
@property
def alive(self):
"""Does this cursor have the potential to return more data?
Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration`. Best to use a for loop::
for doc in collection.aggregate(pipeline):
print(doc)
.. note:: :attr:`alive` can be True while iterating a cursor from
a failed server. In this case :attr:`alive` will return False after
:meth:`next` fails to retrieve the next batch of results from the
server.
"""
return bool(len(self.__data) or (not self.__killed))
@property
def cursor_id(self):
"""Returns the id of the cursor."""
return self.__id
@property
def address(self):
"""The (host, port) of the server used, or None.
.. versionadded:: 3.0
"""
return self.__address
@property
def session(self):
"""The cursor's :class:`~pymongo.client_session.ClientSession`, or None.
.. versionadded:: 3.6
"""
if self.__explicit_session:
return self.__session
def __iter__(self):
return self
def next(self):
"""Advance the cursor."""
# Block until a document is returnable.
while self.alive:
doc = self._try_next(True)
if doc is not None:
return doc
raise StopIteration
__next__ = next
def _try_next(self, get_more_allowed):
"""Advance the cursor blocking for at most one getMore command."""
if not len(self.__data) and not self.__killed and get_more_allowed:
self._refresh()
if len(self.__data):
coll = self.__collection
return coll.database._fix_outgoing(self.__data.popleft(), coll)
else:
return None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class RawBatchCommandCursor(CommandCursor):
_getmore_class = _RawBatchGetMore
def __init__(self, collection, cursor_info, address, retrieved=0,
batch_size=0, max_await_time_ms=None, session=None,
explicit_session=False):
"""Create a new cursor / iterator over raw batches of BSON data.
Should not be called directly by application developers -
see :meth:`~pymongo.collection.Collection.aggregate_raw_batches`
instead.
.. mongodoc:: cursors
"""
assert not cursor_info.get('firstBatch')
super(RawBatchCommandCursor, self).__init__(
collection, cursor_info, address, retrieved, batch_size,
max_await_time_ms, session, explicit_session)
def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):
return response.raw_response(cursor_id)
def __getitem__(self, index):
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")

View File

@@ -0,0 +1,935 @@
# Copyright 2011-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Functions and classes common to multiple pymongo modules."""
import datetime
import warnings
from bson import SON
from bson.binary import (STANDARD, PYTHON_LEGACY,
JAVA_LEGACY, CSHARP_LEGACY)
from bson.codec_options import CodecOptions, TypeRegistry
from bson.py3compat import abc, integer_types, iteritems, string_type
from bson.raw_bson import RawBSONDocument
from pymongo.auth import MECHANISMS
from pymongo.compression_support import (validate_compressors,
validate_zlib_compression_level)
from pymongo.driver_info import DriverInfo
from pymongo.encryption_options import validate_auto_encryption_opts_or_none
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _validate_event_listeners
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _MONGOS_MODES, _ServerMode
from pymongo.ssl_support import (validate_cert_reqs,
validate_allow_invalid_certs)
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
try:
from collections import OrderedDict
ORDERED_TYPES = (SON, OrderedDict)
except ImportError:
ORDERED_TYPES = (SON,)
# Defaults until we connect to a server and get updated limits.
MAX_BSON_SIZE = 16 * (1024 ** 2)
MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE
MIN_WIRE_VERSION = 0
MAX_WIRE_VERSION = 0
MAX_WRITE_BATCH_SIZE = 1000
# What this version of PyMongo supports.
MIN_SUPPORTED_SERVER_VERSION = "2.6"
MIN_SUPPORTED_WIRE_VERSION = 2
MAX_SUPPORTED_WIRE_VERSION = 8
# Frequency to call ismaster on servers, in seconds.
HEARTBEAT_FREQUENCY = 10
# Frequency to process kill-cursors, in seconds. See MongoClient.close_cursor.
KILL_CURSOR_FREQUENCY = 1
# Frequency to process events queue, in seconds.
EVENTS_QUEUE_FREQUENCY = 1
# How long to wait, in seconds, for a suitable server to be found before
# aborting an operation. For example, if the client attempts an insert
# during a replica set election, SERVER_SELECTION_TIMEOUT governs the
# longest it is willing to wait for a new primary to be found.
SERVER_SELECTION_TIMEOUT = 30
# Spec requires at least 500ms between ismaster calls.
MIN_HEARTBEAT_INTERVAL = 0.5
# Spec requires at least 60s between SRV rescans.
MIN_SRV_RESCAN_INTERVAL = 60
# Default connectTimeout in seconds.
CONNECT_TIMEOUT = 20.0
# Default value for maxPoolSize.
MAX_POOL_SIZE = 100
# Default value for minPoolSize.
MIN_POOL_SIZE = 0
# Default value for maxIdleTimeMS.
MAX_IDLE_TIME_MS = None
# Default value for maxIdleTimeMS in seconds.
MAX_IDLE_TIME_SEC = None
# Default value for waitQueueTimeoutMS in seconds.
WAIT_QUEUE_TIMEOUT = None
# Default value for localThresholdMS.
LOCAL_THRESHOLD_MS = 15
# Default value for retryWrites.
RETRY_WRITES = True
# Default value for retryReads.
RETRY_READS = True
# mongod/s 2.6 and above return code 59 when a command doesn't exist.
COMMAND_NOT_FOUND_CODES = (59,)
# Error codes to ignore if GridFS calls createIndex on a secondary
UNAUTHORIZED_CODES = (13, 16547, 16548)
# Maximum number of sessions to send in a single endSessions command.
# From the driver sessions spec.
_MAX_END_SESSIONS = 10000
def partition_node(node):
"""Split a host:port string into (host, int(port)) pair."""
host = node
port = 27017
idx = node.rfind(':')
if idx != -1:
host, port = node[:idx], int(node[idx + 1:])
if host.startswith('['):
host = host[1:-1]
return host, port
def clean_node(node):
"""Split and normalize a node name from an ismaster response."""
host, port = partition_node(node)
# Normalize hostname to lowercase, since DNS is case-insensitive:
# http://tools.ietf.org/html/rfc4343
# This prevents useless rediscovery if "foo.com" is in the seed list but
# "FOO.com" is in the ismaster response.
return host.lower(), port
def raise_config_error(key, dummy):
"""Raise ConfigurationError with the given key name."""
raise ConfigurationError("Unknown option %s" % (key,))
# Mapping of URI uuid representation options to valid subtypes.
_UUID_REPRESENTATIONS = {
'standard': STANDARD,
'pythonLegacy': PYTHON_LEGACY,
'javaLegacy': JAVA_LEGACY,
'csharpLegacy': CSHARP_LEGACY
}
def validate_boolean(option, value):
"""Validates that 'value' is True or False."""
if isinstance(value, bool):
return value
raise TypeError("%s must be True or False" % (option,))
def validate_boolean_or_string(option, value):
"""Validates that value is True, False, 'true', or 'false'."""
if isinstance(value, string_type):
if value not in ('true', 'false'):
raise ValueError("The value of %s must be "
"'true' or 'false'" % (option,))
return value == 'true'
return validate_boolean(option, value)
def validate_integer(option, value):
"""Validates that 'value' is an integer (or basestring representation).
"""
if isinstance(value, integer_types):
return value
elif isinstance(value, string_type):
try:
return int(value)
except ValueError:
raise ValueError("The value of %s must be "
"an integer" % (option,))
raise TypeError("Wrong type for %s, value must be an integer" % (option,))
def validate_positive_integer(option, value):
"""Validate that 'value' is a positive integer, which does not include 0.
"""
val = validate_integer(option, value)
if val <= 0:
raise ValueError("The value of %s must be "
"a positive integer" % (option,))
return val
def validate_non_negative_integer(option, value):
"""Validate that 'value' is a positive integer or 0.
"""
val = validate_integer(option, value)
if val < 0:
raise ValueError("The value of %s must be "
"a non negative integer" % (option,))
return val
def validate_readable(option, value):
"""Validates that 'value' is file-like and readable.
"""
if value is None:
return value
# First make sure its a string py3.3 open(True, 'r') succeeds
# Used in ssl cert checking due to poor ssl module error reporting
value = validate_string(option, value)
open(value, 'r').close()
return value
def validate_positive_integer_or_none(option, value):
"""Validate that 'value' is a positive integer or None.
"""
if value is None:
return value
return validate_positive_integer(option, value)
def validate_non_negative_integer_or_none(option, value):
"""Validate that 'value' is a positive integer or 0 or None.
"""
if value is None:
return value
return validate_non_negative_integer(option, value)
def validate_string(option, value):
"""Validates that 'value' is an instance of `basestring` for Python 2
or `str` for Python 3.
"""
if isinstance(value, string_type):
return value
raise TypeError("Wrong type for %s, value must be "
"an instance of %s" % (option, string_type.__name__))
def validate_string_or_none(option, value):
"""Validates that 'value' is an instance of `basestring` or `None`.
"""
if value is None:
return value
return validate_string(option, value)
def validate_int_or_basestring(option, value):
"""Validates that 'value' is an integer or string.
"""
if isinstance(value, integer_types):
return value
elif isinstance(value, string_type):
try:
return int(value)
except ValueError:
return value
raise TypeError("Wrong type for %s, value must be an "
"integer or a string" % (option,))
def validate_non_negative_int_or_basestring(option, value):
"""Validates that 'value' is an integer or string.
"""
if isinstance(value, integer_types):
return value
elif isinstance(value, string_type):
try:
val = int(value)
except ValueError:
return value
return validate_non_negative_integer(option, val)
raise TypeError("Wrong type for %s, value must be an "
"non negative integer or a string" % (option,))
def validate_positive_float(option, value):
"""Validates that 'value' is a float, or can be converted to one, and is
positive.
"""
errmsg = "%s must be an integer or float" % (option,)
try:
value = float(value)
except ValueError:
raise ValueError(errmsg)
except TypeError:
raise TypeError(errmsg)
# float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at
# one billion - this is a reasonable approximation for infinity
if not 0 < value < 1e9:
raise ValueError("%s must be greater than 0 and "
"less than one billion" % (option,))
return value
def validate_positive_float_or_zero(option, value):
"""Validates that 'value' is 0 or a positive float, or can be converted to
0 or a positive float.
"""
if value == 0 or value == "0":
return 0
return validate_positive_float(option, value)
def validate_timeout_or_none(option, value):
"""Validates a timeout specified in milliseconds returning
a value in floating point seconds.
"""
if value is None:
return value
return validate_positive_float(option, value) / 1000.0
def validate_timeout_or_zero(option, value):
"""Validates a timeout specified in milliseconds returning
a value in floating point seconds for the case where None is an error
and 0 is valid. Setting the timeout to nothing in the URI string is a
config error.
"""
if value is None:
raise ConfigurationError("%s cannot be None" % (option, ))
if value == 0 or value == "0":
return 0
return validate_positive_float(option, value) / 1000.0
def validate_max_staleness(option, value):
"""Validates maxStalenessSeconds according to the Max Staleness Spec."""
if value == -1 or value == "-1":
# Default: No maximum staleness.
return -1
return validate_positive_integer(option, value)
def validate_read_preference(dummy, value):
"""Validate a read preference.
"""
if not isinstance(value, _ServerMode):
raise TypeError("%r is not a read preference." % (value,))
return value
def validate_read_preference_mode(dummy, value):
"""Validate read preference mode for a MongoReplicaSetClient.
.. versionchanged:: 3.5
Returns the original ``value`` instead of the validated read preference
mode.
"""
if value not in _MONGOS_MODES:
raise ValueError("%s is not a valid read preference" % (value,))
return value
def validate_auth_mechanism(option, value):
"""Validate the authMechanism URI option.
"""
# CRAM-MD5 is for server testing only. Undocumented,
# unsupported, may be removed at any time. You have
# been warned.
if value not in MECHANISMS and value != 'CRAM-MD5':
raise ValueError("%s must be in %s" % (option, tuple(MECHANISMS)))
return value
def validate_uuid_representation(dummy, value):
"""Validate the uuid representation option selected in the URI.
"""
try:
return _UUID_REPRESENTATIONS[value]
except KeyError:
raise ValueError("%s is an invalid UUID representation. "
"Must be one of "
"%s" % (value, tuple(_UUID_REPRESENTATIONS)))
def validate_read_preference_tags(name, value):
"""Parse readPreferenceTags if passed as a client kwarg.
"""
if not isinstance(value, list):
value = [value]
tag_sets = []
for tag_set in value:
if tag_set == '':
tag_sets.append({})
continue
try:
tag_sets.append(dict([tag.split(":")
for tag in tag_set.split(",")]))
except Exception:
raise ValueError("%r not a valid "
"value for %s" % (tag_set, name))
return tag_sets
_MECHANISM_PROPS = frozenset(['SERVICE_NAME',
'CANONICALIZE_HOST_NAME',
'SERVICE_REALM'])
def validate_auth_mechanism_properties(option, value):
"""Validate authMechanismProperties."""
value = validate_string(option, value)
props = {}
for opt in value.split(','):
try:
key, val = opt.split(':')
except ValueError:
raise ValueError("auth mechanism properties must be "
"key:value pairs like SERVICE_NAME:"
"mongodb, not %s." % (opt,))
if key not in _MECHANISM_PROPS:
raise ValueError("%s is not a supported auth "
"mechanism property. Must be one of "
"%s." % (key, tuple(_MECHANISM_PROPS)))
if key == 'CANONICALIZE_HOST_NAME':
props[key] = validate_boolean_or_string(key, val)
else:
props[key] = val
return props
def validate_document_class(option, value):
"""Validate the document_class option."""
if not issubclass(value, (abc.MutableMapping, RawBSONDocument)):
raise TypeError("%s must be dict, bson.son.SON, "
"bson.raw_bson.RawBSONDocument, or a "
"sublass of collections.MutableMapping" % (option,))
return value
def validate_type_registry(option, value):
"""Validate the type_registry option."""
if value is not None and not isinstance(value, TypeRegistry):
raise TypeError("%s must be an instance of %s" % (
option, TypeRegistry))
return value
def validate_list(option, value):
"""Validates that 'value' is a list."""
if not isinstance(value, list):
raise TypeError("%s must be a list" % (option,))
return value
def validate_list_or_none(option, value):
"""Validates that 'value' is a list or None."""
if value is None:
return value
return validate_list(option, value)
def validate_list_or_mapping(option, value):
"""Validates that 'value' is a list or a document."""
if not isinstance(value, (abc.Mapping, list)):
raise TypeError("%s must either be a list or an instance of dict, "
"bson.son.SON, or any other type that inherits from "
"collections.Mapping" % (option,))
def validate_is_mapping(option, value):
"""Validate the type of method arguments that expect a document."""
if not isinstance(value, abc.Mapping):
raise TypeError("%s must be an instance of dict, bson.son.SON, or "
"any other type that inherits from "
"collections.Mapping" % (option,))
def validate_is_document_type(option, value):
"""Validate the type of method arguments that expect a MongoDB document."""
if not isinstance(value, (abc.MutableMapping, RawBSONDocument)):
raise TypeError("%s must be an instance of dict, bson.son.SON, "
"bson.raw_bson.RawBSONDocument, or "
"a type that inherits from "
"collections.MutableMapping" % (option,))
def validate_appname_or_none(option, value):
"""Validate the appname option."""
if value is None:
return value
validate_string(option, value)
# We need length in bytes, so encode utf8 first.
if len(value.encode('utf-8')) > 128:
raise ValueError("%s must be <= 128 bytes" % (option,))
return value
def validate_driver_or_none(option, value):
"""Validate the driver keyword arg."""
if value is None:
return value
if not isinstance(value, DriverInfo):
raise TypeError("%s must be an instance of DriverInfo" % (option,))
return value
def validate_is_callable_or_none(option, value):
"""Validates that 'value' is a callable."""
if value is None:
return value
if not callable(value):
raise ValueError("%s must be a callable" % (option,))
return value
def validate_ok_for_replace(replacement):
"""Validate a replacement document."""
validate_is_mapping("replacement", replacement)
# Replacement can be {}
if replacement and not isinstance(replacement, RawBSONDocument):
first = next(iter(replacement))
if first.startswith('$'):
raise ValueError('replacement can not include $ operators')
def validate_ok_for_update(update):
"""Validate an update document."""
validate_list_or_mapping("update", update)
# Update cannot be {}.
if not update:
raise ValueError('update cannot be empty')
is_document = not isinstance(update, list)
first = next(iter(update))
if is_document and not first.startswith('$'):
raise ValueError('update only works with $ operators')
_UNICODE_DECODE_ERROR_HANDLERS = frozenset(['strict', 'replace', 'ignore'])
def validate_unicode_decode_error_handler(dummy, value):
"""Validate the Unicode decode error handler option of CodecOptions.
"""
if value not in _UNICODE_DECODE_ERROR_HANDLERS:
raise ValueError("%s is an invalid Unicode decode error handler. "
"Must be one of "
"%s" % (value, tuple(_UNICODE_DECODE_ERROR_HANDLERS)))
return value
def validate_tzinfo(dummy, value):
"""Validate the tzinfo option
"""
if value is not None and not isinstance(value, datetime.tzinfo):
raise TypeError("%s must be an instance of datetime.tzinfo" % value)
return value
# Dictionary where keys are the names of public URI options, and values
# are lists of aliases for that option. Aliases of option names are assumed
# to have been deprecated.
URI_OPTIONS_ALIAS_MAP = {
'journal': ['j'],
'wtimeoutms': ['wtimeout'],
'tls': ['ssl'],
'tlsallowinvalidcertificates': ['ssl_cert_reqs'],
'tlsallowinvalidhostnames': ['ssl_match_hostname'],
'tlscrlfile': ['ssl_crlfile'],
'tlscafile': ['ssl_ca_certs'],
'tlscertificatekeyfile': ['ssl_certfile'],
'tlscertificatekeyfilepassword': ['ssl_pem_passphrase'],
}
# Dictionary where keys are the names of URI options, and values
# are functions that validate user-input values for that option. If an option
# alias uses a different validator than its public counterpart, it should be
# included here as a key, value pair.
URI_OPTIONS_VALIDATOR_MAP = {
'appname': validate_appname_or_none,
'authmechanism': validate_auth_mechanism,
'authmechanismproperties': validate_auth_mechanism_properties,
'authsource': validate_string,
'compressors': validate_compressors,
'connecttimeoutms': validate_timeout_or_none,
'heartbeatfrequencyms': validate_timeout_or_none,
'journal': validate_boolean_or_string,
'localthresholdms': validate_positive_float_or_zero,
'maxidletimems': validate_timeout_or_none,
'maxpoolsize': validate_positive_integer_or_none,
'maxstalenessseconds': validate_max_staleness,
'readconcernlevel': validate_string_or_none,
'readpreference': validate_read_preference_mode,
'readpreferencetags': validate_read_preference_tags,
'replicaset': validate_string_or_none,
'retryreads': validate_boolean_or_string,
'retrywrites': validate_boolean_or_string,
'serverselectiontimeoutms': validate_timeout_or_zero,
'sockettimeoutms': validate_timeout_or_none,
'ssl_keyfile': validate_readable,
'tls': validate_boolean_or_string,
'tlsallowinvalidcertificates': validate_allow_invalid_certs,
'ssl_cert_reqs': validate_cert_reqs,
'tlsallowinvalidhostnames': lambda *x: not validate_boolean_or_string(*x),
'ssl_match_hostname': validate_boolean_or_string,
'tlscafile': validate_readable,
'tlscertificatekeyfile': validate_readable,
'tlscertificatekeyfilepassword': validate_string_or_none,
'tlsinsecure': validate_boolean_or_string,
'w': validate_non_negative_int_or_basestring,
'wtimeoutms': validate_non_negative_integer,
'zlibcompressionlevel': validate_zlib_compression_level,
}
# Dictionary where keys are the names of URI options specific to pymongo,
# and values are functions that validate user-input values for those options.
NONSPEC_OPTIONS_VALIDATOR_MAP = {
'connect': validate_boolean_or_string,
'driver': validate_driver_or_none,
'fsync': validate_boolean_or_string,
'minpoolsize': validate_non_negative_integer,
'socketkeepalive': validate_boolean_or_string,
'tlscrlfile': validate_readable,
'tz_aware': validate_boolean_or_string,
'unicode_decode_error_handler': validate_unicode_decode_error_handler,
'uuidrepresentation': validate_uuid_representation,
'waitqueuemultiple': validate_non_negative_integer_or_none,
'waitqueuetimeoutms': validate_timeout_or_none,
}
# Dictionary where keys are the names of keyword-only options for the
# MongoClient constructor, and values are functions that validate user-input
# values for those options.
KW_VALIDATORS = {
'document_class': validate_document_class,
'type_registry': validate_type_registry,
'read_preference': validate_read_preference,
'event_listeners': _validate_event_listeners,
'tzinfo': validate_tzinfo,
'username': validate_string_or_none,
'password': validate_string_or_none,
'server_selector': validate_is_callable_or_none,
'auto_encryption_opts': validate_auto_encryption_opts_or_none,
}
# Dictionary where keys are any URI option name, and values are the
# internally-used names of that URI option. Options with only one name
# variant need not be included here. Options whose public and internal
# names are the same need not be included here.
INTERNAL_URI_OPTION_NAME_MAP = {
'j': 'journal',
'wtimeout': 'wtimeoutms',
'tls': 'ssl',
'tlsallowinvalidcertificates': 'ssl_cert_reqs',
'tlsallowinvalidhostnames': 'ssl_match_hostname',
'tlscrlfile': 'ssl_crlfile',
'tlscafile': 'ssl_ca_certs',
'tlscertificatekeyfile': 'ssl_certfile',
'tlscertificatekeyfilepassword': 'ssl_pem_passphrase',
}
# Map from deprecated URI option names to a tuple indicating the method of
# their deprecation and any additional information that may be needed to
# construct the warning message.
URI_OPTIONS_DEPRECATION_MAP = {
# format: <deprecated option name>: (<mode>, <message>),
# Supported <mode> values:
# - 'renamed': <message> should be the new option name. Note that case is
# preserved for renamed options as they are part of user warnings.
# - 'removed': <message> may suggest the rationale for deprecating the
# option and/or recommend remedial action.
'j': ('renamed', 'journal'),
'wtimeout': ('renamed', 'wTimeoutMS'),
'ssl_cert_reqs': ('renamed', 'tlsAllowInvalidCertificates'),
'ssl_match_hostname': ('renamed', 'tlsAllowInvalidHostnames'),
'ssl_crlfile': ('renamed', 'tlsCRLFile'),
'ssl_ca_certs': ('renamed', 'tlsCAFile'),
'ssl_pem_passphrase': ('renamed', 'tlsCertificateKeyFilePassword'),
'waitqueuemultiple': ('removed', (
'Instead of using waitQueueMultiple to bound queuing, limit the size '
'of the thread pool in your application server'))
}
# Augment the option validator map with pymongo-specific option information.
URI_OPTIONS_VALIDATOR_MAP.update(NONSPEC_OPTIONS_VALIDATOR_MAP)
for optname, aliases in iteritems(URI_OPTIONS_ALIAS_MAP):
for alias in aliases:
if alias not in URI_OPTIONS_VALIDATOR_MAP:
URI_OPTIONS_VALIDATOR_MAP[alias] = (
URI_OPTIONS_VALIDATOR_MAP[optname])
# Map containing all URI option and keyword argument validators.
VALIDATORS = URI_OPTIONS_VALIDATOR_MAP.copy()
VALIDATORS.update(KW_VALIDATORS)
# List of timeout-related options.
TIMEOUT_OPTIONS = [
'connecttimeoutms',
'heartbeatfrequencyms',
'maxidletimems',
'maxstalenessseconds',
'serverselectiontimeoutms',
'sockettimeoutms',
'waitqueuetimeoutms',
]
_AUTH_OPTIONS = frozenset(['authmechanismproperties'])
def validate_auth_option(option, value):
"""Validate optional authentication parameters.
"""
lower, value = validate(option, value)
if lower not in _AUTH_OPTIONS:
raise ConfigurationError('Unknown '
'authentication option: %s' % (option,))
return lower, value
def validate(option, value):
"""Generic validation function.
"""
lower = option.lower()
validator = VALIDATORS.get(lower, raise_config_error)
value = validator(option, value)
return lower, value
def get_validated_options(options, warn=True):
"""Validate each entry in options and raise a warning if it is not valid.
Returns a copy of options with invalid entries removed.
:Parameters:
- `opts`: A dict containing MongoDB URI options.
- `warn` (optional): If ``True`` then warnings will be logged and
invalid options will be ignored. Otherwise, invalid options will
cause errors.
"""
if isinstance(options, _CaseInsensitiveDictionary):
validated_options = _CaseInsensitiveDictionary()
get_normed_key = lambda x: x
get_setter_key = lambda x: options.cased_key(x)
else:
validated_options = {}
get_normed_key = lambda x: x.lower()
get_setter_key = lambda x: x
for opt, value in iteritems(options):
normed_key = get_normed_key(opt)
try:
validator = URI_OPTIONS_VALIDATOR_MAP.get(
normed_key, raise_config_error)
value = validator(opt, value)
except (ValueError, TypeError, ConfigurationError) as exc:
if warn:
warnings.warn(str(exc))
else:
raise
else:
validated_options[get_setter_key(normed_key)] = value
return validated_options
# List of write-concern-related options.
WRITE_CONCERN_OPTIONS = frozenset([
'w',
'wtimeout',
'wtimeoutms',
'fsync',
'j',
'journal'
])
class BaseObject(object):
"""A base class that provides attributes and methods common
to multiple pymongo classes.
SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO MONGODB.
"""
def __init__(self, codec_options, read_preference, write_concern,
read_concern):
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of "
"bson.codec_options.CodecOptions")
self.__codec_options = codec_options
if not isinstance(read_preference, _ServerMode):
raise TypeError("%r is not valid for read_preference. See "
"pymongo.read_preferences for valid "
"options." % (read_preference,))
self.__read_preference = read_preference
if not isinstance(write_concern, WriteConcern):
raise TypeError("write_concern must be an instance of "
"pymongo.write_concern.WriteConcern")
self.__write_concern = write_concern
if not isinstance(read_concern, ReadConcern):
raise TypeError("read_concern must be an instance of "
"pymongo.read_concern.ReadConcern")
self.__read_concern = read_concern
@property
def codec_options(self):
"""Read only access to the :class:`~bson.codec_options.CodecOptions`
of this instance.
"""
return self.__codec_options
@property
def write_concern(self):
"""Read only access to the :class:`~pymongo.write_concern.WriteConcern`
of this instance.
.. versionchanged:: 3.0
The :attr:`write_concern` attribute is now read only.
"""
return self.__write_concern
def _write_concern_for(self, session):
"""Read only access to the write concern of this instance or session.
"""
# Override this operation's write concern with the transaction's.
if session and session._in_transaction:
return DEFAULT_WRITE_CONCERN
return self.write_concern
@property
def read_preference(self):
"""Read only access to the read preference of this instance.
.. versionchanged:: 3.0
The :attr:`read_preference` attribute is now read only.
"""
return self.__read_preference
def _read_preference_for(self, session):
"""Read only access to the read preference of this instance or session.
"""
# Override this operation's read preference with the transaction's.
if session:
return session._txn_read_preference() or self.__read_preference
return self.__read_preference
@property
def read_concern(self):
"""Read only access to the :class:`~pymongo.read_concern.ReadConcern`
of this instance.
.. versionadded:: 3.2
"""
return self.__read_concern
class _CaseInsensitiveDictionary(abc.MutableMapping):
def __init__(self, *args, **kwargs):
self.__casedkeys = {}
self.__data = {}
self.update(dict(*args, **kwargs))
def __contains__(self, key):
return key.lower() in self.__data
def __len__(self):
return len(self.__data)
def __iter__(self):
return (key for key in self.__casedkeys)
def __repr__(self):
return str({self.__casedkeys[k]: self.__data[k] for k in self})
def __setitem__(self, key, value):
lc_key = key.lower()
self.__casedkeys[lc_key] = key
self.__data[lc_key] = value
def __getitem__(self, key):
return self.__data[key.lower()]
def __delitem__(self, key):
lc_key = key.lower()
del self.__casedkeys[lc_key]
del self.__data[lc_key]
def __eq__(self, other):
if not isinstance(other, abc.Mapping):
return NotImplemented
if len(self) != len(other):
return False
for key in other:
if self[key] != other[key]:
return False
return True
def get(self, key, default=None):
return self.__data.get(key.lower(), default)
def pop(self, key, *args, **kwargs):
lc_key = key.lower()
self.__casedkeys.pop(lc_key, None)
return self.__data.pop(lc_key, *args, **kwargs)
def popitem(self):
lc_key, cased_key = self.__casedkeys.popitem()
value = self.__data.pop(lc_key)
return cased_key, value
def clear(self):
self.__casedkeys.clear()
self.__data.clear()
def setdefault(self, key, default=None):
lc_key = key.lower()
if key in self:
return self.__data[lc_key]
else:
self.__casedkeys[lc_key] = key
self.__data[lc_key] = default
return default
def update(self, other):
if isinstance(other, _CaseInsensitiveDictionary):
for key in other:
self[other.cased_key(key)] = other[key]
else:
for key in other:
self[key] = other[key]
def cased_key(self, key):
return self.__casedkeys[key.lower()]

View File

@@ -0,0 +1,157 @@
# Copyright 2018 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
try:
import snappy
_HAVE_SNAPPY = True
except ImportError:
# python-snappy isn't available.
_HAVE_SNAPPY = False
try:
import zlib
_HAVE_ZLIB = True
except ImportError:
# Python built without zlib support.
_HAVE_ZLIB = False
try:
from zstandard import ZstdCompressor, ZstdDecompressor
_HAVE_ZSTD = True
except ImportError:
_HAVE_ZSTD = False
from pymongo.monitoring import _SENSITIVE_COMMANDS
_SUPPORTED_COMPRESSORS = set(["snappy", "zlib", "zstd"])
_NO_COMPRESSION = set(['ismaster'])
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
def validate_compressors(dummy, value):
try:
# `value` is string.
compressors = value.split(",")
except AttributeError:
# `value` is an iterable.
compressors = list(value)
for compressor in compressors[:]:
if compressor not in _SUPPORTED_COMPRESSORS:
compressors.remove(compressor)
warnings.warn("Unsupported compressor: %s" % (compressor,))
elif compressor == "snappy" and not _HAVE_SNAPPY:
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with snappy is not available. "
"You must install the python-snappy module for snappy support.")
elif compressor == "zlib" and not _HAVE_ZLIB:
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zlib is not available. "
"The zlib module is not available.")
elif compressor == "zstd" and not _HAVE_ZSTD:
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zstandard is not available. "
"You must install the zstandard module for zstandard support.")
return compressors
def validate_zlib_compression_level(option, value):
try:
level = int(value)
except:
raise TypeError("%s must be an integer, not %r." % (option, value))
if level < -1 or level > 9:
raise ValueError(
"%s must be between -1 and 9, not %d." % (option, level))
return level
class CompressionSettings(object):
def __init__(self, compressors, zlib_compression_level):
self.compressors = compressors
self.zlib_compression_level = zlib_compression_level
def get_compression_context(self, compressors):
if compressors:
chosen = compressors[0]
if chosen == "snappy":
return SnappyContext()
elif chosen == "zlib":
return ZlibContext(self.zlib_compression_level)
elif chosen == "zstd":
return ZstdContext()
def _zlib_no_compress(data):
"""Compress data with zlib level 0."""
cobj = zlib.compressobj(0)
return b"".join([cobj.compress(data), cobj.flush()])
class SnappyContext(object):
compressor_id = 1
@staticmethod
def compress(data):
return snappy.compress(data)
class ZlibContext(object):
compressor_id = 2
def __init__(self, level):
# Jython zlib.compress doesn't support -1
if level == -1:
self.compress = zlib.compress
# Jython zlib.compress also doesn't support 0
elif level == 0:
self.compress = _zlib_no_compress
else:
self.compress = lambda data: zlib.compress(data, level)
class ZstdContext(object):
compressor_id = 3
@staticmethod
def compress(data):
# ZstdCompressor is not thread safe.
# TODO: Use a pool?
return ZstdCompressor().compress(data)
def decompress(data, compressor_id):
if compressor_id == SnappyContext.compressor_id:
# python-snappy doesn't support the buffer interface.
# https://github.com/andrix/python-snappy/issues/65
# This only matters when data is a memoryview since
# id(bytes(data)) == id(data) when data is a bytes.
# NOTE: bytes(memoryview) returns the memoryview repr
# in Python 2.7. The right thing to do in 2.7 is call
# memoryview.tobytes(), but we currently only use
# memoryview in Python 3.x.
return snappy.uncompress(bytes(data))
elif compressor_id == ZlibContext.compressor_id:
return zlib.decompress(data)
elif compressor_id == ZstdContext.compressor_id:
# ZstdDecompressor is not thread safe.
# TODO: Use a pool?
return ZstdDecompressor().decompress(data)
else:
raise ValueError("Unknown compressorId %d" % (compressor_id,))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,65 @@
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DEPRECATED - A manager to handle when cursors are killed after they are
closed.
New cursor managers should be defined as subclasses of CursorManager and can be
installed on a client by calling
:meth:`~pymongo.mongo_client.MongoClient.set_cursor_manager`.
.. versionchanged:: 3.3
Deprecated, for real this time.
.. versionchanged:: 3.0
Undeprecated. :meth:`~pymongo.cursor_manager.CursorManager.close` now
requires an `address` argument. The ``BatchCursorManager`` class is removed.
"""
import warnings
import weakref
from bson.py3compat import integer_types
class CursorManager(object):
"""DEPRECATED - The cursor manager base class."""
def __init__(self, client):
"""Instantiate the manager.
:Parameters:
- `client`: a MongoClient
"""
warnings.warn(
"Cursor managers are deprecated.",
DeprecationWarning,
stacklevel=2)
self.__client = weakref.ref(client)
def close(self, cursor_id, address):
"""Kill a cursor.
Raises TypeError if cursor_id is not an instance of (int, long).
:Parameters:
- `cursor_id`: cursor id to close
- `address`: the cursor's server's (host, port) pair
.. versionchanged:: 3.0
Now requires an `address` argument.
"""
if not isinstance(cursor_id, integer_types):
raise TypeError("cursor_id must be an integer")
self.__client().kill_cursors([cursor_id], address)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,39 @@
# Copyright 2018-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Advanced options for MongoDB drivers implemented on top of PyMongo."""
from collections import namedtuple
from bson.py3compat import string_type
class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])):
"""Info about a driver wrapping PyMongo.
The MongoDB server logs PyMongo's name, version, and platform whenever
PyMongo establishes a connection. A driver implemented on top of PyMongo
can add its own info to this log message. Initialize with three strings
like 'MyDriver', '1.2.3', 'some platform info'. Any of these strings may be
None to accept PyMongo's default.
"""
def __new__(cls, name=None, version=None, platform=None):
self = super(DriverInfo, cls).__new__(cls, name, version, platform)
for name, value in self._asdict().items():
if value is not None and not isinstance(value, string_type):
raise TypeError("Wrong type for DriverInfo %s option, value "
"must be an instance of %s" % (
name, string_type.__name__))
return self

View File

@@ -0,0 +1,522 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for explicit client side encryption.
**Support for client side encryption is in beta. Backwards-breaking changes
may be made before the final release.**
"""
import contextlib
import subprocess
import uuid
import weakref
try:
from pymongocrypt.auto_encrypter import AutoEncrypter
from pymongocrypt.errors import MongoCryptError
from pymongocrypt.explicit_encrypter import ExplicitEncrypter
from pymongocrypt.mongocrypt import MongoCryptOptions
from pymongocrypt.state_machine import MongoCryptCallback
_HAVE_PYMONGOCRYPT = True
except ImportError:
_HAVE_PYMONGOCRYPT = False
MongoCryptCallback = object
from bson import _bson_to_dict, _dict_to_bson, decode, encode
from bson.codec_options import CodecOptions
from bson.binary import (Binary,
STANDARD,
UUID_SUBTYPE)
from bson.errors import BSONError
from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS,
RawBSONDocument,
_inflate_bson)
from bson.son import SON
from pymongo.errors import (ConfigurationError,
EncryptionError,
InvalidOperation,
ServerSelectionTimeoutError)
from pymongo.message import (_COMMAND_OVERHEAD,
_MAX_ENC_BSON_SIZE,
_raise_document_too_large)
from pymongo.mongo_client import MongoClient
from pymongo.pool import _configured_socket, PoolOptions
from pymongo.read_concern import ReadConcern
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern
_HTTPS_PORT = 443
_KMS_CONNECT_TIMEOUT = 10 # TODO: CDRIVER-3262 will define this value.
_MONGOCRYPTD_TIMEOUT_MS = 1000
_DATA_KEY_OPTS = CodecOptions(document_class=SON, uuid_representation=STANDARD)
# Use RawBSONDocument codec options to avoid needlessly decoding
# documents from the key vault.
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument,
uuid_representation=STANDARD)
@contextlib.contextmanager
def _wrap_encryption_errors():
"""Context manager to wrap encryption related errors."""
try:
yield
except BSONError:
# BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged.
raise
except Exception as exc:
raise EncryptionError(exc)
class _EncryptionIO(MongoCryptCallback):
def __init__(self, client, key_vault_coll, mongocryptd_client, opts):
"""Internal class to perform I/O on behalf of pymongocrypt."""
# Use a weak ref to break reference cycle.
if client is not None:
self.client_ref = weakref.ref(client)
else:
self.client_ref = None
self.key_vault_coll = key_vault_coll.with_options(
codec_options=_KEY_VAULT_OPTS,
read_concern=ReadConcern(level='majority'),
write_concern=WriteConcern(w='majority'))
self.mongocryptd_client = mongocryptd_client
self.opts = opts
self._spawned = False
def kms_request(self, kms_context):
"""Complete a KMS request.
:Parameters:
- `kms_context`: A :class:`MongoCryptKmsContext`.
:Returns:
None
"""
endpoint = kms_context.endpoint
message = kms_context.message
ctx = get_ssl_context(None, None, None, None, None, None, True)
opts = PoolOptions(connect_timeout=_KMS_CONNECT_TIMEOUT,
socket_timeout=_KMS_CONNECT_TIMEOUT,
ssl_context=ctx)
conn = _configured_socket((endpoint, _HTTPS_PORT), opts)
try:
conn.sendall(message)
while kms_context.bytes_needed > 0:
data = conn.recv(kms_context.bytes_needed)
kms_context.feed(data)
finally:
conn.close()
def collection_info(self, database, filter):
"""Get the collection info for a namespace.
The returned collection info is passed to libmongocrypt which reads
the JSON schema.
:Parameters:
- `database`: The database on which to run listCollections.
- `filter`: The filter to pass to listCollections.
:Returns:
The first document from the listCollections command response as BSON.
"""
with self.client_ref()[database].list_collections(
filter=RawBSONDocument(filter)) as cursor:
for doc in cursor:
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
def spawn(self):
"""Spawn mongocryptd.
Note this method is thread safe; at most one mongocryptd will start
successfully.
"""
self._spawned = True
args = [self.opts._mongocryptd_spawn_path or 'mongocryptd']
args.extend(self.opts._mongocryptd_spawn_args)
subprocess.Popen(args)
def mark_command(self, database, cmd):
"""Mark a command for encryption.
:Parameters:
- `database`: The database on which to run this command.
- `cmd`: The BSON command to run.
:Returns:
The marked command response from mongocryptd.
"""
if not self._spawned and not self.opts._mongocryptd_bypass_spawn:
self.spawn()
# Database.command only supports mutable mappings so we need to decode
# the raw BSON command first.
inflated_cmd = _inflate_bson(cmd, DEFAULT_RAW_BSON_OPTIONS)
try:
res = self.mongocryptd_client[database].command(
inflated_cmd,
codec_options=DEFAULT_RAW_BSON_OPTIONS)
except ServerSelectionTimeoutError:
if self.opts._mongocryptd_bypass_spawn:
raise
self.spawn()
res = self.mongocryptd_client[database].command(
inflated_cmd,
codec_options=DEFAULT_RAW_BSON_OPTIONS)
return res.raw
def fetch_keys(self, filter):
"""Yields one or more keys from the key vault.
:Parameters:
- `filter`: The filter to pass to find.
:Returns:
A generator which yields the requested keys from the key vault.
"""
with self.key_vault_coll.find(RawBSONDocument(filter)) as cursor:
for key in cursor:
yield key.raw
def insert_data_key(self, data_key):
"""Insert a data key into the key vault.
:Parameters:
- `data_key`: The data key document to insert.
:Returns:
The _id of the inserted data key document.
"""
# insert does not return the inserted _id when given a RawBSONDocument.
doc = _bson_to_dict(data_key, _DATA_KEY_OPTS)
if not isinstance(doc.get('_id'), uuid.UUID):
raise TypeError(
'data_key _id must be a bson.binary.Binary with subtype 4')
res = self.key_vault_coll.insert_one(doc)
return Binary(res.inserted_id.bytes, subtype=UUID_SUBTYPE)
def bson_encode(self, doc):
"""Encode a document to BSON.
A document can be any mapping type (like :class:`dict`).
:Parameters:
- `doc`: mapping type representing a document
:Returns:
The encoded BSON bytes.
"""
return encode(doc)
def close(self):
"""Release resources.
Note it is not safe to call this method from __del__ or any GC hooks.
"""
self.client_ref = None
self.key_vault_coll = None
if self.mongocryptd_client:
self.mongocryptd_client.close()
self.mongocryptd_client = None
class _Encrypter(object):
def __init__(self, io_callbacks, opts):
"""Encrypts and decrypts MongoDB commands.
This class is used to support automatic encryption and decryption of
MongoDB commands.
:Parameters:
- `io_callbacks`: A :class:`MongoCryptCallback`.
- `opts`: The encrypted client's :class:`AutoEncryptionOpts`.
"""
if opts._schema_map is None:
schema_map = None
else:
schema_map = _dict_to_bson(opts._schema_map, False, _DATA_KEY_OPTS)
self._auto_encrypter = AutoEncrypter(io_callbacks, MongoCryptOptions(
opts._kms_providers, schema_map))
self._bypass_auto_encryption = opts._bypass_auto_encryption
self._closed = False
def encrypt(self, database, cmd, check_keys, codec_options):
"""Encrypt a MongoDB command.
:Parameters:
- `database`: The database for this command.
- `cmd`: A command document.
- `check_keys`: If True, check `cmd` for invalid keys.
- `codec_options`: The CodecOptions to use while encoding `cmd`.
:Returns:
The encrypted command to execute.
"""
self._check_closed()
# Workaround for $clusterTime which is incompatible with
# check_keys.
cluster_time = check_keys and cmd.pop('$clusterTime', None)
encoded_cmd = _dict_to_bson(cmd, check_keys, codec_options)
max_cmd_size = _MAX_ENC_BSON_SIZE + _COMMAND_OVERHEAD
if len(encoded_cmd) > max_cmd_size:
raise _raise_document_too_large(
next(iter(cmd)), len(encoded_cmd), max_cmd_size)
with _wrap_encryption_errors():
encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd)
# TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
encrypt_cmd = _inflate_bson(
encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
if cluster_time:
encrypt_cmd['$clusterTime'] = cluster_time
return encrypt_cmd
def decrypt(self, response):
"""Decrypt a MongoDB command response.
:Parameters:
- `response`: A MongoDB command response as BSON.
:Returns:
The decrypted command response.
"""
self._check_closed()
with _wrap_encryption_errors():
return self._auto_encrypter.decrypt(response)
def _check_closed(self):
if self._closed:
raise InvalidOperation("Cannot use MongoClient after close")
def close(self):
"""Cleanup resources."""
self._closed = True
self._auto_encrypter.close()
@staticmethod
def create(client, opts):
"""Create a _CommandEncyptor for a client.
:Parameters:
- `client`: The encrypted MongoClient.
- `opts`: The encrypted client's :class:`AutoEncryptionOpts`.
:Returns:
A :class:`_CommandEncrypter` for this client.
"""
key_vault_client = opts._key_vault_client or client
db, coll = opts._key_vault_namespace.split('.', 1)
key_vault_coll = key_vault_client[db][coll]
mongocryptd_client = MongoClient(
opts._mongocryptd_uri, connect=False,
serverSelectionTimeoutMS=_MONGOCRYPTD_TIMEOUT_MS)
io_callbacks = _EncryptionIO(
client, key_vault_coll, mongocryptd_client, opts)
return _Encrypter(io_callbacks, opts)
class Algorithm(object):
"""An enum that defines the supported encryption algorithms."""
AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic = (
"AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic")
AEAD_AES_256_CBC_HMAC_SHA_512_Random = (
"AEAD_AES_256_CBC_HMAC_SHA_512-Random")
class ClientEncryption(object):
"""Explicit client side encryption."""
def __init__(self, kms_providers, key_vault_namespace, key_vault_client,
codec_options):
"""Explicit client side encryption.
The ClientEncryption class encapsulates explicit operations on a key
vault collection that cannot be done directly on a MongoClient. Similar
to configuring auto encryption on a MongoClient, it is constructed with
a MongoClient (to a MongoDB cluster containing the key vault
collection), KMS provider configuration, and keyVaultNamespace. It
provides an API for explicitly encrypting and decrypting values, and
creating data keys. It does not provide an API to query keys from the
key vault collection, as this can be done directly on the MongoClient.
.. note:: Support for client side encryption is in beta.
Backwards-breaking changes may be made before the final release.
:Parameters:
- `kms_providers`: Map of KMS provider options. Two KMS providers
are supported: "aws" and "local". The kmsProviders map values
differ by provider:
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
These are the AWS access key ID and AWS secret access key used
to generate KMS messages.
- `local`: Map with "key" as a 96-byte array or string. "key"
is the master key used to encrypt/decrypt data keys. This key
should be generated and stored as securely as possible.
- `key_vault_namespace`: The namespace for the key vault collection.
The key vault collection contains all data keys used for encryption
and decryption. Data keys are stored as documents in this MongoDB
collection. Data keys are protected with encryption by a KMS
provider.
- `key_vault_client`: A MongoClient connected to a MongoDB cluster
containing the `key_vault_namespace` collection.
- `codec_options`: An instance of
:class:`~bson.codec_options.CodecOptions` to use when encoding a
value for encryption and decoding the decrypted BSON value.
.. versionadded:: 3.9
"""
if not _HAVE_PYMONGOCRYPT:
raise ConfigurationError(
"client side encryption requires the pymongocrypt library: "
"install a compatible version with: "
"python -m pip install pymongo['encryption']")
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of "
"bson.codec_options.CodecOptions")
self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
self._codec_options = codec_options
db, coll = key_vault_namespace.split('.', 1)
key_vault_coll = key_vault_client[db][coll]
self._io_callbacks = _EncryptionIO(None, key_vault_coll, None, None)
self._encryption = ExplicitEncrypter(
self._io_callbacks, MongoCryptOptions(kms_providers, None))
def create_data_key(self, kms_provider, master_key=None,
key_alt_names=None):
"""Create and insert a new data key into the key vault collection.
:Parameters:
- `kms_provider`: The KMS provider to use. Supported values are
"aws" and "local".
- `master_key`: The `master_key` identifies a KMS-specific key used
to encrypt the new data key. If the kmsProvider is "local" the
`master_key` is not applicable and may be omitted.
If the `kms_provider` is "aws", `master_key` is required and must
have the following fields:
- `region` (string): The AWS region as a string.
- `key` (string): The Amazon Resource Name (ARN) to the AWS
customer master key (CMK).
- `key_alt_names` (optional): An optional list of string alternate
names used to reference a key. If a key is created with alternate
names, then encryption may refer to the key by the unique alternate
name instead of by ``key_id``. The following example shows creating
and referring to a data key by alternate name::
client_encryption.create_data_key("local", keyAltNames=["name1"])
# reference the key with the alternate name
client_encryption.encrypt("457-55-5462", keyAltName="name1",
algorithm=Algorithm.Random)
:Returns:
The ``_id`` of the created data key document.
"""
self._check_closed()
with _wrap_encryption_errors():
return self._encryption.create_data_key(
kms_provider, master_key=master_key,
key_alt_names=key_alt_names)
def encrypt(self, value, algorithm, key_id=None, key_alt_name=None):
"""Encrypt a BSON value with a given key and algorithm.
Note that exactly one of ``key_id`` or ``key_alt_name`` must be
provided.
:Parameters:
- `value`: The BSON value to encrypt.
- `algorithm` (string): The encryption algorithm to use. See
:class:`Algorithm` for some valid options.
- `key_id`: Identifies a data key by ``_id`` which must be a
:class:`~bson.binary.Binary` with subtype 4 (
:attr:`~bson.binary.UUID_SUBTYPE`).
- `key_alt_name`: Identifies a key vault document by 'keyAltName'.
:Returns:
The encrypted value, a :class:`~bson.binary.Binary` with subtype 6.
"""
self._check_closed()
if (key_id is not None and not (
isinstance(key_id, Binary) and
key_id.subtype == UUID_SUBTYPE)):
raise TypeError(
'key_id must be a bson.binary.Binary with subtype 4')
doc = encode({'v': value}, codec_options=self._codec_options)
with _wrap_encryption_errors():
encrypted_doc = self._encryption.encrypt(
doc, algorithm, key_id=key_id, key_alt_name=key_alt_name)
return decode(encrypted_doc)['v']
def decrypt(self, value):
"""Decrypt an encrypted value.
:Parameters:
- `value` (Binary): The encrypted value, a
:class:`~bson.binary.Binary` with subtype 6.
:Returns:
The decrypted BSON value.
"""
self._check_closed()
if not (isinstance(value, Binary) and value.subtype == 6):
raise TypeError(
'value to decrypt must be a bson.binary.Binary with subtype 6')
with _wrap_encryption_errors():
doc = encode({'v': value})
decrypted_doc = self._encryption.decrypt(doc)
return decode(decrypted_doc,
codec_options=self._codec_options)['v']
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def _check_closed(self):
if self._encryption is None:
raise InvalidOperation("Cannot use closed ClientEncryption")
def close(self):
"""Release resources.
Note that using this class in a with-statement will automatically call
:meth:`close`::
with ClientEncryption(...) as client_encryption:
encrypted = client_encryption.encrypt(value, ...)
decrypted = client_encryption.decrypt(encrypted)
"""
if self._io_callbacks:
self._io_callbacks.close()
self._encryption.close()
self._io_callbacks = None
self._encryption = None

View File

@@ -0,0 +1,145 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for automatic client side encryption.
**Support for client side encryption is in beta. Backwards-breaking changes
may be made before the final release.**
"""
import copy
try:
import pymongocrypt
_HAVE_PYMONGOCRYPT = True
except ImportError:
_HAVE_PYMONGOCRYPT = False
from pymongo.errors import ConfigurationError
class AutoEncryptionOpts(object):
"""Options to configure automatic encryption."""
def __init__(self, kms_providers, key_vault_namespace,
key_vault_client=None, schema_map=None,
bypass_auto_encryption=False,
mongocryptd_uri='mongodb://localhost:27020',
mongocryptd_bypass_spawn=False,
mongocryptd_spawn_path='mongocryptd',
mongocryptd_spawn_args=None):
"""Options to configure automatic encryption.
Automatic encryption is an enterprise only feature that only
applies to operations on a collection. Automatic encryption is not
supported for operations on a database or view and will result in
error. To bypass automatic encryption (but enable automatic
decryption), set ``bypass_auto_encryption=True`` in
AutoEncryptionOpts.
Explicit encryption/decryption and automatic decryption is a
community feature. A MongoClient configured with
bypassAutoEncryption=true will still automatically decrypt.
.. note:: Support for client side encryption is in beta.
Backwards-breaking changes may be made before the final release.
:Parameters:
- `kms_providers`: Map of KMS provider options. Two KMS providers
are supported: "aws" and "local". The kmsProviders map values
differ by provider:
- `aws`: Map with "accessKeyId" and "secretAccessKey" as strings.
These are the AWS access key ID and AWS secret access key used
to generate KMS messages.
- `local`: Map with "key" as a 96-byte array or string. "key"
is the master key used to encrypt/decrypt data keys. This key
should be generated and stored as securely as possible.
- `key_vault_namespace`: The namespace for the key vault collection.
The key vault collection contains all data keys used for encryption
and decryption. Data keys are stored as documents in this MongoDB
collection. Data keys are protected with encryption by a KMS
provider.
- `key_vault_client` (optional): By default the key vault collection
is assumed to reside in the same MongoDB cluster as the encrypted
MongoClient. Use this option to route data key queries to a
separate MongoDB cluster.
- `schema_map` (optional): Map of collection namespace ("db.coll") to
JSON Schema. By default, a collection's JSONSchema is periodically
polled with the listCollections command. But a JSONSchema may be
specified locally with the schemaMap option.
**Supplying a `schema_map` provides more security than relying on
JSON Schemas obtained from the server. It protects against a
malicious server advertising a false JSON Schema, which could trick
the client into sending unencrypted data that should be
encrypted.**
Schemas supplied in the schemaMap only apply to configuring
automatic encryption for client side encryption. Other validation
rules in the JSON schema will not be enforced by the driver and
will result in an error.
- `bypass_auto_encryption` (optional): If ``True``, automatic
encryption will be disabled but automatic decryption will still be
enabled. Defaults to ``False``.
- `mongocryptd_uri` (optional): The MongoDB URI used to connect
to the *local* mongocryptd process. Defaults to
``'mongodb://localhost:27020'``.
- `mongocryptd_bypass_spawn` (optional): If ``True``, the encrypted
MongoClient will not attempt to spawn the mongocryptd process.
Defaults to ``False``.
- `mongocryptd_spawn_path` (optional): Used for spawning the
mongocryptd process. Defaults to ``'mongocryptd'`` and spawns
mongocryptd from the system path.
- `mongocryptd_spawn_args` (optional): A list of string arguments to
use when spawning the mongocryptd process. Defaults to
``['--idleShutdownTimeoutSecs=60']``. If the list does not include
the ``idleShutdownTimeoutSecs`` option then
``'--idleShutdownTimeoutSecs=60'`` will be added.
.. versionadded:: 3.9
"""
if not _HAVE_PYMONGOCRYPT:
raise ConfigurationError(
"client side encryption requires the pymongocrypt library: "
"install a compatible version with: "
"python -m pip install pymongo['encryption']")
self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
self._schema_map = schema_map
self._bypass_auto_encryption = bypass_auto_encryption
self._mongocryptd_uri = mongocryptd_uri
self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn
self._mongocryptd_spawn_path = mongocryptd_spawn_path
self._mongocryptd_spawn_args = (copy.copy(mongocryptd_spawn_args) or
['--idleShutdownTimeoutSecs=60'])
if not isinstance(self._mongocryptd_spawn_args, list):
raise TypeError('mongocryptd_spawn_args must be a list')
if not any('idleShutdownTimeoutSecs' in s
for s in self._mongocryptd_spawn_args):
self._mongocryptd_spawn_args.append('--idleShutdownTimeoutSecs=60')
def validate_auto_encryption_opts_or_none(option, value):
"""Validate the driver keyword arg."""
if value is None:
return value
if not isinstance(value, AutoEncryptionOpts):
raise TypeError("%s must be an instance of AutoEncryptionOpts" % (
option,))
return value

View File

@@ -0,0 +1,268 @@
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Exceptions raised by PyMongo."""
import sys
from bson.errors import *
try:
from ssl import CertificateError
except ImportError:
from pymongo.ssl_match_hostname import CertificateError
class PyMongoError(Exception):
"""Base class for all PyMongo exceptions."""
def __init__(self, message='', error_labels=None):
super(PyMongoError, self).__init__(message)
self._message = message
self._error_labels = set(error_labels or [])
def has_error_label(self, label):
"""Return True if this error contains the given label.
.. versionadded:: 3.7
"""
return label in self._error_labels
def _add_error_label(self, label):
"""Add the given label to this error."""
self._error_labels.add(label)
def _remove_error_label(self, label):
"""Remove the given label from this error."""
self._error_labels.remove(label)
def __str__(self):
if sys.version_info[0] == 2 and isinstance(self._message, unicode):
return self._message.encode('utf-8', errors='replace')
return str(self._message)
class ProtocolError(PyMongoError):
"""Raised for failures related to the wire protocol."""
class ConnectionFailure(PyMongoError):
"""Raised when a connection to the database cannot be made or is lost."""
def __init__(self, message='', error_labels=None):
if error_labels is None:
# Connection errors are transient errors by default.
error_labels = ("TransientTransactionError",)
super(ConnectionFailure, self).__init__(
message, error_labels=error_labels)
class AutoReconnect(ConnectionFailure):
"""Raised when a connection to the database is lost and an attempt to
auto-reconnect will be made.
In order to auto-reconnect you must handle this exception, recognizing that
the operation which caused it has not necessarily succeeded. Future
operations will attempt to open a new connection to the database (and
will continue to raise this exception until the first successful
connection is made).
Subclass of :exc:`~pymongo.errors.ConnectionFailure`.
"""
def __init__(self, message='', errors=None):
super(AutoReconnect, self).__init__(message)
self.errors = self.details = errors or []
class NetworkTimeout(AutoReconnect):
"""An operation on an open connection exceeded socketTimeoutMS.
The remaining connections in the pool stay open. In the case of a write
operation, you cannot know whether it succeeded or failed.
Subclass of :exc:`~pymongo.errors.AutoReconnect`.
"""
class NotMasterError(AutoReconnect):
"""The server responded "not master" or "node is recovering".
These errors result from a query, write, or command. The operation failed
because the client thought it was using the primary but the primary has
stepped down, or the client thought it was using a healthy secondary but
the secondary is stale and trying to recover.
The client launches a refresh operation on a background thread, to update
its view of the server as soon as possible after throwing this exception.
Subclass of :exc:`~pymongo.errors.AutoReconnect`.
"""
class ServerSelectionTimeoutError(AutoReconnect):
"""Thrown when no MongoDB server is available for an operation
If there is no suitable server for an operation PyMongo tries for
``serverSelectionTimeoutMS`` (default 30 seconds) to find one, then
throws this exception. For example, it is thrown after attempting an
operation when PyMongo cannot connect to any server, or if you attempt
an insert into a replica set that has no primary and does not elect one
within the timeout window, or if you attempt to query with a Read
Preference that the replica set cannot satisfy.
"""
class ConfigurationError(PyMongoError):
"""Raised when something is incorrectly configured.
"""
class OperationFailure(PyMongoError):
"""Raised when a database operation fails.
.. versionadded:: 2.7
The :attr:`details` attribute.
"""
def __init__(self, error, code=None, details=None):
error_labels = None
if details is not None:
error_labels = details.get('errorLabels')
super(OperationFailure, self).__init__(
error, error_labels=error_labels)
self.__code = code
self.__details = details
@property
def code(self):
"""The error code returned by the server, if any.
"""
return self.__code
@property
def details(self):
"""The complete error document returned by the server.
Depending on the error that occurred, the error document
may include useful information beyond just the error
message. When connected to a mongos the error document
may contain one or more subdocuments if errors occurred
on multiple shards.
"""
return self.__details
class CursorNotFound(OperationFailure):
"""Raised while iterating query results if the cursor is
invalidated on the server.
.. versionadded:: 2.7
"""
class ExecutionTimeout(OperationFailure):
"""Raised when a database operation times out, exceeding the $maxTimeMS
set in the query or command option.
.. note:: Requires server version **>= 2.6.0**
.. versionadded:: 2.7
"""
class WriteConcernError(OperationFailure):
"""Base exception type for errors raised due to write concern.
.. versionadded:: 3.0
"""
class WriteError(OperationFailure):
"""Base exception type for errors raised during write operations.
.. versionadded:: 3.0
"""
class WTimeoutError(WriteConcernError):
"""Raised when a database operation times out (i.e. wtimeout expires)
before replication completes.
With newer versions of MongoDB the `details` attribute may include
write concern fields like 'n', 'updatedExisting', or 'writtenTo'.
.. versionadded:: 2.7
"""
class DuplicateKeyError(WriteError):
"""Raised when an insert or update fails due to a duplicate key error."""
class BulkWriteError(OperationFailure):
"""Exception class for bulk write errors.
.. versionadded:: 2.7
"""
def __init__(self, results):
super(BulkWriteError, self).__init__(
"batch op errors occurred", 65, results)
class InvalidOperation(PyMongoError):
"""Raised when a client attempts to perform an invalid operation."""
class InvalidName(PyMongoError):
"""Raised when an invalid name is used."""
class CollectionInvalid(PyMongoError):
"""Raised when collection validation fails."""
class InvalidURI(ConfigurationError):
"""Raised when trying to parse an invalid mongodb URI."""
class ExceededMaxWaiters(PyMongoError):
"""Raised when a thread tries to get a connection from a pool and
``maxPoolSize * waitQueueMultiple`` threads are already waiting.
.. versionadded:: 2.6
"""
pass
class DocumentTooLarge(InvalidDocument):
"""Raised when an encoded document is too large for the connected server.
"""
pass
class EncryptionError(PyMongoError):
"""Raised when encryption or decryption fails.
This error always wraps another exception which can be retrieved via the
:attr:`cause` property.
.. versionadded:: 3.9
"""
def __init__(self, cause):
super(EncryptionError, self).__init__(str(cause))
self.__cause = cause
@property
def cause(self):
"""The exception that caused this encryption or decryption error."""
return self.__cause

View File

@@ -0,0 +1,276 @@
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bits and pieces used by the driver that don't really fit elsewhere."""
import sys
import traceback
from bson.py3compat import abc, iteritems, itervalues, string_type
from bson.son import SON
from pymongo import ASCENDING
from pymongo.errors import (CursorNotFound,
DuplicateKeyError,
ExecutionTimeout,
NotMasterError,
OperationFailure,
WriteError,
WriteConcernError,
WTimeoutError)
# From the SDAM spec, the "node is shutting down" codes.
_SHUTDOWN_CODES = frozenset([
11600, # InterruptedAtShutdown
91, # ShutdownInProgress
])
# From the SDAM spec, the "not master" error codes are combined with the
# "node is recovering" error codes (of which the "node is shutting down"
# errors are a subset).
_NOT_MASTER_CODES = frozenset([
10107, # NotMaster
13435, # NotMasterNoSlaveOk
11602, # InterruptedDueToReplStateChange
13436, # NotMasterOrSecondary
189, # PrimarySteppedDown
]) | _SHUTDOWN_CODES
# From the retryable writes spec.
_RETRYABLE_ERROR_CODES = _NOT_MASTER_CODES | frozenset([
7, # HostNotFound
6, # HostUnreachable
89, # NetworkTimeout
9001, # SocketException
])
_UUNDER = u"_"
def _gen_index_name(keys):
"""Generate an index name from the set of fields it is over."""
return _UUNDER.join(["%s_%s" % item for item in keys])
def _index_list(key_or_list, direction=None):
"""Helper to generate a list of (key, direction) pairs.
Takes such a list, or a single key, or a single key and direction.
"""
if direction is not None:
return [(key_or_list, direction)]
else:
if isinstance(key_or_list, string_type):
return [(key_or_list, ASCENDING)]
elif not isinstance(key_or_list, (list, tuple)):
raise TypeError("if no direction is specified, "
"key_or_list must be an instance of list")
return key_or_list
def _index_document(index_list):
"""Helper to generate an index specifying document.
Takes a list of (key, direction) pairs.
"""
if isinstance(index_list, abc.Mapping):
raise TypeError("passing a dict to sort/create_index/hint is not "
"allowed - use a list of tuples instead. did you "
"mean %r?" % list(iteritems(index_list)))
elif not isinstance(index_list, (list, tuple)):
raise TypeError("must use a list of (key, direction) pairs, "
"not: " + repr(index_list))
if not len(index_list):
raise ValueError("key_or_list must not be the empty list")
index = SON()
for (key, value) in index_list:
if not isinstance(key, string_type):
raise TypeError("first item in each key pair must be a string")
if not isinstance(value, (string_type, int, abc.Mapping)):
raise TypeError("second item in each key pair must be 1, -1, "
"'2d', 'geoHaystack', or another valid MongoDB "
"index specifier.")
index[key] = value
return index
def _check_command_response(response, msg=None, allowable_errors=None,
parse_write_concern_error=False):
"""Check the response to a command for errors.
"""
if "ok" not in response:
# Server didn't recognize our message as a command.
raise OperationFailure(response.get("$err"),
response.get("code"),
response)
if parse_write_concern_error and 'writeConcernError' in response:
_raise_write_concern_error(response['writeConcernError'])
if not response["ok"]:
details = response
# Mongos returns the error details in a 'raw' object
# for some errors.
if "raw" in response:
for shard in itervalues(response["raw"]):
# Grab the first non-empty raw error from a shard.
if shard.get("errmsg") and not shard.get("ok"):
details = shard
break
errmsg = details["errmsg"]
if allowable_errors is None or errmsg not in allowable_errors:
code = details.get("code")
# Server is "not master" or "recovering"
if code in _NOT_MASTER_CODES:
raise NotMasterError(errmsg, response)
elif ("not master" in errmsg
or "node is recovering" in errmsg):
raise NotMasterError(errmsg, response)
# Server assertion failures
if errmsg == "db assertion failure":
errmsg = ("db assertion failure, assertion: '%s'" %
details.get("assertion", ""))
raise OperationFailure(errmsg,
details.get("assertionCode"),
response)
# Other errors
# findAndModify with upsert can raise duplicate key error
if code in (11000, 11001, 12582):
raise DuplicateKeyError(errmsg, code, response)
elif code == 50:
raise ExecutionTimeout(errmsg, code, response)
elif code == 43:
raise CursorNotFound(errmsg, code, response)
msg = msg or "%s"
raise OperationFailure(msg % errmsg, code, response)
def _check_gle_response(result):
"""Return getlasterror response as a dict, or raise OperationFailure."""
# Did getlasterror itself fail?
_check_command_response(result)
if result.get("wtimeout", False):
# MongoDB versions before 1.8.0 return the error message in an "errmsg"
# field. If "errmsg" exists "err" will also exist set to None, so we
# have to check for "errmsg" first.
raise WTimeoutError(result.get("errmsg", result.get("err")),
result.get("code"),
result)
error_msg = result.get("err", "")
if error_msg is None:
return result
if error_msg.startswith("not master"):
raise NotMasterError(error_msg, result)
details = result
# mongos returns the error code in an error object for some errors.
if "errObjects" in result:
for errobj in result["errObjects"]:
if errobj.get("err") == error_msg:
details = errobj
break
code = details.get("code")
if code in (11000, 11001, 12582):
raise DuplicateKeyError(details["err"], code, result)
raise OperationFailure(details["err"], code, result)
def _raise_last_write_error(write_errors):
# If the last batch had multiple errors only report
# the last error to emulate continue_on_error.
error = write_errors[-1]
if error.get("code") == 11000:
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
raise WriteError(error.get("errmsg"), error.get("code"), error)
def _raise_write_concern_error(error):
if "errInfo" in error and error["errInfo"].get('wtimeout'):
# Make sure we raise WTimeoutError
raise WTimeoutError(
error.get("errmsg"), error.get("code"), error)
raise WriteConcernError(
error.get("errmsg"), error.get("code"), error)
def _check_write_command_response(result):
"""Backward compatibility helper for write command error handling.
"""
# Prefer write errors over write concern errors
write_errors = result.get("writeErrors")
if write_errors:
_raise_last_write_error(write_errors)
error = result.get("writeConcernError")
if error:
_raise_write_concern_error(error)
def _raise_last_error(bulk_write_result):
"""Backward compatibility helper for insert error handling.
"""
# Prefer write errors over write concern errors
write_errors = bulk_write_result.get("writeErrors")
if write_errors:
_raise_last_write_error(write_errors)
_raise_write_concern_error(bulk_write_result["writeConcernErrors"][-1])
def _fields_list_to_dict(fields, option_name):
"""Takes a sequence of field names and returns a matching dictionary.
["a", "b"] becomes {"a": 1, "b": 1}
and
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
"""
if isinstance(fields, abc.Mapping):
return fields
if isinstance(fields, (abc.Sequence, abc.Set)):
if not all(isinstance(field, string_type) for field in fields):
raise TypeError("%s must be a list of key names, each an "
"instance of %s" % (option_name,
string_type.__name__))
return dict.fromkeys(fields, 1)
raise TypeError("%s must be a mapping or "
"list of key names" % (option_name,))
def _handle_exception():
"""Print exceptions raised by subscribers to stderr."""
# Heavily influenced by logging.Handler.handleError.
# See note here:
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
if sys.stderr:
einfo = sys.exc_info()
try:
traceback.print_exception(einfo[0], einfo[1], einfo[2],
None, sys.stderr)
except IOError:
pass
finally:
del einfo

View File

@@ -0,0 +1,158 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parse a response to the 'ismaster' command."""
import itertools
from bson.py3compat import imap
from pymongo import common
from pymongo.server_type import SERVER_TYPE
def _get_server_type(doc):
"""Determine the server type from an ismaster response."""
if not doc.get('ok'):
return SERVER_TYPE.Unknown
if doc.get('isreplicaset'):
return SERVER_TYPE.RSGhost
elif doc.get('setName'):
if doc.get('hidden'):
return SERVER_TYPE.RSOther
elif doc.get('ismaster'):
return SERVER_TYPE.RSPrimary
elif doc.get('secondary'):
return SERVER_TYPE.RSSecondary
elif doc.get('arbiterOnly'):
return SERVER_TYPE.RSArbiter
else:
return SERVER_TYPE.RSOther
elif doc.get('msg') == 'isdbgrid':
return SERVER_TYPE.Mongos
else:
return SERVER_TYPE.Standalone
class IsMaster(object):
__slots__ = ('_doc', '_server_type', '_is_writable', '_is_readable')
def __init__(self, doc):
"""Parse an ismaster response from the server."""
self._server_type = _get_server_type(doc)
self._doc = doc
self._is_writable = self._server_type in (
SERVER_TYPE.RSPrimary,
SERVER_TYPE.Standalone,
SERVER_TYPE.Mongos)
self._is_readable = (
self.server_type == SERVER_TYPE.RSSecondary
or self._is_writable)
@property
def document(self):
"""The complete ismaster command response document.
.. versionadded:: 3.4
"""
return self._doc.copy()
@property
def server_type(self):
return self._server_type
@property
def all_hosts(self):
"""List of hosts, passives, and arbiters known to this server."""
return set(imap(common.clean_node, itertools.chain(
self._doc.get('hosts', []),
self._doc.get('passives', []),
self._doc.get('arbiters', []))))
@property
def tags(self):
"""Replica set member tags or empty dict."""
return self._doc.get('tags', {})
@property
def primary(self):
"""This server's opinion about who the primary is, or None."""
if self._doc.get('primary'):
return common.partition_node(self._doc['primary'])
else:
return None
@property
def replica_set_name(self):
"""Replica set name or None."""
return self._doc.get('setName')
@property
def max_bson_size(self):
return self._doc.get('maxBsonObjectSize', common.MAX_BSON_SIZE)
@property
def max_message_size(self):
return self._doc.get('maxMessageSizeBytes', 2 * self.max_bson_size)
@property
def max_write_batch_size(self):
return self._doc.get('maxWriteBatchSize', common.MAX_WRITE_BATCH_SIZE)
@property
def min_wire_version(self):
return self._doc.get('minWireVersion', common.MIN_WIRE_VERSION)
@property
def max_wire_version(self):
return self._doc.get('maxWireVersion', common.MAX_WIRE_VERSION)
@property
def set_version(self):
return self._doc.get('setVersion')
@property
def election_id(self):
return self._doc.get('electionId')
@property
def cluster_time(self):
return self._doc.get('$clusterTime')
@property
def logical_session_timeout_minutes(self):
return self._doc.get('logicalSessionTimeoutMinutes')
@property
def is_writable(self):
return self._is_writable
@property
def is_readable(self):
return self._is_readable
@property
def me(self):
me = self._doc.get('me')
if me:
return common.clean_node(me)
@property
def last_write_date(self):
return self._doc.get('lastWrite', {}).get('lastWriteDate')
@property
def compressors(self):
return self._doc.get('compression')

View File

@@ -0,0 +1,116 @@
# Copyright 2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Criteria to select ServerDescriptions based on maxStalenessSeconds.
The Max Staleness Spec says: When there is a known primary P,
a secondary S's staleness is estimated with this formula:
(S.lastUpdateTime - S.lastWriteDate) - (P.lastUpdateTime - P.lastWriteDate)
+ heartbeatFrequencyMS
When there is no known primary, a secondary S's staleness is estimated with:
SMax.lastWriteDate - S.lastWriteDate + heartbeatFrequencyMS
where "SMax" is the secondary with the greatest lastWriteDate.
"""
from pymongo.errors import ConfigurationError
from pymongo.server_type import SERVER_TYPE
# Constant defined in Max Staleness Spec: An idle primary writes a no-op every
# 10 seconds to refresh secondaries' lastWriteDate values.
IDLE_WRITE_PERIOD = 10
SMALLEST_MAX_STALENESS = 90
def _validate_max_staleness(max_staleness,
heartbeat_frequency):
# We checked for max staleness -1 before this, it must be positive here.
if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD:
raise ConfigurationError(
"maxStalenessSeconds must be at least heartbeatFrequencyMS +"
" %d seconds. maxStalenessSeconds is set to %d,"
" heartbeatFrequencyMS is set to %d." % (
IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000))
if max_staleness < SMALLEST_MAX_STALENESS:
raise ConfigurationError(
"maxStalenessSeconds must be at least %d. "
"maxStalenessSeconds is set to %d." % (
SMALLEST_MAX_STALENESS, max_staleness))
def _with_primary(max_staleness, selection):
"""Apply max_staleness, in seconds, to a Selection with a known primary."""
primary = selection.primary
sds = []
for s in selection.server_descriptions:
if s.server_type == SERVER_TYPE.RSSecondary:
# See max-staleness.rst for explanation of this formula.
staleness = (
(s.last_update_time - s.last_write_date) -
(primary.last_update_time - primary.last_write_date) +
selection.heartbeat_frequency)
if staleness <= max_staleness:
sds.append(s)
else:
sds.append(s)
return selection.with_server_descriptions(sds)
def _no_primary(max_staleness, selection):
"""Apply max_staleness, in seconds, to a Selection with no known primary."""
# Secondary that's replicated the most recent writes.
smax = selection.secondary_with_max_last_write_date()
if not smax:
# No secondaries and no primary, short-circuit out of here.
return selection.with_server_descriptions([])
sds = []
for s in selection.server_descriptions:
if s.server_type == SERVER_TYPE.RSSecondary:
# See max-staleness.rst for explanation of this formula.
staleness = (smax.last_write_date -
s.last_write_date +
selection.heartbeat_frequency)
if staleness <= max_staleness:
sds.append(s)
else:
sds.append(s)
return selection.with_server_descriptions(sds)
def select(max_staleness, selection):
"""Apply max_staleness, in seconds, to a Selection."""
if max_staleness == -1:
return selection
# Server Selection Spec: If the TopologyType is ReplicaSetWithPrimary or
# ReplicaSetNoPrimary, a client MUST raise an error if maxStaleness <
# heartbeatFrequency + IDLE_WRITE_PERIOD, or if maxStaleness < 90.
_validate_max_staleness(max_staleness, selection.heartbeat_frequency)
if selection.primary:
return _with_primary(max_staleness, selection)
else:
return _no_primary(max_staleness, selection)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,48 @@
# Copyright 2011-2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Deprecated. See :doc:`/examples/high_availability`."""
import warnings
from pymongo import mongo_client
class MongoReplicaSetClient(mongo_client.MongoClient):
"""Deprecated alias for :class:`~pymongo.mongo_client.MongoClient`.
:class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient`
will be removed in a future version of PyMongo.
.. versionchanged:: 3.0
:class:`~pymongo.mongo_client.MongoClient` is now the one and only
client class for a standalone server, mongos, or replica set.
It includes the functionality that had been split into
:class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient`: it
can connect to a replica set, discover all its members, and monitor
the set for stepdowns, elections, and reconfigs.
The ``refresh`` method is removed from
:class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient`,
as are the ``seeds`` and ``hosts`` properties.
"""
def __init__(self, *args, **kwargs):
warnings.warn('MongoReplicaSetClient is deprecated, use MongoClient'
' to connect to a replica set',
DeprecationWarning, stacklevel=2)
super(MongoReplicaSetClient, self).__init__(*args, **kwargs)
def __repr__(self):
return "MongoReplicaSetClient(%s)" % (self._repr_helper(),)

View File

@@ -0,0 +1,261 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Class to monitor a MongoDB server on a background thread."""
import weakref
from pymongo import common, periodic_executor
from pymongo.errors import OperationFailure
from pymongo.monotonic import time as _time
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.server_type import SERVER_TYPE
from pymongo.srv_resolver import _SrvResolver
class MonitorBase(object):
def __init__(self, *args, **kwargs):
"""Override this method to create an executor."""
raise NotImplementedError
def open(self):
"""Start monitoring, or restart after a fork.
Multiple calls have no effect.
"""
self._executor.open()
def close(self):
"""Close and stop monitoring.
open() restarts the monitor after closing.
"""
self._executor.close()
def join(self, timeout=None):
"""Wait for the monitor to stop."""
self._executor.join(timeout)
def request_check(self):
"""If the monitor is sleeping, wake it soon."""
self._executor.wake()
class Monitor(MonitorBase):
def __init__(
self,
server_description,
topology,
pool,
topology_settings):
"""Class to monitor a MongoDB server on a background thread.
Pass an initial ServerDescription, a Topology, a Pool, and
TopologySettings.
The Topology is weakly referenced. The Pool must be exclusive to this
Monitor.
"""
self._server_description = server_description
self._pool = pool
self._settings = topology_settings
self._avg_round_trip_time = MovingAverage()
self._listeners = self._settings._pool_options.event_listeners
pub = self._listeners is not None
self._publish = pub and self._listeners.enabled_for_server_heartbeat
# We strongly reference the executor and it weakly references us via
# this closure. When the monitor is freed, stop the executor soon.
def target():
monitor = self_ref()
if monitor is None:
return False # Stop the executor.
Monitor._run(monitor)
return True
executor = periodic_executor.PeriodicExecutor(
interval=self._settings.heartbeat_frequency,
min_interval=common.MIN_HEARTBEAT_INTERVAL,
target=target,
name="pymongo_server_monitor_thread")
self._executor = executor
# Avoid cycles. When self or topology is freed, stop executor soon.
self_ref = weakref.ref(self, executor.close)
self._topology = weakref.proxy(topology, executor.close)
def close(self):
super(Monitor, self).close()
# Increment the pool_id and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
self._pool.reset()
def _run(self):
try:
self._server_description = self._check_with_retry()
self._topology.on_change(self._server_description)
except ReferenceError:
# Topology was garbage-collected.
self.close()
def _check_with_retry(self):
"""Call ismaster once or twice. Reset server's pool on error.
Returns a ServerDescription.
"""
# According to the spec, if an ismaster call fails we reset the
# server's pool. If a server was once connected, change its type
# to Unknown only after retrying once.
address = self._server_description.address
retry = True
if self._server_description.server_type == SERVER_TYPE.Unknown:
retry = False
start = _time()
try:
return self._check_once()
except ReferenceError:
raise
except Exception as error:
error_time = _time() - start
if self._publish:
self._listeners.publish_server_heartbeat_failed(
address, error_time, error)
self._topology.reset_pool(address)
default = ServerDescription(address, error=error)
if not retry:
self._avg_round_trip_time.reset()
# Server type defaults to Unknown.
return default
# Try a second and final time. If it fails return original error.
# Always send metadata: this is a new connection.
start = _time()
try:
return self._check_once()
except ReferenceError:
raise
except Exception as error:
error_time = _time() - start
if self._publish:
self._listeners.publish_server_heartbeat_failed(
address, error_time, error)
self._avg_round_trip_time.reset()
return default
def _check_once(self):
"""A single attempt to call ismaster.
Returns a ServerDescription, or raises an exception.
"""
address = self._server_description.address
if self._publish:
self._listeners.publish_server_heartbeat_started(address)
with self._pool.get_socket({}) as sock_info:
response, round_trip_time = self._check_with_socket(sock_info)
self._avg_round_trip_time.add_sample(round_trip_time)
sd = ServerDescription(
address=address,
ismaster=response,
round_trip_time=self._avg_round_trip_time.get())
if self._publish:
self._listeners.publish_server_heartbeat_succeeded(
address, round_trip_time, response)
return sd
def _check_with_socket(self, sock_info):
"""Return (IsMaster, round_trip_time).
Can raise ConnectionFailure or OperationFailure.
"""
start = _time()
try:
return (sock_info.ismaster(self._pool.opts.metadata,
self._topology.max_cluster_time()),
_time() - start)
except OperationFailure as exc:
# Update max cluster time even when isMaster fails.
self._topology.receive_cluster_time(
exc.details.get('$clusterTime'))
raise
class SrvMonitor(MonitorBase):
def __init__(self, topology, topology_settings):
"""Class to poll SRV records on a background thread.
Pass a Topology and a TopologySettings.
The Topology is weakly referenced.
"""
self._settings = topology_settings
self._seedlist = self._settings._seeds
self._fqdn = self._settings.fqdn
# We strongly reference the executor and it weakly references us via
# this closure. When the monitor is freed, stop the executor soon.
def target():
monitor = self_ref()
if monitor is None:
return False # Stop the executor.
SrvMonitor._run(monitor)
return True
executor = periodic_executor.PeriodicExecutor(
interval=common.MIN_SRV_RESCAN_INTERVAL,
min_interval=self._settings.heartbeat_frequency,
target=target,
name="pymongo_srv_polling_thread")
self._executor = executor
# Avoid cycles. When self or topology is freed, stop executor soon.
self_ref = weakref.ref(self, executor.close)
self._topology = weakref.proxy(topology, executor.close)
def _run(self):
seedlist = self._get_seedlist()
if seedlist:
self._seedlist = seedlist
try:
self._topology.on_srv_update(self._seedlist)
except ReferenceError:
# Topology was garbage-collected.
self.close()
def _get_seedlist(self):
"""Poll SRV records for a seedlist.
Returns a list of ServerDescriptions.
"""
try:
seedlist, ttl = _SrvResolver(self._fqdn).get_hosts_and_min_ttl()
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
# - SRV records must be rescanned every heartbeatFrequencyMS
# - Topology must be left unchanged
self.request_check()
return None
else:
self._executor.update_interval(
max(ttl, common.MIN_SRV_RESCAN_INTERVAL))
return seedlist

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
# Copyright 2014-2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Time. Monotonic if possible.
"""
from __future__ import absolute_import
__all__ = ['time']
try:
# Patches standard time module.
# From https://pypi.python.org/pypi/Monotime.
import monotime
except ImportError:
pass
try:
# From https://pypi.python.org/pypi/monotonic.
from monotonic import monotonic as time
except ImportError:
try:
# Monotime or Python 3.
from time import monotonic as time
except ImportError:
# Not monotonic.
from time import time

View File

@@ -0,0 +1,320 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal network layer helper methods."""
import datetime
import errno
import select
import struct
import threading
_HAS_POLL = True
_EVENT_MASK = 0
try:
from select import poll
_EVENT_MASK = (
select.POLLIN | select.POLLPRI | select.POLLERR | select.POLLHUP)
except ImportError:
_HAS_POLL = False
try:
from select import error as _SELECT_ERROR
except ImportError:
_SELECT_ERROR = OSError
from bson import _decode_all_selective
from bson.py3compat import PY3
from pymongo import helpers, message
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import decompress, _NO_COMPRESSION
from pymongo.errors import (AutoReconnect,
NotMasterError,
OperationFailure,
ProtocolError)
from pymongo.message import _UNPACK_REPLY
_UNPACK_HEADER = struct.Struct("<iiii").unpack
def command(sock, dbname, spec, slave_ok, is_mongos,
read_preference, codec_options, session, client, check=True,
allowable_errors=None, address=None,
check_keys=False, listeners=None, max_bson_size=None,
read_concern=None,
parse_write_concern_error=False,
collation=None,
compression_ctx=None,
use_op_msg=False,
unacknowledged=False,
user_fields=None):
"""Execute a command over the socket, or raise socket.error.
:Parameters:
- `sock`: a raw socket instance
- `dbname`: name of the database on which to run the command
- `spec`: a command document as an ordered dict type, eg SON.
- `slave_ok`: whether to set the SlaveOkay wire protocol bit
- `is_mongos`: are we connected to a mongos?
- `read_preference`: a read preference
- `codec_options`: a CodecOptions instance
- `session`: optional ClientSession instance.
- `client`: optional MongoClient instance for updating $clusterTime.
- `check`: raise OperationFailure if there are errors
- `allowable_errors`: errors to ignore if `check` is True
- `address`: the (host, port) of `sock`
- `check_keys`: if True, check `spec` for invalid keys
- `listeners`: An instance of :class:`~pymongo.monitoring.EventListeners`
- `max_bson_size`: The maximum encoded bson size for this server
- `read_concern`: The read concern for this command.
- `parse_write_concern_error`: Whether to parse the ``writeConcernError``
field in the command response.
- `collation`: The collation for this command.
- `compression_ctx`: optional compression Context.
- `use_op_msg`: True if we should use OP_MSG.
- `unacknowledged`: True if this is an unacknowledged command.
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
"""
name = next(iter(spec))
ns = dbname + '.$cmd'
flags = 4 if slave_ok else 0
# Publish the original command document, perhaps with lsid and $clusterTime.
orig = spec
if is_mongos and not use_op_msg:
spec = message._maybe_add_read_preference(spec, read_preference)
if read_concern and not (session and session._in_transaction):
if read_concern.level:
spec['readConcern'] = read_concern.document
if (session and session.options.causal_consistency
and session.operation_time is not None):
spec.setdefault(
'readConcern', {})['afterClusterTime'] = session.operation_time
if collation is not None:
spec['collation'] = collation
publish = listeners is not None and listeners.enabled_for_commands
if publish:
start = datetime.datetime.now()
if compression_ctx and name.lower() in _NO_COMPRESSION:
compression_ctx = None
if (client and client._encrypter and
not client._encrypter._bypass_auto_encryption):
spec = orig = client._encrypter.encrypt(
dbname, spec, check_keys, codec_options)
# We already checked the keys, no need to do it again.
check_keys = False
if use_op_msg:
flags = 2 if unacknowledged else 0
request_id, msg, size, max_doc_size = message._op_msg(
flags, spec, dbname, read_preference, slave_ok, check_keys,
codec_options, ctx=compression_ctx)
# If this is an unacknowledged write then make sure the encoded doc(s)
# are small enough, otherwise rely on the server to return an error.
if (unacknowledged and max_bson_size is not None and
max_doc_size > max_bson_size):
message._raise_document_too_large(name, size, max_bson_size)
else:
request_id, msg, size = message.query(
flags, ns, 0, -1, spec, None, codec_options, check_keys,
compression_ctx)
if (max_bson_size is not None
and size > max_bson_size + message._COMMAND_OVERHEAD):
message._raise_document_too_large(
name, size, max_bson_size + message._COMMAND_OVERHEAD)
if publish:
encoding_duration = datetime.datetime.now() - start
listeners.publish_command_start(orig, dbname, request_id, address)
start = datetime.datetime.now()
try:
sock.sendall(msg)
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
response_doc = {"ok": 1}
else:
reply = receive_message(sock, request_id)
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields)
response_doc = unpacked_docs[0]
if client:
client._process_response(response_doc, session)
if check:
helpers._check_command_response(
response_doc, None, allowable_errors,
parse_write_concern_error=parse_write_concern_error)
except Exception as exc:
if publish:
duration = (datetime.datetime.now() - start) + encoding_duration
if isinstance(exc, (NotMasterError, OperationFailure)):
failure = exc.details
else:
failure = message._convert_exception(exc)
listeners.publish_command_failure(
duration, failure, name, request_id, address)
raise
if publish:
duration = (datetime.datetime.now() - start) + encoding_duration
listeners.publish_command_success(
duration, response_doc, name, request_id, address)
if client and client._encrypter and reply:
decrypted = client._encrypter.decrypt(reply.raw_command_response())
response_doc = _decode_all_selective(decrypted, codec_options,
user_fields)[0]
return response_doc
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
def receive_message(sock, request_id, max_message_size=MAX_MESSAGE_SIZE):
"""Receive a raw BSON message or raise socket.error."""
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(
_receive_data_on_socket(sock, 16))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError("Got response id %r but expected "
"%r" % (response_to, request_id))
if length <= 16:
raise ProtocolError("Message length (%r) not longer than standard "
"message header size (16)" % (length,))
if length > max_message_size:
raise ProtocolError("Message length (%r) is larger than server max "
"message size (%r)" % (length, max_message_size))
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
_receive_data_on_socket(sock, 9))
data = decompress(
_receive_data_on_socket(sock, length - 25), compressor_id)
else:
data = _receive_data_on_socket(sock, length - 16)
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError("Got opcode %r but expected "
"%r" % (op_code, _UNPACK_REPLY.keys()))
return unpack_reply(data)
# memoryview was introduced in Python 2.7 but we only use it on Python 3
# because before 2.7.4 the struct module did not support memoryview:
# https://bugs.python.org/issue10212.
# In Jython, using slice assignment on a memoryview results in a
# NullPointerException.
if not PY3:
def _receive_data_on_socket(sock, length):
buf = bytearray(length)
i = 0
while length:
try:
chunk = sock.recv(length)
except (IOError, OSError) as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk == b"":
raise AutoReconnect("connection closed")
buf[i:i + len(chunk)] = chunk
i += len(chunk)
length -= len(chunk)
return bytes(buf)
else:
def _receive_data_on_socket(sock, length):
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
chunk_length = sock.recv_into(mv[bytes_read:])
except (IOError, OSError) as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise AutoReconnect("connection closed")
bytes_read += chunk_length
return mv
def _errno_from_exception(exc):
if hasattr(exc, 'errno'):
return exc.errno
elif exc.args:
return exc.args[0]
else:
return None
class SocketChecker(object):
def __init__(self):
if _HAS_POLL:
self._lock = threading.Lock()
self._poller = poll()
else:
self._lock = None
self._poller = None
def socket_closed(self, sock):
"""Return True if we know socket has been closed, False otherwise.
"""
while True:
try:
if self._poller:
with self._lock:
self._poller.register(sock, _EVENT_MASK)
try:
rd = self._poller.poll(0)
finally:
self._poller.unregister(sock)
else:
rd, _, _ = select.select([sock], [], [], 0)
except (RuntimeError, KeyError):
# RuntimeError is raised during a concurrent poll. KeyError
# is raised by unregister if the socket is not in the poller.
# These errors should not be possible since we protect the
# poller with a mutex.
raise
except ValueError:
# ValueError is raised by register/unregister/select if the
# socket file descriptor is negative or outside the range for
# select (> 1023).
return True
except (_SELECT_ERROR, IOError) as exc:
if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN):
continue
return True
except Exception:
# Any other exceptions should be attributed to a closed
# or invalid socket.
return True
return len(rd) > 0

View File

@@ -0,0 +1,377 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Operation class definitions."""
from pymongo.common import validate_boolean, validate_is_mapping, validate_list
from pymongo.collation import validate_collation_or_none
from pymongo.helpers import _gen_index_name, _index_document, _index_list
class InsertOne(object):
"""Represents an insert_one operation."""
__slots__ = ("_doc",)
def __init__(self, document):
"""Create an InsertOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:Parameters:
- `document`: The document to insert. If the document is missing an
_id field one will be added.
"""
self._doc = document
def _add_to_bulk(self, bulkobj):
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_insert(self._doc)
def __repr__(self):
return "InsertOne(%r)" % (self._doc,)
def __eq__(self, other):
if type(other) == type(self):
return other._doc == self._doc
return NotImplemented
def __ne__(self, other):
return not self == other
class DeleteOne(object):
"""Represents a delete_one operation."""
__slots__ = ("_filter", "_collation")
def __init__(self, filter, collation=None):
"""Create a DeleteOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:Parameters:
- `filter`: A query that matches the document to delete.
- `collation` (optional): An instance of
:class:`~pymongo.collation.Collation`. This option is only
supported on MongoDB 3.4 and above.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj):
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_delete(self._filter, 1, collation=self._collation)
def __repr__(self):
return "DeleteOne(%r, %r)" % (self._filter, self._collation)
def __eq__(self, other):
if type(other) == type(self):
return ((other._filter, other._collation) ==
(self._filter, self._collation))
return NotImplemented
def __ne__(self, other):
return not self == other
class DeleteMany(object):
"""Represents a delete_many operation."""
__slots__ = ("_filter", "_collation")
def __init__(self, filter, collation=None):
"""Create a DeleteMany instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:Parameters:
- `filter`: A query that matches the documents to delete.
- `collation` (optional): An instance of
:class:`~pymongo.collation.Collation`. This option is only
supported on MongoDB 3.4 and above.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
self._filter = filter
self._collation = collation
def _add_to_bulk(self, bulkobj):
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_delete(self._filter, 0, collation=self._collation)
def __repr__(self):
return "DeleteMany(%r, %r)" % (self._filter, self._collation)
def __eq__(self, other):
if type(other) == type(self):
return ((other._filter, other._collation) ==
(self._filter, self._collation))
return NotImplemented
def __ne__(self, other):
return not self == other
class ReplaceOne(object):
"""Represents a replace_one operation."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation")
def __init__(self, filter, replacement, upsert=False, collation=None):
"""Create a ReplaceOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:Parameters:
- `filter`: A query that matches the document to replace.
- `replacement`: The new document.
- `upsert` (optional): If ``True``, perform an insert if no documents
match the filter.
- `collation` (optional): An instance of
:class:`~pymongo.collation.Collation`. This option is only
supported on MongoDB 3.4 and above.
.. versionchanged:: 3.5
Added the `collation` option.
"""
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
self._filter = filter
self._doc = replacement
self._upsert = upsert
self._collation = collation
def _add_to_bulk(self, bulkobj):
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_replace(self._filter, self._doc, self._upsert,
collation=self._collation)
def __eq__(self, other):
if type(other) == type(self):
return (
(other._filter, other._doc, other._upsert, other._collation) ==
(self._filter, self._doc, self._upsert, self._collation))
return NotImplemented
def __ne__(self, other):
return not self == other
def __repr__(self):
return "%s(%r, %r, %r, %r)" % (
self.__class__.__name__, self._filter, self._doc, self._upsert,
self._collation)
class _UpdateOp(object):
"""Private base class for update operations."""
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters")
def __init__(self, filter, doc, upsert, collation, array_filters):
if filter is not None:
validate_is_mapping("filter", filter)
if upsert is not None:
validate_boolean("upsert", upsert)
if array_filters is not None:
validate_list("array_filters", array_filters)
self._filter = filter
self._doc = doc
self._upsert = upsert
self._collation = collation
self._array_filters = array_filters
def __eq__(self, other):
if type(other) == type(self):
return (
(other._filter, other._doc, other._upsert, other._collation,
other._array_filters) ==
(self._filter, self._doc, self._upsert, self._collation,
self._array_filters))
return NotImplemented
def __ne__(self, other):
return not self == other
def __repr__(self):
return "%s(%r, %r, %r, %r, %r)" % (
self.__class__.__name__, self._filter, self._doc, self._upsert,
self._collation, self._array_filters)
class UpdateOne(_UpdateOp):
"""Represents an update_one operation."""
__slots__ = ()
def __init__(self, filter, update, upsert=False, collation=None,
array_filters=None):
"""Represents an update_one operation.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:Parameters:
- `filter`: A query that matches the document to update.
- `update`: The modifications to apply.
- `upsert` (optional): If ``True``, perform an insert if no documents
match the filter.
- `collation` (optional): An instance of
:class:`~pymongo.collation.Collation`. This option is only
supported on MongoDB 3.4 and above.
- `array_filters` (optional): A list of filters specifying which
array elements an update should apply. Requires MongoDB 3.6+.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super(UpdateOne, self).__init__(filter, update, upsert, collation,
array_filters)
def _add_to_bulk(self, bulkobj):
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_update(self._filter, self._doc, False, self._upsert,
collation=self._collation,
array_filters=self._array_filters)
class UpdateMany(_UpdateOp):
"""Represents an update_many operation."""
__slots__ = ()
def __init__(self, filter, update, upsert=False, collation=None,
array_filters=None):
"""Create an UpdateMany instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
:Parameters:
- `filter`: A query that matches the documents to update.
- `update`: The modifications to apply.
- `upsert` (optional): If ``True``, perform an insert if no documents
match the filter.
- `collation` (optional): An instance of
:class:`~pymongo.collation.Collation`. This option is only
supported on MongoDB 3.4 and above.
- `array_filters` (optional): A list of filters specifying which
array elements an update should apply. Requires MongoDB 3.6+.
.. versionchanged:: 3.9
Added the ability to accept a pipeline as the `update`.
.. versionchanged:: 3.6
Added the `array_filters` option.
.. versionchanged:: 3.5
Added the `collation` option.
"""
super(UpdateMany, self).__init__(filter, update, upsert, collation,
array_filters)
def _add_to_bulk(self, bulkobj):
"""Add this operation to the _Bulk instance `bulkobj`."""
bulkobj.add_update(self._filter, self._doc, True, self._upsert,
collation=self._collation,
array_filters=self._array_filters)
class IndexModel(object):
"""Represents an index to create."""
__slots__ = ("__document",)
def __init__(self, keys, **kwargs):
"""Create an Index instance.
For use with :meth:`~pymongo.collection.Collection.create_indexes`.
Takes either a single key or a list of (key, direction) pairs.
The key(s) must be an instance of :class:`basestring`
(:class:`str` in python 3), and the direction(s) must be one of
(:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`,
:data:`~pymongo.GEO2D`, :data:`~pymongo.GEOHAYSTACK`,
:data:`~pymongo.GEOSPHERE`, :data:`~pymongo.HASHED`,
:data:`~pymongo.TEXT`).
Valid options include, but are not limited to:
- `name`: custom name to use for this index - if none is
given, a name will be generated.
- `unique`: if ``True`` creates a uniqueness constraint on the index.
- `background`: if ``True`` this index should be created in the
background.
- `sparse`: if ``True``, omit from the index any documents that lack
the indexed field.
- `bucketSize`: for use with geoHaystack indexes.
Number of documents to group together within a certain proximity
to a given longitude and latitude.
- `min`: minimum value for keys in a :data:`~pymongo.GEO2D`
index.
- `max`: maximum value for keys in a :data:`~pymongo.GEO2D`
index.
- `expireAfterSeconds`: <int> Used to create an expiring (TTL)
collection. MongoDB will automatically delete documents from
this collection after <int> seconds. The indexed field must
be a UTC datetime or the data will not expire.
- `partialFilterExpression`: A document that specifies a filter for
a partial index. Requires server version >= 3.2.
- `collation`: An instance of :class:`~pymongo.collation.Collation`
that specifies the collation to use in MongoDB >= 3.4.
- `wildcardProjection`: Allows users to include or exclude specific
field paths from a `wildcard index`_ using the { "$**" : 1} key
pattern. Requires server version >= 4.2.
See the MongoDB documentation for a full list of supported options by
server version.
:Parameters:
- `keys`: a single key or a list of (key, direction)
pairs specifying the index to create
- `**kwargs` (optional): any additional index creation
options (see the above list) should be passed as keyword
arguments
.. versionchanged:: 3.2
Added partialFilterExpression to support partial indexes.
.. _wildcard index: https://docs.mongodb.com/master/core/index-wildcard/#wildcard-index-core
"""
keys = _index_list(keys)
if "name" not in kwargs:
kwargs["name"] = _gen_index_name(keys)
kwargs["key"] = _index_document(keys)
collation = validate_collation_or_none(kwargs.pop('collation', None))
self.__document = kwargs
if collation is not None:
self.__document['collation'] = collation
@property
def document(self):
"""An index document suitable for passing to the createIndexes
command.
"""
return self.__document

View File

@@ -0,0 +1,177 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Run a target function on a background thread."""
import atexit
import threading
import time
import weakref
from pymongo.monotonic import time as _time
class PeriodicExecutor(object):
def __init__(self, interval, min_interval, target, name=None):
""""Run a target function periodically on a background thread.
If the target's return value is false, the executor stops.
:Parameters:
- `interval`: Seconds between calls to `target`.
- `min_interval`: Minimum seconds between calls if `wake` is
called very often.
- `target`: A function.
- `name`: A name to give the underlying thread.
"""
# threading.Event and its internal condition variable are expensive
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
# The executor's design is constrained by several Python issues, see
# "periodic_executor.rst" in this repository.
self._event = False
self._interval = interval
self._min_interval = min_interval
self._target = target
self._stopped = False
self._thread = None
self._name = name
self._thread_will_exit = False
self._lock = threading.Lock()
def open(self):
"""Start. Multiple calls have no effect.
Not safe to call from multiple threads at once.
"""
with self._lock:
if self._thread_will_exit:
# If the background thread has read self._stopped as True
# there is a chance that it has not yet exited. The call to
# join should not block indefinitely because there is no
# other work done outside the while loop in self._run.
try:
self._thread.join()
except ReferenceError:
# Thread terminated.
pass
self._thread_will_exit = False
self._stopped = False
started = False
try:
started = self._thread and self._thread.is_alive()
except ReferenceError:
# Thread terminated.
pass
if not started:
thread = threading.Thread(target=self._run, name=self._name)
thread.daemon = True
self._thread = weakref.proxy(thread)
_register_executor(self)
thread.start()
def close(self, dummy=None):
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
callback; see monitor.py.
"""
self._stopped = True
def join(self, timeout=None):
if self._thread is not None:
try:
self._thread.join(timeout)
except (ReferenceError, RuntimeError):
# Thread already terminated, or not yet started.
pass
def wake(self):
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval):
self._interval = new_interval
def __should_stop(self):
with self._lock:
if self._stopped:
self._thread_will_exit = True
return True
return False
def _run(self):
while not self.__should_stop():
try:
if not self._target():
self._stopped = True
break
except:
with self._lock:
self._stopped = True
self._thread_will_exit = True
raise
deadline = _time() + self._interval
while not self._stopped and _time() < deadline:
time.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
# an executor is kept alive by a strong reference from its thread and perhaps
# from other objects. When the thread dies and all other referrers are freed,
# the executor is freed and removed from _EXECUTORS. If any threads are
# running when the interpreter begins to shut down, we try to halt and join
# them to avoid spurious errors.
_EXECUTORS = set()
def _register_executor(executor):
ref = weakref.ref(executor, _on_executor_deleted)
_EXECUTORS.add(ref)
def _on_executor_deleted(ref):
_EXECUTORS.remove(ref)
def _shutdown_executors():
if _EXECUTORS is None:
return
# Copy the set. Stopping threads has the side effect of removing executors.
executors = list(_EXECUTORS)
# First signal all executors to close...
for ref in executors:
executor = ref()
if executor:
executor.close()
# ...then try to join them.
for ref in executors:
executor = ref()
if executor:
executor.join(1)
executor = None
atexit.register(_shutdown_executors)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,76 @@
# Copyright 2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License",
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for working with read concerns."""
from bson.py3compat import string_type
class ReadConcern(object):
"""ReadConcern
:Parameters:
- `level`: (string) The read concern level specifies the level of
isolation for read operations. For example, a read operation using a
read concern level of ``majority`` will only return data that has been
written to a majority of nodes. If the level is left unspecified, the
server default will be used.
.. versionadded:: 3.2
"""
def __init__(self, level=None):
if level is None or isinstance(level, string_type):
self.__level = level
else:
raise TypeError(
'level must be a string or None.')
@property
def level(self):
"""The read concern level."""
return self.__level
@property
def ok_for_legacy(self):
"""Return ``True`` if this read concern is compatible with
old wire protocol versions."""
return self.level is None or self.level == 'local'
@property
def document(self):
"""The document representation of this read concern.
.. note::
:class:`ReadConcern` is immutable. Mutating the value of
:attr:`document` does not mutate this :class:`ReadConcern`.
"""
doc = {}
if self.__level:
doc['level'] = self.level
return doc
def __eq__(self, other):
if isinstance(other, ReadConcern):
return self.document == other.document
return NotImplemented
def __repr__(self):
if self.level:
return 'ReadConcern(%s)' % self.level
return 'ReadConcern()'
DEFAULT_READ_CONCERN = ReadConcern()

View File

@@ -0,0 +1,471 @@
# Copyright 2012-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License",
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for choosing which member of a replica set to read from."""
from bson.py3compat import abc, integer_types
from pymongo import max_staleness_selectors
from pymongo.errors import ConfigurationError
from pymongo.server_selectors import (member_with_tags_server_selector,
secondary_with_tags_server_selector)
_PRIMARY = 0
_PRIMARY_PREFERRED = 1
_SECONDARY = 2
_SECONDARY_PREFERRED = 3
_NEAREST = 4
_MONGOS_MODES = (
'primary',
'primaryPreferred',
'secondary',
'secondaryPreferred',
'nearest',
)
def _validate_tag_sets(tag_sets):
"""Validate tag sets for a MongoReplicaSetClient.
"""
if tag_sets is None:
return tag_sets
if not isinstance(tag_sets, list):
raise TypeError((
"Tag sets %r invalid, must be a list") % (tag_sets,))
if len(tag_sets) == 0:
raise ValueError((
"Tag sets %r invalid, must be None or contain at least one set of"
" tags") % (tag_sets,))
for tags in tag_sets:
if not isinstance(tags, abc.Mapping):
raise TypeError(
"Tag set %r invalid, must be an instance of dict, "
"bson.son.SON or other type that inherits from "
"collection.Mapping" % (tags,))
return tag_sets
def _invalid_max_staleness_msg(max_staleness):
return ("maxStalenessSeconds must be a positive integer, not %s" %
max_staleness)
# Some duplication with common.py to avoid import cycle.
def _validate_max_staleness(max_staleness):
"""Validate max_staleness."""
if max_staleness == -1:
return -1
if not isinstance(max_staleness, integer_types):
raise TypeError(_invalid_max_staleness_msg(max_staleness))
if max_staleness <= 0:
raise ValueError(_invalid_max_staleness_msg(max_staleness))
return max_staleness
class _ServerMode(object):
"""Base class for all read preferences.
"""
__slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness")
def __init__(self, mode, tag_sets=None, max_staleness=-1):
self.__mongos_mode = _MONGOS_MODES[mode]
self.__mode = mode
self.__tag_sets = _validate_tag_sets(tag_sets)
self.__max_staleness = _validate_max_staleness(max_staleness)
@property
def name(self):
"""The name of this read preference.
"""
return self.__class__.__name__
@property
def mongos_mode(self):
"""The mongos mode of this read preference.
"""
return self.__mongos_mode
@property
def document(self):
"""Read preference as a document.
"""
doc = {'mode': self.__mongos_mode}
if self.__tag_sets not in (None, [{}]):
doc['tags'] = self.__tag_sets
if self.__max_staleness != -1:
doc['maxStalenessSeconds'] = self.__max_staleness
return doc
@property
def mode(self):
"""The mode of this read preference instance.
"""
return self.__mode
@property
def tag_sets(self):
"""Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to
read only from members whose ``dc`` tag has the value ``"ny"``.
To specify a priority-order for tag sets, provide a list of
tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag
set, ``{}``, means "read from any member that matches the mode,
ignoring tags." MongoReplicaSetClient tries each set of tags in turn
until it finds a set of tags with at least one matching member.
.. seealso:: `Data-Center Awareness
<http://www.mongodb.org/display/DOCS/Data+Center+Awareness>`_
"""
return list(self.__tag_sets) if self.__tag_sets else [{}]
@property
def max_staleness(self):
"""The maximum estimated length of time (in seconds) a replica set
secondary can fall behind the primary in replication before it will
no longer be selected for operations, or -1 for no maximum."""
return self.__max_staleness
@property
def min_wire_version(self):
"""The wire protocol version the server must support.
Some read preferences impose version requirements on all servers (e.g.
maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5).
All servers' maxWireVersion must be at least this read preference's
`min_wire_version`, or the driver raises
:exc:`~pymongo.errors.ConfigurationError`.
"""
return 0 if self.__max_staleness == -1 else 5
def __repr__(self):
return "%s(tag_sets=%r, max_staleness=%r)" % (
self.name, self.__tag_sets, self.__max_staleness)
def __eq__(self, other):
if isinstance(other, _ServerMode):
return (self.mode == other.mode and
self.tag_sets == other.tag_sets and
self.max_staleness == other.max_staleness)
return NotImplemented
def __ne__(self, other):
return not self == other
def __getstate__(self):
"""Return value of object for pickling.
Needed explicitly because __slots__() defined.
"""
return {'mode': self.__mode,
'tag_sets': self.__tag_sets,
'max_staleness': self.__max_staleness}
def __setstate__(self, value):
"""Restore from pickling."""
self.__mode = value['mode']
self.__mongos_mode = _MONGOS_MODES[self.__mode]
self.__tag_sets = _validate_tag_sets(value['tag_sets'])
self.__max_staleness = _validate_max_staleness(value['max_staleness'])
class Primary(_ServerMode):
"""Primary read preference.
* When directly connected to one mongod queries are allowed if the server
is standalone or a replica set primary.
* When connected to a mongos queries are sent to the primary of a shard.
* When connected to a replica set queries are sent to the primary of
the replica set.
"""
__slots__ = ()
def __init__(self):
super(Primary, self).__init__(_PRIMARY)
def __call__(self, selection):
"""Apply this read preference to a Selection."""
return selection.primary_selection
def __repr__(self):
return "Primary()"
def __eq__(self, other):
if isinstance(other, _ServerMode):
return other.mode == _PRIMARY
return NotImplemented
class PrimaryPreferred(_ServerMode):
"""PrimaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are sent to the primary of a shard if
available, otherwise a shard secondary.
* When connected to a replica set queries are sent to the primary if
available, otherwise a secondary.
:Parameters:
- `tag_sets`: The :attr:`~tag_sets` to use if the primary is not
available.
- `max_staleness`: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
"""
__slots__ = ()
def __init__(self, tag_sets=None, max_staleness=-1):
super(PrimaryPreferred, self).__init__(_PRIMARY_PREFERRED,
tag_sets,
max_staleness)
def __call__(self, selection):
"""Apply this read preference to Selection."""
if selection.primary:
return selection.primary_selection
else:
return secondary_with_tags_server_selector(
self.tag_sets,
max_staleness_selectors.select(
self.max_staleness, selection))
class Secondary(_ServerMode):
"""Secondary read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries. An error is raised if no secondaries are available.
* When connected to a replica set queries are distributed among
secondaries. An error is raised if no secondaries are available.
:Parameters:
- `tag_sets`: The :attr:`~tag_sets` for this read preference.
- `max_staleness`: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
"""
__slots__ = ()
def __init__(self, tag_sets=None, max_staleness=-1):
super(Secondary, self).__init__(_SECONDARY, tag_sets, max_staleness)
def __call__(self, selection):
"""Apply this read preference to Selection."""
return secondary_with_tags_server_selector(
self.tag_sets,
max_staleness_selectors.select(
self.max_staleness, selection))
class SecondaryPreferred(_ServerMode):
"""SecondaryPreferred read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among shard
secondaries, or the shard primary if no secondary is available.
* When connected to a replica set queries are distributed among
secondaries, or the primary if no secondary is available.
:Parameters:
- `tag_sets`: The :attr:`~tag_sets` for this read preference.
- `max_staleness`: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
"""
__slots__ = ()
def __init__(self, tag_sets=None, max_staleness=-1):
super(SecondaryPreferred, self).__init__(_SECONDARY_PREFERRED,
tag_sets,
max_staleness)
def __call__(self, selection):
"""Apply this read preference to Selection."""
secondaries = secondary_with_tags_server_selector(
self.tag_sets,
max_staleness_selectors.select(
self.max_staleness, selection))
if secondaries:
return secondaries
else:
return selection.primary_selection
class Nearest(_ServerMode):
"""Nearest read preference.
* When directly connected to one mongod queries are allowed to standalone
servers, to a replica set primary, or to replica set secondaries.
* When connected to a mongos queries are distributed among all members of
a shard.
* When connected to a replica set queries are distributed among all
members.
:Parameters:
- `tag_sets`: The :attr:`~tag_sets` for this read preference.
- `max_staleness`: (integer, in seconds) The maximum estimated
length of time a replica set secondary can fall behind the primary in
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
"""
__slots__ = ()
def __init__(self, tag_sets=None, max_staleness=-1):
super(Nearest, self).__init__(_NEAREST, tag_sets, max_staleness)
def __call__(self, selection):
"""Apply this read preference to Selection."""
return member_with_tags_server_selector(
self.tag_sets,
max_staleness_selectors.select(
self.max_staleness, selection))
_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred,
Secondary, SecondaryPreferred, Nearest)
def make_read_preference(mode, tag_sets, max_staleness=-1):
if mode == _PRIMARY:
if tag_sets not in (None, [{}]):
raise ConfigurationError("Read preference primary "
"cannot be combined with tags")
if max_staleness != -1:
raise ConfigurationError("Read preference primary cannot be "
"combined with maxStalenessSeconds")
return Primary()
return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness)
_MODES = (
'PRIMARY',
'PRIMARY_PREFERRED',
'SECONDARY',
'SECONDARY_PREFERRED',
'NEAREST',
)
class ReadPreference(object):
"""An enum that defines the read preference modes supported by PyMongo.
See :doc:`/examples/high_availability` for code examples.
A read preference is used in three cases:
:class:`~pymongo.mongo_client.MongoClient` connected to a single mongod:
- ``PRIMARY``: Queries are allowed if the server is standalone or a replica
set primary.
- All other modes allow queries to standalone servers, to a replica set
primary, or to replica set secondaries.
:class:`~pymongo.mongo_client.MongoClient` initialized with the
``replicaSet`` option:
- ``PRIMARY``: Read from the primary. This is the default, and provides the
strongest consistency. If no primary is available, raise
:class:`~pymongo.errors.AutoReconnect`.
- ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is
none, read from a secondary.
- ``SECONDARY``: Read from a secondary. If no secondary is available,
raise :class:`~pymongo.errors.AutoReconnect`.
- ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise
from the primary.
- ``NEAREST``: Read from any member.
:class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a
sharded cluster of replica sets:
- ``PRIMARY``: Read from the primary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
This is the default.
- ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is
none, read from a secondary of the shard.
- ``SECONDARY``: Read from a secondary of the shard, or raise
:class:`~pymongo.errors.OperationFailure` if there is none.
- ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available,
otherwise from the shard primary.
- ``NEAREST``: Read from any shard member.
"""
PRIMARY = Primary()
PRIMARY_PREFERRED = PrimaryPreferred()
SECONDARY = Secondary()
SECONDARY_PREFERRED = SecondaryPreferred()
NEAREST = Nearest()
def read_pref_mode_from_name(name):
"""Get the read preference mode from mongos/uri name.
"""
return _MONGOS_MODES.index(name)
class MovingAverage(object):
"""Tracks an exponentially-weighted moving average."""
def __init__(self):
self.average = None
def add_sample(self, sample):
if sample < 0:
# Likely system time change while waiting for ismaster response
# and not using time.monotonic. Ignore it, the next one will
# probably be valid.
return
if self.average is None:
self.average = sample
else:
# The Server Selection Spec requires an exponentially weighted
# average with alpha = 0.2.
self.average = 0.8 * self.average + 0.2 * sample
def get(self):
"""Get the calculated average, or None if no samples yet."""
return self.average
def reset(self):
self.average = None

View File

@@ -0,0 +1,107 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Represent a response from the server."""
class Response(object):
__slots__ = ('_data', '_address', '_request_id', '_duration',
'_from_command', '_docs')
def __init__(self, data, address, request_id, duration, from_command,
docs):
"""Represent a response from the server.
:Parameters:
- `data`: A network response message.
- `address`: (host, port) of the source server.
- `request_id`: The request id of this operation.
- `duration`: The duration of the operation.
- `from_command`: if the response is the result of a db command.
"""
self._data = data
self._address = address
self._request_id = request_id
self._duration = duration
self._from_command = from_command
self._docs = docs
@property
def data(self):
"""Server response's raw BSON bytes."""
return self._data
@property
def address(self):
"""(host, port) of the source server."""
return self._address
@property
def request_id(self):
"""The request id of this operation."""
return self._request_id
@property
def duration(self):
"""The duration of the operation."""
return self._duration
@property
def from_command(self):
"""If the response is a result from a db command."""
return self._from_command
@property
def docs(self):
"""The decoded document(s)."""
return self._docs
class ExhaustResponse(Response):
__slots__ = ('_socket_info', '_pool')
def __init__(self, data, address, socket_info, pool, request_id, duration,
from_command, docs):
"""Represent a response to an exhaust cursor's initial query.
:Parameters:
- `data`: A network response message.
- `address`: (host, port) of the source server.
- `socket_info`: The SocketInfo used for the initial query.
- `pool`: The Pool from which the SocketInfo came.
- `request_id`: The request id of this operation.
- `duration`: The duration of the operation.
- `from_command`: If the response is the result of a db command.
"""
super(ExhaustResponse, self).__init__(data,
address,
request_id,
duration,
from_command, docs)
self._socket_info = socket_info
self._pool = pool
@property
def socket_info(self):
"""The SocketInfo used for the initial query.
The server will send batches on this socket, without waiting for
getMores from the client, until the result set is exhausted or there
is an error.
"""
return self._socket_info
@property
def pool(self):
"""The Pool from which the SocketInfo came."""
return self._pool

View File

@@ -0,0 +1,226 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Result class definitions."""
from pymongo.errors import InvalidOperation
class _WriteResult(object):
"""Base class for write result classes."""
__slots__ = ("__acknowledged",)
def __init__(self, acknowledged):
self.__acknowledged = acknowledged
def _raise_if_unacknowledged(self, property_name):
"""Raise an exception on property access if unacknowledged."""
if not self.__acknowledged:
raise InvalidOperation("A value for %s is not available when "
"the write is unacknowledged. Check the "
"acknowledged attribute to avoid this "
"error." % (property_name,))
@property
def acknowledged(self):
"""Is this the result of an acknowledged write operation?
The :attr:`acknowledged` attribute will be ``False`` when using
``WriteConcern(w=0)``, otherwise ``True``.
.. note::
If the :attr:`acknowledged` attribute is ``False`` all other
attibutes of this class will raise
:class:`~pymongo.errors.InvalidOperation` when accessed. Values for
other attributes cannot be determined if the write operation was
unacknowledged.
.. seealso::
:class:`~pymongo.write_concern.WriteConcern`
"""
return self.__acknowledged
class InsertOneResult(_WriteResult):
"""The return type for :meth:`~pymongo.collection.Collection.insert_one`.
"""
__slots__ = ("__inserted_id", "__acknowledged")
def __init__(self, inserted_id, acknowledged):
self.__inserted_id = inserted_id
super(InsertOneResult, self).__init__(acknowledged)
@property
def inserted_id(self):
"""The inserted document's _id."""
return self.__inserted_id
class InsertManyResult(_WriteResult):
"""The return type for :meth:`~pymongo.collection.Collection.insert_many`.
"""
__slots__ = ("__inserted_ids", "__acknowledged")
def __init__(self, inserted_ids, acknowledged):
self.__inserted_ids = inserted_ids
super(InsertManyResult, self).__init__(acknowledged)
@property
def inserted_ids(self):
"""A list of _ids of the inserted documents, in the order provided.
.. note:: If ``False`` is passed for the `ordered` parameter to
:meth:`~pymongo.collection.Collection.insert_many` the server
may have inserted the documents in a different order than what
is presented here.
"""
return self.__inserted_ids
class UpdateResult(_WriteResult):
"""The return type for :meth:`~pymongo.collection.Collection.update_one`,
:meth:`~pymongo.collection.Collection.update_many`, and
:meth:`~pymongo.collection.Collection.replace_one`.
"""
__slots__ = ("__raw_result", "__acknowledged")
def __init__(self, raw_result, acknowledged):
self.__raw_result = raw_result
super(UpdateResult, self).__init__(acknowledged)
@property
def raw_result(self):
"""The raw result document returned by the server."""
return self.__raw_result
@property
def matched_count(self):
"""The number of documents matched for this update."""
self._raise_if_unacknowledged("matched_count")
if self.upserted_id is not None:
return 0
return self.__raw_result.get("n", 0)
@property
def modified_count(self):
"""The number of documents modified.
.. note:: modified_count is only reported by MongoDB 2.6 and later.
When connected to an earlier server version, or in certain mixed
version sharding configurations, this attribute will be set to
``None``.
"""
self._raise_if_unacknowledged("modified_count")
return self.__raw_result.get("nModified")
@property
def upserted_id(self):
"""The _id of the inserted document if an upsert took place. Otherwise
``None``.
"""
self._raise_if_unacknowledged("upserted_id")
return self.__raw_result.get("upserted")
class DeleteResult(_WriteResult):
"""The return type for :meth:`~pymongo.collection.Collection.delete_one`
and :meth:`~pymongo.collection.Collection.delete_many`"""
__slots__ = ("__raw_result", "__acknowledged")
def __init__(self, raw_result, acknowledged):
self.__raw_result = raw_result
super(DeleteResult, self).__init__(acknowledged)
@property
def raw_result(self):
"""The raw result document returned by the server."""
return self.__raw_result
@property
def deleted_count(self):
"""The number of documents deleted."""
self._raise_if_unacknowledged("deleted_count")
return self.__raw_result.get("n", 0)
class BulkWriteResult(_WriteResult):
"""An object wrapper for bulk API write results."""
__slots__ = ("__bulk_api_result", "__acknowledged")
def __init__(self, bulk_api_result, acknowledged):
"""Create a BulkWriteResult instance.
:Parameters:
- `bulk_api_result`: A result dict from the bulk API
- `acknowledged`: Was this write result acknowledged? If ``False``
then all properties of this object will raise
:exc:`~pymongo.errors.InvalidOperation`.
"""
self.__bulk_api_result = bulk_api_result
super(BulkWriteResult, self).__init__(acknowledged)
@property
def bulk_api_result(self):
"""The raw bulk API result."""
return self.__bulk_api_result
@property
def inserted_count(self):
"""The number of documents inserted."""
self._raise_if_unacknowledged("inserted_count")
return self.__bulk_api_result.get("nInserted")
@property
def matched_count(self):
"""The number of documents matched for an update."""
self._raise_if_unacknowledged("matched_count")
return self.__bulk_api_result.get("nMatched")
@property
def modified_count(self):
"""The number of documents modified.
.. note:: modified_count is only reported by MongoDB 2.6 and later.
When connected to an earlier server version, or in certain mixed
version sharding configurations, this attribute will be set to
``None``.
"""
self._raise_if_unacknowledged("modified_count")
return self.__bulk_api_result.get("nModified")
@property
def deleted_count(self):
"""The number of documents deleted."""
self._raise_if_unacknowledged("deleted_count")
return self.__bulk_api_result.get("nRemoved")
@property
def upserted_count(self):
"""The number of documents upserted."""
self._raise_if_unacknowledged("upserted_count")
return self.__bulk_api_result.get("nUpserted")
@property
def upserted_ids(self):
"""A map of operation index to the _id of the upserted document."""
self._raise_if_unacknowledged("upserted_ids")
if self.__bulk_api_result:
return dict((upsert["index"], upsert["_id"])
for upsert in self.bulk_api_result["upserted"])

View File

@@ -0,0 +1,108 @@
# Copyright 2016-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An implementation of RFC4013 SASLprep."""
from bson.py3compat import text_type as _text_type
try:
import stringprep
except ImportError:
HAVE_STRINGPREP = False
def saslprep(data):
"""SASLprep dummy"""
if isinstance(data, _text_type):
raise TypeError(
"The stringprep module is not available. Usernames and "
"passwords must be ASCII strings.")
return data
else:
HAVE_STRINGPREP = True
import unicodedata
# RFC4013 section 2.3 prohibited output.
_PROHIBITED = (
# A strict reading of RFC 4013 requires table c12 here, but
# characters from it are mapped to SPACE in the Map step. Can
# normalization reintroduce them somehow?
stringprep.in_table_c12,
stringprep.in_table_c21_c22,
stringprep.in_table_c3,
stringprep.in_table_c4,
stringprep.in_table_c5,
stringprep.in_table_c6,
stringprep.in_table_c7,
stringprep.in_table_c8,
stringprep.in_table_c9)
def saslprep(data, prohibit_unassigned_code_points=True):
"""An implementation of RFC4013 SASLprep.
:Parameters:
- `data`: The string to SASLprep. Unicode strings
(python 2.x unicode, 3.x str) are supported. Byte strings
(python 2.x str, 3.x bytes) are ignored.
- `prohibit_unassigned_code_points`: True / False. RFC 3454
and RFCs for various SASL mechanisms distinguish between
`queries` (unassigned code points allowed) and
`stored strings` (unassigned code points prohibited). Defaults
to ``True`` (unassigned code points are prohibited).
:Returns:
The SASLprep'ed version of `data`.
"""
if not isinstance(data, _text_type):
return data
if prohibit_unassigned_code_points:
prohibited = _PROHIBITED + (stringprep.in_table_a1,)
else:
prohibited = _PROHIBITED
# RFC3454 section 2, step 1 - Map
# RFC4013 section 2.1 mappings
# Map Non-ASCII space characters to SPACE (U+0020). Map
# commonly mapped to nothing characters to, well, nothing.
in_table_c12 = stringprep.in_table_c12
in_table_b1 = stringprep.in_table_b1
data = u"".join(
[u"\u0020" if in_table_c12(elt) else elt
for elt in data if not in_table_b1(elt)])
# RFC3454 section 2, step 2 - Normalize
# RFC4013 section 2.2 normalization
data = unicodedata.ucd_3_2_0.normalize('NFKC', data)
in_table_d1 = stringprep.in_table_d1
if in_table_d1(data[0]):
if not in_table_d1(data[-1]):
# RFC3454, Section 6, #3. If a string contains any
# RandALCat character, the first and last characters
# MUST be RandALCat characters.
raise ValueError("SASLprep: failed bidirectional check")
# RFC3454, Section 6, #2. If a string contains any RandALCat
# character, it MUST NOT contain any LCat character.
prohibited = prohibited + (stringprep.in_table_d2,)
else:
# RFC3454, Section 6, #3. Following the logic of #3, if
# the first character is not a RandALCat, no other character
# can be either.
prohibited = prohibited + (in_table_d1,)
# RFC3454 section 2, step 3 and 4 - Prohibit and check bidi
for char in data:
if any(in_table(char) for in_table in prohibited):
raise ValueError(
"SASLprep: failed prohibited character check")
return data

View File

@@ -0,0 +1,232 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Communicate with one MongoDB server in a topology."""
from datetime import datetime
from bson import _decode_all_selective
from pymongo.errors import NotMasterError, OperationFailure
from pymongo.helpers import _check_command_response
from pymongo.message import _convert_exception
from pymongo.response import Response, ExhaustResponse
from pymongo.server_type import SERVER_TYPE
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}}
class Server(object):
def __init__(self, server_description, pool, monitor, topology_id=None,
listeners=None, events=None):
"""Represent one MongoDB server."""
self._description = server_description
self._pool = pool
self._monitor = monitor
self._topology_id = topology_id
self._publish = listeners is not None and listeners.enabled_for_server
self._listener = listeners
self._events = None
if self._publish:
self._events = events()
def open(self):
"""Start monitoring, or restart after a fork.
Multiple calls have no effect.
"""
self._monitor.open()
def reset(self):
"""Clear the connection pool."""
self.pool.reset()
def close(self):
"""Clear the connection pool and stop the monitor.
Reconnect with open().
"""
if self._publish:
self._events.put((self._listener.publish_server_closed,
(self._description.address, self._topology_id)))
self._monitor.close()
self._pool.reset()
def request_check(self):
"""Check the server's state soon."""
self._monitor.request_check()
def run_operation_with_response(
self,
sock_info,
operation,
set_slave_okay,
listeners,
exhaust,
unpack_res):
"""Run a _Query or _GetMore operation and return a Response object.
This method is used only to run _Query/_GetMore operations from
cursors.
Can raise ConnectionFailure, OperationFailure, etc.
:Parameters:
- `operation`: A _Query or _GetMore object.
- `set_slave_okay`: Pass to operation.get_message.
- `all_credentials`: dict, maps auth source to MongoCredential.
- `listeners`: Instance of _EventListeners or None.
- `exhaust`: If True, then this is an exhaust cursor operation.
- `unpack_res`: A callable that decodes the wire protocol response.
"""
duration = None
publish = listeners.enabled_for_commands
if publish:
start = datetime.now()
send_message = not operation.exhaust_mgr
if send_message:
use_cmd = operation.use_command(sock_info, exhaust)
message = operation.get_message(
set_slave_okay, sock_info, use_cmd)
request_id, data, max_doc_size = self._split_message(message)
else:
use_cmd = False
request_id = 0
if publish:
cmd, dbn = operation.as_command(sock_info)
listeners.publish_command_start(
cmd, dbn, request_id, sock_info.address)
start = datetime.now()
try:
if send_message:
sock_info.send_message(data, max_doc_size)
reply = sock_info.receive_message(request_id)
else:
reply = sock_info.receive_message(None)
# Unpack and check for command errors.
if use_cmd:
user_fields = _CURSOR_DOC_FIELDS
legacy_response = False
else:
user_fields = None
legacy_response = True
docs = unpack_res(reply, operation.cursor_id,
operation.codec_options,
legacy_response=legacy_response,
user_fields=user_fields)
if use_cmd:
first = docs[0]
operation.client._process_response(
first, operation.session)
_check_command_response(first)
except Exception as exc:
if publish:
duration = datetime.now() - start
if isinstance(exc, (NotMasterError, OperationFailure)):
failure = exc.details
else:
failure = _convert_exception(exc)
listeners.publish_command_failure(
duration, failure, operation.name,
request_id, sock_info.address)
raise
if publish:
duration = datetime.now() - start
# Must publish in find / getMore / explain command response
# format.
if use_cmd:
res = docs[0]
elif operation.name == "explain":
res = docs[0] if docs else {}
else:
res = {"cursor": {"id": reply.cursor_id,
"ns": operation.namespace()},
"ok": 1}
if operation.name == "find":
res["cursor"]["firstBatch"] = docs
else:
res["cursor"]["nextBatch"] = docs
listeners.publish_command_success(
duration, res, operation.name, request_id,
sock_info.address)
# Decrypt response.
client = operation.client
if client and client._encrypter:
if use_cmd:
decrypted = client._encrypter.decrypt(
reply.raw_command_response())
docs = _decode_all_selective(
decrypted, operation.codec_options, user_fields)
if exhaust:
response = ExhaustResponse(
data=reply,
address=self._description.address,
socket_info=sock_info,
pool=self._pool,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs)
else:
response = Response(
data=reply,
address=self._description.address,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs)
return response
def get_socket(self, all_credentials, checkout=False):
return self.pool.get_socket(all_credentials, checkout)
@property
def description(self):
return self._description
@description.setter
def description(self, server_description):
assert server_description.address == self._description.address
self._description = server_description
@property
def pool(self):
return self._pool
def _split_message(self, message):
"""Return request_id, data, max_doc_size.
:Parameters:
- `message`: (request_id, data, max_doc_size) or (request_id, data)
"""
if len(message) == 3:
return message
else:
# get_more and kill_cursors messages don't include BSON documents.
request_id, data = message
return request_id, data, 0
def __str__(self):
d = self._description
return '<Server "%s:%s" %s>' % (
d.address[0], d.address[1],
SERVER_TYPE._fields[d.server_type])

View File

@@ -0,0 +1,211 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Represent one server the driver is connected to."""
from bson import EPOCH_NAIVE
from pymongo.server_type import SERVER_TYPE
from pymongo.ismaster import IsMaster
from pymongo.monotonic import time as _time
class ServerDescription(object):
"""Immutable representation of one server.
:Parameters:
- `address`: A (host, port) pair
- `ismaster`: Optional IsMaster instance
- `round_trip_time`: Optional float
- `error`: Optional, the last error attempting to connect to the server
"""
__slots__ = (
'_address', '_server_type', '_all_hosts', '_tags', '_replica_set_name',
'_primary', '_max_bson_size', '_max_message_size',
'_max_write_batch_size', '_min_wire_version', '_max_wire_version',
'_round_trip_time', '_me', '_is_writable', '_is_readable',
'_ls_timeout_minutes', '_error', '_set_version', '_election_id',
'_cluster_time', '_last_write_date', '_last_update_time')
def __init__(
self,
address,
ismaster=None,
round_trip_time=None,
error=None):
self._address = address
if not ismaster:
ismaster = IsMaster({})
self._server_type = ismaster.server_type
self._all_hosts = ismaster.all_hosts
self._tags = ismaster.tags
self._replica_set_name = ismaster.replica_set_name
self._primary = ismaster.primary
self._max_bson_size = ismaster.max_bson_size
self._max_message_size = ismaster.max_message_size
self._max_write_batch_size = ismaster.max_write_batch_size
self._min_wire_version = ismaster.min_wire_version
self._max_wire_version = ismaster.max_wire_version
self._set_version = ismaster.set_version
self._election_id = ismaster.election_id
self._cluster_time = ismaster.cluster_time
self._is_writable = ismaster.is_writable
self._is_readable = ismaster.is_readable
self._ls_timeout_minutes = ismaster.logical_session_timeout_minutes
self._round_trip_time = round_trip_time
self._me = ismaster.me
self._last_update_time = _time()
self._error = error
if ismaster.last_write_date:
# Convert from datetime to seconds.
delta = ismaster.last_write_date - EPOCH_NAIVE
self._last_write_date = delta.total_seconds()
else:
self._last_write_date = None
@property
def address(self):
"""The address (host, port) of this server."""
return self._address
@property
def server_type(self):
"""The type of this server."""
return self._server_type
@property
def server_type_name(self):
"""The server type as a human readable string.
.. versionadded:: 3.4
"""
return SERVER_TYPE._fields[self._server_type]
@property
def all_hosts(self):
"""List of hosts, passives, and arbiters known to this server."""
return self._all_hosts
@property
def tags(self):
return self._tags
@property
def replica_set_name(self):
"""Replica set name or None."""
return self._replica_set_name
@property
def primary(self):
"""This server's opinion about who the primary is, or None."""
return self._primary
@property
def max_bson_size(self):
return self._max_bson_size
@property
def max_message_size(self):
return self._max_message_size
@property
def max_write_batch_size(self):
return self._max_write_batch_size
@property
def min_wire_version(self):
return self._min_wire_version
@property
def max_wire_version(self):
return self._max_wire_version
@property
def set_version(self):
return self._set_version
@property
def election_id(self):
return self._election_id
@property
def cluster_time(self):
return self._cluster_time
@property
def election_tuple(self):
return self._set_version, self._election_id
@property
def me(self):
return self._me
@property
def logical_session_timeout_minutes(self):
return self._ls_timeout_minutes
@property
def last_write_date(self):
return self._last_write_date
@property
def last_update_time(self):
return self._last_update_time
@property
def round_trip_time(self):
"""The current average latency or None."""
# This override is for unittesting only!
if self._address in self._host_to_round_trip_time:
return self._host_to_round_trip_time[self._address]
return self._round_trip_time
@property
def error(self):
"""The last error attempting to connect to the server, or None."""
return self._error
@property
def is_writable(self):
return self._is_writable
@property
def is_readable(self):
return self._is_readable
@property
def mongos(self):
return self._server_type == SERVER_TYPE.Mongos
@property
def is_server_type_known(self):
return self.server_type != SERVER_TYPE.Unknown
@property
def retryable_writes_supported(self):
"""Checks if this server supports retryable writes."""
return (
self._ls_timeout_minutes is not None and
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary))
@property
def retryable_reads_supported(self):
"""Checks if this server supports retryable writes."""
return self._max_wire_version >= 6
# For unittesting only. Use under no circumstances!
_host_to_round_trip_time = {}

View File

@@ -0,0 +1,156 @@
# Copyright 2014-2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Criteria to select some ServerDescriptions from a TopologyDescription."""
from pymongo.server_type import SERVER_TYPE
class Selection(object):
"""Input or output of a server selector function."""
@classmethod
def from_topology_description(cls, topology_description):
known_servers = topology_description.known_servers
primary = None
for sd in known_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
primary = sd
break
return Selection(topology_description,
topology_description.known_servers,
topology_description.common_wire_version,
primary)
def __init__(self,
topology_description,
server_descriptions,
common_wire_version,
primary):
self.topology_description = topology_description
self.server_descriptions = server_descriptions
self.primary = primary
self.common_wire_version = common_wire_version
def with_server_descriptions(self, server_descriptions):
return Selection(self.topology_description,
server_descriptions,
self.common_wire_version,
self.primary)
def secondary_with_max_last_write_date(self):
secondaries = secondary_server_selector(self)
if secondaries.server_descriptions:
return max(secondaries.server_descriptions,
key=lambda sd: sd.last_write_date)
@property
def primary_selection(self):
primaries = [self.primary] if self.primary else []
return self.with_server_descriptions(primaries)
@property
def heartbeat_frequency(self):
return self.topology_description.heartbeat_frequency
@property
def topology_type(self):
return self.topology_description.topology_type
def __bool__(self):
return bool(self.server_descriptions)
__nonzero__ = __bool__ # Python 2.
def __getitem__(self, item):
return self.server_descriptions[item]
def any_server_selector(selection):
return selection
def readable_server_selector(selection):
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.is_readable])
def writable_server_selector(selection):
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if s.is_writable])
def secondary_server_selector(selection):
return selection.with_server_descriptions(
[s for s in selection.server_descriptions
if s.server_type == SERVER_TYPE.RSSecondary])
def arbiter_server_selector(selection):
return selection.with_server_descriptions(
[s for s in selection.server_descriptions
if s.server_type == SERVER_TYPE.RSArbiter])
def writable_preferred_server_selector(selection):
"""Like PrimaryPreferred but doesn't use tags or latency."""
return (writable_server_selector(selection) or
secondary_server_selector(selection))
def apply_single_tag_set(tag_set, selection):
"""All servers matching one tag set.
A tag set is a dict. A server matches if its tags are a superset:
A server tagged {'a': '1', 'b': '2'} matches the tag set {'a': '1'}.
The empty tag set {} matches any server.
"""
def tags_match(server_tags):
for key, value in tag_set.items():
if key not in server_tags or server_tags[key] != value:
return False
return True
return selection.with_server_descriptions(
[s for s in selection.server_descriptions if tags_match(s.tags)])
def apply_tag_sets(tag_sets, selection):
"""All servers match a list of tag sets.
tag_sets is a list of dicts. The empty tag set {} matches any server,
and may be provided at the end of the list as a fallback. So
[{'a': 'value'}, {}] expresses a preference for servers tagged
{'a': 'value'}, but accepts any server if none matches the first
preference.
"""
for tag_set in tag_sets:
with_tag_set = apply_single_tag_set(tag_set, selection)
if with_tag_set:
return with_tag_set
return selection.with_server_descriptions([])
def secondary_with_tags_server_selector(tag_sets, selection):
"""All near-enough secondaries matching the tag sets."""
return apply_tag_sets(tag_sets, secondary_server_selector(selection))
def member_with_tags_server_selector(tag_sets, selection):
"""All near-enough members matching the tag sets."""
return apply_tag_sets(tag_sets, readable_server_selector(selection))

View File

@@ -0,0 +1,23 @@
# Copyright 2014-2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Type codes for MongoDB servers."""
from collections import namedtuple
SERVER_TYPE = namedtuple('ServerType',
['Unknown', 'Mongos', 'RSPrimary', 'RSSecondary',
'RSArbiter', 'RSOther', 'RSGhost',
'Standalone'])(*range(8))

View File

@@ -0,0 +1,129 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Represent MongoClient's configuration."""
import threading
from bson.objectid import ObjectId
from pymongo import common, monitor, pool
from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
from pymongo.errors import ConfigurationError
from pymongo.pool import PoolOptions
from pymongo.server_description import ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE
class TopologySettings(object):
def __init__(self,
seeds=None,
replica_set_name=None,
pool_class=None,
pool_options=None,
monitor_class=None,
condition_class=None,
local_threshold_ms=LOCAL_THRESHOLD_MS,
server_selection_timeout=SERVER_SELECTION_TIMEOUT,
heartbeat_frequency=common.HEARTBEAT_FREQUENCY,
server_selector=None,
fqdn=None):
"""Represent MongoClient's configuration.
Take a list of (host, port) pairs and optional replica set name.
"""
if heartbeat_frequency < common.MIN_HEARTBEAT_INTERVAL:
raise ConfigurationError(
"heartbeatFrequencyMS cannot be less than %d" % (
common.MIN_HEARTBEAT_INTERVAL * 1000,))
self._seeds = seeds or [('localhost', 27017)]
self._replica_set_name = replica_set_name
self._pool_class = pool_class or pool.Pool
self._pool_options = pool_options or PoolOptions()
self._monitor_class = monitor_class or monitor.Monitor
self._condition_class = condition_class or threading.Condition
self._local_threshold_ms = local_threshold_ms
self._server_selection_timeout = server_selection_timeout
self._server_selector = server_selector
self._fqdn = fqdn
self._heartbeat_frequency = heartbeat_frequency
self._direct = (len(self._seeds) == 1 and not replica_set_name)
self._topology_id = ObjectId()
@property
def seeds(self):
"""List of server addresses."""
return self._seeds
@property
def replica_set_name(self):
return self._replica_set_name
@property
def pool_class(self):
return self._pool_class
@property
def pool_options(self):
return self._pool_options
@property
def monitor_class(self):
return self._monitor_class
@property
def condition_class(self):
return self._condition_class
@property
def local_threshold_ms(self):
return self._local_threshold_ms
@property
def server_selection_timeout(self):
return self._server_selection_timeout
@property
def server_selector(self):
return self._server_selector
@property
def heartbeat_frequency(self):
return self._heartbeat_frequency
@property
def fqdn(self):
return self._fqdn
@property
def direct(self):
"""Connect directly to a single server, or use a set of servers?
True if there is one seed and no replica_set_name.
"""
return self._direct
def get_topology_type(self):
if self.direct:
return TOPOLOGY_TYPE.Single
elif self.replica_set_name is not None:
return TOPOLOGY_TYPE.ReplicaSetNoPrimary
else:
return TOPOLOGY_TYPE.Unknown
def get_server_descriptions(self):
"""Initial dict of (address, ServerDescription) for all seeds."""
return dict([
(address, ServerDescription(address))
for address in self.seeds])

View File

@@ -0,0 +1,191 @@
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""**DEPRECATED**: Manipulators that can edit SON objects as they enter and exit
a database.
The :class:`~pymongo.son_manipulator.SONManipulator` API has limitations as a
technique for transforming your data. Instead, it is more flexible and
straightforward to transform outgoing documents in your own code before passing
them to PyMongo, and transform incoming documents after receiving them from
PyMongo. SON Manipulators will be removed from PyMongo in 4.0.
PyMongo does **not** apply SON manipulators to documents passed to
the modern methods :meth:`~pymongo.collection.Collection.bulk_write`,
:meth:`~pymongo.collection.Collection.insert_one`,
:meth:`~pymongo.collection.Collection.insert_many`,
:meth:`~pymongo.collection.Collection.update_one`, or
:meth:`~pymongo.collection.Collection.update_many`. SON manipulators are
**not** applied to documents returned by the modern methods
:meth:`~pymongo.collection.Collection.find_one_and_delete`,
:meth:`~pymongo.collection.Collection.find_one_and_replace`, and
:meth:`~pymongo.collection.Collection.find_one_and_update`.
"""
from bson.dbref import DBRef
from bson.objectid import ObjectId
from bson.py3compat import abc
from bson.son import SON
class SONManipulator(object):
"""A base son manipulator.
This manipulator just saves and restores objects without changing them.
"""
def will_copy(self):
"""Will this SON manipulator make a copy of the incoming document?
Derived classes that do need to make a copy should override this
method, returning True instead of False. All non-copying manipulators
will be applied first (so that the user's document will be updated
appropriately), followed by copying manipulators.
"""
return False
def transform_incoming(self, son, collection):
"""Manipulate an incoming SON object.
:Parameters:
- `son`: the SON object to be inserted into the database
- `collection`: the collection the object is being inserted into
"""
if self.will_copy():
return SON(son)
return son
def transform_outgoing(self, son, collection):
"""Manipulate an outgoing SON object.
:Parameters:
- `son`: the SON object being retrieved from the database
- `collection`: the collection this object was stored in
"""
if self.will_copy():
return SON(son)
return son
class ObjectIdInjector(SONManipulator):
"""A son manipulator that adds the _id field if it is missing.
.. versionchanged:: 2.7
ObjectIdInjector is no longer used by PyMongo, but remains in this
module for backwards compatibility.
"""
def transform_incoming(self, son, collection):
"""Add an _id field if it is missing.
"""
if not "_id" in son:
son["_id"] = ObjectId()
return son
# This is now handled during BSON encoding (for performance reasons),
# but I'm keeping this here as a reference for those implementing new
# SONManipulators.
class ObjectIdShuffler(SONManipulator):
"""A son manipulator that moves _id to the first position.
"""
def will_copy(self):
"""We need to copy to be sure that we are dealing with SON, not a dict.
"""
return True
def transform_incoming(self, son, collection):
"""Move _id to the front if it's there.
"""
if not "_id" in son:
return son
transformed = SON({"_id": son["_id"]})
transformed.update(son)
return transformed
class NamespaceInjector(SONManipulator):
"""A son manipulator that adds the _ns field.
"""
def transform_incoming(self, son, collection):
"""Add the _ns field to the incoming object
"""
son["_ns"] = collection.name
return son
class AutoReference(SONManipulator):
"""Transparently reference and de-reference already saved embedded objects.
This manipulator should probably only be used when the NamespaceInjector is
also being used, otherwise it doesn't make too much sense - documents can
only be auto-referenced if they have an *_ns* field.
NOTE: this will behave poorly if you have a circular reference.
TODO: this only works for documents that are in the same database. To fix
this we'll need to add a DatabaseInjector that adds *_db* and then make
use of the optional *database* support for DBRefs.
"""
def __init__(self, db):
self.database = db
def will_copy(self):
"""We need to copy so the user's document doesn't get transformed refs.
"""
return True
def transform_incoming(self, son, collection):
"""Replace embedded documents with DBRefs.
"""
def transform_value(value):
if isinstance(value, abc.MutableMapping):
if "_id" in value and "_ns" in value:
return DBRef(value["_ns"], transform_value(value["_id"]))
else:
return transform_dict(SON(value))
elif isinstance(value, list):
return [transform_value(v) for v in value]
return value
def transform_dict(object):
for (key, value) in object.items():
object[key] = transform_value(value)
return object
return transform_dict(SON(son))
def transform_outgoing(self, son, collection):
"""Replace DBRefs with embedded documents.
"""
def transform_value(value):
if isinstance(value, DBRef):
return self.database.dereference(value)
elif isinstance(value, list):
return [transform_value(v) for v in value]
elif isinstance(value, abc.MutableMapping):
return transform_dict(SON(value))
return value
def transform_dict(object):
for (key, value) in object.items():
object[key] = transform_value(value)
return object
return transform_dict(SON(son))

View File

@@ -0,0 +1,107 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
try:
from dns import resolver
_HAVE_DNSPYTHON = True
except ImportError:
_HAVE_DNSPYTHON = False
from bson.py3compat import PY3
from pymongo.common import CONNECT_TIMEOUT
from pymongo.errors import ConfigurationError
if PY3:
# dnspython can return bytes or str from various parts
# of its API depending on version. We always want str.
def maybe_decode(text):
if isinstance(text, bytes):
return text.decode()
return text
else:
def maybe_decode(text):
return text
class _SrvResolver(object):
def __init__(self, fqdn, connect_timeout=None):
self.__fqdn = fqdn
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
# Validate the fully qualified domain name.
try:
self.__plist = self.__fqdn.split(".")[1:]
except Exception:
raise ConfigurationError("Invalid URI host: %s" % (fqdn,))
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError("Invalid URI host: %s" % (fqdn,))
def get_options(self):
try:
results = resolver.query(self.__fqdn, 'TXT',
lifetime=self.__connect_timeout)
except (resolver.NoAnswer, resolver.NXDOMAIN):
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc))
if len(results) > 1:
raise ConfigurationError('Only one TXT record is supported')
return (
b'&'.join([b''.join(res.strings) for res in results])).decode(
'utf-8')
def _resolve_uri(self, encapsulate_errors):
try:
results = resolver.query('_mongodb._tcp.' + self.__fqdn, 'SRV',
lifetime=self.__connect_timeout)
except Exception as exc:
if not encapsulate_errors:
# Raise the original error.
raise
# Else, raise all errors as ConfigurationError.
raise ConfigurationError(str(exc))
return results
def _get_srv_response_and_hosts(self, encapsulate_errors):
results = self._resolve_uri(encapsulate_errors)
# Construct address tuples
nodes = [
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port)
for res in results]
# Validate hosts
for node in nodes:
try:
nlist = node[0].split(".")[1:][-self.__slen:]
except Exception:
raise ConfigurationError("Invalid SRV host: %s" % (node[0],))
if self.__plist != nlist:
raise ConfigurationError("Invalid SRV host: %s" % (node[0],))
return results, nodes
def get_hosts(self):
_, nodes = self._get_srv_response_and_hosts(True)
return nodes
def get_hosts_and_min_ttl(self):
results, nodes = self._get_srv_response_and_hosts(False)
return nodes, results.rrset.ttl

View File

@@ -0,0 +1,96 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""A fake SSLContext implementation."""
try:
import ssl
except ImportError:
pass
class SSLContext(object):
"""A fake SSLContext.
This implements an API similar to ssl.SSLContext from python 3.2
but does not implement methods or properties that would be
incompatible with ssl.wrap_socket from python 2.7 < 2.7.9.
You must pass protocol which must be one of the PROTOCOL_* constants
defined in the ssl module. ssl.PROTOCOL_SSLv23 is recommended for maximum
interoperability.
"""
__slots__ = ('_cafile', '_certfile',
'_keyfile', '_protocol', '_verify_mode')
def __init__(self, protocol):
self._cafile = None
self._certfile = None
self._keyfile = None
self._protocol = protocol
self._verify_mode = ssl.CERT_NONE
@property
def protocol(self):
"""The protocol version chosen when constructing the context.
This attribute is read-only.
"""
return self._protocol
def __get_verify_mode(self):
"""Whether to try to verify other peers' certificates and how to
behave if verification fails. This attribute must be one of
ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
"""
return self._verify_mode
def __set_verify_mode(self, value):
"""Setter for verify_mode."""
self._verify_mode = value
verify_mode = property(__get_verify_mode, __set_verify_mode)
def load_cert_chain(self, certfile, keyfile=None):
"""Load a private key and the corresponding certificate. The certfile
string must be the path to a single file in PEM format containing the
certificate as well as any number of CA certificates needed to
establish the certificate's authenticity. The keyfile string, if
present, must point to a file containing the private key. Otherwise
the private key will be taken from certfile as well.
"""
self._certfile = certfile
self._keyfile = keyfile
def load_verify_locations(self, cafile=None, dummy=None):
"""Load a set of "certification authority"(CA) certificates used to
validate other peers' certificates when `~verify_mode` is other than
ssl.CERT_NONE.
"""
self._cafile = cafile
def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, dummy=None):
"""Wrap an existing Python socket sock and return an ssl.SSLSocket
object.
"""
return ssl.wrap_socket(sock, keyfile=self._keyfile,
certfile=self._certfile,
server_side=server_side,
cert_reqs=self._verify_mode,
ssl_version=self._protocol,
ca_certs=self._cafile,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs)

View File

@@ -0,0 +1,135 @@
# Backport of the match_hostname logic from python 3.5, with small
# changes to support IP address matching on python 2.7 and 3.4.
import re
import sys
try:
# Python 3.4+, or the ipaddress module from pypi.
from ipaddress import ip_address
except ImportError:
ip_address = lambda address: None
# ipaddress.ip_address requires unicode
if sys.version_info[0] < 3:
_unicode = unicode
else:
_unicode = lambda value: value
class CertificateError(ValueError):
pass
def _dnsname_match(dn, hostname, max_wildcards=1):
"""Matching according to RFC 6125, section 6.4.3
http://tools.ietf.org/html/rfc6125#section-6.4.3
"""
pats = []
if not dn:
return False
parts = dn.split(r'.')
leftmost = parts[0]
remainder = parts[1:]
wildcards = leftmost.count('*')
if wildcards > max_wildcards:
# Issue #17980: avoid denials of service by refusing more
# than one wildcard per fragment. A survey of established
# policy among SSL implementations showed it to be a
# reasonable choice.
raise CertificateError(
"too many wildcards in certificate DNS name: " + repr(dn))
# speed up common case w/o wildcards
if not wildcards:
return dn.lower() == hostname.lower()
# RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which
# the wildcard character comprises a label other than the left-most label.
if leftmost == '*':
# When '*' is a fragment by itself, it matches a non-empty dotless
# fragment.
pats.append('[^.]+')
elif leftmost.startswith('xn--') or hostname.startswith('xn--'):
# RFC 6125, section 6.4.3, subitem 3.
# The client SHOULD NOT attempt to match a presented identifier
# where the wildcard character is embedded within an A-label or
# U-label of an internationalized domain name.
pats.append(re.escape(leftmost))
else:
# Otherwise, '*' matches any dotless string, e.g. www*
pats.append(re.escape(leftmost).replace(r'\*', '[^.]*'))
# add the remaining fragments, ignore any wildcards
for frag in remainder:
pats.append(re.escape(frag))
pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
return pat.match(hostname)
def _ipaddress_match(ipname, host_ip):
"""Exact matching of IP addresses.
RFC 6125 explicitly doesn't define an algorithm for this
(section 1.7.2 - "Out of Scope").
"""
# OpenSSL may add a trailing newline to a subjectAltName's IP address
ip = ip_address(_unicode(ipname).rstrip())
return ip == host_ip
def match_hostname(cert, hostname):
"""Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
rules are followed.
CertificateError is raised on failure. On success, the function
returns nothing.
"""
if not cert:
raise ValueError("empty or no certificate, match_hostname needs a "
"SSL socket or SSL context with either "
"CERT_OPTIONAL or CERT_REQUIRED")
try:
host_ip = ip_address(_unicode(hostname))
except (ValueError, UnicodeError):
# Not an IP address (common case)
host_ip = None
dnsnames = []
san = cert.get('subjectAltName', ())
for key, value in san:
if key == 'DNS':
if host_ip is None and _dnsname_match(value, hostname):
return
dnsnames.append(value)
elif key == 'IP Address':
if host_ip is not None and _ipaddress_match(value, host_ip):
return
dnsnames.append(value)
if not dnsnames:
# The subject is only checked when there is no dNSName entry
# in subjectAltName
for sub in cert.get('subject', ()):
for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name
# must be used.
if key == 'commonName':
if _dnsname_match(value, hostname):
return
dnsnames.append(value)
if len(dnsnames) > 1:
raise CertificateError("hostname %r "
"doesn't match either of %s"
% (hostname, ', '.join(map(repr, dnsnames))))
elif len(dnsnames) == 1:
raise CertificateError("hostname %r "
"doesn't match %r"
% (hostname, dnsnames[0]))
else:
raise CertificateError("no appropriate commonName or "
"subjectAltName fields were found")

View File

@@ -0,0 +1,204 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Support for SSL in PyMongo."""
import atexit
import sys
import threading
HAVE_SSL = True
try:
import ssl
except ImportError:
HAVE_SSL = False
HAVE_CERTIFI = False
try:
import certifi
HAVE_CERTIFI = True
except ImportError:
pass
HAVE_WINCERTSTORE = False
try:
from wincertstore import CertFile
HAVE_WINCERTSTORE = True
except ImportError:
pass
from bson.py3compat import string_type
from pymongo.errors import ConfigurationError
_WINCERTSLOCK = threading.Lock()
_WINCERTS = None
_PY37PLUS = sys.version_info[:2] >= (3, 7)
if HAVE_SSL:
try:
# Python 2.7.9+, PyPy 2.5.1+, etc.
from ssl import SSLContext
except ImportError:
from pymongo.ssl_context import SSLContext
def validate_cert_reqs(option, value):
"""Validate the cert reqs are valid. It must be None or one of the
three values ``ssl.CERT_NONE``, ``ssl.CERT_OPTIONAL`` or
``ssl.CERT_REQUIRED``.
"""
if value is None:
return value
elif isinstance(value, string_type) and hasattr(ssl, value):
value = getattr(ssl, value)
if value in (ssl.CERT_NONE, ssl.CERT_OPTIONAL, ssl.CERT_REQUIRED):
return value
raise ValueError("The value of %s must be one of: "
"`ssl.CERT_NONE`, `ssl.CERT_OPTIONAL` or "
"`ssl.CERT_REQUIRED`" % (option,))
def validate_allow_invalid_certs(option, value):
"""Validate the option to allow invalid certificates is valid."""
# Avoid circular import.
from pymongo.common import validate_boolean_or_string
boolean_cert_reqs = validate_boolean_or_string(option, value)
if boolean_cert_reqs:
return ssl.CERT_NONE
return ssl.CERT_REQUIRED
def _load_wincerts():
"""Set _WINCERTS to an instance of wincertstore.Certfile."""
global _WINCERTS
certfile = CertFile()
certfile.addstore("CA")
certfile.addstore("ROOT")
atexit.register(certfile.close)
_WINCERTS = certfile
# XXX: Possible future work.
# - OCSP? Not supported by python at all.
# http://bugs.python.org/issue17123
# - Adding an ssl_context keyword argument to MongoClient? This might
# be useful for sites that have unusual requirements rather than
# trying to expose every SSLContext option through a keyword/uri
# parameter.
def get_ssl_context(*args):
"""Create and return an SSLContext object."""
(certfile,
keyfile,
passphrase,
ca_certs,
cert_reqs,
crlfile,
match_hostname) = args
verify_mode = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
# Note PROTOCOL_SSLv23 is about the most misleading name imaginable.
# This configures the server and client to negotiate the
# highest protocol version they both support. A very good thing.
# PROTOCOL_TLS_CLIENT was added in CPython 3.6, deprecating
# PROTOCOL_SSLv23.
ctx = SSLContext(
getattr(ssl, "PROTOCOL_TLS_CLIENT", ssl.PROTOCOL_SSLv23))
# SSLContext.check_hostname was added in CPython 2.7.9 and 3.4.
# PROTOCOL_TLS_CLIENT (added in Python 3.6) enables it by default.
if hasattr(ctx, "check_hostname"):
if _PY37PLUS and verify_mode != ssl.CERT_NONE:
# Python 3.7 uses OpenSSL's hostname matching implementation
# making it the obvious version to start using this with.
# Python 3.6 might have been a good version, but it suffers
# from https://bugs.python.org/issue32185.
# We'll use our bundled match_hostname for older Python
# versions, which also supports IP address matching
# with Python < 3.5.
ctx.check_hostname = match_hostname
else:
ctx.check_hostname = False
if hasattr(ctx, "options"):
# Explicitly disable SSLv2, SSLv3 and TLS compression. Note that
# up to date versions of MongoDB 2.4 and above already disable
# SSLv2 and SSLv3, python disables SSLv2 by default in >= 2.7.7
# and >= 3.3.4 and SSLv3 in >= 3.4.3. There is no way for us to do
# any of this explicitly for python 2.7 before 2.7.9.
ctx.options |= getattr(ssl, "OP_NO_SSLv2", 0)
ctx.options |= getattr(ssl, "OP_NO_SSLv3", 0)
# OpenSSL >= 1.0.0
ctx.options |= getattr(ssl, "OP_NO_COMPRESSION", 0)
# Python 3.7+ with OpenSSL >= 1.1.0h
ctx.options |= getattr(ssl, "OP_NO_RENEGOTIATION", 0)
if certfile is not None:
try:
if passphrase is not None:
vi = sys.version_info
# Since python just added a new parameter to an existing method
# this seems to be about the best we can do.
if (vi[0] == 2 and vi < (2, 7, 9) or
vi[0] == 3 and vi < (3, 3)):
raise ConfigurationError(
"Support for ssl_pem_passphrase requires "
"python 2.7.9+ (pypy 2.5.1+) or 3.3+")
ctx.load_cert_chain(certfile, keyfile, passphrase)
else:
ctx.load_cert_chain(certfile, keyfile)
except ssl.SSLError as exc:
raise ConfigurationError(
"Private key doesn't match certificate: %s" % (exc,))
if crlfile is not None:
if not hasattr(ctx, "verify_flags"):
raise ConfigurationError(
"Support for ssl_crlfile requires "
"python 2.7.9+ (pypy 2.5.1+) or 3.4+")
# Match the server's behavior.
ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
ctx.load_verify_locations(crlfile)
if ca_certs is not None:
ctx.load_verify_locations(ca_certs)
elif cert_reqs != ssl.CERT_NONE:
# CPython >= 2.7.9 or >= 3.4.0, pypy >= 2.5.1
if hasattr(ctx, "load_default_certs"):
ctx.load_default_certs()
# Python >= 3.2.0, useless on Windows.
elif (sys.platform != "win32" and
hasattr(ctx, "set_default_verify_paths")):
ctx.set_default_verify_paths()
elif sys.platform == "win32" and HAVE_WINCERTSTORE:
with _WINCERTSLOCK:
if _WINCERTS is None:
_load_wincerts()
ctx.load_verify_locations(_WINCERTS.name)
elif HAVE_CERTIFI:
ctx.load_verify_locations(certifi.where())
else:
raise ConfigurationError(
"`ssl_cert_reqs` is not ssl.CERT_NONE and no system "
"CA certificates could be loaded. `ssl_ca_certs` is "
"required.")
ctx.verify_mode = verify_mode
return ctx
else:
def validate_cert_reqs(option, dummy):
"""No ssl module, raise ConfigurationError."""
raise ConfigurationError("The value of %s is set but can't be "
"validated. The ssl module is not available"
% (option,))
def validate_allow_invalid_certs(option, dummy):
"""No ssl module, raise ConfigurationError."""
return validate_cert_reqs(option, dummy)
def get_ssl_context(*dummy):
"""No ssl module, raise ConfigurationError."""
raise ConfigurationError("The ssl module is not available.")

Some files were not shown because too many files have changed in this diff Show More