#!/usr/bin/env python
"""
Tool for checking kABI and providing diagnostics when kABI is broken.
See Documentation/uek/kabi.txt for usage instructions, or try "kabi -h".
"""
from __future__ import print_function

import argparse
import collections
import difflib
import os
import re
import shutil
import sys
import tempfile
import zlib


def pretty(s):
    """
    Format C code for a human reader.
    """
    # semicolon -> add newline
    s = s.replace(" ;", ";\n")
    # add newlines surrounding braces
    s = s.replace("{", "{\n")
    s = s.replace("}", "\n}")
    # fix extra spacing surrounding brackets, parens, commas, and pointers
    s = re.sub(r"([\[\(\*]) ", lambda m: m.group(1), s)
    s = re.sub(r" ([\]\),])", lambda m: m.group(1), s)
    final_result = []
    indent = 0
    for line in s.split("\n"):
        line = line.strip()
        if not line:
            continue
        if line[0] == "}":
            indent -= 1
        final_result.append("\t" * indent + line)
        if line[-1] == "{":
            indent += 1
    return "\n".join(final_result)


class SymTypes(object):
    """
    This class represents a "symtypes" database.

    To compute "symvers", the genksyms program needs to take the pre-processed C
    source file, index all of the structs, unions, enums, and typedefs it sees,
    and then create a C pseudo-declaration of each exported function, which
    contains only fundamental C types. For example, given a simple set of
    declarations:

        typedef unsigned int flags_t;
        struct obj {
            void *data;
        }
        enum result {
            RESULT_OK,
            RESULT_ERROR,
        }

        enum result do_operation(struct obj *obj, flags_t flags);

    Then genksyms would inline all declarations to the following declaration:

        enum result { RESULT_OK, RESULT_ERROR, } do_operation(
            struct obj { void *data; } *obj, typedef unsigned int flags_t flags
        );

    This, give or take some whitespace, is what genksyms would then compute a
    CRC for, and store as the "symbol version" for the do_operation() function.

    In the process, genksyms indexes all type declarations it sees. It can
    output this index for any particular compilation unit (CU) to
    "somefile.symtypes". This class parses and stores this indexed data for
    later use. Since there are typically hundreds or thousands of CUs, this
    class supports reading multiple CU Symtypes files into a single index and
    outputting them to a "Combined Symtypes" file. This reduces disk space,
    since most declarations are repeated in several CUs.  Further, this class
    can filter the set of declarations stored in the index to only those related
    to a set of exported symbols (e.g. a kABI).

    Since each CU includes a unique set of header files, in a unique order, the
    type declarations visible to genksyms can differ between CUs. For the most
    part, these duplicates come down to two cases:
    1. There may be multiple structs by the same names, used in completely
       different subsystems so they don't conflict.
    2. Some structs are forward declared and never defined (appearing as
       UNKNOWN) in the .symtypes file.
    Both cases are expected, and fairly uninteresting for the purpose of kABI
    checking. However, this class must be prepared to handle these duplicate
    symbol declarations.

    Whenever a symbol name is duplicated, we create a "versioned name" by
    appending the CRC32 of its declaration line to the symbol name: e.g.
    "orignal@DEADBEEF". Then, we create an index for every compilation unit,
    which contains (a) the list of exported function names defined in that CU,
    and (b) the versioned name of any duplicated symbols referred to in that CU.
    When later reconstructing the full declaration of a function, we can use the
    context information of its compilation unit to identify the correct
    declarations, ensuring we compute the correct CRC or provide the correct
    regression information.
    """

    PREFIXES = {
        "t": "",  # typedef
        "E": "",  # enum element
        "e": "enum ",
        "s": "struct ",
        "u": "union ",
    }

    def __init__(self):
        """Create empty SymTypes"""
        # Maps a symbol name to list of tokens in its declaration.
        self.symtok = {}
        # Maps a symbol name to the CRC32 of its declaration line.
        self.symcrc = {}
        # Maps an original symbol name to its duplicates.
        self.dupes = collections.defaultdict(set)
        # Maps a filename to original symbol name to crc
        self.file_symvers = collections.defaultdict(dict)
        # Maps exported symbol name to filename where it is defined.
        self.exports = {}

    @classmethod
    def from_file(cls, filename):
        """
        Create a SymTypes database from a file. We use the filename argument as
        an alias for the compilation unit of exported symbols found here, unless
        we find "F#" lines created by self.write(), in which case we treat this
        as an archive.
        """
        with open(filename) as fp:
            t = cls()
            t.add_file(fp, filename)
            return t

    def name(self, tok):
        """
        Given any string, if it is a token code, change it to a full name.
        This means doing two things:
        1. Replace "?#" prefixes (e.g "s#list_head" -> "struct list_head")
        2. Remove "@" versions (e.g. "s#task_struct@ef30a39b"
                                -> "struct task_struct")
        Leave other strings unmodified.
        """
        if len(tok) >= 2 and tok[1] == "#":
            tok = self.PREFIXES[tok[0]] + tok[2:]
        if "@" in tok:
            tok = tok[:tok.index("@")]
        return tok

    def versioned(self, symbol, filename):
        """
        Given a symbol and the file in which it was defined, return the
        versioned symbol if it was a duplicate.
        """
        if symbol in self.dupes:
            crc = self.file_symvers[filename][symbol]
            return "{}@{:08x}".format(symbol, crc)
        return symbol

    def _gen(self, tok, seen, out, filename):
        """
        Recursive function to generate a symbol's full declaration.
        tok: the token we're generating from
        seen: set of seen tokens, which prevents us from expanding a struct more
          than once.
        out: list we output to
        filename: the compilation unit filename, to resolve duplicates
        """
        if tok in seen:
            out.append(self.name(tok))
            return
        seen.add(tok)
        tok = self.versioned(tok, filename)
        for decl_tok in self.symtok[tok]:
            if "#" in decl_tok:
                self._gen(decl_tok, seen, out, filename)
            elif decl_tok != "extern":
                out.append(decl_tok)

    def gen(self, token, filename):
        """Generate a symbol's full (recursive) declaration."""
        seen = set()
        tokens = []
        self._gen(token, seen, tokens, filename)
        tokens.append("")  # add space at end
        return " ".join(tokens)

    def gen_short_decl(self, token):
        """Generate a short declaration of a symbol."""
        result = []
        for tok in self.symtok[token]:
            result.append(self.name(tok))
        return pretty(" ".join(result))

    def crc(self, token, fn=None):
        """Compute the genksyms CRC32 corresponding to this symbol."""
        # Note that the self.symcrc is different from genksyms
        if not fn:
            fn = self.exports[token]
        return zlib.crc32(self.gen(token, fn).encode()) & 0xFFFFFFFF

    def _add_duplicate(self, symbol, crc):
        """
        Given a symbol which we already know has duplicates, add this one to
        the list.
        """
        new_name = "{}@{:08x}".format(symbol, crc)
        self.dupes[symbol].add(new_name)
        return new_name

    def _new_duplicate(self, symbol, crc):
        """
        Given a symbol which we previously assumed was unique, turn it into a
        duplicate symbol. This involves finding the old, unversioned name and
        removing it from our tables, replacing it with a versioned name. Then,
        add the currently found symbol as a duplicate.
        """
        new_name = "{}@{:08x}".format(symbol, self.symcrc[symbol])
        self.dupes[symbol].add(new_name)
        self.symtok[new_name] = self.symtok[symbol]
        self.symcrc[new_name] = self.symcrc[symbol]
        del self.symtok[symbol]
        del self.symcrc[symbol]
        return self._add_duplicate(symbol, crc)

    def resolve_duplicate(self, symbol, crc):
        """
        For each declaration line, we need to consider a few cases to ensure
        that we get the correct symbol CRC, taking into account duplicates.
        Return 3 items:
        [0] - the original symbol name (e.g. s#list_head)
        [1] - a maybe versioned name (e.g. s#list_head@deadbeef)
        [3] - the correct CRC
        """
        # Case 1: this symbol is already a versioned symbol, because we're
        # reading a Combined Symtypes file. The correct CRC has already been
        # computed for us, but we need to track this versioned symbol.
        if "@" in symbol:
            orig, crc = symbol.split("@", 1)
            self.dupes[orig].add(symbol)
            return orig, symbol, int(crc, 16)
        # Case 2: this symbol name is not in the dupes hash.
        if symbol not in self.dupes:
            if symbol not in self.symcrc or self.symcrc[symbol] == crc:
                # 2a: Either we've never seen this symbol before, or we've seen
                # the exact declaration before.
                return symbol, symbol, crc
            else:
                # 2b: We've seen a different declaration for this symbol, we
                # need to mark that older declaration as a duplicate, then add
                # this one.
                return symbol, self._new_duplicate(symbol, crc), crc
        # Case 3: this name is in the dupes hash - we've already seen at least
        # two different definitions of this name. Search them to see whether
        # we've seen this exact one.
        for potential_symbol in self.dupes[symbol]:
            if self.symcrc[potential_symbol] == crc:
                return symbol, potential_symbol, crc
        # Case 4: yet another distinct declaration, add it as well
        return symbol, self._add_duplicate(symbol, crc), crc

    def add(self, symbol_line, filename):
        """
        Add a line containing a declaration

        Note that this line may come from a "Combined Symtypes" file or a "CU
        Symtypes" file (see class docstring for details). If we're reading a
        combined file, then symbol duplicates should already have been resolved
        by a prior run, and we're just loading that information again. If we're
        reading a CU file, then we may encounter new duplicates due to previous
        CU files, and we need to handle them correctly.
        """
        # Compute CRC and split into the token and definition
        symbol_line = symbol_line.strip()
        line_crc = zlib.crc32(symbol_line.encode()) & 0xFFFFFFFF
        arr = symbol_line.split()
        token_name = arr[0]
        tokens = arr[1:]

        # For Combined Symvers, F# lines contain per-file information: what
        # duplicate symbol versions were used in that file, and what exported
        # symbols were defined in that file.
        if token_name.startswith("F#"):
            arc_filename = token_name[2:]
            for tok in tokens:
                if "@" in tok:
                    tok, crc = tok.split("@")
                    crc = int(crc, 16)
                    self.file_symvers[arc_filename][tok] = crc
                else:
                    self.exports[tok] = arc_filename
            # If we have F# entries, this is an archive and not a symtypes file,
            # so remove any older file_symvers.
            if filename in self.file_symvers:
                del self.file_symvers[filename]
            return

        # Otherwise, convert the token name to one which contains a CRC if
        # necessary to resolve duplication.
        orig, token_name, line_crc = self.resolve_duplicate(token_name, line_crc)

        # If we're in a CU's symvers file, we need to record the CRC of every
        # symbol used in this CU, so that later we can disambiguate between
        # potential duplicates. If we're loading a combined symvers file, this
        # will be None.
        if filename:
            self.file_symvers[filename][orig] = line_crc

        # Don't bother processing an identical token to one we've seen before.
        if token_name in self.symcrc:
            return  # duplicate with same CRC, skip
        self.symcrc[token_name] = line_crc
        if "#" not in token_name and filename:
            self.exports[token_name] = filename
        self.symtok[token_name] = tokens

    def add_file(self, fileobj, filename):
        for line in fileobj:
            self.add(line.strip(), filename)

    def consolidate_symvers(self):
        """
        Once all files for a kernel build are added, use this to remove any
        entries from file_symvers which do not relate to duplicated symbols.
        """
        for filename, symvers_map in self.file_symvers.items():
            for symbol in list(symvers_map.keys()):
                if symbol not in self.dupes:
                    del symvers_map[symbol]

    def filter_exports(self, symbols, verbose=True):
        """
        Given a set of symbols we care about (e.g. a kABI), filter out all
        unnecessary data from the symtypes.
        """
        files_prev = set(self.exports.values())
        for symbol in list(self.exports.keys()):
            if symbol not in symbols:
                del self.exports[symbol]
        files_to_remove = files_prev - set(self.exports.values())
        for fn in files_to_remove:
            if fn in self.file_symvers:
                del self.file_symvers[fn]

        seen_symbols = set()
        queue = collections.deque(self.exports.keys())
        while queue:
            symbol = queue.popleft()
            seen_symbols.add(symbol)
            if symbol in self.dupes:
                queue.extend(self.dupes[symbol])
            else:
                for dep in self.deps(symbol):
                    if dep not in seen_symbols:
                        queue.append(dep)
        symbols_to_remove = set(self.symtok.keys()) - seen_symbols
        for symbol in symbols_to_remove:
            del self.symtok[symbol]
            del self.symcrc[symbol]

        # Note: we could try to eliminate duplicates here. It seems possible
        # (detect symbols with only one duplicate version available, remove
        # alternatives, rename them, etc). But it's a lot of bookkeeping for
        # something which saves very little space.

        if verbose:
            rmvd = len(files_to_remove)
            curr = len(self.file_symvers)
            orig = rmvd + curr
            print("Reduced files from {} to {}, a {:2.1f}% reduction.".format(
                orig, curr, 100 * rmvd / orig,
            ))
            rmvd = len(symbols_to_remove)
            curr = len(self.symtok)
            orig = rmvd + curr
            print("Reduced symbols from {} to {}. A {:2.1f}% reduction.".format(
                orig, curr, 100 * rmvd / orig,
            ))

    def write(self, fp):
        """
        Write the contents of this database out to a collection file.
        The contents of file_symvers and exports are combined into a specially
        formatted line after all symbols. This is what makes a "Consolidated
        Symvers" file.
        """
        # Output in alphabetical order by symbol name. This reduces diff
        # churn.
        for symbol, tokens in sorted(self.symtok.items()):
            fp.write("{} {}\n".format(symbol, " ".join(tokens)))
        files = set(self.exports.values()) | set(self.file_symvers.keys())
        filename_to_symbols = collections.defaultdict(list)
        for sym, fn in self.exports.items():
            filename_to_symbols[fn].append(sym)
        for filename in sorted(files):
            outline = []
            outline.extend(filename_to_symbols.get(filename, []))
            outline.extend([
                "{}@{:08x}".format(s, v)
                for s, v in self.file_symvers[filename].items()
            ])
            outline.sort()
            fp.write("F#{} {}\n".format(filename, " ".join(outline)))

    def deps(self, symbol):
        """Return dependencies of a symbol"""
        deps = set()
        for tok in self.symtok[symbol]:
            if "#" not in tok:
                continue  # not a symbol
            deps.add(tok)
        return deps

    @staticmethod
    def identify_kabi_difference(st1, st2, symbol):
        """
        Given two symbol types databases and a symbol, identify any difference

        The goal here is not to identify every type which references a changed
        type (as the genksyms approach does). Instead, we want to identify just
        the types which actually changed.
        """
        queue = collections.deque()
        queue.append(symbol)
        changed_symbols = set()
        seen = set()
        fn1 = st1.exports[symbol]
        fn2 = st2.exports[symbol]
        while queue:
            symbol = queue.popleft()
            seen.add(symbol)
            vs1 = st1.versioned(symbol, fn1)
            vs2 = st2.versioned(symbol, fn2)
            if st1.symcrc[vs1] != st2.symcrc[vs2]:
                changed_symbols.add((symbol, vs1, vs2))
            for sym in st1.deps(vs1) & st2.deps(vs2):
                if sym not in seen:
                    queue.append(sym)
        return changed_symbols


