/* SPDX-License-Identifier: MIT */
/* SPDX-FileCopyrightText: (c) Copyright 2024 Andrew Bower <andrew@bower.uk> */

#include <assert.h>
#include <ctype.h>
#include <getopt.h>
#include <errno.h>
#include <grp.h>
#include <pwd.h>
#include <sched.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/types.h>

#include "usrgrp.h"

int usrgrp_parse(struct users_groups *ug, const char *arg) {
  enum tok_type t = TOK_NAME;
  enum {
    STATE_USER,
    STATE_GROUP,
    STATE_SUPPLEMENTAL,
  } state = STATE_USER;
  char *copy;
  char *scan;
  ssize_t n_toks;
  int i;

  copy = strdup(arg);
  if (copy == NULL) return -1;

  ug->buf_tok = copy;
  if (*copy == ':') {
    t = TOK_ID;
    copy++;
  }

  for (n_toks = 0, scan = copy; strsep(&scan, ":"); n_toks++);
  if (n_toks > 2) {
    ug->num_supplemental = n_toks - 2;
    ug->supplemental = malloc(ug->num_supplemental * sizeof(struct sys_entry));
    if (ug->supplemental == NULL) {
      free(ug->buf_tok);
      return -1;
    }
  }

  scan = copy;
  for(i = 0; i < n_toks; i++) {
    size_t len = strlen(scan);
    struct sys_entry entry = (struct sys_entry) {
      .tok = scan,
      .tok_type = len ? t : TOK_NONE,
    };
    switch (state) {
    case STATE_USER:
      ug->user = entry;
      state = STATE_GROUP;
      break;
    case STATE_GROUP:
      ug->group = entry;
      state = STATE_SUPPLEMENTAL;
      break;
    case STATE_SUPPLEMENTAL:
      ug->supplemental[i - 2] = entry;
    }
    scan += len + 1;
  }

  return 0;
}

void usrgrp_resolve_uid(struct users_groups *ug, uid_t nid, bool lookup) {
  struct sys_entry *entry = &ug->user;
  struct passwd *password;

  entry->tok_type = TOK_ID;
  entry->uid = nid;
  entry->resolved = true;
  entry->user_gid = (gid_t) -1;
  ug->username = ug->home = ug->shell = NULL;
  if (!lookup)
    return;

  /* Look up numeric ID - but chpst does not do this. Could
   * be useful in future - let's catch build regressions. */
  password = getpwuid(entry->uid);
  if (password) {
    entry->user_gid = password->pw_gid;
    ug->username = strdup(password->pw_name);
    ug->home = strdup(password->pw_dir);
    ug->shell = strdup(password->pw_shell);
  }
}

static int resolve_user(struct users_groups *ug) {
  struct sys_entry *entry = &ug->user;
  struct passwd *password;
  int toks = 0;
  int rc = 0;
  long nid;

  switch (entry->tok_type) {
  case TOK_NONE:
    entry->uid = (uid_t) -1;
    entry->user_gid = (gid_t) -1;
    entry->resolved = false;
    break;
  case TOK_NAME:
    errno = 0;
    password = getpwnam(entry->tok);
    if (password) {
      entry->uid = password->pw_uid;
      entry->user_gid = password->pw_gid;
      entry->resolved = true;
      ug->username = strdup(password->pw_name);
      ug->home = strdup(password->pw_dir);
      ug->shell = strdup(password->pw_shell);
    } else {
      if (errno != 0) {
        rc = errno;
        fprintf(stderr, "getpwnam(\"%s\"): %s\n", entry->tok, strerror(rc));
      } else {
        rc = ENOENT;
        fprintf(stderr, "no such user: %s\n", entry->tok);
      }
      entry->uid = (uid_t) -1;
      entry->user_gid = (gid_t) -1;
      entry->resolved = false;
      ug->username = ug->home = ug->shell = NULL;
    }
    break;
  case TOK_ID:
    toks = sscanf(entry->tok, "%ld", &nid);
    if (toks == 1) {
      usrgrp_resolve_uid(ug, nid, false);
    } else {
      entry->resolved = false;
    }
  }

  return rc;
}

static int resolve_group(struct sys_entry *entry) {
  struct group *group;
  size_t toks;
  int rc = 0;
  long nid;

  switch (entry->tok_type) {
  case TOK_NONE:
    entry->gid = (gid_t) -1;
    entry->resolved = false;
    break;
  case TOK_NAME:
    errno = 0;
    group = getgrnam(entry->tok);
    if (group) {
      entry->gid = group->gr_gid;
      entry->resolved = true;
    } else {
      if (errno != 0) {
        rc = errno;
        fprintf(stderr, "getgrnam(\"%s\": %s\n", entry->tok, strerror(rc));
      } else {
        rc = ENOENT;
        fprintf(stderr, "no such group: %s\n", entry->tok);
      }
      entry->gid = (gid_t) -1;
      entry->resolved = false;
    }
    break;
  case TOK_ID:
    toks = sscanf(entry->tok, "%ld", &nid);
    if (toks == 1) {
      entry->gid = nid;
      entry->resolved = true;
    } else {
      entry->resolved = false;
    }
  }

  return rc;
}

