diff --git a/scripts/snap_pac.py b/scripts/snap_pac.py index ba67beb..54f0cd5 100755 --- a/scripts/snap_pac.py +++ b/scripts/snap_pac.py @@ -135,17 +135,27 @@ def get_snapper_configs(conf_file): return line[1].lstrip("\"").split() -def get_pre_number(snapshot_type, prefile): - if snapshot_type == "pre": - pre_number = None - else: - try: - pre_number = prefile.read_text() - except FileNotFoundError: - raise FileNotFoundError(f"prefile {prefile} not found. Ensure you have run the pre snapshot first.") +class Prefile: + """Handles reading and writing of pre snapshot number.""" + def __init__(self, snapper_config, snapshot_type): + self.file = Path(tempfile.gettempdir()) / f"snap-pac-pre_{snapper_config}" + self.snapshot_type = snapshot_type + + def read(self): + if self.snapshot_type == "pre": + pre_number = None else: - prefile.unlink() - return pre_number + try: + pre_number = self.file.read_text() + except FileNotFoundError: + raise FileNotFoundError(f"prefile {self.file} not found. Ensure you have run the pre snapshot first.") + else: + self.file.unlink() + return pre_number + + def write(self, num): + if self.snapshot_type == "pre": + self.file.write_text(num) def check_skip(): @@ -177,18 +187,15 @@ if __name__ == "__main__": config_processor = ConfigProcessor(snap_pac_ini, snapshot_type) snapper_configs = get_snapper_configs(snapper_conf_file) chroot = os.stat("/") != os.stat("/proc/1/root/.") - tmpdir = Path(tempfile.gettempdir()) for snapper_config in snapper_configs: data = config_processor(snapper_config) if data["snapshot"]: - prefile = tmpdir / f"snap-pac-pre_{snapper_config}" - pre_number = get_pre_number(snapshot_type, prefile) + prefile = Prefile(snapper_config, snapshot_type) + pre_number = prefile.read() snapper_cmd = SnapperCmd(snapper_config, snapshot_type, data["cleanup_algorithm"], data["description"], chroot, pre_number, data["userdata"]) num = snapper_cmd() logging.info(f"==> {snapper_config}: {num}") - - if snapshot_type == "pre": - prefile.write_text(num) + prefile.write(num) diff --git a/tests/test_script.py b/tests/test_script.py index adbe07a..f36dd8e 100644 --- a/tests/test_script.py +++ b/tests/test_script.py @@ -4,15 +4,7 @@ import os import pytest -from scripts.snap_pac import SnapperCmd, ConfigProcessor, check_skip, get_pre_number, get_snapper_configs - - -@pytest.fixture -def prefile(): - with tempfile.NamedTemporaryFile("w", delete=False) as f: - f.write("1234") - name = f.name - return Path(name) +from scripts.snap_pac import check_skip, ConfigProcessor, get_snapper_configs, Prefile, SnapperCmd @pytest.mark.parametrize("snapper_cmd, actual_cmd", [ @@ -108,14 +100,19 @@ def test_config_processor(section, command, packages, snapshot_type, result): assert config_processor(section) == result -def test_get_pre_number_pre(prefile): - assert get_pre_number("pre", prefile) is None +def test_prefile_read_none(): + prefile = Prefile("root", "pre") + assert prefile.read() is None -def test_get_pre_number_post(prefile): - assert get_pre_number("post", prefile) == "1234" +def test_prefile_read(): + prefile = Prefile("root", "pre") + prefile.write("1234") + prefile = Prefile("root", "post") + assert prefile.read() == "1234" def test_no_prefile(): + prefile = Prefile("foo-pre-file-not-found", "post") with pytest.raises(FileNotFoundError): - get_pre_number("post", Path("/tmp/foo-pre-file-not-found")) + prefile.read()