#!/usr/bin/env python3
#
# Copyright (c) 2025 Klara, Inc.
# Copyright (c) 2026 Mark Johnston <markj@FreeBSD.org>
#
# SPDX-License-Identifier: BSD-2-Clause
#
# git-mfc - Helper command for MFCing commits to stable or release branches.
#

import argparse
import datetime
import os
import re
import subprocess
import sys


def err(code, msg):
    print(os.path.basename(sys.argv[0]) + ': ' + msg, file=sys.stderr)
    sys.exit(code)


try:
    import git
except ImportError:
    err(1, "gitpython module missing; install it with pip or the devel/py-gitpython package")


def commit_summary(commit):
    return f'{commit.hexsha} ("{commit.summary}")'


def origin_branch():
    """
    Determine the default origin branch by parsing sys/conf/newvers.sh.
    """
    newvers = 'sys/conf/newvers.sh'
    if not os.path.exists(newvers):
        err(1, f"{newvers} not found; use --origin to specify the upstream branch")

    revision = None
    branch = None
    with open(newvers) as f:
        for line in f:
            m = re.match(r'^REVISION="([^"]+)"', line)
            if m:
                revision = m.group(1)
            m = re.match(r'^BRANCH="([^"]+)"', line)
            if m:
                branch = m.group(1)

    if branch is None:
        err(1, f"could not determine BRANCH from {newvers}")
    if branch == "CURRENT":
        err(1, "this is a CURRENT tree, we do not MFC to CURRENT")
    if branch == "STABLE":
        return "main"
    if re.match(r'^RELEASE(-p\d+)?$', branch):
        if revision is None:
            err(1, f"could not determine REVISION from {newvers}")
        major = revision.split('.')[0]
        return f"stable/{major}"

    err(1, f"unexpected BRANCH value '{branch}' in {newvers}")


def fixes(repo, commit, warn=False):
    """
    Look at the commit's log message and return a list of commits that it fixes.

    Hard-code a bunch of special cases in the src repo, where the fixes tag is
    valid but malformed somehow.
    """
    special = {
        "a693d17b9985a03bd9b5108e890d669005ab41eb": ["8a42005d1e491932666eb9f0be3e70ea1a28a3f7"],
        "a04ca1c229195c7089b878a94fbe76505ea693b5": ["93b7818226cf5270646725805b4a8c17a1ad3761"],
        "852088f6af6c5cd44542dde72aa5c3f4c4f9353c": ["b5fb9ae6872c499f1a02bec41f48b163a73a2aaa"],
        "9c5d7e4a0c02bc45b61f565586da2abcc65d70fa": ["bec000c9c1ef409989685bb03ff0532907befb4a"],
        "894cb08f0d3656fdb81f4d89085bedc4235f3cb6": ["5678d1d98a348f315453555377ccb28821a2ffcd"],
        "cbddb2f02c7687d1039abcffd931e94e481c11a5": ["0118b0c8e58a438a931a5ce1bf8d7ae6208cc61b"],
        "e9da71cd35d46ca13da4396d99e0af1703290e68": [],
        "43cd6bbba053f02999d846ef297655ebec2d218b": ["38cbdae33b7c3f772845c06f52b86c0ddeab6a17"],
        "94efe9f91be7f3aa501983716db5a4378282a734": ["7520b88860d7a79432e12ffcc47056844518bb62"],
        "b332adfa96218148dfbb936a9c09d00484c868e3": ["7520b88860d7a79432e12ffcc47056844518bb62"],
        "b911f504005df67f8c25f9b2f817c16588cd309c": ["801fb66a7e34c340f23d82f2b375feee4bd87df4"],
        "3d37e7e5f540f513ab1d8fa61d9208c43b889401": ["fb9baa9b2045a193a3caf0a46b5cac5ef7a84b61"],
        "0912408a281f203c43d0b3f73c38117336588342": ["9e33a616939fcff87f7539e3c41323deca5c74ce"],
        "c036339ddf0cf80164f41ea31f1d8d27f4a068a9": ["a305b20ead13bb29880e15ff20c3bb83b5397a82"],
        "9b56dfd27c64fcaf5dfbaa1eb3e2bd2b163fa56c": [],
        "ab92cab02633580f763a38a329a5b25050bb4fbf": [],
        "28fdb212adc0431fff683749a1307038e25ff58e": [],
        "811912c46b5886f1aa3bb7a51a6ec1270bc947a8": [],
        "813d981e1e78daffde4b2a05df35d054fcb4343f": [],
        "ecf2a89a997ad4a14339b6a2f544e44b422620a0": [],
        "f6517a7e69c10c6057d6c990a9f3ea22a2b62398": [],
        "a5770eb54f7d13717098b5c34cc2dd51d2772021": ["86c06ff8864bc0e2233980c741b689714000850d"],
        "b84d0aaa4e64fb95b105d0d38f6295fec7a82110": [],
        "b586c66baf4824d175d051b3f5b06588c9aa2bc8": [],
    }
    if commit.hexsha in special:
        return [repo.commit(c) for c in special[commit.hexsha]]

    if commit.committed_date < datetime.datetime(2021, 1, 1).timestamp():
        return []

    regexp = re.compile(r'^([0-9a-f][0-9a-f]+)[, \(]?')

    fixed = []
    for line in commit.message.splitlines():
        for tag, offset in (("Fixes", 1), ("MFC-with", 1), ("MFC with", 2)):
            if line.startswith(tag + ':'):
                line = ' '.join(line.split(':')[1:]).strip()
                m = regexp.match(line)
                if not m:
                    continue
                try:
                    fixed.append(repo.commit(m.group(1)))
                except (git.exc.BadName, ValueError):
                    if warn:
                        print(f"warning: {commit.hexsha[:12]}: {tag} tag "
                              f"references unknown commit {m.group(1)}",
                              file=sys.stderr)
    return fixed