def read_symvers(filename):
    """
    Read a Module.symvers formatted file.
    Return a dict mapping:
    symbol name -> (CRC32 hash, containing module, module license)
    """
    vers = {}
    for line in open(filename):
        split = line.strip().split()
        symbol = split[1]
        hash_ = split[0]
        # Some lines have 5 fields (the 5th being the symbol namespace).
        # Even in this case, index 2 and 3 are dir_ and type_.
        dir_, type_ = split[2], split[3]
        vers[symbol] = (hash_, dir_, type_)
    return vers


def read_lockedlist(filename):
    """Read a "kabi_lockedlist" file, returning a list of symbols."""
    with open(filename) as f:
        return set([
            line.strip() for line in f if not line.startswith("[")
        ])


def read_symbols(filename):
    """Read a set of symbols from either a lockedlist or a symvers file"""
    with open(filename) as f:
        first_line = f.readline()
    if first_line.startswith("["):
        return read_lockedlist(filename)
    else:
        return set(read_symvers(filename).keys())


def print_diffs(diffs, st1, st2, desc1="kABI", desc2="Build"):
    """Print the differences returned by identify_kabi_differences()"""
    for orig, v1, v2 in diffs:
        name = st1.name(orig)
        sys.stdout.writelines(difflib.unified_diff(
            [line + "\n" for line in st1.gen_short_decl(v1).split("\n")],
            [line + "\n" for line in st2.gen_short_decl(v2).split("\n")],
            fromfile="{} - {}".format(name, desc1),
            tofile="{} - {}".format(name, desc2),
        ))
        print()

