# Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is also distributed with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation.  The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have included with
# MySQL.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA

"""MySQL X DevAPI Python implementation"""

import re
import json
import logging

from . import constants
from .compat import (INT_TYPES, STRING_TYPES, JSONDecodeError, urlparse,
                     unquote, parse_qsl)
from .connection import Client, Session
from .constants import Auth, LockContention, SSLMode
from .crud import Schema, Collection, Table, View
from .dbdoc import DbDoc
# pylint: disable=W0622
from .errors import (Error, InterfaceError, DatabaseError, NotSupportedError,
                     DataError, IntegrityError, ProgrammingError,
                     OperationalError, InternalError, PoolError, TimeoutError)
from .result import (Column, Row, Result, BufferingResult, RowResult,
                     SqlResult, DocResult, ColumnType)
from .statement import (Statement, FilterableStatement, SqlStatement,
                        FindStatement, AddStatement, RemoveStatement,
                        ModifyStatement, SelectStatement, InsertStatement,
                        DeleteStatement, UpdateStatement,
                        CreateCollectionIndexStatement, Expr, ReadStatement,
                        WriteStatement)

from .expr import ExprParser as expr

_SPLIT = re.compile(r',(?![^\(\)]*\))')
_PRIORITY = re.compile(r'^\(address=(.+),priority=(\d+)\)$', re.VERBOSE)
_SSL_OPTS = ["ssl-cert", "ssl-ca", "ssl-key", "ssl-crl"]
_SESS_OPTS = _SSL_OPTS + ["user", "password", "schema", "host", "port",
                          "routers", "socket", "ssl-mode", "auth", "use-pure",
                          "connect-timeout", "connection-attributes"]

logging.getLogger(__name__).addHandler(logging.NullHandler())


def _parse_address_list(path):
    """Parses a list of host, port pairs

    Args:
        path: String containing a list of routers or just router

    Returns:
        Returns a dict with parsed values of host, port and priority if
        specified.
    """
    path = path.replace(" ", "")
    array = not("," not in path and path.count(":") > 1
                and path.count("[") == 1) and path.startswith("[") \
                and path.endswith("]")

    routers = []
    address_list = _SPLIT.split(path[1:-1] if array else path)
    for address in address_list:
        router = {}

        match = _PRIORITY.match(address)
        if match:
            address = match.group(1)
            router["priority"] = int(match.group(2))

        match = urlparse("//{0}".format(address))
        if not match.hostname:
            raise InterfaceError("Invalid address: {0}".format(address))

        router.update(host=match.hostname, port=match.port)
        routers.append(router)

    return {"routers": routers} if array else routers[0]


def _parse_connection_uri(uri):
    """Parses the connection string and returns a dictionary with the
    connection settings.

    Args:
        uri: mysqlx URI scheme to connect to a MySQL server/farm.

    Returns:
        Returns a dict with parsed values of credentials and address of the
        MySQL server/farm.
    """
    settings = {"schema": ""}
    uri = "{0}{1}".format("" if uri.startswith("mysqlx://")
                          else "mysqlx://", uri)
    _, temp = uri.split("://", 1)
    userinfo, temp = temp.partition("@")[::2]
    host, query_str = temp.partition("?")[::2]

    pos = host.rfind("/")
    if host[pos:].find(")") == -1 and pos > 0:
        host, settings["schema"] = host.rsplit("/", 1)
    host = host.strip("()")

    if not host or not userinfo or ":" not in userinfo:
        raise InterfaceError("Malformed URI '{0}'".format(uri))
    user, password = userinfo.split(":", 1)
    settings["user"], settings["password"] = unquote(user), unquote(password)

    if host.startswith(("/", "..", ".")):
        settings["socket"] = unquote(host)
    elif host.startswith("\\."):
        raise InterfaceError("Windows Pipe is not supported.")
    else:
        settings.update(_parse_address_list(host))

    for key, val in parse_qsl(query_str, True):
        opt = key.replace("_", "-").lower()
        if opt in settings:
            raise InterfaceError("Duplicate option '{0}'.".format(key))
        if opt in _SSL_OPTS:
            settings[opt] = unquote(val.strip("()"))
        else:
            val_str = val.lower()
            if val_str in ("1", "true"):
                settings[opt] = True
            elif val_str in ("0", "false"):
                settings[opt] = False
            else:
                settings[opt] = val_str
    return settings