def reverts(commit):
    """
    Look at the commit's log message and return the hash of the commit it
    reverts, or None.
    """
    for line in commit.message.splitlines():
        m = re.match(r'^This reverts commit ([0-9a-f]{40})\.', line)
        if m:
            return m.group(1)
    return None


def mfcclosure(repo, upstream, commits, _reverted=None):
    """
    Given a set of commits to MFC from the upstream branch, return a tuple
    (tomfc, reverted) where tomfc is a list of all commits in the original
    set plus those which fix up commits in the original set (excluding
    reverted fixups), and reverted is the set of commit hashes that have
    been reverted on the upstream branch.

    Reverts are collected during the upstream walk, so no separate scan is
    needed.  Recursive calls share the same reverted set via _reverted.
    """
    tomfc = []
    for rev in commits:
        try:
            if ".." in rev:
                revlist = list(repo.iter_commits(rev))
                revlist.reverse()
                tomfc.extend(revlist)
            else:
                tomfc.append(repo.commit(rev))
        except git.exc.BadName:
            err(1, f"Revision '{rev}' is invalid")

    toplevel = _reverted is None
    if toplevel:
        _reverted = set()

    tovisit = tomfc.copy()

    for commit in repo.iter_commits(upstream):
        r = reverts(commit)
        if r:
            _reverted.add(r)

        if commit in tovisit:
            tovisit.remove(commit)
            if len(tovisit) == 0:
                break
            continue

        for fix in fixes(repo, commit):
            if fix in tomfc and commit not in tomfc:
                if commit.hexsha in _reverted:
                    print(f"warning: skipping {commit_summary(commit)}: "
                          f"fixup was reverted upstream", file=sys.stderr)
                    continue
                i = tomfc.index(fix) + 1
                sub, _ = mfcclosure(repo, upstream, [commit.hexsha],
                                    _reverted=_reverted)
                tomfc[i:i] = sub
    else:
        err(1, "the following commits are not in the upstream branch: " +
            ', '.join([c.hexsha for c in tovisit]))
    return (tomfc, _reverted)