static const char *tok_type_name(enum tok_type tok_type) {
  switch (tok_type) {
  case TOK_NONE:
    return "NONE";
  case TOK_NAME:
    return "NAME";
  case TOK_ID:
    return "ID";
  default:
    return "?";
  }
}

static void print_user(FILE *out, struct sys_entry *entry) {
  fprintf(out, "%s:%d:%d:%s:%s\n",
          entry->tok ? entry->tok : "",
          entry->uid, entry->user_gid,
          tok_type_name(entry->tok_type),
          entry->resolved ? "RESOLVED" : "");
}

static void print_group(FILE *out, struct sys_entry *entry) {
  fprintf(out, "%s:%d:%s:%s\n",
          entry->tok ? entry->tok : "",
          entry->gid,
          tok_type_name(entry->tok_type),
          entry->resolved ? "RESOLVED" : "");
}

void usrgrp_print(FILE *out, const char *what, struct users_groups *ug) {
  int i;

  fprintf(out, "%s:\n  user: ", what);
  print_user(out, &ug->user);
  fprintf(out, "  group: ");
  print_group(out, &ug->group);
  for (i = 0; i < ug->num_supplemental; i++) {
    fprintf(out, "  supplemental: ");
    print_group(out, ug->supplemental + i);
  }
}

int usrgrp_resolve(struct users_groups *ug) {
  int errors = 0;
  int i;

  if (resolve_user(ug))
    errors++;
  if (resolve_group(&ug->group))
    errors++;
  for (i = 0; i < ug->num_supplemental; i++) {
    if (resolve_group(&ug->supplemental[i]))
      errors++;
  }

  /* Use user's group if another one wasn't requested. */
  if (ug->group.tok_type == TOK_NONE &&
      ug->user.resolved == true &&
      ug->user.user_gid != (gid_t) -1) {
    ug->group.tok_type = TOK_ID;
    ug->group.gid = ug->user.user_gid;
    ug->group.resolved = true;

    /* And add its supplemental groups */
    if (ug->username && ug->num_supplemental == 0) {
      int n_groups = 0, n;
      gid_t *groups = NULL;

      n = getgrouplist(ug->username, ug->group.gid, NULL, &n_groups);
      groups = malloc(n_groups * sizeof *groups);
      if (groups == NULL)
        return errors + 1;
      if ((n = getgrouplist(ug->username, ug->group.gid, groups, &n_groups)) == -1) {
        free(groups);
        return errors + 1;
      }
      if ((ug->supplemental = malloc(n * sizeof ug->supplemental[0])) == NULL) {
        free(groups);
        return errors + 1;
      }

      for (i = 0; i < n; i++) {
        if (groups[i] == ug->group.gid)
          continue;
        ug->supplemental[ug->num_supplemental] = (struct sys_entry) {
          .tok = NULL,
          .gid = groups[i],
          .tok_type = TOK_ID,
          .resolved = true,
        };
        ug->num_supplemental++;
      }
      free(groups);
    }
  }

  return errors;
}

void usrgrp_free(struct users_groups *ug) {
  free(ug->home);
  free(ug->shell);
  free(ug->username);
  if (ug->buf_tok)
    free(ug->buf_tok);
  if (ug->supplemental)
    free(ug->supplemental);
}

int usrgrp_from_env(struct users_groups *ug,
                    const char *uid, const char *gid, const char *gidlist) {
  char *scan, *copy;
  ssize_t n_toks, toks;
  int i;
  long nid;
  struct sys_entry *entry;

  if (uid && *uid) {
    entry = &ug->user;
    toks = sscanf(uid, "%ld", &nid);
    if (toks != 1)
      goto fail;
    entry->tok_type = TOK_ID;
    entry->uid = nid;
    entry->resolved = true;
    entry->user_gid = (gid_t) -1;
    ug->username = ug->home = ug->shell = NULL;
  }
  if (gid && *gid) {
    entry = &ug->group;
    toks = sscanf(gid, "%ld", &nid);
    if (toks != 1)
      goto fail;
    entry->tok_type = TOK_ID;
    entry->gid = nid;
    entry->resolved = true;
  }
  if (gidlist && *gidlist) {
    if ((copy = strdup(gidlist)) == NULL)
      goto fail;
    for (n_toks = 0, scan = copy; strsep(&scan, ","); n_toks++);
    ug->num_supplemental = n_toks;
    ug->supplemental = malloc(ug->num_supplemental * sizeof(struct sys_entry));
    if (ug->supplemental == NULL) {
      goto fails;
    }
    scan = copy;
    for(i = 0; i < n_toks; i++) {
      entry = ug->supplemental + i;
      toks = sscanf(scan, "%ld", &nid);
      if (toks != 1)
        goto fails;
      entry->tok_type = TOK_ID;
      entry->gid = nid;
      entry->resolved = true;
      scan += strlen(scan) + 1;
    }
  fails:
    free(copy);
  }
  return 0;

fail:
  return -1;
}
