#!/usr/bin/python3

# Copyright (c) 2020, SUSE LLC, All rights reserved.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public
# License along with this library.

"""This script obtains information from the configured region server in the
   cloud environment and uses the information to register the guest with
   the SMT server based on the information provided by the region server.

   The configuration is in INI format and is located in
   /etc/regionserverclnt.cfg

   Logic:
   1.) Check if we are in the same region
       + Comparing information received from the region server and the
         cached data
   2.) Check if already registered
   3.) Register"""

import argparse
import json
import logging
import os
import requests
import subprocess
import sys
import time
import urllib.parse
import urllib3
import uuid

import cloudregister.registerutils as utils

from cloudregister import smt
from lxml import etree
from requests.auth import HTTPBasicAuth

# Disable the urllib warnings
# We have server certs that have no subject alt names
# We have to check the server state API without certificate validation
urllib3.disable_warnings()

# ----------------------------------------------------------------------------
def get_products():
    """Get a list of the installed products"""
    
    products = []
    # It is possible for users to force a zypper process before we had
    # a chance to set up the repos. We'll wait for the lock for a little while
    for wait_cnt in range(4):
        if utils.is_zypper_running():
            time.sleep(5 * wait_cnt)
        else:
            break
    else:
        errMsg = 'Wait time expired could not acquire zypper lock file'
        logging.error(errMsg)
        return products
    
    try:
        cmd = subprocess.Popen(
            ["zypper", "--no-remote", "-x", "products"], stdout=subprocess.PIPE
        )
        product_xml = cmd.communicate()
        # Just in case something else started zypper again
        if cmd.returncode != 0:
            errMsg = 'zypper product query returned with zypper code %d'
            logging.error(errMsg % cmd.returncode)
            return products
    except OSError:
        errMsg = 'Could not get product list %s' % cmd[1]
        logging.error(errMsg)
        return products

    # Detrmine the base product
    baseProdSet = '/etc/products.d/baseproduct'
    baseprodName = None
    if os.path.islink(baseProdSet):
        baseprod = os.path.realpath(baseProdSet)
        baseprodName = baseprod.split(os.sep)[-1].split('.')[0]
    else:
        errMsg = 'No baseproduct installed system cannot be registerd'
        logging.error(errMsg)
        return products

    product_tree = etree.fromstring(product_xml[0].decode())
    for child in product_tree.find("product-list"):
        name = child.attrib['name']
        if name == baseprodName:
            continue
        vers = child.attrib['version']
        arch = child.attrib['arch']
        prod = name + "/" + vers + "/" + arch
        if prod not in products:
            products.append(prod)

    return products

# ----------------------------------------------------------------------------
def register_modules(extensions, products, registered=[]):
    """Register modules obeying dependencies"""
    for extension in extensions:
        arch = extension.get('arch')
        identifier = extension.get('identifier')
        version = extension.get('version')
        triplet = '/'.join((identifier, version, arch))
        if triplet in products and triplet not in registered:
            registered.append(triplet)
            cmd = [
                register_cmd,
                '--url',
                'https://%s' % registration_target.get_FQDN(),
                '--product',
                triplet
            ]
            if os.path.exists(instance_data_filepath):
                cmd.append('--instance-data')
                cmd.append(instance_data_filepath)

            logging.info('Registration: %s' % ' '.join(cmd))
            p = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE
            )
            res = p.communicate()
            if p.returncode != 0:
                error_message = res[0].decode()
                logging.error('\tRegistration failed: %s' % error_message)

        register_modules(
            extension.get('extensions'), products, registered
        )