def mfc_origins(repo, commit, warn=False):
    """
    Given a commit, return a list of upstream commits that it was
    cherry-picked from.  This is usually a single commit, but may be
    multiple if several upstream commits were squashed into one MFC.

    Handle a bunch of special commits where the existing cherry-pick
    metadata is incorrect.
    """
    special = {
        # stable/15
        "e224f2c1b98ba065942c83d322cdb8e3987d0c81": ["068fea0aa15bceb7b6b01687542b58ee81d1d887"],
        "48578fbefa666abfcfe3a47269343febefae94f9": ["728ec0c094ce473ae17ebd1adb05f0959bf3a68e"],
        "2bf7d850c53e55ffced055235de432055f0cb9e2": ["7e79bc8ce70693a892c443c42af5ec16a95ba466"],
        "e554d44c3a7ffa8664e89f35aec80b917f69e6b3": ["561dc357c2f5892af3aa481a1020860b7ff473e0"],
        # stable/14
        "aab45924bdb10338654e25ab9ecec106b7eb368b": ["c57c02ebf7bcc9b02a0dc11711e8d8a6960ad34b"],
        "11dde2c8b7156a7d2072589c22c2f3c0de6880d8": ["62af5b9dc6205289a0ace964d060fba64e71ef28"],
        "19f7b2bbc4e4fec33b7e8d546dd05a79533ca8e4": ["d0cbb1930e82a53b07b1091402ff14cdfe7a4898"],
        "d034ff89b84f8470eb5bbcbb65c64c762e07f0bd": ["ac2156c139f8f685b84a71a7f0f164d6cccc7656"],
        "fce2a3509a65a374820dc889929f8e8f5dbd1707": ["a9722e5ae8519a9a28e950dbd8d489178e9bc27e"],
        "5946b0c6cbc77e6c5f62f5f7e635c6036e14f4d0": [],
        "06bb8e1dab004ccb283f7a20fe84aa1326baf6b7": [],
        "e5fadc41b48045d8978993d6c4ac72c64542b470": ["e20971500194d2f7299e9d01ca3b20e9bc6b4009"],
        # stable/13
        "fb4bc4c325eca15247a2e21ded0a70f92ec15488": ["c57c02ebf7bcc9b02a0dc11711e8d8a6960ad34b"],
        "213406054e46b50d5df6d28b01a0a3410132b322": ["62af5b9dc6205289a0ace964d060fba64e71ef28"],
        "eb6edafe8f484a69ff074018e5b93e481787923e": ["d0cbb1930e82a53b07b1091402ff14cdfe7a4898"],
        "ced9fa71eaf95a27d5672c3419d7d9ca6c189168": ["ac2156c139f8f685b84a71a7f0f164d6cccc7656"],
        "65bab39e140f97cace92a2923e50c6b654b02e22": ["346483b1f10454c5617a25d5e136829f60fb1184"],
        "972637dc06a04432dc58e240b8ef3e9f538b98bb": ["9209ea69bc03e7e9f678b2294da7a0317b5c9c5b"],
        "93a95ebbf7c8eb85aeb53a9fde329348992333e7": ["6ccff5c0452c4bc2a4dd497e39a801ab8db8a021"],
        "9d31ae318711825d3a6ffa544d197708905435cf": ["aef815e7873b006bd040ac1690425709635e32e7"],
        "a75324d674f5df10f1407080a49cfe933dbb06ec": ["64ae2f785e0672286901c15045a24cbc533538d3"],
        "0729ba2f49c956789701aecb70f4f555181fd3a7": ["91a8bed5a49eb2d1e4e096a4c68c108cebec8818"],
        "bec0d2c9c8413707b0fff8e65fb96aa53f149be3": ["8d5c7813061dfa0b187500dfe3aeea7a28181c13"],
    }
    if commit.hexsha in special:
        return [repo.commit(s) for s in special[commit.hexsha]]

    origins = []
    for line in commit.message.splitlines():
        m = re.match(r'^\(cherry picked from commit ([0-9a-f]{40})', line)
        if m:
            try:
                origins.append(repo.commit(m.group(1)))
            except (git.exc.BadName, ValueError):
                if warn:
                    print(f"warning: {commit.hexsha[:12]}: cherry-picked from "
                          f"unknown commit {m.group(1)}", file=sys.stderr)
    return origins


