connection.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122
  1. from __future__ import unicode_literals
  2. from distutils.version import StrictVersion
  3. from itertools import chain
  4. import io
  5. import os
  6. import socket
  7. import sys
  8. import threading
  9. import warnings
  10. try:
  11. import ssl
  12. ssl_available = True
  13. except ImportError:
  14. ssl_available = False
  15. from redis._compat import (xrange, imap, byte_to_chr, unicode, bytes, long,
  16. nativestr, basestring, iteritems,
  17. LifoQueue, Empty, Full, urlparse, parse_qs,
  18. recv, recv_into, select, unquote)
  19. from redis.exceptions import (
  20. DataError,
  21. RedisError,
  22. ConnectionError,
  23. TimeoutError,
  24. BusyLoadingError,
  25. ResponseError,
  26. InvalidResponse,
  27. AuthenticationError,
  28. NoScriptError,
  29. ExecAbortError,
  30. ReadOnlyError
  31. )
  32. from redis.utils import HIREDIS_AVAILABLE
  33. if HIREDIS_AVAILABLE:
  34. import hiredis
  35. hiredis_version = StrictVersion(hiredis.__version__)
  36. HIREDIS_SUPPORTS_CALLABLE_ERRORS = \
  37. hiredis_version >= StrictVersion('0.1.3')
  38. HIREDIS_SUPPORTS_BYTE_BUFFER = \
  39. hiredis_version >= StrictVersion('0.1.4')
  40. if not HIREDIS_SUPPORTS_BYTE_BUFFER:
  41. msg = ("redis-py works best with hiredis >= 0.1.4. You're running "
  42. "hiredis %s. Please consider upgrading." % hiredis.__version__)
  43. warnings.warn(msg)
  44. HIREDIS_USE_BYTE_BUFFER = True
  45. # only use byte buffer if hiredis supports it
  46. if not HIREDIS_SUPPORTS_BYTE_BUFFER:
  47. HIREDIS_USE_BYTE_BUFFER = False
  48. SYM_STAR = b'*'
  49. SYM_DOLLAR = b'$'
  50. SYM_CRLF = b'\r\n'
  51. SYM_EMPTY = b''
  52. SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
  53. class Token(object):
  54. """
  55. Literal strings in Redis commands, such as the command names and any
  56. hard-coded arguments are wrapped in this class so we know not to apply
  57. and encoding rules on them.
  58. """
  59. _cache = {}
  60. @classmethod
  61. def get_token(cls, value):
  62. "Gets a cached token object or creates a new one if not already cached"
  63. # Use try/except because after running for a short time most tokens
  64. # should already be cached
  65. try:
  66. return cls._cache[value]
  67. except KeyError:
  68. token = Token(value)
  69. cls._cache[value] = token
  70. return token
  71. def __init__(self, value):
  72. if isinstance(value, Token):
  73. value = value.value
  74. self.value = value
  75. self.encoded_value = value.encode()
  76. def __repr__(self):
  77. return self.value
  78. def __str__(self):
  79. return self.value
  80. class Encoder(object):
  81. "Encode strings to bytes and decode bytes to strings"
  82. def __init__(self, encoding, encoding_errors, decode_responses):
  83. self.encoding = encoding
  84. self.encoding_errors = encoding_errors
  85. self.decode_responses = decode_responses
  86. def encode(self, value):
  87. "Return a bytestring representation of the value"
  88. if isinstance(value, Token):
  89. return value.encoded_value
  90. elif isinstance(value, bytes):
  91. return value
  92. elif isinstance(value, bool):
  93. # special case bool since it is a subclass of int
  94. raise DataError("Invalid input of type: 'bool'. Convert to a "
  95. "byte, string or number first.")
  96. elif isinstance(value, float):
  97. value = repr(value).encode()
  98. elif isinstance(value, (int, long)):
  99. # python 2 repr() on longs is '123L', so use str() instead
  100. value = str(value).encode()
  101. elif not isinstance(value, basestring):
  102. # a value we don't know how to deal with. throw an error
  103. typename = type(value).__name__
  104. raise DataError("Invalid input of type: '%s'. Convert to a "
  105. "byte, string or number first." % typename)
  106. if isinstance(value, unicode):
  107. value = value.encode(self.encoding, self.encoding_errors)
  108. return value
  109. def decode(self, value, force=False):
  110. "Return a unicode string from the byte representation"
  111. if (self.decode_responses or force) and isinstance(value, bytes):
  112. value = value.decode(self.encoding, self.encoding_errors)
  113. return value
  114. class BaseParser(object):
  115. EXCEPTION_CLASSES = {
  116. 'ERR': {
  117. 'max number of clients reached': ConnectionError
  118. },
  119. 'EXECABORT': ExecAbortError,
  120. 'LOADING': BusyLoadingError,
  121. 'NOSCRIPT': NoScriptError,
  122. 'READONLY': ReadOnlyError,
  123. }
  124. def parse_error(self, response):
  125. "Parse an error response"
  126. error_code = response.split(' ')[0]
  127. if error_code in self.EXCEPTION_CLASSES:
  128. response = response[len(error_code) + 1:]
  129. exception_class = self.EXCEPTION_CLASSES[error_code]
  130. if isinstance(exception_class, dict):
  131. exception_class = exception_class.get(response, ResponseError)
  132. return exception_class(response)
  133. return ResponseError(response)
  134. class SocketBuffer(object):
  135. def __init__(self, socket, socket_read_size):
  136. self._sock = socket
  137. self.socket_read_size = socket_read_size
  138. self._buffer = io.BytesIO()
  139. # number of bytes written to the buffer from the socket
  140. self.bytes_written = 0
  141. # number of bytes read from the buffer
  142. self.bytes_read = 0
  143. @property
  144. def length(self):
  145. return self.bytes_written - self.bytes_read
  146. def _read_from_socket(self, length=None):
  147. socket_read_size = self.socket_read_size
  148. buf = self._buffer
  149. buf.seek(self.bytes_written)
  150. marker = 0
  151. try:
  152. while True:
  153. data = recv(self._sock, socket_read_size)
  154. # an empty string indicates the server shutdown the socket
  155. if isinstance(data, bytes) and len(data) == 0:
  156. raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
  157. buf.write(data)
  158. data_length = len(data)
  159. self.bytes_written += data_length
  160. marker += data_length
  161. if length is not None and length > marker:
  162. continue
  163. break
  164. except socket.timeout:
  165. raise TimeoutError("Timeout reading from socket")
  166. except socket.error:
  167. e = sys.exc_info()[1]
  168. raise ConnectionError("Error while reading from socket: %s" %
  169. (e.args,))
  170. def read(self, length):
  171. length = length + 2 # make sure to read the \r\n terminator
  172. # make sure we've read enough data from the socket
  173. if length > self.length:
  174. self._read_from_socket(length - self.length)
  175. self._buffer.seek(self.bytes_read)
  176. data = self._buffer.read(length)
  177. self.bytes_read += len(data)
  178. # purge the buffer when we've consumed it all so it doesn't
  179. # grow forever
  180. if self.bytes_read == self.bytes_written:
  181. self.purge()
  182. return data[:-2]
  183. def readline(self):
  184. buf = self._buffer
  185. buf.seek(self.bytes_read)
  186. data = buf.readline()
  187. while not data.endswith(SYM_CRLF):
  188. # there's more data in the socket that we need
  189. self._read_from_socket()
  190. buf.seek(self.bytes_read)
  191. data = buf.readline()
  192. self.bytes_read += len(data)
  193. # purge the buffer when we've consumed it all so it doesn't
  194. # grow forever
  195. if self.bytes_read == self.bytes_written:
  196. self.purge()
  197. return data[:-2]
  198. def purge(self):
  199. self._buffer.seek(0)
  200. self._buffer.truncate()
  201. self.bytes_written = 0
  202. self.bytes_read = 0
  203. def close(self):
  204. try:
  205. self.purge()
  206. self._buffer.close()
  207. except Exception:
  208. # issue #633 suggests the purge/close somehow raised a
  209. # BadFileDescriptor error. Perhaps the client ran out of
  210. # memory or something else? It's probably OK to ignore
  211. # any error being raised from purge/close since we're
  212. # removing the reference to the instance below.
  213. pass
  214. self._buffer = None
  215. self._sock = None
  216. class PythonParser(BaseParser):
  217. "Plain Python parsing class"
  218. def __init__(self, socket_read_size):
  219. self.socket_read_size = socket_read_size
  220. self.encoder = None
  221. self._sock = None
  222. self._buffer = None
  223. def __del__(self):
  224. try:
  225. self.on_disconnect()
  226. except Exception:
  227. pass
  228. def on_connect(self, connection):
  229. "Called when the socket connects"
  230. self._sock = connection._sock
  231. self._buffer = SocketBuffer(self._sock, self.socket_read_size)
  232. self.encoder = connection.encoder
  233. def on_disconnect(self):
  234. "Called when the socket disconnects"
  235. if self._sock is not None:
  236. self._sock.close()
  237. self._sock = None
  238. if self._buffer is not None:
  239. self._buffer.close()
  240. self._buffer = None
  241. self.encoder = None
  242. def can_read(self):
  243. return self._buffer and bool(self._buffer.length)
  244. def read_response(self):
  245. response = self._buffer.readline()
  246. if not response:
  247. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  248. byte, response = byte_to_chr(response[0]), response[1:]
  249. if byte not in ('-', '+', ':', '$', '*'):
  250. raise InvalidResponse("Protocol Error: %s, %s" %
  251. (str(byte), str(response)))
  252. # server returned an error
  253. if byte == '-':
  254. response = nativestr(response)
  255. error = self.parse_error(response)
  256. # if the error is a ConnectionError, raise immediately so the user
  257. # is notified
  258. if isinstance(error, ConnectionError):
  259. raise error
  260. # otherwise, we're dealing with a ResponseError that might belong
  261. # inside a pipeline response. the connection's read_response()
  262. # and/or the pipeline's execute() will raise this error if
  263. # necessary, so just return the exception instance here.
  264. return error
  265. # single value
  266. elif byte == '+':
  267. pass
  268. # int value
  269. elif byte == ':':
  270. response = long(response)
  271. # bulk response
  272. elif byte == '$':
  273. length = int(response)
  274. if length == -1:
  275. return None
  276. response = self._buffer.read(length)
  277. # multi-bulk response
  278. elif byte == '*':
  279. length = int(response)
  280. if length == -1:
  281. return None
  282. response = [self.read_response() for i in xrange(length)]
  283. if isinstance(response, bytes):
  284. response = self.encoder.decode(response)
  285. return response
  286. class HiredisParser(BaseParser):
  287. "Parser class for connections using Hiredis"
  288. def __init__(self, socket_read_size):
  289. if not HIREDIS_AVAILABLE:
  290. raise RedisError("Hiredis is not installed")
  291. self.socket_read_size = socket_read_size
  292. if HIREDIS_USE_BYTE_BUFFER:
  293. self._buffer = bytearray(socket_read_size)
  294. def __del__(self):
  295. try:
  296. self.on_disconnect()
  297. except Exception:
  298. pass
  299. def on_connect(self, connection):
  300. self._sock = connection._sock
  301. kwargs = {
  302. 'protocolError': InvalidResponse,
  303. 'replyError': self.parse_error,
  304. }
  305. # hiredis < 0.1.3 doesn't support functions that create exceptions
  306. if not HIREDIS_SUPPORTS_CALLABLE_ERRORS:
  307. kwargs['replyError'] = ResponseError
  308. if connection.encoder.decode_responses:
  309. kwargs['encoding'] = connection.encoder.encoding
  310. self._reader = hiredis.Reader(**kwargs)
  311. self._next_response = False
  312. def on_disconnect(self):
  313. self._sock = None
  314. self._reader = None
  315. self._next_response = False
  316. def can_read(self):
  317. if not self._reader:
  318. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  319. if self._next_response is False:
  320. self._next_response = self._reader.gets()
  321. return self._next_response is not False
  322. def read_response(self):
  323. if not self._reader:
  324. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  325. # _next_response might be cached from a can_read() call
  326. if self._next_response is not False:
  327. response = self._next_response
  328. self._next_response = False
  329. return response
  330. response = self._reader.gets()
  331. socket_read_size = self.socket_read_size
  332. while response is False:
  333. try:
  334. if HIREDIS_USE_BYTE_BUFFER:
  335. bufflen = recv_into(self._sock, self._buffer)
  336. if bufflen == 0:
  337. raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
  338. else:
  339. buffer = recv(self._sock, socket_read_size)
  340. # an empty string indicates the server shutdown the socket
  341. if not isinstance(buffer, bytes) or len(buffer) == 0:
  342. raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
  343. except socket.timeout:
  344. raise TimeoutError("Timeout reading from socket")
  345. except socket.error:
  346. e = sys.exc_info()[1]
  347. raise ConnectionError("Error while reading from socket: %s" %
  348. (e.args,))
  349. if HIREDIS_USE_BYTE_BUFFER:
  350. self._reader.feed(self._buffer, 0, bufflen)
  351. else:
  352. self._reader.feed(buffer)
  353. response = self._reader.gets()
  354. # if an older version of hiredis is installed, we need to attempt
  355. # to convert ResponseErrors to their appropriate types.
  356. if not HIREDIS_SUPPORTS_CALLABLE_ERRORS:
  357. if isinstance(response, ResponseError):
  358. response = self.parse_error(response.args[0])
  359. elif isinstance(response, list) and response and \
  360. isinstance(response[0], ResponseError):
  361. response[0] = self.parse_error(response[0].args[0])
  362. # if the response is a ConnectionError or the response is a list and
  363. # the first item is a ConnectionError, raise it as something bad
  364. # happened
  365. if isinstance(response, ConnectionError):
  366. raise response
  367. elif isinstance(response, list) and response and \
  368. isinstance(response[0], ConnectionError):
  369. raise response[0]
  370. return response
  371. if HIREDIS_AVAILABLE:
  372. DefaultParser = HiredisParser
  373. else:
  374. DefaultParser = PythonParser
  375. class Connection(object):
  376. "Manages TCP communication to and from a Redis server"
  377. description_format = "Connection<host=%(host)s,port=%(port)s,db=%(db)s>"
  378. def __init__(self, host='localhost', port=6379, db=0, password=None,
  379. socket_timeout=None, socket_connect_timeout=None,
  380. socket_keepalive=False, socket_keepalive_options=None,
  381. socket_type=0, retry_on_timeout=False, encoding='utf-8',
  382. encoding_errors='strict', decode_responses=False,
  383. parser_class=DefaultParser, socket_read_size=65536):
  384. self.pid = os.getpid()
  385. self.host = host
  386. self.port = int(port)
  387. self.db = db
  388. self.password = password
  389. self.socket_timeout = socket_timeout
  390. self.socket_connect_timeout = socket_connect_timeout or socket_timeout
  391. self.socket_keepalive = socket_keepalive
  392. self.socket_keepalive_options = socket_keepalive_options or {}
  393. self.socket_type = socket_type
  394. self.retry_on_timeout = retry_on_timeout
  395. self.encoder = Encoder(encoding, encoding_errors, decode_responses)
  396. self._sock = None
  397. self._parser = parser_class(socket_read_size=socket_read_size)
  398. self._description_args = {
  399. 'host': self.host,
  400. 'port': self.port,
  401. 'db': self.db,
  402. }
  403. self._connect_callbacks = []
  404. self._buffer_cutoff = 6000
  405. def __repr__(self):
  406. return self.description_format % self._description_args
  407. def __del__(self):
  408. try:
  409. self.disconnect()
  410. except Exception:
  411. pass
  412. def register_connect_callback(self, callback):
  413. self._connect_callbacks.append(callback)
  414. def clear_connect_callbacks(self):
  415. self._connect_callbacks = []
  416. def connect(self):
  417. "Connects to the Redis server if not already connected"
  418. if self._sock:
  419. return
  420. try:
  421. sock = self._connect()
  422. except socket.timeout:
  423. raise TimeoutError("Timeout connecting to server")
  424. except socket.error:
  425. e = sys.exc_info()[1]
  426. raise ConnectionError(self._error_message(e))
  427. self._sock = sock
  428. try:
  429. self.on_connect()
  430. except RedisError:
  431. # clean up after any error in on_connect
  432. self.disconnect()
  433. raise
  434. # run any user callbacks. right now the only internal callback
  435. # is for pubsub channel/pattern resubscription
  436. for callback in self._connect_callbacks:
  437. callback(self)
  438. def _connect(self):
  439. "Create a TCP socket connection"
  440. # we want to mimic what socket.create_connection does to support
  441. # ipv4/ipv6, but we want to set options prior to calling
  442. # socket.connect()
  443. err = None
  444. for res in socket.getaddrinfo(self.host, self.port, self.socket_type,
  445. socket.SOCK_STREAM):
  446. family, socktype, proto, canonname, socket_address = res
  447. sock = None
  448. try:
  449. sock = socket.socket(family, socktype, proto)
  450. # TCP_NODELAY
  451. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  452. # TCP_KEEPALIVE
  453. if self.socket_keepalive:
  454. sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
  455. for k, v in iteritems(self.socket_keepalive_options):
  456. sock.setsockopt(socket.SOL_TCP, k, v)
  457. # set the socket_connect_timeout before we connect
  458. sock.settimeout(self.socket_connect_timeout)
  459. # connect
  460. sock.connect(socket_address)
  461. # set the socket_timeout now that we're connected
  462. sock.settimeout(self.socket_timeout)
  463. return sock
  464. except socket.error as _:
  465. err = _
  466. if sock is not None:
  467. sock.close()
  468. if err is not None:
  469. raise err
  470. raise socket.error("socket.getaddrinfo returned an empty list")
  471. def _error_message(self, exception):
  472. # args for socket.error can either be (errno, "message")
  473. # or just "message"
  474. if len(exception.args) == 1:
  475. return "Error connecting to %s:%s. %s." % \
  476. (self.host, self.port, exception.args[0])
  477. else:
  478. return "Error %s connecting to %s:%s. %s." % \
  479. (exception.args[0], self.host, self.port, exception.args[1])
  480. def on_connect(self):
  481. "Initialize the connection, authenticate and select a database"
  482. self._parser.on_connect(self)
  483. # if a password is specified, authenticate
  484. if self.password:
  485. self.send_command('AUTH', self.password)
  486. if nativestr(self.read_response()) != 'OK':
  487. raise AuthenticationError('Invalid Password')
  488. # if a database is specified, switch to it
  489. if self.db:
  490. self.send_command('SELECT', self.db)
  491. if nativestr(self.read_response()) != 'OK':
  492. raise ConnectionError('Invalid Database')
  493. def disconnect(self):
  494. "Disconnects from the Redis server"
  495. self._parser.on_disconnect()
  496. if self._sock is None:
  497. return
  498. try:
  499. self._sock.shutdown(socket.SHUT_RDWR)
  500. self._sock.close()
  501. except socket.error:
  502. pass
  503. self._sock = None
  504. def send_packed_command(self, command):
  505. "Send an already packed command to the Redis server"
  506. if not self._sock:
  507. self.connect()
  508. try:
  509. if isinstance(command, str):
  510. command = [command]
  511. for item in command:
  512. self._sock.sendall(item)
  513. except socket.timeout:
  514. self.disconnect()
  515. raise TimeoutError("Timeout writing to socket")
  516. except socket.error:
  517. e = sys.exc_info()[1]
  518. self.disconnect()
  519. if len(e.args) == 1:
  520. errno, errmsg = 'UNKNOWN', e.args[0]
  521. else:
  522. errno = e.args[0]
  523. errmsg = e.args[1]
  524. raise ConnectionError("Error %s while writing to socket. %s." %
  525. (errno, errmsg))
  526. except Exception as e:
  527. self.disconnect()
  528. raise e
  529. def send_command(self, *args):
  530. "Pack and send a command to the Redis server"
  531. self.send_packed_command(self.pack_command(*args))
  532. def can_read(self, timeout=0):
  533. "Poll the socket to see if there's data that can be read."
  534. sock = self._sock
  535. if not sock:
  536. self.connect()
  537. sock = self._sock
  538. return self._parser.can_read() or \
  539. bool(select([sock], [], [], timeout)[0])
  540. def read_response(self):
  541. "Read the response from a previously sent command"
  542. try:
  543. response = self._parser.read_response()
  544. except Exception as e:
  545. self.disconnect()
  546. raise e
  547. if isinstance(response, ResponseError):
  548. raise response
  549. return response
  550. def pack_command(self, *args):
  551. "Pack a series of arguments into the Redis protocol"
  552. output = []
  553. # the client might have included 1 or more literal arguments in
  554. # the command name, e.g., 'CONFIG GET'. The Redis server expects these
  555. # arguments to be sent separately, so split the first argument
  556. # manually. All of these arguements get wrapped in the Token class
  557. # to prevent them from being encoded.
  558. command = args[0]
  559. if ' ' in command:
  560. args = tuple(Token.get_token(s)
  561. for s in command.split()) + args[1:]
  562. else:
  563. args = (Token.get_token(command),) + args[1:]
  564. buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
  565. buffer_cutoff = self._buffer_cutoff
  566. for arg in imap(self.encoder.encode, args):
  567. # to avoid large string mallocs, chunk the command into the
  568. # output list if we're sending large values
  569. if len(buff) > buffer_cutoff or len(arg) > buffer_cutoff:
  570. buff = SYM_EMPTY.join(
  571. (buff, SYM_DOLLAR, str(len(arg)).encode(), SYM_CRLF))
  572. output.append(buff)
  573. output.append(arg)
  574. buff = SYM_CRLF
  575. else:
  576. buff = SYM_EMPTY.join(
  577. (buff, SYM_DOLLAR, str(len(arg)).encode(),
  578. SYM_CRLF, arg, SYM_CRLF))
  579. output.append(buff)
  580. return output
  581. def pack_commands(self, commands):
  582. "Pack multiple commands into the Redis protocol"
  583. output = []
  584. pieces = []
  585. buffer_length = 0
  586. buffer_cutoff = self._buffer_cutoff
  587. for cmd in commands:
  588. for chunk in self.pack_command(*cmd):
  589. chunklen = len(chunk)
  590. if buffer_length > buffer_cutoff or chunklen > buffer_cutoff:
  591. output.append(SYM_EMPTY.join(pieces))
  592. buffer_length = 0
  593. pieces = []
  594. if chunklen > self._buffer_cutoff:
  595. output.append(chunk)
  596. else:
  597. pieces.append(chunk)
  598. buffer_length += chunklen
  599. if pieces:
  600. output.append(SYM_EMPTY.join(pieces))
  601. return output
  602. class SSLConnection(Connection):
  603. description_format = "SSLConnection<host=%(host)s,port=%(port)s,db=%(db)s>"
  604. def __init__(self, ssl_keyfile=None, ssl_certfile=None,
  605. ssl_cert_reqs='required', ssl_ca_certs=None, **kwargs):
  606. if not ssl_available:
  607. raise RedisError("Python wasn't built with SSL support")
  608. super(SSLConnection, self).__init__(**kwargs)
  609. self.keyfile = ssl_keyfile
  610. self.certfile = ssl_certfile
  611. if ssl_cert_reqs is None:
  612. ssl_cert_reqs = ssl.CERT_NONE
  613. elif isinstance(ssl_cert_reqs, basestring):
  614. CERT_REQS = {
  615. 'none': ssl.CERT_NONE,
  616. 'optional': ssl.CERT_OPTIONAL,
  617. 'required': ssl.CERT_REQUIRED
  618. }
  619. if ssl_cert_reqs not in CERT_REQS:
  620. raise RedisError(
  621. "Invalid SSL Certificate Requirements Flag: %s" %
  622. ssl_cert_reqs)
  623. ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
  624. self.cert_reqs = ssl_cert_reqs
  625. self.ca_certs = ssl_ca_certs
  626. def _connect(self):
  627. "Wrap the socket with SSL support"
  628. sock = super(SSLConnection, self)._connect()
  629. sock = ssl.wrap_socket(sock,
  630. cert_reqs=self.cert_reqs,
  631. keyfile=self.keyfile,
  632. certfile=self.certfile,
  633. ca_certs=self.ca_certs)
  634. return sock
  635. class UnixDomainSocketConnection(Connection):
  636. description_format = "UnixDomainSocketConnection<path=%(path)s,db=%(db)s>"
  637. def __init__(self, path='', db=0, password=None,
  638. socket_timeout=None, encoding='utf-8',
  639. encoding_errors='strict', decode_responses=False,
  640. retry_on_timeout=False,
  641. parser_class=DefaultParser, socket_read_size=65536):
  642. self.pid = os.getpid()
  643. self.path = path
  644. self.db = db
  645. self.password = password
  646. self.socket_timeout = socket_timeout
  647. self.retry_on_timeout = retry_on_timeout
  648. self.encoder = Encoder(encoding, encoding_errors, decode_responses)
  649. self._sock = None
  650. self._parser = parser_class(socket_read_size=socket_read_size)
  651. self._description_args = {
  652. 'path': self.path,
  653. 'db': self.db,
  654. }
  655. self._connect_callbacks = []
  656. self._buffer_cutoff = 6000
  657. def _connect(self):
  658. "Create a Unix domain socket connection"
  659. sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  660. sock.settimeout(self.socket_timeout)
  661. sock.connect(self.path)
  662. return sock
  663. def _error_message(self, exception):
  664. # args for socket.error can either be (errno, "message")
  665. # or just "message"
  666. if len(exception.args) == 1:
  667. return "Error connecting to unix socket: %s. %s." % \
  668. (self.path, exception.args[0])
  669. else:
  670. return "Error %s connecting to unix socket: %s. %s." % \
  671. (exception.args[0], self.path, exception.args[1])
  672. FALSE_STRINGS = ('0', 'F', 'FALSE', 'N', 'NO')
  673. def to_bool(value):
  674. if value is None or value == '':
  675. return None
  676. if isinstance(value, basestring) and value.upper() in FALSE_STRINGS:
  677. return False
  678. return bool(value)
  679. URL_QUERY_ARGUMENT_PARSERS = {
  680. 'socket_timeout': float,
  681. 'socket_connect_timeout': float,
  682. 'socket_keepalive': to_bool,
  683. 'retry_on_timeout': to_bool,
  684. 'max_connections': int,
  685. }
  686. class ConnectionPool(object):
  687. "Generic connection pool"
  688. @classmethod
  689. def from_url(cls, url, db=None, decode_components=False, **kwargs):
  690. """
  691. Return a connection pool configured from the given URL.
  692. For example::
  693. redis://[:password]@localhost:6379/0
  694. rediss://[:password]@localhost:6379/0
  695. unix://[:password]@/path/to/socket.sock?db=0
  696. Three URL schemes are supported:
  697. - ```redis://``
  698. <https://www.iana.org/assignments/uri-schemes/prov/redis>`_ creates a
  699. normal TCP socket connection
  700. - ```rediss://``
  701. <https://www.iana.org/assignments/uri-schemes/prov/rediss>`_ creates
  702. a SSL wrapped TCP socket connection
  703. - ``unix://`` creates a Unix Domain Socket connection
  704. There are several ways to specify a database number. The parse function
  705. will return the first specified option:
  706. 1. A ``db`` querystring option, e.g. redis://localhost?db=0
  707. 2. If using the redis:// scheme, the path argument of the url, e.g.
  708. redis://localhost/0
  709. 3. The ``db`` argument to this function.
  710. If none of these options are specified, db=0 is used.
  711. The ``decode_components`` argument allows this function to work with
  712. percent-encoded URLs. If this argument is set to ``True`` all ``%xx``
  713. escapes will be replaced by their single-character equivalents after
  714. the URL has been parsed. This only applies to the ``hostname``,
  715. ``path``, and ``password`` components.
  716. Any additional querystring arguments and keyword arguments will be
  717. passed along to the ConnectionPool class's initializer. The querystring
  718. arguments ``socket_connect_timeout`` and ``socket_timeout`` if supplied
  719. are parsed as float values. The arguments ``socket_keepalive`` and
  720. ``retry_on_timeout`` are parsed to boolean values that accept
  721. True/False, Yes/No values to indicate state. Invalid types cause a
  722. ``UserWarning`` to be raised. In the case of conflicting arguments,
  723. querystring arguments always win.
  724. """
  725. url = urlparse(url)
  726. url_options = {}
  727. for name, value in iteritems(parse_qs(url.query)):
  728. if value and len(value) > 0:
  729. parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
  730. if parser:
  731. try:
  732. url_options[name] = parser(value[0])
  733. except (TypeError, ValueError):
  734. warnings.warn(UserWarning(
  735. "Invalid value for `%s` in connection URL." % name
  736. ))
  737. else:
  738. url_options[name] = value[0]
  739. if decode_components:
  740. password = unquote(url.password) if url.password else None
  741. path = unquote(url.path) if url.path else None
  742. hostname = unquote(url.hostname) if url.hostname else None
  743. else:
  744. password = url.password
  745. path = url.path
  746. hostname = url.hostname
  747. # We only support redis:// and unix:// schemes.
  748. if url.scheme == 'unix':
  749. url_options.update({
  750. 'password': password,
  751. 'path': path,
  752. 'connection_class': UnixDomainSocketConnection,
  753. })
  754. else:
  755. url_options.update({
  756. 'host': hostname,
  757. 'port': int(url.port or 6379),
  758. 'password': password,
  759. })
  760. # If there's a path argument, use it as the db argument if a
  761. # querystring value wasn't specified
  762. if 'db' not in url_options and path:
  763. try:
  764. url_options['db'] = int(path.replace('/', ''))
  765. except (AttributeError, ValueError):
  766. pass
  767. if url.scheme == 'rediss':
  768. url_options['connection_class'] = SSLConnection
  769. # last shot at the db value
  770. url_options['db'] = int(url_options.get('db', db or 0))
  771. # update the arguments from the URL values
  772. kwargs.update(url_options)
  773. # backwards compatability
  774. if 'charset' in kwargs:
  775. warnings.warn(DeprecationWarning(
  776. '"charset" is deprecated. Use "encoding" instead'))
  777. kwargs['encoding'] = kwargs.pop('charset')
  778. if 'errors' in kwargs:
  779. warnings.warn(DeprecationWarning(
  780. '"errors" is deprecated. Use "encoding_errors" instead'))
  781. kwargs['encoding_errors'] = kwargs.pop('errors')
  782. return cls(**kwargs)
  783. def __init__(self, connection_class=Connection, max_connections=None,
  784. **connection_kwargs):
  785. """
  786. Create a connection pool. If max_connections is set, then this
  787. object raises redis.ConnectionError when the pool's limit is reached.
  788. By default, TCP connections are created unless connection_class is
  789. specified. Use redis.UnixDomainSocketConnection for unix sockets.
  790. Any additional keyword arguments are passed to the constructor of
  791. connection_class.
  792. """
  793. max_connections = max_connections or 2 ** 31
  794. if not isinstance(max_connections, (int, long)) or max_connections < 0:
  795. raise ValueError('"max_connections" must be a positive integer')
  796. self.connection_class = connection_class
  797. self.connection_kwargs = connection_kwargs
  798. self.max_connections = max_connections
  799. self.reset()
  800. def __repr__(self):
  801. return "%s<%s>" % (
  802. type(self).__name__,
  803. self.connection_class.description_format % self.connection_kwargs,
  804. )
  805. def reset(self):
  806. self.pid = os.getpid()
  807. self._created_connections = 0
  808. self._available_connections = []
  809. self._in_use_connections = set()
  810. self._check_lock = threading.Lock()
  811. def _checkpid(self):
  812. if self.pid != os.getpid():
  813. with self._check_lock:
  814. if self.pid == os.getpid():
  815. # another thread already did the work while we waited
  816. # on the lock.
  817. return
  818. self.disconnect()
  819. self.reset()
  820. def get_connection(self, command_name, *keys, **options):
  821. "Get a connection from the pool"
  822. self._checkpid()
  823. try:
  824. connection = self._available_connections.pop()
  825. except IndexError:
  826. connection = self.make_connection()
  827. self._in_use_connections.add(connection)
  828. return connection
  829. def get_encoder(self):
  830. "Return an encoder based on encoding settings"
  831. kwargs = self.connection_kwargs
  832. return Encoder(
  833. encoding=kwargs.get('encoding', 'utf-8'),
  834. encoding_errors=kwargs.get('encoding_errors', 'strict'),
  835. decode_responses=kwargs.get('decode_responses', False)
  836. )
  837. def make_connection(self):
  838. "Create a new connection"
  839. if self._created_connections >= self.max_connections:
  840. raise ConnectionError("Too many connections")
  841. self._created_connections += 1
  842. return self.connection_class(**self.connection_kwargs)
  843. def release(self, connection):
  844. "Releases the connection back to the pool"
  845. self._checkpid()
  846. if connection.pid != self.pid:
  847. return
  848. self._in_use_connections.remove(connection)
  849. self._available_connections.append(connection)
  850. def disconnect(self):
  851. "Disconnects all connections in the pool"
  852. all_conns = chain(self._available_connections,
  853. self._in_use_connections)
  854. for connection in all_conns:
  855. connection.disconnect()
  856. class BlockingConnectionPool(ConnectionPool):
  857. """
  858. Thread-safe blocking connection pool::
  859. >>> from redis.client import Redis
  860. >>> client = Redis(connection_pool=BlockingConnectionPool())
  861. It performs the same function as the default
  862. ``:py:class: ~redis.connection.ConnectionPool`` implementation, in that,
  863. it maintains a pool of reusable connections that can be shared by
  864. multiple redis clients (safely across threads if required).
  865. The difference is that, in the event that a client tries to get a
  866. connection from the pool when all of connections are in use, rather than
  867. raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default
  868. ``:py:class: ~redis.connection.ConnectionPool`` implementation does), it
  869. makes the client wait ("blocks") for a specified number of seconds until
  870. a connection becomes available.
  871. Use ``max_connections`` to increase / decrease the pool size::
  872. >>> pool = BlockingConnectionPool(max_connections=10)
  873. Use ``timeout`` to tell it either how many seconds to wait for a connection
  874. to become available, or to block forever:
  875. # Block forever.
  876. >>> pool = BlockingConnectionPool(timeout=None)
  877. # Raise a ``ConnectionError`` after five seconds if a connection is
  878. # not available.
  879. >>> pool = BlockingConnectionPool(timeout=5)
  880. """
  881. def __init__(self, max_connections=50, timeout=20,
  882. connection_class=Connection, queue_class=LifoQueue,
  883. **connection_kwargs):
  884. self.queue_class = queue_class
  885. self.timeout = timeout
  886. super(BlockingConnectionPool, self).__init__(
  887. connection_class=connection_class,
  888. max_connections=max_connections,
  889. **connection_kwargs)
  890. def reset(self):
  891. self.pid = os.getpid()
  892. self._check_lock = threading.Lock()
  893. # Create and fill up a thread safe queue with ``None`` values.
  894. self.pool = self.queue_class(self.max_connections)
  895. while True:
  896. try:
  897. self.pool.put_nowait(None)
  898. except Full:
  899. break
  900. # Keep a list of actual connection instances so that we can
  901. # disconnect them later.
  902. self._connections = []
  903. def make_connection(self):
  904. "Make a fresh connection."
  905. connection = self.connection_class(**self.connection_kwargs)
  906. self._connections.append(connection)
  907. return connection
  908. def get_connection(self, command_name, *keys, **options):
  909. """
  910. Get a connection, blocking for ``self.timeout`` until a connection
  911. is available from the pool.
  912. If the connection returned is ``None`` then creates a new connection.
  913. Because we use a last-in first-out queue, the existing connections
  914. (having been returned to the pool after the initial ``None`` values
  915. were added) will be returned before ``None`` values. This means we only
  916. create new connections when we need to, i.e.: the actual number of
  917. connections will only increase in response to demand.
  918. """
  919. # Make sure we haven't changed process.
  920. self._checkpid()
  921. # Try and get a connection from the pool. If one isn't available within
  922. # self.timeout then raise a ``ConnectionError``.
  923. connection = None
  924. try:
  925. connection = self.pool.get(block=True, timeout=self.timeout)
  926. except Empty:
  927. # Note that this is not caught by the redis client and will be
  928. # raised unless handled by application code. If you want never to
  929. raise ConnectionError("No connection available.")
  930. # If the ``connection`` is actually ``None`` then that's a cue to make
  931. # a new connection to add to the pool.
  932. if connection is None:
  933. connection = self.make_connection()
  934. return connection
  935. def release(self, connection):
  936. "Releases the connection back to the pool."
  937. # Make sure we haven't changed process.
  938. self._checkpid()
  939. if connection.pid != self.pid:
  940. return
  941. # Put the connection back into the pool.
  942. try:
  943. self.pool.put_nowait(connection)
  944. except Full:
  945. # perhaps the pool has been reset() after a fork? regardless,
  946. # we don't want this connection
  947. pass
  948. def disconnect(self):
  949. "Disconnects all connections in the pool."
  950. for connection in self._connections:
  951. connection.disconnect()