"""
Py-opredflag.
Copyright (C) 2023 BobDotCom
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import asyncio
import copy
import json
import os
import sys
from datetime import datetime, timezone
from typing import Any, Literal, cast
import aiohttp
import semver
from async_lru import alru_cache
from .enums import Compatibility, VersionComparison
if sys.version_info < (3, 11):
from typing_extensions import NotRequired, TypedDict
else:
from typing import NotRequired, TypedDict
__all__ = ("Updater",)
class FileVersion(TypedDict):
"""An oprf version file version entry."""
version: str | None
path: str
class VersionDataFile(TypedDict):
"""A version data file."""
version: int
last_updated: datetime
data: dict[str, FileVersion | list[FileVersion]]
class VersionDataFileJson(TypedDict):
"""The json serializable version of a VersionDataFile type."""
version: int
last_updated: str
data: dict[str, FileVersion | list[FileVersion]]
class UpdaterData(TypedDict):
"""Output data entry from the updater."""
key: str
path: str
old_version: str
new_version: NotRequired[str]
reason: NotRequired[str]
multi_key: bool
def load_version_data_file(file_path: str) -> VersionDataFile:
"""Load a version data file from a path."""
with open(file_path, encoding="utf-8") as f_obj:
version_data = update_version_data_file(json.load(f_obj))
new_version_data: VersionDataFile = {
**version_data,
"last_updated": datetime.fromisoformat(version_data["last_updated"]),
}
return new_version_data
def dump_version_data_file(version_data: VersionDataFile, file_path: str) -> None:
"""Dump a version data file to a path."""
new_version_data = {
**version_data,
"last_updated": version_data["last_updated"].isoformat(),
}
with open(file_path, "w", encoding="utf-8") as f_obj:
json.dump(new_version_data, f_obj, indent=2)
def update_version_data_file(version_data: dict[Any, Any]) -> VersionDataFileJson:
"""Handle all the updates required if updating from an old file."""
old_file_version = version_data.get("version")
# This is meant to be recursive. Each Version will only upgrade to the next,
# so if multiple versions need updating, they will progress in order.
match old_file_version:
case None:
# Consider this as version 0. Migrate to version 1 and recurse.
# Set last_updated to epoch
return update_version_data_file(
{
"version": 1,
"last_updated": datetime.fromtimestamp(0, timezone.utc).isoformat(),
"data": version_data,
}
)
case 1:
# Currently up-to-date. Assume nothing to do, end recursion and return.
# Before returning, do a quick check to make sure the data is correct.
try:
# We can type ignore here because we're catching value and type errors below
datetime.fromisoformat(version_data.get("last_updated")) # type: ignore
except (ValueError, TypeError) as e:
raise ValueError(
"Couldn't parse version file 'last_updated' value. "
f"Expected an ISO 8601 string, but got: {version_data.get('last_updated')!r}"
) from e
if not isinstance(old_file_data := version_data.get("data"), dict):
raise TypeError(
"Couldn't parse version file 'data' value. "
f"Expected a dict, got {type(old_file_data).__name__}: {old_file_data!r}"
)
# We can skip version because we've already checked it
# Now, since we've done the checks we can cast it to the return type
return cast(VersionDataFileJson, version_data)
case _:
raise ValueError(
f"Expected file version 1, got {old_file_version!r}. Please update your version file."
)
def format_data(data: UpdaterData) -> str:
"""Format :class:`UpdaterData` into a human-readable string.
Parameters
----------
data:
Data to format
Returns
-------
str
The formatted data
"""
versions = "{old_version}"
if "new_version" in data:
versions += "->{new_version}"
reason = f"{data['key']} {versions}"
reason += " ({reason})" if "reason" in data else ""
reason += " ({path})" if data["multi_key"] else ""
return reason.format_map(data)
def compare_versions(first: str | None, second: str | None) -> VersionComparison:
# pylint: disable=too-many-return-statements
"""Compare two semantic version strings.
Parameters
----------
first:
The first version
second:
The second version
Returns
-------
:class:`~.VersionComparision`
The comparison result
"""
if first is None or second is None:
return VersionComparison.UNKNOWN
val1 = semver.Version.parse(first)
val2 = semver.Version.parse(second)
if val1 > val2:
if val1.major > val2.major:
return VersionComparison.NEWER_MAJOR
if val1.minor > val2.minor:
return VersionComparison.NEWER_MINOR
if val1.patch > val2.patch:
return VersionComparison.NEWER_PATCH
return VersionComparison.NEWER
if val1 == val2:
return VersionComparison.EQUAL
if val2.major > val1.major:
return VersionComparison.OLDER_MAJOR
if val2.minor > val1.minor:
return VersionComparison.OLDER_MINOR
if val2.patch > val1.patch:
return VersionComparison.OLDER_PATCH
return VersionComparison.OLDER
[docs]
class Updater:
"""The base updater class which handles the whole process.
Parameters
----------
directory:
Local root directory
version_json:
Location of local versions.json file
repository:
Location of OpRedFlag asset GitHub repository, in User/Repo format
branch:
The branch of the repository to use
include:
Files to update, separated by commas
exclude:
Files to skip, separated by commas
compatibility: :class:`~.Compatibility`
Compatibility level, will only allow updates of this level or lower
strict:
Fail if local file versions are newer than remote
update_timestamp_after:
Automatically update the `last_updated` key in the version file after this amount of days,
even when no changes have been made. Set to 0 for always or -1 for never.
"""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
directory: str = ".",
version_json: str = "oprf-versions.json",
repository: str = "Op-Redflag/OpRedFlag",
branch: str = "master",
include: str = "*",
exclude: str = "",
compatibility: Compatibility = Compatibility.MINOR,
strict: bool = False,
update_timestamp_after: int = 30,
):
"""Initialize the updater."""
# pylint: disable=too-many-arguments,too-many-positional-arguments
self.directory = directory
self.version_json = os.path.join(directory, version_json)
self.repository = repository
self.branch = branch
self.include = include
self.exclude = exclude
self.compatibility = compatibility
self.strict = strict
self.update_timestamp_after = update_timestamp_after
self.session: aiohttp.ClientSession | None = None
self._remote_version_data: dict[str, FileVersion] | None = None
self.data: dict[
Literal["fetched", "skipped", "up-to-date"], list[UpdaterData]
] = {
"fetched": [],
"skipped": [],
"up-to-date": [],
}
self.pending_write_files: dict[str, str] = {}
self.local_version_data = load_version_data_file(self.version_json)
self.original_version_data = copy.deepcopy(self.local_version_data)
def write_files(self) -> None:
""":meta private: Write all scheduled files."""
for filepath, value in self.pending_write_files.items():
with open(filepath, "w", encoding="utf-8") as f_obj:
f_obj.write(value)
@property
def remote_version_data(self) -> dict[str, FileVersion]:
""":meta private: Version data fetched from the remote repository, raises if unset."""
if self._remote_version_data is None:
raise RuntimeError("Version unset")
return self._remote_version_data
@remote_version_data.setter
def remote_version_data(self, value: dict[str, FileVersion]) -> None:
self._remote_version_data = value
def build_remote_url(self, path: str) -> str:
""":meta private: Build a URL to fetch a file from."""
return (
f"https://raw.githubusercontent.com/{self.repository}/{self.branch}/{path}"
)
def save_version_data(self) -> None:
""":meta private: Save the local versions file."""
# If we're actually updating, we need to update the timestamp.
if self.original_version_data != self.local_version_data:
self.local_version_data["last_updated"] = datetime.now(timezone.utc)
# Also update the timestamp if it's been long enough
if self.update_timestamp_after >= 0:
time_since_update = (
datetime.now(timezone.utc) - self.local_version_data["last_updated"]
)
if time_since_update.days >= self.update_timestamp_after:
self.local_version_data["last_updated"] = datetime.now(timezone.utc)
dump_version_data_file(self.local_version_data, self.version_json)
@alru_cache(ttl=30, typed=True)
async def fetch_remote_version_data(self) -> None:
""":meta private: Fetch remote version data."""
if self.session is None:
raise RuntimeError("Session unset")
async with self.session.get(
self.build_remote_url("versions.json"),
) as response:
self.remote_version_data = await response.json(content_type="text/plain")
@alru_cache(ttl=30, typed=True)
async def fetch_file(self, path: str) -> str:
""":meta private: Fetch a file."""
if self.session is None:
raise RuntimeError("Session unset")
async with self.session.get(
self.build_remote_url(path),
) as response:
response.raise_for_status()
return await response.text()
async def update_file(
self, key: str, data: FileVersion, multi_key: bool = False
) -> None:
""":meta private: Update a file by key.
Parameters
----------
key:
The oprf-versions.json key for this file
data:
The local data we have saved for this file
multi_key:
If there are multiple files for this key, specifies path in script output
"""
async def fetch_data() -> None:
self.pending_write_files[os.path.join(self.directory, data["path"])] = (
await self.fetch_file(self.remote_version_data[key]["path"])
)
self.data["fetched"].append(
UpdaterData(
key=key,
path=data["path"],
old_version=data["version"] or "null",
new_version=self.remote_version_data[key]["version"] or "null",
multi_key=multi_key,
)
)
data["version"] = self.remote_version_data[key]["version"]
match compare_versions(
self.remote_version_data[key]["version"], data["version"]
):
case VersionComparison.NEWER_MAJOR:
# Remote is a major version bump ahead of us
if self.compatibility == Compatibility.MAJOR:
await fetch_data()
else:
self.data["skipped"].append(
UpdaterData(
key=key,
path=data["path"],
old_version=data["version"] or "null",
new_version=self.remote_version_data[key]["version"]
or "null",
reason="Major version newer than local",
multi_key=multi_key,
)
)
case VersionComparison.NEWER_MINOR:
# Remote is a minor version bump ahead of us
if self.compatibility in (Compatibility.MAJOR, Compatibility.MINOR):
await fetch_data()
else:
self.data["skipped"].append(
UpdaterData(
key=key,
path=data["path"],
old_version=data["version"] or "null",
new_version=self.remote_version_data[key]["version"]
or "null",
reason="Minor version newer than local",
multi_key=multi_key,
)
)
case VersionComparison.NEWER_PATCH:
# Remote is a patch version bump ahead of us
if self.compatibility in (
Compatibility.MAJOR,
Compatibility.MINOR,
Compatibility.PATCH,
):
await fetch_data()
else:
self.data["skipped"].append(
UpdaterData(
key=key,
path=data["path"],
old_version=data["version"] or "null",
new_version=self.remote_version_data[key]["version"]
or "null",
reason="Patch version newer than local",
multi_key=multi_key,
)
)
case VersionComparison.NEWER | VersionComparison.UNKNOWN:
# Remote is a pre-release ahead of us, or we don't have a saved version yet
await fetch_data()
# pylint: disable=line-too-long
case (
VersionComparison.OLDER_MAJOR
| VersionComparison.OLDER_MINOR
| VersionComparison.OLDER_PATCH
| VersionComparison.OLDER
): # noqa: E501
data_obj = UpdaterData(
key=key,
path=data["path"],
old_version=data["version"] or "null",
new_version=self.remote_version_data[key]["version"] or "null",
reason="Newer than remote",
multi_key=multi_key,
)
if self.strict:
raise RuntimeError(format_data(data_obj))
self.data["skipped"].append(data_obj)
case VersionComparison.EQUAL:
self.data["up-to-date"].append(
UpdaterData(
key=key,
path=data["path"],
old_version=data["version"] or "null",
multi_key=multi_key,
)
)
def get_keys(self) -> list[str]:
""":meta private: Get keys."""
if self.include == "*":
keys_to_check = list(self.local_version_data["data"].keys())
elif self.include == "":
keys_to_check = []
else:
keys_to_check = self.include.split(",")
if self.exclude != "":
for k in self.exclude.split(","):
keys_to_check.remove(k)
return keys_to_check
[docs]
async def run(self) -> list[str]:
"""Run the update.
Returns
-------
list[str]
The script output, split by newlines
"""
self.session = aiohttp.ClientSession()
try:
await self.fetch_remote_version_data()
keys_to_check = self.get_keys()
# TODO: When upgraded to 3.11, use asyncio.TaskGroup for main runner
# https://docs.python.org/3.11/library/asyncio-task.html#task-groups
# Old:
# coros = []
# coros.append()
# await asyncio.gather(coros)
# New:
# async with asyncio.TaskGroup() as tg:
# tg.create_task()
coros = []
for k, val in {
k: self.local_version_data["data"][k] for k in keys_to_check
}.items():
if isinstance(val, list):
for data_part in val:
coros.append(self.update_file(k, data_part, True))
else:
coros.append(self.update_file(k, val))
await asyncio.gather(*coros)
output = []
for key, values in self.data.items():
if key == "fetched":
if len(values) > 0:
print("\n".join(map(format_data, values)))
continue
if len(values) > 0:
output.append(f"{key.title()}:")
for item in values:
output.append(f"\t{format_data(item)}")
self.write_files()
self.save_version_data()
return output
except BaseException: # pylint: disable=try-except-raise
# We don't actually want to catch exceptions, we're just using this for the "finally" block
raise
finally:
await self.session.close()