# ----------------------------------------------------------------------------
argparse = argparse.ArgumentParser(description='Register on-demand instance')
argparse.add_argument(
    '--clean',
    action='store_true',
    dest='clean_up',
    default=False,
    help='Clean up registration data'
)
argparse.add_argument(
    '-d', '--delay',
    default=0,
    dest='delay_time',
    help='Delay the start of registartion by given value in seconds',
    type=int
)
# default config file location not set it is encoded in the utils function
# get_config(()
argparse.add_argument(
    '-f', '--config-file',
    dest='config_file',
    help='Path to config file, default: /etc/regionserverclnt.cfg',
)
argparse.add_argument(
    '--force-new',
    action='store_true',
    dest='force_new_registration',
    default=False,
    help='Force a new registration, for exceptional cases only',
)
help_msg='The target update server cert fingerprint. '
help_msg+='Use in exceptional cases only'
argparse.add_argument(
    '--smt-fp',
    dest='user_smt_fp',
    help=help_msg
)
help_msg='The target update server FQDN. '
help_msg+='Use in exceptional cases only'
argparse.add_argument(
    '--smt-fqdn',
    dest='user_smt_fqdn',
    help=help_msg
)
help_msg='The target update server IP. '
help_msg+='Use in exceptional cases only'
argparse.add_argument(
    '--smt-ip',
    dest='user_smt_ip',
    help=help_msg
)

args = argparse.parse_args()

if args.user_smt_ip or args.user_smt_fqdn or args.user_smt_fp:
    if not (args.user_smt_ip and args.user_smt_fqdn and args.user_smt_fp):
        msg = '--smt-ip, --smt-fqdn, and --smt-fp must be used together'
        print(msg, file=sys.stderr)
        sys.exit(1)

if args.clean_up and args.force_new_registration:
    msg = '--clean and --force-new are incompatible, use one or the other'
    print(msg, file=sys.stderr)
    sys.exit(1)

time.sleep(int(args.delay_time))

if args.clean_up:
    utils.remove_registration_data()
    utils.clean_smt_cache()
    utils.clear_new_registration_flag()
    sys.exit(0)

config_file = args.config_file
if config_file:
    config_file = os.path.expanduser(args.config_file)
cfg = utils.get_config(config_file)
utils.start_logging()

if not os.path.isdir(utils.REGISTRATION_DATA_DIR):
    os.makedirs(utils.REGISTRATION_DATA_DIR)

utils.set_new_registration_flag()
if args.force_new_registration:
    logging.info('Forced new registration')

if args.user_smt_ip:
    msg = 'Using user specified SMT server:\n'
    msg += '\n\t"IP:%s"' % args.user_smt_ip
    msg += '\n\t"FQDN:%s"' % args.user_smt_fqdn
    msg += '\n\t"Fingerprint:%s"' % args.user_smt_fp
    logging.info(msg)

cached_smt_servers = utils.get_available_smt_servers()
if cached_smt_servers:
    # If we have an update server cache the system is registered in
    # some way shape or form
    utils.clear_new_registration_flag()

# Forced registration or user specified SMT, clear existing registration
# data
if (args.force_new_registration and cached_smt_servers) or args.user_smt_ip:
    if utils.is_zypper_running():
        msg = 'zypper is running: Registration with the update '
        msg += 'infrastructure is only possible if zypper is not running.\n'
        msg += 'Please re-run the force registration process after zypper '
        msg += 'has completed'
        print(msg)
        sys.exit(1)
    utils.remove_registration_data()
    utils.clean_smt_cache()
    utils.set_new_registration_flag()
    cached_smt_servers = []

# Proxy setup
proxies = None
proxy = utils.set_proxy()
if proxy:
    http_proxy = os.environ.get('http_proxy')
    https_proxy = os.environ.get('https_proxy')
    no_proxy = os.environ.get('no_proxy')
    proxies = {'http_proxy': http_proxy,
               'https_proxy': https_proxy,
               'no_proxy': no_proxy}
    logging.info('Using proxy settings: %s' % proxies)

if args.user_smt_ip:
    smt_xml = '<regionSMTdata><smtInfo '
    smt_xml += 'fingerprint="%s" ' % args.user_smt_fp
    smt_xml += 'SMTserverIP="%s" ' % args.user_smt_ip
    smt_xml += 'SMTserverName="%s"' % args.user_smt_fqdn
    smt_xml += '/></regionSMTdata>'
    region_smt_data = etree.fromstring(smt_xml)
else:
    region_smt_data = utils.fetch_smt_data(cfg, proxies)

registration_smt = utils.get_current_smt()

# Check if we are in the same region
# This implies that at least one of the cached updated servers matches the
# data received or the user provided update server data
region_smt_servers = {'cached': [], 'new': []}
# Compare the smt information received from the region server with the
# cached update server data