# Sub-commands:
def check(args):
    """
    Do a kABI check as check-kabi script would have done it.
    If there is a failure, use the type databases to print diagnostics.
    """
    symvers_kabi = read_symvers(args.symvers_kabi)
    symvers_build = read_symvers(args.symvers_build)

    changed_license = []
    moved = []
    changed_abi = []
    unexported = []
    for sym, (hash_, dir_, type_) in symvers_kabi.items():
        if sym not in symvers_build:
            unexported.append(sym)
            continue
        build_hash, build_dir, build_type = symvers_build[sym]
        if build_hash != hash_:
            changed_abi.append(sym)
        if build_dir != dir_:
            moved.append(sym)
        if build_type != type_:
            changed_license.append(sym)
    if moved:
        print("*** WARNING - ABI SYMBOLS MOVED ***")
        print()
        print("The following symbols moved (typically caused by moving a symbol from being")
        print("provided by the kernel vmlinux out to a loadable module). This is not an")
        print("error, but is being reported for completeness:")
        print()
        print("\n".join(moved))
        print()
    if changed_license:
        print("*** ERROR - SYMBOL LICENSE HAS CHANGED ***")
        print()
        print("The usage license for the following symbols has changed (this will cause an ABI breakage):")
        print()
        print("\n".join(changed_license))
        print()
    if changed_abi:
        print("*** ERROR - ABI BREAKAGE WAS DETECTED ***")
        print()
        print("The following symbols have been changed (this will cause an ABI breakage):")
        print("For help diagnosing why, please see the diffs below the listing.")
        print()
        print("\n".join(changed_abi))
        print()
    if unexported:
        print("*** ERROR - ABI SYMBOL WAS REMOVED ***")
        print()
        print("The following symbols have been removed, or unexported. This will cause")
        print("an ABI breakage:")
        print()
        print("\n".join(unexported))
        print()

    if not changed_license and not changed_abi and not unexported:
        # No error conditions? Exit success.
        sys.exit(0)

    if not changed_abi:
        # Our diagnostics only apply to changed_abi symbols. If there are none,
        # exit with failure status because we still have an ABI breakage.
        sys.exit(1)

    if not args.symtypes_kabi or not args.symtypes_build:
        print("*** MISSING --symtypes-kabi AND --symtypes-build ***")
        print()
        print("Without this information, we cannot provide detailed diagnostics.")
        sys.exit(1)

    st_kabi = SymTypes.from_file(args.symtypes_kabi)
    st_build = SymTypes.from_file(args.symtypes_build)

    print("*** DETECTED TYPE DIFFERENCES ***")
    print()
    diffs = set()
    for symbol in changed_abi:
        diffs |= SymTypes.identify_kabi_difference(st_kabi, st_build, symbol)
    print_diffs(diffs, st_kabi, st_build)

    sys.exit(1)