def dangling(repo, upstream, author=None):
    """
    Find commits in the current branch which have been cherry-picked from
    upstream but are missing fixup commits.
    """
    base = repo.merge_base('HEAD', upstream, all=True)[-1]
    result = []
    picked = []
    mine = []
    for commit in repo.iter_commits(base.hexsha + '..HEAD'):
        refs = mfc_origins(repo, commit, warn=True)
        for ref in refs:
            picked.append(ref.hexsha)
            if author is None or author in commit.author.email:
                mine.append(ref.hexsha)
    # Create a mapping of all upstream commits to their fixups,
    # tracking reverts as we go.
    fixups = {}
    reverted = set()
    for commit in repo.iter_commits(base.hexsha + '..' + upstream):
        r = reverts(commit)
        if r:
            reverted.add(r)
        for fix in fixes(repo, commit, warn=True):
            if fix.hexsha not in fixups:
                fixups[fix.hexsha] = []
            fixups[fix.hexsha].append(commit.hexsha)
    # For each of the user's cherry-picks, find fixups that haven't been
    # cherry-picked (by anyone) or reverted.
    for commit in mine:
        for f in fixups.get(commit, []):
            if f not in picked and f not in reverted:
                result.append((repo.commit(f), repo.commit(commit)))
    return result


def mfc_after(commit):
    """
    Parse "MFC after:" tags from a commit message and return the time delta,
    or None if there is no parseable MFC-after tag.
    """
    units = {
        'day': 1, 'days': 1,
        'week': 7, 'weeks': 7,
        'month': 30, 'months': 30,
    }
    for line in commit.message.splitlines():
        m = re.match(r'^MFC[ -]after:\s*(\d+)\s+(\w+)', line, re.IGNORECASE)
        if m:
            count = int(m.group(1))
            unit = m.group(2).lower()
            if unit in units:
                return datetime.timedelta(days=count * units[unit])
    return None


def pending(repo, upstream, author=None, baking=False):
    """
    Find commits in the upstream branch that have an "MFC after:" tag, whose
    waiting period has elapsed, and which have not been cherry-picked to the
    current branch.

    If baking is True, also include commits whose waiting period
    has not yet elapsed.
    """
    base = repo.merge_base('HEAD', upstream, all=True)[-1]

    # Collect the set of upstream commits already cherry-picked.
    picked = set()
    for commit in repo.iter_commits(base.hexsha + '..HEAD'):
        for ref in mfc_origins(repo, commit):
            picked.add(ref.hexsha)

    results = {}

    now = datetime.datetime.now(datetime.timezone.utc)
    revlist = list(repo.iter_commits(base.hexsha + '..' + upstream))
    revlist.reverse()
    for commit in revlist:
        r = reverts(commit)
        if r and r in results:
            # Don't report reverted commits.
            del results[r]
            continue

        if commit.hexsha in picked:
            # Don't report commits that have already been cherry-picked.
            continue

        delta = mfc_after(commit)
        if delta is None:
            # Not yet eligible for MFC.
            continue

        if author is not None and author not in commit.author.email:
            # Don't report commits that don't match the requested author.
            continue

        committed = datetime.datetime.fromtimestamp(
            commit.committed_date, tz=datetime.timezone.utc)
        ready_date = committed + delta
        ready = now >= ready_date

        if not baking and not ready:
            continue

        results[commit.hexsha] = (commit, ready_date, ready)

    return list(results.values())


def cherry_pick(commit, edit=False):
    """Cherry-pick a single commit with -x.  Abort on failure."""
    cmd = ['git', 'cherry-pick', '-x']
    if edit:
        cmd.append('-e')
    cmd.append(commit.hexsha)
    result = subprocess.run(cmd, capture_output=False)
    if result.returncode != 0:
        print(f"\nCherry-pick of {commit.hexsha} ({commit.summary}) failed.",
              file=sys.stderr)
        print("Resolve the conflict and then re-run git mfc to continue,",
              file=sys.stderr)
        print("or run 'git cherry-pick --abort' to give up.", file=sys.stderr)
        sys.exit(1)