for child in region_smt_data:
    smt_server = smt.SMT(child, utils.https_only(cfg))
    for cached_smt in cached_smt_servers:
        if cached_smt == smt_server:
            cached_smt_servers.remove(cached_smt)
            region_smt_servers['cached'].append(cached_smt)
            break
    else:
        region_smt_servers['new'].append(smt_server)

# If we have extra SMT data check if the extra server is the registration
# target if yes clean up the registration. Clean up the entire update
# server cache
if cached_smt_servers:
    logging.info('Have extra cached SMT data, clearing cache')
    for smt_srv in cached_smt_servers:
        if registration_smt and smt_srv.is_equivalent(registration_smt):
            msg = 'Extra cached server is current registration target, '
            msg += 'cleaning up registration'
            logging.info(msg)
            utils.remove_registration_data()
            registration_smt = None
            break
    # Clean the cache and re-write all the cache data later
    utils.clean_smt_cache()
    cached_smt_servers = []

if region_smt_servers['new']:
    # Create a new cache
    smt_count = len(region_smt_servers['cached']) + 1
    for smt_server in region_smt_servers['new']:
        store_file_name = (
            utils.REGISTRATION_DATA_DIR +
            utils.AVAILABLE_SMT_SERVER_DATA_FILE_NAME % smt_count
        )
        utils.store_smt_data(store_file_name, smt_server)
        smt_count += 1

# We are in a new region none of the previously cached update servers
# matched the new servers. The system needs to be re-registered
if not region_smt_servers['cached']:
    registration_smt = None
    utils.remove_registration_data()

# We no longer need to differentiate between new and existing SMT servers
region_smt_servers = region_smt_servers['cached'] + region_smt_servers['new']

# Check if the target SMT for the registration is alive or if we can
# find a server that is alive in this region
if registration_smt:
    registration_smt_cache_file_name = (
            utils.REGISTRATION_DATA_DIR +
            utils.REGISTERED_SMT_SERVER_DATA_FILE_NAME
        )
    alive = registration_smt.is_responsive()
    if alive:
        msg = 'Instance is registered, and update server is reachable, '
        msg += 'nothing to do'
        # The cache data may have been cleared, write if necessary
        if not os.path.exists(registration_smt_cache_file_name):
            utils.store_smt_data(
                registration_smt_cache_file_name,
                registration_smt
            )
        logging.info(msg)
        sys.exit(0)
    else:
        # The configured server is not resposive, lets check if we can
        # find another server
        new_target = utils.find_equivalent_smt_server(
            registration_smt,
            region_smt_servers
        )
        if new_target:
            smt_ip = new_target.get_ipv4()
            if utils.has_ipv6_access(new_target):
                smt_ip = new_target.get_ipv6()
            msg = 'Configured update server is unresponsive, switching '
            msg += 'to equivalent update server with ip %s' % smt_ip
            utils.replace_hosts_entry(registration_smt, new_target)
            utils.store_smt_data(
                registration_smt_cache_file_name,
                registration_smt
            )
        else:
            msg = 'Configured update server is unresponsive. Could not find '
            msg += 'a replacement update server in this region. '
            msg += 'Possible network configuration issue'
            logging.error(msg)
            sys.exit(1)

# Figure out which server is responsive and use it as registration target
registration_target = None
tested_smt_servers = []
for smt_srv in region_smt_servers:
    tested_smt_servers.append((smt_srv.get_ipv4(), smt_srv.get_ipv6()))
    alive = smt_srv.is_responsive()
    if alive:
        registration_target = smt_srv
        # Use the first server that responds
        break

if not registration_target:
    logging.error('No response from: %s' % str(tested_smt_servers))
    sys.exit(1)

# Add the target SMT server to the hosts file
utils.add_hosts_entry(registration_target)

# Create location to store data if it does not exist
if not os.path.exists(utils.REGISTRATION_DATA_DIR):
    os.system('mkdir -p %s' % utils.REGISTRATION_DATA_DIR)

# Write the data of the current target server
utils.set_as_current_smt(registration_target)

# Check if we need to send along any instance data
instance_data_filepath = utils.REGISTRATION_DATA_DIR + str(uuid.uuid4())
instance_data = utils.get_instance_data(cfg)
if instance_data:
    inst_data_out = open(instance_data_filepath, 'w')
    inst_data_out.write(instance_data)
    inst_data_out.close()

