#!/usr/bin/python3

# Copyright (c) 2019 SUSE LLC, All Rights Reserved
#
# slecompliancereport is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# slecompliancereportis distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with slecompliancereportis

"""
This script will create a report of all SLES and SLES For SAP on-demand
instances running in a given account and report by instance ID and IP if
the instance is compliant with the terms of use. An on-demand instance
registered to the SUSE update infrastructure must be running with the
proper billing construct in order to be compliant with the update
infrastructure access requirements.

By default the tool will scann all the connected regions.
"""

import argparse
import boto3
import json
import paramiko
import sys
import time

from  more_itertools import unique_everseen

argparse = argparse.ArgumentParser(
    description='Generate SLE Compliance Report'
)
argparse.add_argument(
    '--access-id',
    dest='access_key',
    help='AWS access key',
    metavar='AWS_ACCESS_KEY',
    required=True
)
argparse.add_argument(
    '--client-update',
    action='store_true',
    default=False,
    dest='auto_update_client',
    help='Automatically update the registration client when needed'
)
argparse.add_argument(
    '--instance-ids',
    dest='instance_ids',
    help='Comma separated list of instance-ids to check (Optional)',
    metavar='INSTANCE_IDS'
)
argparse.add_argument(
    '--instance-user',
    dest='instace_user',
    help='The user name for the user that can login with the given ssh key',
    metavar='INSTANCE_USER',
    required=True
)
argparse.add_argument(
    '-p', '--private-key-file',
    dest='private_key',
    help='Private SSH key file',
    metavar='PRIVATE_KEY',
    required=True
)
help_msg = 'Comma separated list of regions (Optional, '
help_msg = 'scan all connected regions by default'
argparse.add_argument(
    '-r', '--regions',
    dest='regions',
    help=help_msg,
    metavar='EC2_REGIONS'
)
argparse.add_argument(
    '-s', '--secret-key',
    dest='secret_key',
    help='AWS secret access key',
    metavar='AWS_SECRET_KEY',
    required=True
)
argparse.add_argument(
    '--verbose',
    action='store_true',
    default=False,
    dest='verbose',
    help='Enable verbose output (Optional)'
)

args = argparse.parse_args()

regions = []
if args.regions:
    regions = args.regions.split(',')
else:
    # Some connected regions may be missing depending on the version of
    # boto3. Unfortunately the region list is not an API call to EC2 but
    # handeled intrinsically in boto3 as a hard coded list.
    regions = boto3.session.Session(
        aws_access_key_id=args.access_key,
        aws_secret_access_key=args.secret_key
    ).get_available_regions('ec2')
    # Special region may not be avialble to everyone but we'll test it
    regions.append('ap-northeast-3')

