#!/opt/cloudlinux/venv/bin/python3 -bb
# coding=utf-8
#
# Copyright © Cloud Linux GmbH & Cloud Linux Software, Inc 2010-2021 All Rights Reserved
#
# Licensed under CLOUD LINUX LICENSE AGREEMENT
# http://cloudlinux.com/docs/LICENCE.TXT
#
import doctest
import importlib
import logging
import os
import sys
# in order to run tests no
# matter which pwd is set
sys.path.append(os.path.dirname(__file__))
import unittest
from unittest.loader import _make_failed_load_tests

import coverage
from tap import TAPTestRunner

from lvestats.lib.dbengine import make_db_engine, get_db_client_library_name
from lvestats.lib.parsers.run_tests_argparse import run_tests_parser, ALLOWED_DB_TYPES
from test.base import base_db_test
from test.base.base_db_test import BaseDb

DOCTESTS = ['lvestats.lib.commons.func',
            'lvestats.lib.chart.dbgovchartmain',
            'lvestats.lib.commons.sizeutil']


class SortedTestLoader(unittest.TestLoader):
    def __init__(self, db_tests_only=False):
        self.db_tests_only = db_tests_only
        super(SortedTestLoader, self).__init__()

    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
        """Find and return all test modules from the specified start
        directory, recursing into subdirectories to find them. Only test files
        that match the pattern will be loaded. (Using shell style pattern
        matching.)

        All test modules must be importable from the top level of the project.
        If the start directory is not the top level directory then the top
        level directory must be specified separately.

        If a test package name (directory with '__init__.py') matches the
        pattern then the package will be checked for a 'load_tests' function. If
        this exists then it will be called with loader, tests, pattern.

        If load_tests exists then discovery does  *not* recurse into the package,
        load_tests is responsible for loading all tests in the package.

        The pattern is deliberately not stored as a loader attribute so that
        packages can continue discovery themselves. top_level_dir is stored so
        load_tests does not need to pass this argument in to loader.discover().
        """
        set_implicit_top = False
        if top_level_dir is None and self._top_level_dir is not None:
            # make top_level_dir optional if called from load_tests in a package
            top_level_dir = self._top_level_dir
        elif top_level_dir is None:
            set_implicit_top = True
            top_level_dir = start_dir

        top_level_dir = os.path.abspath(top_level_dir)

        if top_level_dir not in sys.path:
            # all test modules must be importable from the top level directory
            # should we *unconditionally* put the start directory in first
            # in sys.path to minimise likelihood of conflicts between installed
            # modules and development versions?
            sys.path.insert(0, top_level_dir)
        self._top_level_dir = top_level_dir

        is_not_importable = False
        if os.path.isdir(os.path.abspath(start_dir)):
            start_dir = os.path.abspath(start_dir)
            if start_dir != top_level_dir:
                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
        else:
            # support for discovery from dotted module names
            try:
                __import__(start_dir)
            except ImportError:
                is_not_importable = True
            else:
                the_module = sys.modules[start_dir]
                top_part = start_dir.split('.')[0]
                start_dir = os.path.abspath(os.path.dirname(the_module.__file__))
                if set_implicit_top:
                    self._top_level_dir = self._get_directory_containing_module(top_part)
                    sys.path.remove(top_level_dir)

        if is_not_importable:
            raise ImportError('Start directory is not importable: %r' % start_dir)

        tests_ = list(self._find_tests(start_dir, pattern))
        return self.suiteClass(tests_)

    def loadTestsFromModule(self, module, *args, pattern=None, use_load_tests=True, **kws):  # pylint: disable=arguments-differ
        """Return a suite of all tests cases contained in the given module"""
        tests = []
        for name in dir(module):
            obj = getattr(module, name)
            if isinstance(obj, type) and issubclass(obj, unittest.case.TestCase):
                if self.db_tests_only and not issubclass(obj, base_db_test.BaseDbTest):
                    continue  # Skip not DB tests if --db-tests-only argument is true
                tests.append(self.loadTestsFromTestCase(obj))

        load_tests = getattr(module, 'load_tests', None)
        tests = self.suiteClass(tests)
        if use_load_tests and load_tests is not None:
            try:
                return load_tests(self, tests, None)
            except Exception as e:
                return _make_failed_load_tests(module.__name__, e,
                                               self.suiteClass)
        return tests


def gathering_doctests():
    doc_test_finder = doctest.DocTestFinder(recurse=True, exclude_empty=False)
    doc_test_suit = unittest.TestSuite()
    for doc_tested_module in DOCTESTS:
        doc_test_suit.addTests(doctest.DocTestSuite(
            doc_tested_module, test_finder=doc_test_finder))
    return doc_test_suit


if __name__ == "__main__":
    parser = run_tests_parser()
    opts = parser.parse_args()

    if opts.coverage:
        cov = coverage.coverage()
        cov.start()

    verbose_level = opts.verbose + 1
    if opts.quiet:
        verbose_level = 0

    # CLTEST_DBURI=postgresql://postgres:sW%3EKMpW%24@127.0.0.1 /opt/alt/python37/bin/nosetests -a postgresql -v
    # CLTEST_DBURI=mysql://root:Jzb%2477%21JkS%7C%7D%3Bi%23%29b@localhost /opt/alt/python37/bin/nosetests -a mysql -v

    if opts.dbtype not in ALLOWED_DB_TYPES:
        print("Db type should be one of %s" % str(ALLOWED_DB_TYPES))
        sys.exit(1)

    if opts.path:
        sys.path.insert(0, os.path.normpath(os.path.abspath(opts.path)))

    if opts.dbtype != "sqlite" and not all((opts.login, opts.password)):
        print("login (-dl) and password (-dp) should be specified for %s database" % opts.dbtype)
        sys.exit(1)
    if verbose_level < 4:
        logging.disable(logging.CRITICAL)
    if opts.dbtype != 'sqlite':
        dburi = "{dbtype}+{library}://".format(dbtype=opts.dbtype,
                                               library=get_db_client_library_name(opts.dbtype))
        dburi += "{login}:{password}@{host}".format(login=opts.login, password=opts.password, host=opts.host)
    else:
        dburi = "{dbtype}://".format(dbtype=opts.dbtype)
    if opts.dbname:
        dburi += '/' + opts.dbname

    BaseDb.init_database(dburi, drop_database_=False)
    BaseDb.set_superglobal_engine(make_db_engine(BaseDb.cfg))
    BaseDb.create_tables()

    tests = SortedTestLoader(db_tests_only=opts.db_only).discover(".", "{0}*".format(opts.pattern))
    tests.addTests(gathering_doctests())  # add some doctests
    try:
        runner = TAPTestRunner(verbosity=verbose_level)
        if opts.with_tap:
            runner.set_stream(True)

        # remove header
        runner.set_header(False)
        res = runner.run(tests)
    except KeyboardInterrupt:
        BaseDb.drop_trash()
        print('\nStopped')
        sys.exit(1)
    except:
        BaseDb.drop_trash()
        raise
    BaseDb.drop_trash()

    if len(res.errors) + len(res.failures):
        sys.exit(1)

    if opts.coverage:
        # noinspection PyUnboundLocalVariable
        cov.stop()
        cov.save()

        print("=" * 60)
        print("Coverage report:")
        lvestats_location = importlib.util.find_spec("lvestats").submodule_search_locations[0]
        include_mask = os.path.join(lvestats_location, '*')
        cov.report(file=sys.stdout, ignore_errors=True, show_missing=True, include=[include_mask])
        print("=" * 60)