def report(args):
    """
    Read a symtypes file and output every exported function's CRC32.
    """
    t = SymTypes.from_file(args.symtypes)
    for exported, filename in t.exports.items():
        decl = t.gen(exported, filename)
        crc = zlib.crc32(decl.encode()) & 0xFFFFFFFF
        print("0x{:08x}\t{}".format(crc, exported))


def collect_helper(directory, output, minimize_kabi, verbose=True):
    st = SymTypes()
    for dirpath, dirnames, filenames in os.walk(directory):
        for fn in filenames:
            if fn.endswith(".symtypes"):
                file_path = os.path.join(dirpath, fn)
                with open(file_path) as fp:
                    relative_path = os.path.relpath(file_path, directory)
                    st.add_file(fp, relative_path)
    st.consolidate_symvers()
    if minimize_kabi:
        st.filter_exports(minimize_kabi, verbose)
    with open(output, "w") as fp:
        st.write(fp)


def collect(args):
    kabi_syms = None
    if args.minimize_kabi:
        kabi_syms = read_symbols(args.minimize_kabi)
    return collect_helper(args.directory, args.output, kabi_syms)


def consolidate(args):
    st = SymTypes.from_file(args.input)
    if args.kabi:
        symbols = read_symbols(args.kabi)
        st.filter_exports(symbols, True)
    with open(args.output, "w") as f:
        st.write(f)


