#!/usr/bin/env python3

import argparse
import os
import sys
from pathlib import Path

import repomd
import psycopg2


def connect_db(connstr):
    try:
        conn = psycopg2.connect(connstr)
    except psycopg2.Error as e:
        sys.stderr.write("Error connecting to the Postgres server: %s\n" % e)
        return None

    return conn


def get_repo_kind_id(cur, name):
    cur.execute("""SELECT 
                       id 
                   FROM 
                       downloads_repokind 
                   WHERE
                       name=%s;""", (name,))

    repo_kind_id = cur.fetchone()

    if repo_kind_id is None:
        cur.execute("""INSERT INTO 
                           downloads_repokind (name) 
                       VALUES (%s)
                       RETURNING id;""",
                    (name,))

        repo_kind_id = cur.fetchone()

    # We'll have a tuple at this point, so just grab the ID
    return repo_kind_id[0]


def load_packages(conn, repo_type, packages):
    cur = conn.cursor()

    # Get the ID for the repo type. If it's not present, create it
    repo_kind_id = get_repo_kind_id(cur, repo_type)

    # Delete all the existing records. We can't really attempt to do a merge
    # as we need to ensure records get removed when packages are removed
    cur.execute("""DELETE FROM 
                       downloads_repopackage rp 
                   USING 
                       downloads_reporepository rr, 
                       downloads_repodistribution rd 
                   WHERE 
                       rp.repo_id = rr.id AND 
                       rr.distro_id = rd.id AND 
                       rd.kind_id=%s;""",
                (repo_kind_id,))

    # Get the lists of distros, repos and architectures
    cur.execute("""SELECT * FROM downloads_repodistribution;""")
    distro_list = cur.fetchall()

    cur.execute("""SELECT * FROM downloads_reporepository;""")
    repo_list = cur.fetchall()

    cur.execute("""SELECT * FROM downloads_repoarchitecture;""")
    arch_list = cur.fetchall()

    # Loop round all the packages, and insert the details into the DB
    for package in packages:
        # Create the distribution if needed
        distro = [item for item in distro_list if
                  item[1] == repo_kind_id and item[2] == package['dist']]

        if len(distro) == 0:
            cur.execute(
                """INSERT INTO 
                       downloads_repodistribution (kind_id, name) 
                   VALUES (%s, %s) 
                   RETURNING id, kind_id, name;""",
                (repo_kind_id, package['dist'],))
            distro = cur.fetchone()
            distro_list.append(distro)
        else:
            distro = distro[0]

        # Create the repository if needed
        repo = [item for item in repo_list if
                item[2] == package['repo'] and item[1] == distro[0]]

        if len(repo) == 0:
            cur.execute(
                """INSERT INTO 
                       downloads_reporepository (distro_id, name) 
                   VALUES (%s, %s)
                   RETURNING id, distro_id, name;""",
                (distro[0], package['repo'],))
            repo = cur.fetchone()
            repo_list.append(repo)
        else:
            repo = repo[0]

        # Create the architecture if needed
        arch = [item for item in arch_list if item[1] == package['arch']]

        if len(arch) == 0:
            cur.execute(
                """INSERT INTO 
                       downloads_repoarchitecture (name) 
                   VALUES (%s) 
                   RETURNING id, name;""",
                (package['arch'],))
            arch = cur.fetchone()
            arch_list.append(arch)
        else:
            arch = arch[0]

        cur.execute("""INSERT INTO 
                           downloads_repopackage (
                               name, 
                               version, 
                               build, 
                               description, 
                               licence, 
                               maintainer, 
                               url, 
                               file, 
                               arch_id,  
                               repo_id
                           ) 
                       VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                       ON CONFLICT (file) DO NOTHING""",
                    (package['pkg'],
                     package['ver'],
                     package['bld'],
                     package['desc'],
                     package['lic'],
                     package['mntr'],
                     package['url'],
                     package['file'],
                     arch[0],
                     repo[0]))

    cur.close()

    conn.commit()


def get_apt_package_files(repo_dir):
    # Get a list of all the Packages files
    package_files = []

    for path in Path(repo_dir).rglob('Packages'):
        package_files.append(path.absolute())

    return package_files


def get_apt_distribution(package_file):
    # Find the Release file corresponding to the Packages file, and get the
    # distribution name from it.
    with open(os.path.dirname(package_file) + '/Release', "r") as release_data:
        for release_line in release_data:
            if release_line.startswith("Archive: "):
                return release_line[9:]

    return None


def get_apt_licence(deb_file):
    # TODO: Figure out a sane way to get the license
    license = ''

    return license


def get_apt_packages(package_file):
    # Extract all the packages from a package file
    packages = []

    with open(package_file, "r") as package_data:
        package = ""
        first = True

        for line in package_data:
            if line.startswith('Package: ') and not first:
                # Get the distribution
                distribution = get_apt_distribution(package_file)
                if distribution is not None:
                    package = package + '\nDistribution: ' + distribution

                packages.append(package)
                package = ""
            else:
                first = False

            package = package + line

    return packages