def _validate_settings(settings):
    """Validates the settings to be passed to a Session object
    the port values are converted to int if specified or set to 33060
    otherwise. The priority values for each router is converted to int
    if specified.

    Args:
        settings: dict containing connection settings.
    """
    invalid_opts = set(settings.keys()).difference(_SESS_OPTS)
    if invalid_opts:
        raise ProgrammingError("Invalid options: {0}."
                               "".format(", ".join(invalid_opts)))

    if "routers" in settings:
        for router in settings["routers"]:
            _validate_hosts(router)
    elif "host" in settings:
        _validate_hosts(settings)

    if "ssl-mode" in settings:
        try:
            settings["ssl-mode"] = settings["ssl-mode"].lower()
            SSLMode.index(settings["ssl-mode"])
        except (AttributeError, ValueError):
            raise InterfaceError("Invalid SSL Mode '{0}'."
                                 "".format(settings["ssl-mode"]))
        if settings["ssl-mode"] == SSLMode.DISABLED and \
            any(key in settings for key in _SSL_OPTS):
            raise InterfaceError("SSL options used with ssl-mode 'disabled'.")

    if "ssl-crl" in settings and not "ssl-ca" in settings:
        raise InterfaceError("CA Certificate not provided.")
    if "ssl-key" in settings and not "ssl-cert" in settings:
        raise InterfaceError("Client Certificate not provided.")

    if not "ssl-ca" in settings and settings.get("ssl-mode") \
        in [SSLMode.VERIFY_IDENTITY, SSLMode.VERIFY_CA]:
        raise InterfaceError("Cannot verify Server without CA.")
    if "ssl-ca" in settings and settings.get("ssl-mode") \
        not in [SSLMode.VERIFY_IDENTITY, SSLMode.VERIFY_CA]:
        raise InterfaceError("Must verify Server if CA is provided.")

    if "auth" in settings:
        try:
            settings["auth"] = settings["auth"].lower()
            Auth.index(settings["auth"])
        except (AttributeError, ValueError):
            raise InterfaceError("Invalid Auth '{0}'".format(settings["auth"]))

    if "connection-attributes" in settings:
        validate_connection_attributes(settings)


def _validate_hosts(settings):
    """Validate hosts.

    Args:
        settings (dict): Settings dictionary.

    Raises:
        :class:`mysqlx.InterfaceError`: If priority or port are invalid.
    """
    if "priority" in settings and settings["priority"]:
        try:
            settings["priority"] = int(settings["priority"])
        except NameError:
            raise InterfaceError("Invalid priority")

    if "port" in settings and settings["port"]:
        try:
            settings["port"] = int(settings["port"])
        except NameError:
            raise InterfaceError("Invalid port")
    elif "host" in settings:
        settings["port"] = 33060


