#!/usr/bin/python3
# -*- mode: Python; coding: utf-8 -*-

# Recombine /lib passwd and group files to /etc
#
# Copyright (C) 2016  Endless Mobile, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

# On OSTree systems, we keep the systems users in /lib/passwd and
# /lib/group so they can be maintained via ostree upgrades. However, on
# converted systems, debian maintainer scripts try to update the passwd
# databases via the shadow utilities like useradd and usermod. These
# programs have no idea about /lib/passwd and /lib/group, so we want to
# merge the files back together when converting.

from argparse import ArgumentParser
import os
import shutil
import subprocess
import sys

DEBUG = False
def debug(*args, **kwargs):
    """Print message when debugging enabled"""
    if DEBUG:
        print(*args, **kwargs)

def parse_passwd_entry(line):
    """Split the passwd line into fields"""
    fields = ['name', 'password', 'uid', 'gid', 'gecos', 'homedir',
              'shell']
    values = line.split(':')
    if len(values) != len(fields):
        raise Exception('There are not %d fields in passwd entry "%s"'
                        %(len(fields), line))
    return dict(zip(fields, values))

def parse_group_entry(line):
    """Split the group line into fields"""
    fields = ['name', 'password', 'gid', 'users']
    values = line.split(':')
    if len(values) != len(fields):
        raise Exception('There are not %d fields in group entry "%s"'
                        %(len(fields), line))
    return dict(zip(fields, values))

def merge_files(root, type, dry_run=False, keep=False):
    """Merge /lib account file into /etc"""
    # Make sure type if valid
    if type not in ['passwd', 'group']:
        raise Exception('Unrecognized file type %s' % type)

    # Handle passwd or group files
    etc_file = os.path.join(root, 'etc', type)
    lib_file = os.path.join(root, 'lib', type)
    if type == 'passwd':
        id_field = 'uid'
        parser = parse_passwd_entry
        converter = 'pwconv'
    else:
        id_field = 'gid'
        parser = parse_group_entry
        converter = 'grpconv'

    # Sanity checks
    if not os.path.exists(etc_file):
        raise Exception('%s not found' % etc_file)
    if not os.path.exists(lib_file):
        print('No', lib_file, 'present, skipping')
        return

    # Read out the current etc file and with maps by id and name
    etc_file_ids = {}
    etc_file_names = {}
    with open(etc_file) as f:
        for line in f:
            line = line.strip()
            entry = parser(line)
            etc_file_ids[entry[id_field]] = entry
            etc_file_names[entry['name']] = entry

    debug(etc_file, 'ids:')
    debug('\n'.join(map(str, etc_file_ids.items())))
    debug(etc_file, 'names:')
    debug('\n'.join(map(str, etc_file_names.items())))

    # List of lines from the /lib file
    lib_file_lines = []
    with open(lib_file) as f:
        for line in f:
            lib_file_lines.append(line.strip())

    debug(lib_file, 'entries:')
    debug('\n'.join(lib_file_lines))

    # Find which entries need merging
    merge_entries = []
    bad_entries = []
    for line in lib_file_lines:
        entry = parser(line)
        id = entry[id_field]
        name = entry['name']
        if id in etc_file_ids:
            # This id already exists in /etc. Make sure the name matches.
            etc_entry = etc_file_ids[id]
            if etc_entry['name'] != name:
                bad_entries.append(line)
                print('Warning: ID number', id, 'in', etc_file,
                      'with different name')
            continue
        if name in etc_file_names:
            # This name already exists in /etc. Make sure the id matches.
            etc_entry = etc_file_ids[name]
            if etc_entry[id_field] != id:
                bad_entries.append(entry)
                print('Warning: name', name, 'in', etc_file,
                      'with different ID number')
            continue
        merge_entries.append(line)

    debug(lib_file, 'entries to merge:')
    debug('\n'.join(map(str, merge_entries)))
    debug(lib_file, 'bad entries:')
    debug('\n'.join(map(str, bad_entries)))

    if len(merge_entries) == 0:
        debug('No entries to merge')
        return

    if dry_run:
        print('Dry run. Would add the following entries to %s:'
              % etc_file)
        print('\n'.join(merge_entries))
        return

    # Write out the new entries in /etc. This is mimicing shadow, which
    # will make a backup with a - suffix, write the new file with a +
    # suffix, then move the new file in place.
    etc_backup = etc_file + '-'
    etc_new = etc_file + '+'
    debug('Backing up', etc_file, 'to', etc_backup)
    shutil.copy2(etc_file, etc_backup)
    debug('Copying', etc_file, 'to', etc_new)
    shutil.copy2(etc_file, etc_new)

    # Append the entries for merging
    debug('Merging new entries into', etc_new)
    with open(etc_new, 'a') as f:
        for line in merge_entries:
            f.write(line + '\n')

    # Rename into place
    debug('Renaming', etc_new, 'to', etc_file)
    os.rename(etc_new, etc_file)

    # If there were no bad entries, remove the /lib file
    if len(bad_entries) > 0:
        print('Keeping', lib_file, 'since there were conflicts with',
              etc_file)
    elif keep:
        print('Keeping', lib_file, 'per user request')
    else:
        debug('Removing', lib_file)
        os.unlink(lib_file)

    # Regenerate entries in the corresponding shadow files
    debug('Running', converter)
    subprocess.check_call([converter, '-R', os.path.abspath(root)])

def check_file(root, type):
    """Check /etc account file valid after merging"""
    # Make sure type if valid
    if type not in ['passwd', 'group']:
        raise Exception('Unrecognized file type %s' % type)

    # Handle passwd or group files
    etc_file = os.path.join(root, 'etc', type)
    if type == 'passwd':
        checker = 'pwck'
    else:
        checker = 'grpck'

    debug('Running', checker)
    subprocess.check_call([checker, '-q', '-R', os.path.abspath(root)])

argparser = ArgumentParser(description='Merge /lib/passwd entries into /etc/passwd')
argparser.add_argument('-r', '--root', default='/',
                       help='alternate root directory')
argparser.add_argument('-n', '--dry-run', action='store_true',
                       help='just show what would happen')
argparser.add_argument('-k', '--keep', action='store_true',
                       help='keep /lib files even if not needed')
argparser.add_argument('-c', '--check', action='store_true',
                       help='check /etc files after merging')
argparser.add_argument('--debug', action='store_true',
                       help='enable debugging messages')
args = argparser.parse_args()
DEBUG = args.debug

merge_files(args.root, 'passwd', args.dry_run, args.keep)
merge_files(args.root, 'group', args.dry_run, args.keep)
if args.check:
    check_file(args.root, 'passwd')
    check_file(args.root, 'group')
