# Authors:
# Endi S. Dewata <edewata@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the Lesser GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# Copyright (C) 2013 Red Hat, Inc.
# All rights reserved.
#
from __future__ import absolute_import
from __future__ import print_function
import functools
import os
import re
import shutil
import traceback
import pki
import pki.util
DEFAULT_VERSION = '10.0.0'
UPGRADE_DIR = pki.SHARE_DIR + '/upgrade'
BACKUP_DIR = pki.LOG_DIR + '/upgrade'
SYSTEM_TRACKER = pki.CONF_DIR + '/pki.version'
verbose = False
@functools.total_ordering
[docs]class Version(object):
def __init__(self, obj):
if isinstance(obj, str):
# parse <version>-<release>
pos = obj.find('-')
if pos > 0:
self.version = obj[0:pos]
elif pos < 0:
self.version = obj
else:
raise Exception('Invalid version number: ' + obj)
# parse <major>.<minor>.<patch>
match = re.match(r'^(\d+)\.(\d+)\.(\d+)$', self.version)
if match is None:
raise Exception('Invalid version number: ' + self.version)
self.major = int(match.group(1))
self.minor = int(match.group(2))
self.patch = int(match.group(3))
elif isinstance(obj, Version):
self.major = obj.major
self.minor = obj.minor
self.patch = obj.patch
else:
raise Exception('Unsupported version type: ' + str(type(obj)))
# release is ignored in comparisons
def __eq__(self, other):
return (self.major == other.major and
self.minor == other.minor and
self.patch == other.patch)
def __lt__(self, other):
if self.major < other.major:
return True
if self.major == other.major and self.minor < other.minor:
return True
if (self.major == other.major and
self.minor == other.minor and
self.patch < other.patch):
return True
return False
# not hashable
__hash__ = None
def __repr__(self):
return self.version
[docs]class PKIUpgradeTracker(object):
def __init__(self, name, filename, delimiter='=',
version_key='PKI_VERSION',
index_key='PKI_UPGRADE_INDEX'):
self.name = name
self.filename = filename
self.version_key = version_key
self.index_key = index_key
# properties must be read and written immediately to avoid
# interfering with scriptlets that update the same file
self.properties = pki.PropertyFile(filename, delimiter)
[docs] def remove(self):
if verbose:
print('Removing ' + self.name + ' tracker.')
self.remove_version()
self.remove_index()
[docs] def set(self, version):
if verbose:
print('Setting ' + self.name + ' tracker to version ' +
str(version) + '.')
self.set_version(version)
self.remove_index()
[docs] def show(self):
print(self.name + ':')
version = self.get_version()
print(' Configuration version: ' + str(version))
index = self.get_index()
if index > 0:
print(' Last completed scriptlet: ' + str(index))
print()
[docs] def get_index(self):
self.properties.read()
index = self.properties.get(self.index_key)
if index:
return int(index)
return 0
[docs] def set_index(self, index):
self.properties.read()
# find index
i = self.properties.index(self.index_key)
if i >= 0:
# if already exists, update index
self.properties.set(self.index_key, str(index))
else:
# find version
i = self.properties.index(self.version_key)
if i >= 0:
# if version exists, add index after version
self.properties.set(self.index_key, str(index), index=i + 1)
else:
# otherwise, add index at the end separated by a blank line
# if last line is not empty, append empty line
length = len(self.properties.lines)
if length > 0 and self.properties.lines[length - 1] != '':
self.properties.insert_line(length, '')
length += 1
# add index
self.properties.set(self.index_key, str(index), index=length)
self.properties.write()
[docs] def remove_index(self):
self.properties.read()
self.properties.remove(self.index_key)
self.properties.write()
[docs] def get_version(self):
self.properties.read()
version = self.properties.get(self.version_key)
if version:
return Version(version)
return Version(DEFAULT_VERSION)
[docs] def set_version(self, version):
self.properties.read()
# find version
i = self.properties.index(self.version_key)
if i >= 0:
# if already exists, update version
self.properties.set(self.version_key, str(version))
else:
# otherwise, add version at the end separated by a blank line
# if last line is not empty, append empty line
length = len(self.properties.lines)
if length > 0 and self.properties.lines[length - 1] != '':
self.properties.insert_line(length, '')
length += 1
# add version
self.properties.set(self.version_key, str(version), index=length)
self.properties.write()
[docs] def remove_version(self):
self.properties.read()
self.properties.remove(self.version_key)
self.properties.write()
@functools.total_ordering
[docs]class PKIUpgradeScriptlet(object):
def __init__(self):
self.version = None
self.index = None
self.last = False
self.message = None
self.upgrader = None
[docs] def get_backup_dir(self):
return BACKUP_DIR + '/' + str(self.version) + '/' + str(self.index)
[docs] def can_upgrade(self):
# A scriptlet can run if the version matches the tracker and
# the index is the next to be executed.
tracker = self.upgrader.get_tracker()
return self.version == tracker.get_version() and \
self.index == tracker.get_index() + 1
[docs] def update_tracker(self):
# Increment the index in the tracker. If it's the last scriptlet
# in this version, update the tracker version.
tracker = self.upgrader.get_tracker()
self.backup(tracker.filename)
if not self.last:
tracker.set_index(self.index)
else:
tracker.remove_index()
tracker.set_version(self.version.next)
[docs] def upgrade_system(self):
# Callback method to upgrade the system.
pass
[docs] def init(self):
backup_dir = self.get_backup_dir()
if os.path.exists(backup_dir):
# remove old backup dir
shutil.rmtree(backup_dir)
# create backup dir
os.makedirs(backup_dir)
[docs] def upgrade(self):
try:
if not self.can_upgrade():
if verbose:
print('Skipping system.')
return
if verbose:
print('Upgrading system.')
self.upgrade_system()
self.update_tracker()
except Exception as e:
if verbose:
traceback.print_exc()
else:
print('ERROR: %s' % e)
message = 'Failed upgrading system.'
if self.upgrader.silent:
print(message)
else:
result = pki.read_text(
message + ' Continue (Yes/No)',
options=['Y', 'N'], default='Y', delimiter='?',
case_sensitive=False).lower()
if result == 'y':
return
raise pki.PKIException('Upgrade failed: %s' % e, e)
[docs] def revert(self):
backup_dir = self.get_backup_dir()
if not os.path.exists(backup_dir):
return
oldfiles = backup_dir + '/oldfiles'
if os.path.exists(oldfiles):
# restore all backed up files
for sourcepath, _, filenames in os.walk(oldfiles):
# unused item _ for dirnames
destpath = sourcepath[len(oldfiles):]
if destpath == '':
destpath = '/'
if not os.path.isdir(destpath):
if verbose:
print('Restoring ' + destpath)
pki.util.copydirs(sourcepath, destpath)
for filename in filenames:
sourcefile = os.path.join(sourcepath, filename)
targetfile = os.path.join(destpath, filename)
if verbose:
print('Restoring ' + targetfile)
pki.util.copyfile(sourcefile, targetfile)
newfiles = backup_dir + '/newfiles'
if os.path.exists(newfiles):
# get paths that did not exist before upgrade
paths = []
with open(newfiles, 'r') as f:
for path in f:
path = path.strip('\n')
paths.append(path)
# remove paths in reverse order
paths.reverse()
for path in paths:
if not os.path.exists(path):
continue
if verbose:
print('Deleting ' + path)
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)
[docs] def backup(self, path):
backup_dir = self.get_backup_dir()
if not os.path.exists(backup_dir):
os.makedirs(backup_dir)
if os.path.exists(path):
# if path exists, keep a copy
oldfiles = backup_dir + '/oldfiles'
if not os.path.exists(oldfiles):
os.mkdir(oldfiles)
dest = oldfiles + path
sourceparent = os.path.dirname(path)
destparent = os.path.dirname(dest)
pki.util.copydirs(sourceparent, destparent)
if os.path.isfile(path):
if verbose:
print('Saving ' + path)
# do not overwrite initial backup
pki.util.copyfile(path, dest, overwrite=False)
else:
for sourcepath, _, filenames in os.walk(path):
relpath = sourcepath[len(path):]
destpath = dest + relpath
if verbose:
print('Saving ' + sourcepath)
pki.util.copydirs(sourcepath, destpath)
for filename in filenames:
sourcefile = os.path.join(sourcepath, filename)
targetfile = os.path.join(destpath, filename)
if verbose:
print('Saving ' + sourcefile)
# do not overwrite initial backup
pki.util.copyfile(sourcefile, targetfile,
overwrite=False)
else:
# otherwise, record the name
if verbose:
print('Recording ' + path)
with open(backup_dir + '/newfiles', 'a') as f:
f.write(path + '\n')
def __eq__(self, other):
return self.version == other.version and self.index == other.index
def __lt__(self, other):
if self.version < other.version:
return True
return self.version == other.version and self.index < other.index
# not hashable
__hash__ = None
[docs]class PKIUpgrader(object):
def __init__(self, upgrade_dir=UPGRADE_DIR, version=None, index=None,
silent=False):
self.upgrade_dir = upgrade_dir
self.version = version
self.index = index
self.silent = silent
if version and not os.path.exists(self.version_dir(version)):
raise pki.PKIException(
'Invalid scriptlet version: ' + str(version))
self.system_tracker = None
[docs] def version_dir(self, version):
return os.path.join(self.upgrade_dir, str(version))
[docs] def all_versions(self):
all_versions = []
if os.path.exists(self.upgrade_dir):
for version in os.listdir(self.upgrade_dir):
version = Version(version)
all_versions.append(version)
all_versions.sort()
return all_versions
[docs] def versions(self):
current_version = self.get_current_version()
target_version = self.get_target_version()
current_versions = []
for version in self.all_versions():
# skip old versions
if version >= current_version:
current_versions.append(version)
current_versions.sort()
versions = []
for index, version in enumerate(current_versions):
# link versions
if index < len(current_versions) - 1:
version.next = current_versions[index + 1]
else:
version.next = target_version
# if no scriptlet version is specified, add all versions to the list
# if scriptlet version is specified, add only that version to the
# list
if not self.version or str(version) == self.version:
versions.append(version)
return versions
[docs] def scriptlets(self, version):
scriptlets = []
version_dir = self.version_dir(version)
if not os.path.exists(version_dir):
return scriptlets
filenames = os.listdir(version_dir)
for filename in filenames:
# parse <index>-<classname>
try:
i = filename.index('-')
except ValueError as e:
raise pki.PKIException(
'Invalid scriptlet name: ' + filename,
e)
index = int(filename[0:i])
classname = filename[i + 1:]
if self.index and index != self.index:
continue
# load scriptlet class
variables = {}
absname = os.path.join(version_dir, filename)
with open(absname, 'r') as f:
bytecode = compile(f.read(), absname, 'exec')
exec(bytecode, variables) # pylint: disable=W0122
# create scriptlet object
scriptlet = variables[classname]()
scriptlet.upgrader = self
scriptlet.version = version
scriptlet.index = index
scriptlet.last = index == len(filenames)
scriptlets.append(scriptlet)
# sort scriptlets based on index
scriptlets.sort()
return scriptlets
[docs] def get_tracker(self):
if self.system_tracker:
tracker = self.system_tracker
else:
tracker = PKIUpgradeTracker(
'system',
SYSTEM_TRACKER,
delimiter=': ',
version_key='Configuration-Version',
index_key='Scriptlet-Index')
self.system_tracker = tracker
return tracker
[docs] def get_current_version(self):
tracker = self.get_tracker()
return tracker.get_version()
[docs] def get_target_version(self):
return Version(pki.implementation_version())
[docs] def is_complete(self):
current_version = self.get_current_version()
target_version = self.get_target_version()
return current_version == target_version
[docs] def upgrade_version(self, version):
print('Upgrading from version ' + str(version) + ' to ' +
str(version.next) + ':')
scriptlets = self.scriptlets(version)
if len(scriptlets) == 0:
print('No upgrade scriptlets.')
self.set_tracker(version.next)
return
# execute scriptlets
for scriptlet in scriptlets:
message = str(scriptlet.index) + '. ' + scriptlet.message
if self.silent:
print(message)
else:
result = pki.read_text(
message + ' (Yes/No)',
options=['Y', 'N'],
default='Y',
case_sensitive=False).lower()
if result == 'n':
raise pki.PKIException('Upgrade canceled.')
try:
scriptlet.init()
scriptlet.upgrade()
except pki.PKIException:
raise
except Exception as e: # pylint: disable=W0703
print()
message = 'Upgrade failed: %s' % e
if verbose:
traceback.print_exc()
else:
print(e)
print()
result = pki.read_text(
'Continue (Yes/No)',
options=['Y', 'N'],
default='Y',
delimiter='?',
case_sensitive=False).lower()
if result == 'n':
raise pki.PKIException(message, e)
[docs] def upgrade(self):
versions = self.versions()
for version in versions:
self.upgrade_version(version)
print()
if self.is_complete():
print('Upgrade complete.')
else:
self.show_tracker()
print('Upgrade incomplete.')
[docs] def revert_version(self, version):
print('Reverting to version ' + str(version) + ':')
scriptlets = self.scriptlets(version)
scriptlets.reverse()
for scriptlet in scriptlets:
message = str(scriptlet.index) + '. ' + scriptlet.message
if self.silent:
print(message)
else:
result = pki.read_text(
message + ' (Yes/No)',
options=['Y', 'N'], default='Y',
case_sensitive=False).lower()
if result == 'n':
raise pki.PKIException('Revert canceled.')
try:
scriptlet.revert()
except pki.PKIException:
raise
except Exception as e: # pylint: disable=W0703
print()
message = 'Revert failed: %s' % e
if verbose:
traceback.print_exc()
else:
print(e)
print()
result = pki.read_text(
'Continue (Yes/No)', options=['Y', 'N'],
default='Y', delimiter='?', case_sensitive=False).lower()
if result == 'n':
raise pki.PKIException(message, e)
self.set_tracker(version)
[docs] def revert(self):
current_version = self.get_current_version()
versions = self.all_versions()
versions.reverse()
# find the first version smaller than the current version
for version in versions:
if version >= current_version:
continue
self.revert_version(version)
return
print('Unable to revert from version ' + str(current_version) + '.')
[docs] def show_tracker(self):
tracker = self.get_tracker()
tracker.show()
[docs] def status(self):
self.show_tracker()
if self.is_complete():
print('Upgrade complete.')
else:
print('Upgrade incomplete.')
[docs] def set_tracker(self, version):
tracker = self.get_tracker()
tracker.set(version)
print('Tracker has been set to version ' + str(version) + '.')
[docs] def reset_tracker(self):
target_version = self.get_target_version()
self.set_tracker(target_version)
[docs] def remove_tracker(self):
tracker = self.get_tracker()
tracker.remove()
print('Tracker has been removed.')