def validate_connection_attributes(settings):
    """Validate connection-attributes.

    Args:
        settings (dict): Settings dictionary.

    Raises:
        :class:`mysqlx.InterfaceError`: If attribute name or value exceeds size.
    """
    attributes = {}
    if "connection-attributes" not in settings:
        return

    conn_attrs = settings["connection-attributes"]

    if isinstance(conn_attrs, STRING_TYPES):
        if conn_attrs == "":
            settings["connection-attributes"] = {}
            return
        if not (conn_attrs.startswith("[") and conn_attrs.endswith("]")) and \
           not conn_attrs in ['False', "false", "True", "true"]:
            raise InterfaceError("connection-attributes must be Boolean or a "
                                 "list of key-value pairs, found: '{}'"
                                 "".format(conn_attrs))
        elif conn_attrs in ['False', "false", "True", "true"]:
            if conn_attrs in ['False', "false"]:
                settings["connection-attributes"] = False
            else:
                settings["connection-attributes"] = {}
            return
        else:
            conn_attributes = conn_attrs[1:-1].split(",")
            for attr in conn_attributes:
                if attr == "":
                    continue
                attr_name_val = attr.split('=')
                attr_name = attr_name_val[0]
                attr_val = attr_name_val[1] if len(attr_name_val) > 1 else ""
                if attr_name in attributes:
                    raise InterfaceError("Duplicate key '{}' used in "
                                         "connection-attributes"
                                         "".format(attr_name))
                else:
                    attributes[attr_name] = attr_val
    elif isinstance(conn_attrs, dict):
        for attr_name in conn_attrs:
            attr_value = conn_attrs[attr_name]
            if not isinstance(attr_value, STRING_TYPES):
                attr_value = repr(attr_value)
            attributes[attr_name] = attr_value
    elif isinstance(conn_attrs, bool) or conn_attrs in [0, 1]:
        if conn_attrs:
            settings["connection-attributes"] = {}
        else:
            settings["connection-attributes"] = False
        return
    elif isinstance(conn_attrs, set):
        for attr_name in conn_attrs:
            attributes[attr_name] = ""
    elif isinstance(conn_attrs, list):
        for attr in conn_attrs:
            if attr == "":
                continue
            attr_name_val = attr.split('=')
            attr_name = attr_name_val[0]
            attr_val = attr_name_val[1] if len(attr_name_val) > 1 else ""
            if attr_name in attributes:
                raise InterfaceError("Duplicate key '{}' used in "
                                     "connection-attributes"
                                     "".format(attr_name))
            else:
                attributes[attr_name] = attr_val
    elif not isinstance(conn_attrs, bool):
        raise InterfaceError("connection-attributes must be Boolean or a list "
                             "of key-value pairs, found: '{}'"
                             "".format(conn_attrs))

    if attributes:
        for attr_name in attributes:
            attr_value = attributes[attr_name]

            # Validate name type
            if not isinstance(attr_name, STRING_TYPES):
                raise InterfaceError("Attribute name '{}' must be a string"
                                     "type".format(attr_name))
            # Validate attribute name limit 32 characters
            if len(attr_name) > 32:
                raise InterfaceError("Attribute name '{}' exceeds 32 "
                                     "characters limit size.".format(attr_name))
            # Validate names in connection-attributes cannot start with "_"
            if attr_name.startswith("_"):
                raise InterfaceError("Key names in connection-attributes "
                                     "cannot start with '_', found: '{}'"
                                     "".format(attr_name))

            # Validate value type
            if not isinstance(attr_value, STRING_TYPES):
                raise InterfaceError("Attribute '{}' value: '{}' must "
                                     "be a string type."
                                     "".format(attr_name, attr_value))
            # Validate attribute value limit 1024 characters
            if len(attr_value) > 1024:
                raise InterfaceError("Attribute '{}' value: '{}' "
                                     "exceeds 1024 characters limit size"
                                     "".format(attr_name, attr_value))

    settings["connection-attributes"] = attributes


def _get_connection_settings(*args, **kwargs):
    """Parses the connection string and returns a dictionary with the
    connection settings.

    Args:
        *args: Variable length argument list with the connection data used
               to connect to the database. It can be a dictionary or a
               connection string.
        **kwargs: Arbitrary keyword arguments with connection data used to
                  connect to the database.

    Returns:
        mysqlx.Session: Session object.

    Raises:
        TypeError: If connection timeout is not a positive integer.
    """
    settings = {}
    if args:
        if isinstance(args[0], STRING_TYPES):
            settings = _parse_connection_uri(args[0])
        elif isinstance(args[0], dict):
            for key, val in args[0].items():
                settings[key.replace("_", "-")] = val
    elif kwargs:
        for key, val in kwargs.items():
            settings[key.replace("_", "-")] = val

    if not settings:
        raise InterfaceError("Settings not provided")

    if "connect-timeout" in settings:
        try:
            if isinstance(settings["connect-timeout"], STRING_TYPES):
                settings["connect-timeout"] = int(settings["connect-timeout"])
            if not isinstance(settings["connect-timeout"], INT_TYPES) \
               or settings["connect-timeout"] < 0:
                raise ValueError
        except ValueError:
            raise TypeError("The connection timeout value must be a positive "
                            "integer (including 0)")

    _validate_settings(settings)
    return settings


def get_session(*args, **kwargs):
    """Creates a Session instance using the provided connection data.

    Args:
        *args: Variable length argument list with the connection data used
               to connect to a MySQL server. It can be a dictionary or a
               connection string.
        **kwargs: Arbitrary keyword arguments with connection data used to
                  connect to the database.

    Returns:
        mysqlx.Session: Session object.
    """
    settings = _get_connection_settings(*args, **kwargs)
    return Session(settings)