register_cmd = '/usr/sbin/SUSEConnect'
if not (os.path.exists(register_cmd) and os.access(register_cmd, os.X_OK)):
    logging.error('No registration executable found')
    sys.exit(1)

# get product list
products = get_products()
if products is None:
    logging.error('No products installed on system')
    sys.exit(1)

if not utils.import_smt_cert(registration_target):
    logging.error('SMT certificate import failed')
    sys.exit(1)

# Register the base product first
base_registered = False
failed_smts = []
while not base_registered:
    cmd = [
        register_cmd,
        '--url',
        'https://%s' % registration_target.get_FQDN()
    ]
    if os.path.exists(instance_data_filepath):
        cmd.append('--instance-data')
        cmd.append(instance_data_filepath)
    p = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE)
    res = p.communicate()
    if p.returncode != 0:
        # Even on error SUSEConnect writes messages to stdout, go figure
        error_message = res[0].decode()
        failed_smts.append(registration_target.get_ipv4())
        if len(failed_smts) == len(region_smt_servers):
            logging.error('Baseproduct registration failed')
            logging.error('\t%s' % error_message)
            sys.exit(1)
        for smt_srv in region_smt_servers:
            target_smt_ipv4 = registration_target.get_ipv4()
            target_smt_ipv6 = registration_target.get_ipv6()
            new_smt_ipv4 = smt_srv.get_ipv4()
            new_smt_ipv6 = smt_srv.get_ipv6()
            if (
                    smt_srv.get_ipv4() != \
                    registration_target.get_ipv4() and
                    smt_srv.get_ipv4() not in failed_smts
            ):
                error_msg = 'Registration with %s failed. Trying %s'
                logging.error(
                    error_msg % (
                        str((target_smt_ipv4, target_smt_ipv6)),
                        str((new_smt_ipv4, new_smt_ipv6))
                    )
                )
                utils.remove_registration_data()
                utils.add_hosts_entry(smt_srv)
                registration_target = smt_srv
                break
    else:
        base_registered = True
        utils.clear_new_registration_flag()

base_prod_xml = open('/etc/products.d/baseproduct').read()
prod_def_start = base_prod_xml.index('<product')
product_tree = etree.fromstring(base_prod_xml[prod_def_start:])
prod_identifier = product_tree.find('name').text.lower()
version = product_tree.find('version').text
arch = product_tree.find('arch').text
headers = {'Accept': 'application/vnd.scc.suse.com.v4+json'}
query_args = 'identifier=%s&version=%s&arch=%s' % (
    prod_identifier, version, arch)
user, password = utils.get_credentials(
    utils.get_credentials_file(registration_target)
)
auth_creds = HTTPBasicAuth(user, password)
res = requests.get(
    'https://%s/connect/systems/products?%s' % (
        registration_target.get_FQDN(), query_args
    ),
    auth=auth_creds,
    headers=headers
)
if res.status_code != 200:
    err_msg = 'Unable to obtain product information from server "%s"\n'
    err_msg += '\t%s\n\t%s\nUnable to register modules, exiting.'
    ips = '%s,%s' % (
        registration_target.get_ipv4(), registration_target.get_ipv6()
    )
    logging.error(err_msg % (ips, res.reason, res.content.decode("UTF-8")))
    sys.exit(1)

prod_data = json.loads(res.text)
extensions = prod_data.get('extensions')
register_modules(extensions, products)

if os.path.exists(instance_data_filepath):
    os.unlink(instance_data_filepath)

# Enable Nvidia repo if repo(s) are configured and destination can be reached
if utils.has_nvidia_support():
    nvidia_repo_names = utils.find_repos('nvidia')
    for repo_name in nvidia_repo_names:
        url = urllib.parse.urlparse(utils.get_repo_url(repo_name))
        cmd = ['ping', '-c', '2', url.hostname]
        if utils.exec_subprocess(cmd):
            msg = 'Cannot reach host: "%s", will not enable repo "%s"'
            logging.info(msg % (url.hostname, repo_name))
        else:
            utils.enable_repository(repo_name)

utils.switch_services_to_plugin()