def main():
    parser = argparse.ArgumentParser(
        prog='git mfc',
        description='Cherry-pick commits from main to a stable branch, '
                    'automatically including fixup commits.',
    )
    parser.add_argument(
        '-e', '--edit', action='store_true',
        help='Edit the commit message before committing',
    )
    parser.add_argument(
        '-n', action='store_true',
        help='List the commits that would be cherry-picked, but do not act',
    )
    parser.add_argument(
        '-a', '--author', type=str, default=None,
        help='Filter --pending/--dangling results by author (default: current user)',
    )
    parser.add_argument(
        '--all', action='store_true',
        help='With --pending/--dangling, show commits from all authors',
    )
    parser.add_argument(
        '--baking', action='store_true',
        help='With --pending, also show commits whose MFC-after period '
             'has not yet elapsed',
    )
    parser.add_argument(
        '--dangling', action='store_true',
        help='Find cherry-picked commits in the current branch that are '
             'missing fixup commits from the origin branch',
    )
    parser.add_argument(
        '-f', '--force', action='store_true',
        help='Cherry-pick commits even if they appear to be already present',
    )
    parser.add_argument(
        '--ignore-reverts', action='store_true',
        help='Cherry-pick commits even if they were reverted upstream',
    )
    parser.add_argument(
        '--pending', action='store_true',
        help='Show upstream commits with MFC-after tags that are ready to '
             'be cherry-picked',
    )
    parser.add_argument(
        '--origin', type=str, default=None,
        help='Upstream branch (default: derived from sys/conf/newvers.sh)',
    )
    parser.add_argument(
        '-r', '--remote', type=str, default=None,
        help='Upstream remote (default: tracking remote of current branch)',
    )
    parser.add_argument(
        'commits', nargs='*', metavar='revision',
        help='Commits to MFC (hashes, ranges with "..", etc.)',
    )
    args = parser.parse_args()

    if args.dangling and args.pending:
        err(1, 'usage error: --dangling and --pending are mutually exclusive')
    if (args.dangling or args.pending) and len(args.commits) > 0:
        err(1, 'usage error: revisions cannot be specified with --dangling or --pending')
    if not args.dangling and not args.pending and len(args.commits) == 0:
        err(1, 'usage error: at least one revision is required')

    origin = args.origin if args.origin else origin_branch()

    repo = git.Repo()
    if args.remote:
        remote = args.remote
    else:
        try:
            tracking = repo.active_branch.tracking_branch()
            remote = tracking.remote_name if tracking else 'freebsd'
        except (TypeError, ValueError):
            remote = 'freebsd'

    upstream = remote + '/' + origin

    if not args.n:
        repo.remotes[remote].fetch(origin)

    show_all = getattr(args, 'all')
    if show_all:
        author = None
    elif args.author:
        author = args.author
    elif args.dangling or args.pending:
        author = repo.config_reader().get_value('user', 'email', default=None)
        if author is None:
            err(1, "could not determine user email; use -a or --all")
    else:
        author = None

    if args.dangling:
        missing = dangling(repo, upstream, author=author)
        for fixup, commit in missing:
            print(f'{commit_summary(fixup)} fixes {commit_summary(commit)}')
    elif args.pending:
        results = pending(repo, upstream, author=author,
                          baking=args.baking)
        for commit, ready_date, ready in results:
            date_str = ready_date.strftime('%Y-%m-%d')
            if ready:
                status = f"ready since {date_str}"
            else:
                status = f"ready on {date_str}"
            print(f'{commit_summary(commit)} ({status})')
    else:
        tomfc, reverted = mfcclosure(repo, upstream, args.commits)

        # Check if any explicitly requested commits were reverted.
        if not args.ignore_reverts:
            for commit in tomfc:
                if commit.hexsha in reverted:
                    err(1, f"{commit_summary(commit)} was reverted "
                        f"upstream; use --ignore-reverts to override")

        # Find commits already cherry-picked to the current branch so we
        # can skip them.  This lets the user re-run the same command after
        # resolving a conflict without re-applying earlier commits.
        if not args.force:
            already_picked = set()
            base = repo.merge_base('HEAD', upstream, all=True)[-1]
            for c in repo.iter_commits(base.hexsha + '..HEAD'):
                for ref in mfc_origins(repo, c):
                    already_picked.add(ref.hexsha)
        else:
            already_picked = set()

        if args.n:
            for commit in tomfc:
                if commit.hexsha in already_picked:
                    continue
                print(commit.hexsha, commit.summary)
            return

        for commit in tomfc:
            if commit.hexsha in already_picked:
                print(f"Skipping {commit.hexsha[:12]} {commit.summary} "
                      f"(already cherry-picked)")
                continue
            print(f"Cherry-picking {commit.hexsha[:12]} {commit.summary}")
            cherry_pick(commit, edit=args.edit)


if __name__ == '__main__':
    main()