def get_apt_package(package_data):
    # Decode a Package entry into a dictionary
    package = {'pkg': '',
               'ver': '',
               'bld': '',
               'arch': '',
               'file': '',
               'desc': '',
               'dist': '',
               'url': '',
               'mntr': '',
               'repo': '',
               'lic': ''}
    in_description = False

    for line in package_data.splitlines():
        # Package
        if line.startswith("Package: "):
            package['pkg'] = line[9:]

        # Version
        if line.startswith("Version: "):
            # The build is normally prefixed with a -, but sometimes
            # just .pgdg
            package['ver'] = line[9:].split('-')[0]
            package['bld'] = ''.join(line[9:].split('-')[1:])

            if '.pgdg' in package['ver']:
                version = line[9:].split('.pgdg')
                package['ver'] = version[0]
                package['bld'] = 'pgdg' + version[1]

        # Architecture
        if line.startswith("Architecture: "):
            package['arch'] = line[14:]

        # Filename
        if line.startswith("Filename: "):
            package['file'] = line[10:]

            # Licence
            licence = get_apt_licence(line[10:])
            if licence is not None:
                package['lic'] = licence

        # Description. This can be multi-line. Treat the first line
        # as normal, then scan the rest until we hit the end
        if in_description:
            if line.strip() == ".":
                package['desc'] = package['desc'] + "\n"

            # The description ends when we encounter a line that doesn't start
            # with a space.
            elif not line.startswith(" "):
                in_description = False
            else:
                package['desc'] = \
                    package['desc'] + '\n' + line.strip()

        if line.startswith("Description: "):
            package['desc'] = line[13:]
            in_description = True

        # Distribution/Repo
        if line.startswith("Distribution: "):
            package['dist'] = line[14:].split('-')[0]
            package['repo'] = line[14:]

        # URL
        if line.startswith("Homepage: "):
            package['url'] = line[10:]

        # Packager
        if line.startswith("Maintainer: "):
            package['mntr'] = line[12:]

    return package


def get_yum_repos(repo_dir):
    # Get a list of all repo dirs
    repos = []

    for path in Path(repo_dir).rglob('repomd.xml'):
        repos.append(path.parent.parent.absolute())

    return repos


def get_yum_distribution(repo):
    distribution_dir = repo.name

    parts = distribution_dir.split('-')
    distribution = '-'.join(parts[:2])

    return distribution


def get_yum_package_info(repo, base_path):
    packages = []

    repo_data = repomd.load('file://' + str(repo))
    for package_data in repo_data:
        package = {'pkg': package_data.name,
                   'ver': package_data.version,
                   'bld': package_data.release,
                   'arch': package_data.arch,
                   'file': os.path.relpath(str(repo) + '/' +
                                           package_data.location,
                                           base_path),
                   'desc': package_data.description,
                   'dist': get_yum_distribution(repo),
                   'url': package_data.url,
                   'mntr': package_data.vendor,
                   'repo': os.path.relpath(str(repo), base_path).split('/')[0],
                   'lic': package_data.license
                   }

        packages.append(package)

    return packages


def main():
    # Command line arguments
    parser = argparse.ArgumentParser(description='Scan a set of apy/yum repos '
                                                 'and load the contents into '
                                                 'the pgweb database.')
    parser.add_argument('--conn', required=True,
                        help='a PostgreSQL connection string for a pgweb '
                             'database.')

    parser.add_argument('--apt', help='an apt repo directory to scan')
    parser.add_argument('--yum', help='a yum repo directory to scan')
    parser.add_argument('--zypp', help='a zypper repo directory to scan')

    args = parser.parse_args()

    if args.apt is None and args.yum is None and args.zypp is None:
        parser.error('at least one of the following arguments must be '
                     'specified: --apt, --yum, --zypp')

    conn = connect_db(args.conn)
    if conn is None:
        sys.exit(1)

    apt_package_info = []
    if args.apt:
        package_files = get_apt_package_files(args.apt)

        for package_file in package_files:
            packages = get_apt_packages(package_file)

            for package in packages:
                apt_package_info.append(get_apt_package(package))

        print('{} apt packages found'.format(len(apt_package_info)))
        load_packages(conn, 'apt', apt_package_info)

    yum_package_info = []
    if args.yum:
        repos = get_yum_repos(args.yum)

        for repo in repos:
            yum_package_info.extend(get_yum_package_info(repo, args.yum))

        print('{} yum packages found'.format(len(yum_package_info)))
        load_packages(conn, 'yum', yum_package_info)

    zypp_package_info = []
    if args.zypp:
        repos = get_yum_repos(args.zypp)

        for repo in repos:
            zypp_package_info.extend(get_yum_package_info(repo, args.zypp))

        print('{} zypper packages found'.format(len(zypp_package_info)))
        load_packages(conn, 'zypp', zypp_package_info)

    # Fin
    conn.close()


if __name__ == "__main__":
    main()