compliance_state = {}
warning_sle11 = {}
for region in regions:
    if args.verbose:
        print('Working in region: "%s"' % region)
    compliance_state[region] = {}
    warning_sle11[region] = []
    ec2 = boto3.client(
        aws_access_key_id=args.access_key,
        aws_secret_access_key=args.secret_key,
        region_name=region,
        service_name='ec2'
    )
    if not ec2 and region == 'ap-northeast-3':
        continue
    if not ec2:
        print(
            'Unable to access region "%s" with provided access keys' % region
        )
        sys.exit(1)
    describe_kwargs = {
        'Filters': [
            {
                'Name':'instance-state-name',
                'Values': ['running']
            }
        ]
    }
    if args.instance_ids:
        describe_kwargs['InstanceIds'] = args.instance_ids.split(',')
    instance_info = ec2.describe_instances(**describe_kwargs)
    instances = []
    for reservation in instance_info['Reservations']:
        instances += reservation['Instances']
    for instance in instances:
        instance_id = instance.get('InstanceId')
        if args.verbose:
            print('\tProcessing instance: "%s"' % instance_id)
        # Collect all IP addresses to make every effort to connect to the
        # instance
        ip_addresses = []
        for nic in instance.get('NetworkInterfaces'):
            ip_addresses.append(nic.get('Association', {}).get('PublicIp'))
            for ipv6 in nic.get('Ipv6Addresses', {}):
                ip_addresses.append(ipv6.get('Ipv6Address'))
            ip_addresses.append(nic.get('PrivateIpAddress'))
            for private_ip in nic.get('PrivateIpAddresses', []):
                ip_addresses.append(
                    private_ip.get('Association', {}).get('PublicIp')
                )
                ip_addresses.append(private_ip.get('PrivateIpAddress'))
        access_success = False
        client = paramiko.client.SSHClient()
        if args.verbose:
            print(
                '\tPotential connection addresses: "%s"' % str(
                    list(unique_everseen(ip_addresses))
                )
            )
        for ip in unique_everseen(ip_addresses):
            if not ip:
                continue
            client.set_missing_host_key_policy(paramiko.WarningPolicy())
            ssh_connection = False
            timeout_counter = 0
            while not ssh_connection:
                try:
                    if args.verbose:
                        print('\tAttempt connection to: "%s"' % ip)
                    ssh_connection = client.connect(
                        key_filename=args.private_key,
                        username=args.instace_user,
                        hostname=ip,
                        timeout=30
                    )
                except Exception:
                    time.sleep(10)
                    timeout_counter += 1
                    if timeout_counter == 2:
                        break
                else:
                    ssh_connection = True
            if ssh_connection:
                break
        if not ssh_connection:
            compliance_state[region][instance_id] = 'Unknown'
            continue
        base_cmd = ''
        if args.instace_user != 'root':
            base_cmd = 'sudo '
        # Figure out the distro
        if args.verbose:
            print('\tDetermine if this is SLE')
        is_sle = False
        sle_ver_warn = False
        cmd = base_cmd + 'cat /etc/os-release'
        stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
        cmd_error = stderr.read()
        if cmd_error:
            # Has no os-release try the old file
            cmd = base_cmd + 'cat /etc/SuSE-release'
            stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
            cmd_error = stderr.read()
            if cmd_error:
                # Not a SLE based instance, next instance
                if args.verbose:
                    print('\tNot a SLE based instance')
                continue
            release_entries = stdout.read().strip().decode('utf-8')
            for entry in release_entries:
                if 'SUSE' in entry and 'Enterprise' in entry:
                    is_sle = True
                if 'VERSION' in entry and '11' in entry:
                    sle_ver_warn = True
        else:
            os_release = stdout.read().strip().decode('utf-8')
            os_release_entries = os_release.split('\n')
            for entry in os_release_entries:
                if entry.startswith('VERSION_ID'):
                    try:
                        version = entry.split('=')[-1].strip().split('"')[1]
                        if version.startswith('11'):
                            sle_ver_warn = True
                    except IndexError:
                        # Doesn't meet our pattern not SLE, next instance
                        if args.verbose:
                            print('\tNot a SLE based instance')
                        continue
                if entry.startswith('PRETTY_NAME'):
                    if 'SUSE' in entry and 'Enterprise' in entry:
                        is_sle = True

        if not is_sle:
            if args.verbose:
                print('\tNot a SLE based instance')
            continue

        if sle_ver_warn:
            warning_sle11[region].append(instance_id)

        # Look for registration indicators
        # Check the services first
        if args.verbose:
            print('\tDetermine if the instance uses the SUSE '
                  'update infrastructure')
        uses_update_infra = False
        cmd = base_cmd + 'ls /etc/zypp/services.d/*'
        stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
        cmd_error = stderr.read()
        if cmd_error:
            # Maybe has the new registration client using the service plugin
            cmd = base_cmd + 'ls /usr/lib/zypp/plugins/services/*'
            stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
            cmd_error = stderr.read()
        if cmd_error:
            # Maybe converted to raw repos
            cmd = base_cmd + 'ls /etc/zypp/repos.d/*'
            stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
            cmd_error = stderr.read()
            if cmd_error:
                # No repos and no service, next instance
                if args.verbose:
                    print('\tInstance is not registered')
                continue

            repos = stdout.read().strip().decode('utf-8').split('\n')
            for repo in repos:
                cmd = base_cmd + 'cat %s' % repo
                stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
                repo_def = stdout.read().strip().decode('utf-8')
                if 'susecloud' in repo_def:
                    uses_update_infra = True
                    break
        else:
            services = stdout.read().strip().decode('utf-8').split('\n')
            for service in services:
                cmd = base_cmd + 'cat %s' % service
                stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
                service_def = stdout.read().strip().decode('utf-8')
                if 'susecloud' in service_def:
                    uses_update_infra = True
                    break

        if not uses_update_infra:
            # Not pointing to the update infrastructure, next instance
            if args.verbose:
                print('\tInstance is not registered')
            continue

        # It's a SLE instance that points to the update infrastructure
        # check for compliance
        cmd = 'curl http://169.254.169.254/latest/dynamic/'
        cmd += 'instance-identity/document'
        stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
        metadata_raw = stdout.read().strip().decode('utf-8')
        metadata = json.loads(metadata_raw)
        # Do not verify the data, just drive of presense
        mp_code = metadata.get('marketplaceProductCodes')
        bp = metadata.get('billingProducts')
        if mp_code or bp:
            compliance_state[region][instance_id] = [True]
        else:
            compliance_state[region][instance_id] = [False]
                        
        # Check the client version
        cmd = 'rpm -qa cloud-regionsrv-client'
        stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
        installed_client = stdout.read().strip().decode('utf-8')
        major_version = int(
            installed_client.split('cloud-regionsrv-client-')[-1][0]
        )
        if major_version < 9:
            if args.auto_update_client and uses_update_infra:
                cmd = base_cmd + 'zypper up cloud-regionsrv-client'
                if args.verbose:
                    msg = '\tUpdating cloud-regionsrv-client package '
                    msg += 'on %s' % instance_id
                    print(msg)
                stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
            else:
                compliance_state[region][instance_id].append('warn')            
        client.close()

for region, instance_data in compliance_state.items():
    print('Compliance data for region: "%s"' % region)
    client_warn_instances = []
    compliant_instances = []
    non_compliant_instances = []
    unknown_state_instances = []
    for instance_id, state in instance_data.items():
        if state[0]:
            if state[0] == 'Unknown':
                unknown_state_instances.append(instance_id)
            else:
                compliant_instances.append(instance_id)
        else:
            non_compliant_instances.append(instance_id)
        if len(state) == 2:
            client_warn_instances.append(instance_id)
    if compliant_instances:
        print('\n\tCompliant instances:')
        for inst_id in compliant_instances:
            print('\t\t%s' % inst_id)
    if unknown_state_instances:
        print('\n\tInstances for which the state could not be determined:')
        for inst_id in unknown_state_instances:
            print('\t\t%s' % inst_id)
    if non_compliant_instances:
        print('\n\tNon compliant instances:')
        for inst_id in non_compliant_instances:
            print('\t\t%s' % inst_id)
    if client_warn_instances:
        header = '\n\tInstances not meeting minimum registration '
        header += 'client requirement:'
        print(header)
        for inst_id in client_warn_instances:
            print('\t\t%s' % inst_id)
        
    msg = '\n\tFollowing instances are based on SLE 11 which reached EOL '
    msg += 'on March 31st, 2019'
    eol_instances = warning_sle11[region]
    if eol_instances:
        print(msg)
    for inst_id in eol_instances:
        print('\t\t%s' % inst_id)
    print('\n\n')