def compare_helper(symtypes_lhs, symtypes_rhs, print_missing=False, print_symbols=True):
    st_lhs = SymTypes.from_file(symtypes_lhs)
    st_rhs = SymTypes.from_file(symtypes_rhs)
    lhs_syms = set(st_lhs.exports.keys())
    rhs_syms = set(st_rhs.exports.keys())
    common_symbols = lhs_syms & rhs_syms
    if not common_symbols:
        print("error: these share no symbols in common, nothing to compare!")
        sys.exit(1)
    if print_missing:
        lhs_only = lhs_syms - rhs_syms
        if lhs_only:
            print("The following symbols appear only in the Baseline:")
            print("\n".join(lhs_only))
        rhs_only = rhs_syms - lhs_syms
        if rhs_only:
            print("The following symbols appear only in the Comparison:")
            print("\n".join(rhs_only))
    diffs = set()
    ret = 0
    diff_syms = []
    for symbol in common_symbols:
        if st_lhs.crc(symbol) != st_rhs.crc(symbol):
            diff_syms.append(symbol)
            ret = 1
            diffs |= SymTypes.identify_kabi_difference(st_lhs, st_rhs, symbol)
    if print_symbols:
        print("The following symbols differ:")
        print("\n".join(diff_syms))
    print_diffs(diffs, st_lhs, st_rhs, "Baseline", "Comparison")
    sys.exit(ret)