def get_client(connection_string, options_string):
    """Creates a Client instance with the provided connection data and settings.

    Args:
        connection_string: A string or a dict type object to indicate the \
            connection data used to connect to a MySQL server.

            The string must have the following uri format::

                cnx_str = 'mysqlx://{user}:{pwd}@{host}:{port}'
                cnx_str = ('mysqlx://{user}:{pwd}@['
                           '    (address={host}:{port}, priority=n),'
                           '    (address={host}:{port}, priority=n), ...]'
                           '       ?[option=value]')

            And the dictionary::

                cnx_dict = {
                    'host': 'The host where the MySQL product is running',
                    'port': '(int) the port number configured for X protocol',
                    'user': 'The user name account',
                    'password': 'The password for the given user account',
                    'ssl-mode': 'The flags for ssl mode in mysqlx.SSLMode.FLAG',
                    'ssl-ca': 'The path to the ca.cert'
                    "connect-timeout": '(int) milliseconds to wait on timeout'
                }

        options_string: A string in the form of a document or a dictionary \
            type with configuration for the client.

            Current options include::

                options = {
                    'pooling': {
                        'enabled': (bool), # [True | False], True by default
                        'max_size': (int), # Maximum connections per pool
                        "max_idle_time": (int), # milliseconds that a
                            # connection will remain active while not in use.
                            # By default 0, means infinite.
                        "queue_timeout": (int), # milliseconds a request will
                            # wait for a connection to become available.
                            # By default 0, means infinite.
                    }
                }

    Returns:
        mysqlx.Client: Client object.

    .. versionadded:: 8.0.13
    """
    if not isinstance(connection_string, (STRING_TYPES, dict)):
        raise InterfaceError("connection_data must be a string or dict")

    settings_dict = _get_connection_settings(connection_string)

    if not isinstance(options_string, (STRING_TYPES, dict)):
        raise InterfaceError("connection_options must be a string or dict")

    if isinstance(options_string, STRING_TYPES):
        try:
            options_dict = json.loads(options_string)
        except JSONDecodeError:
            raise InterfaceError("'pooling' options must be given in the form "
                                 "of a document or dict")
    else:
        options_dict = {}
        for key, value in options_string.items():
            options_dict[key.replace("-", "_")] = value

    if not isinstance(options_dict, dict):
        raise InterfaceError("'pooling' options must be given in the form of a "
                             "document or dict")
    pooling_options_dict = {}
    if "pooling" in options_dict:
        pooling_options = options_dict.pop("pooling")
        if not isinstance(pooling_options, (dict)):
            raise InterfaceError("'pooling' options must be given in the form "
                                 "document or dict")
        # Fill default pooling settings
        pooling_options_dict["enabled"] = pooling_options.pop("enabled", True)
        pooling_options_dict["max_size"] = pooling_options.pop("max_size", 25)
        pooling_options_dict["max_idle_time"] = \
            pooling_options.pop("max_idle_time", 0)
        pooling_options_dict["queue_timeout"] = \
            pooling_options.pop("queue_timeout", 0)

        # No other options besides pooling are supported
        if len(pooling_options) > 0:
            raise InterfaceError("Unrecognized pooling options: {}"
                                 "".format(pooling_options))
        # No other options besides pooling are supported
        if len(options_dict) > 0:
            raise InterfaceError("Unrecognized connection options: {}"
                                 "".format(options_dict.keys()))

    return Client(settings_dict, pooling_options_dict)


__all__ = [
    # mysqlx.connection
    "Client", "Session", "get_client", "get_session", "expr",

    # mysqlx.constants
    "Auth", "LockContention", "SSLMode",

    # mysqlx.crud
    "Schema", "Collection", "Table", "View",

    # mysqlx.errors
    "Error", "InterfaceError", "DatabaseError", "NotSupportedError",
    "DataError", "IntegrityError", "ProgrammingError", "OperationalError",
    "InternalError", "PoolError", "TimeoutError",

    # mysqlx.result
    "Column", "Row", "Result", "BufferingResult", "RowResult",
    "SqlResult", "DocResult", "ColumnType",

    # mysqlx.statement
    "DbDoc", "Statement", "FilterableStatement", "SqlStatement",
    "FindStatement", "AddStatement", "RemoveStatement", "ModifyStatement",
    "SelectStatement", "InsertStatement", "DeleteStatement", "UpdateStatement",
    "CreateCollectionIndexStatement", "Expr",
]