def compare(args):
    return compare_helper(args.symtypes_lhs, args.symtypes_rhs, args.print_missing, args.print_symbols)


def debug(args):
    tmp = tempfile.mkdtemp()
    symtypes_build = os.path.join(tmp, "Symtypes.build")
    try:
        minimize_kabi = set(SymTypes.from_file(args.symtypes).exports.keys())
        collect_helper(args.directory, symtypes_build, minimize_kabi, verbose=False)
        compare_helper(symtypes_build, args.symtypes, True, True)
    finally:
        shutil.rmtree(tmp)


def smoke(args):
    kabi_st = SymTypes.from_file(args.symtypes)
    kabi_st_syms = set(kabi_st.exports.keys())
    kabi_sv = read_symvers(args.symvers)
    kabi_sv_syms = set(kabi_sv.keys())
    lockedlist = read_lockedlist(args.lockedlist)
    ret = 0
    # First, compare lockedlist with Module.kabi.
    sv_only = kabi_sv_syms - lockedlist
    ll_only = lockedlist - kabi_sv_syms
    if sv_only:
        print("ERROR: Module.kabi contains symbols not in kabi_lockedlist!")
        print("\n".join(sorted(sv_only)))
        ret = 1
    if ll_only:
        print("ERROR: kabi_lockedlist contains symbols not in Module.kabi!")
        print("\n".join(sorted(ll_only)))
        ret = 1

    # Second, compare Symtypes.kabi with Module.kabi
    st_only = kabi_st_syms - kabi_sv_syms
    sv_only = kabi_sv_syms - kabi_st_syms
    if sv_only:
        print("ERROR: Module.kabi contains symbols not in Symtypes.kabi!")
        print("\n".join(sorted(sv_only)))
        ret = 1
    if st_only:
        print("ERROR: Symtypes.kabi contains symbols not in Module.kabi!")
        print("\n".join(sorted(st_only)))
        ret = 1
    differing_syms = []
    common_syms = kabi_st_syms & kabi_sv_syms
    for sym in common_syms:
        crc = int(kabi_sv[sym][0], 16)
        computed_crc = kabi_st.crc(sym, kabi_st.exports[sym])
        if computed_crc != crc:
            differing_syms.append((sym, crc, computed_crc))
    if differing_syms:
        differing_syms.sort()
        print(
            "ERROR: some symbol versions computed via Symtypes.kabi do not "
            "match their corresponding entries from Module.kabi:"
        )
        print("Computed\tRecorded\tSymbol")
        for sym, crc, computed_crc in differing_syms:
            print("{:08x}\t{:08x}\t{}".format(computed_crc, crc, sym))
        ret = 1

    if ret:
        print("NOTE: These smoke tests errors do not indicate kABI breakages.")
        print("Instead, they indicate an error in the maintenance of UEK kABI.")
        print("Unless you've modified Module.kabi or Symtypes.kabi, you should")
        print("report these to the UEK maintainers.")
    sys.exit(ret)


def main():
    p = argparse.ArgumentParser(
        description="kABI checking and diagnostics",
    )
    subp = p.add_subparsers(title="sub-command")

    report_p = subp.add_parser(
        "report",
        help=("Compute and report symbol versions using only data from "
              "a symtypes file. Mostly for debugging this script."),
    )
    report_p.set_defaults(func=report)
    report_p.add_argument(
        "symtypes", help="symtypes file or collection"
    )

    check_p = subp.add_parser(
        "check",
        help=("Compare Module.symvers with Module.kabi. If kABI is broken, "
              "print a report. If both symtypes files are provided, then "
              "broken kABI will also trigger a diagnostic that identifies "
              "the type declarations causing the breakage."),
    )
    check_p.set_defaults(func=check)
    check_p.add_argument(
        "--symvers-kabi", "-k",
        help="Module.kabi file (the kABI declaration)",
        required=True,
    )
    check_p.add_argument(
        "--symvers-build", "-s",
        help="Module.symvers file (created from a recent build)",
        required=True,
    )
    check_p.add_argument(
        "--symtypes-kabi", "-K",
        help="Symtypes.kabi file (symtypes collection from kABI)",
    )
    check_p.add_argument(
        "--symtypes-build", "-S",
        help="Symtypes.build file (symtypes collection from build)",
    )


    collect_p = subp.add_parser(
        "collect",
        help="Collect all symtypes files from a kernel build to a file",
    )
    collect_p.set_defaults(func=collect)
    collect_p.add_argument(
        "directory", help="directory to search recursively",
    )
    collect_p.add_argument(
        "-o", "--output",
        default="Symtypes.build",
        help="File to write the collection to",
    )
    collect_p.add_argument(
        "--minimize-kabi",
        help=("A Module.kabi or kabi_lockedlist file to use for minimizing "
              "the size of the Symtypes. Similar in operation to the "
              "consolidate subcommand."),
    )

    consolidate_p = subp.add_parser(
        "consolidate",
        help=("Given a kABI, filter the Symtypes to just data necessary "
              "to validate kABI symbols. With no kABI, just reads and "
              "re-writes the Symtypes, which could be good for testing."),
    )
    consolidate_p.set_defaults(func=consolidate)
    consolidate_p.add_argument(
        "--input", "-i",
        help="Input Symtypes",
        required=True,
    )
    consolidate_p.add_argument(
        "--output", "-o",
        help="Output Symtypes",
        required=True,
    )
    consolidate_p.add_argument(
        "--kabi", "-k",
        help=("Module.symvers or kabi_lockedlist containing the list of "
              "kABI symbols to filter to."),
    )

    compare_p = subp.add_parser(
        "compare",
        help="Given two Symtypes files, compare all exported symbols in both.",
    )
    compare_p.set_defaults(func=compare)
    compare_p.add_argument(
        "symtypes_lhs",
        help="Symtypes file for baseline (e.g. kABI reference)",
    )
    compare_p.add_argument(
        "symtypes_rhs",
        help=("Symtypes file for comparison (e.g. mm/slub.symtypes or "
              "Symtypes.build)"),
    )
    compare_p.add_argument(
        "--no-print-symbols",
        action="store_false",
        dest="print_symbols",
        help=("Do not print the symbols whose symvers differ. Only print the "
              "underlying type differences."),
    )
    compare_p.add_argument(
        "--print-missing",
        action="store_true",
        help="Print symbols which are present in only one of the files.",
    )

    debug_p = subp.add_parser(
        "debug",
        help=("Given a local build directory, compare against a Symtypes.kabi. "
              "Roughly equivalent to running collect, consolidate, and then "
              "compare."),
    )
    debug_p.set_defaults(func=debug)
    debug_p.add_argument(
        "directory",
        help="local full build directory",
    )
    debug_p.add_argument(
        "symtypes",
        help="kABI symtypes file",
    )

    smoke_p = subp.add_parser(
        "smoke",
        help="Smoke test a given lockedlist, Module.kabi, and Symtypes.kabi",
    )
    smoke_p.set_defaults(func=smoke)
    smoke_p.add_argument(
        "--symtypes", "-t",
        help="Symtypes.kabi file",
    )
    smoke_p.add_argument(
        "--symvers", "-v",
        help="Module.kabi file",
    )
    smoke_p.add_argument(
        "--lockedlist", "-l",
        help="kabi_lockedlist file",
    )

    args = p.parse_args()
    if not getattr(args, "func", None):
        p.print_help()
        sys.exit(1)
    args.func(args)


if __name__ == '__main__':
    main()
