diff --git a/drydock_provisioner/cli/design/commands.py b/drydock_provisioner/cli/design/commands.py index 8dca5b35..baedba1e 100644 --- a/drydock_provisioner/cli/design/commands.py +++ b/drydock_provisioner/cli/design/commands.py @@ -16,6 +16,7 @@ Contains commands related to designs """ import click +import json from drydock_provisioner.cli.design.actions import DesignList from drydock_provisioner.cli.design.actions import DesignShow @@ -36,14 +37,15 @@ def design(): @click.pass_context def design_create(ctx, base_design=None): """Create a design.""" - click.echo(DesignCreate(ctx.obj['CLIENT'], base_design).invoke()) + click.echo( + json.dumps(DesignCreate(ctx.obj['CLIENT'], base_design).invoke())) @design.command(name='list') @click.pass_context def design_list(ctx): """List designs.""" - click.echo(DesignList(ctx.obj['CLIENT']).invoke()) + click.echo(json.dumps(DesignList(ctx.obj['CLIENT']).invoke())) @design.command(name='show') @@ -54,4 +56,4 @@ def design_show(ctx, design_id): if not design_id: ctx.fail('The design id must be specified by --design-id') - click.echo(DesignShow(ctx.obj['CLIENT'], design_id).invoke()) + click.echo(json.dumps(DesignShow(ctx.obj['CLIENT'], design_id).invoke())) diff --git a/drydock_provisioner/cli/part/commands.py b/drydock_provisioner/cli/part/commands.py index 215bce74..2a90a3a4 100644 --- a/drydock_provisioner/cli/part/commands.py +++ b/drydock_provisioner/cli/part/commands.py @@ -17,6 +17,7 @@ Contains commands related to parts of designs. """ import click +import json from drydock_provisioner.cli.part.actions import PartList from drydock_provisioner.cli.part.actions import PartShow @@ -50,10 +51,11 @@ def part_create(ctx, file=None): file_contents = file_input.read() # here is where some potential validation could be done on the input file click.echo( - PartCreate( - ctx.obj['CLIENT'], - design_id=ctx.obj['DESIGN_ID'], - in_file=file_contents).invoke()) + json.dumps( + PartCreate( + ctx.obj['CLIENT'], + design_id=ctx.obj['DESIGN_ID'], + in_file=file_contents).invoke())) @part.command(name='list') @@ -61,7 +63,9 @@ def part_create(ctx, file=None): def part_list(ctx): """List parts of a design.""" click.echo( - PartList(ctx.obj['CLIENT'], design_id=ctx.obj['DESIGN_ID']).invoke()) + json.dumps( + PartList(ctx.obj['CLIENT'], design_id=ctx.obj['DESIGN_ID']) + .invoke())) @part.command(name='show') @@ -78,9 +82,10 @@ def part_show(ctx, source, kind, key): ctx.fail('The key must be specified by --key') click.echo( - PartShow( - ctx.obj['CLIENT'], - design_id=ctx.obj['DESIGN_ID'], - kind=kind, - key=key, - source=source).invoke()) + json.dumps( + PartShow( + ctx.obj['CLIENT'], + design_id=ctx.obj['DESIGN_ID'], + kind=kind, + key=key, + source=source).invoke())) diff --git a/drydock_provisioner/cli/task/commands.py b/drydock_provisioner/cli/task/commands.py index 7fc96df6..dfb5e76c 100644 --- a/drydock_provisioner/cli/task/commands.py +++ b/drydock_provisioner/cli/task/commands.py @@ -15,6 +15,7 @@ Contains commands related to tasks against designs """ import click +import json from drydock_provisioner.cli.task.actions import TaskList from drydock_provisioner.cli.task.actions import TaskShow @@ -58,16 +59,17 @@ def task_create(ctx, ctx.fail('Error: Action must be specified using --action') click.echo( - TaskCreate( - ctx.obj['CLIENT'], - design_id=design_id, - action_name=action, - node_names=[x.strip() for x in node_names.split(',')] - if node_names else [], - rack_names=[x.strip() for x in rack_names.split(',')] - if rack_names else [], - node_tags=[x.strip() for x in node_tags.split(',')] - if node_tags else []).invoke()) + json.dumps( + TaskCreate( + ctx.obj['CLIENT'], + design_id=design_id, + action_name=action, + node_names=[x.strip() for x in node_names.split(',')] + if node_names else [], + rack_names=[x.strip() for x in rack_names.split(',')] + if rack_names else [], + node_tags=[x.strip() for x in node_tags.split(',')] + if node_tags else []).invoke())) @task.command(name='list') @@ -75,7 +77,7 @@ def task_create(ctx, def task_list(ctx): """ List tasks. """ - click.echo(TaskList(ctx.obj['CLIENT']).invoke()) + click.echo(json.dumps(TaskList(ctx.obj['CLIENT']).invoke())) @task.command(name='show') @@ -87,4 +89,5 @@ def task_show(ctx, task_id=None): if not task_id: ctx.fail('The task id must be specified by --task-id') - click.echo(TaskShow(ctx.obj['CLIENT'], task_id=task_id).invoke()) + click.echo( + json.dumps(TaskShow(ctx.obj['CLIENT'], task_id=task_id).invoke())) diff --git a/drydock_provisioner/config.py b/drydock_provisioner/config.py index 4b0a0ba9..d103de86 100644 --- a/drydock_provisioner/config.py +++ b/drydock_provisioner/config.py @@ -28,7 +28,6 @@ package. It is assumed that: * This module is only used in the context of sample file generation. """ -import collections import importlib import os import pkgutil @@ -41,9 +40,8 @@ IGNORED_MODULES = ('drydock', 'config') class DrydockConfig(object): - """ - Initialize all the core options - """ + """Initialize all the core options.""" + # Default options options = [ cfg.IntOpt( @@ -52,6 +50,12 @@ class DrydockConfig(object): help= 'Polling interval in seconds for checking subtask or downstream status' ), + cfg.IntOpt( + 'leader_grace_period', + default=300, + help= + 'How long a leader has to check-in before leaderhsip can be usurped, in seconds' + ), ] # Logging options @@ -76,6 +80,13 @@ class DrydockConfig(object): help='Logger name for API server logging'), ] + # Database options + database_options = [ + cfg.StrOpt( + 'database_connect_string', + help='The URI database connect string.'), + ] + # Enabled plugins plugin_options = [ cfg.MultiStrOpt( @@ -93,7 +104,7 @@ class DrydockConfig(object): default= 'drydock_provisioner.drivers.node.maasdriver.driver.MaasNodeDriver', help='Module path string of the Node driver to enable'), - # TODO Network driver not yet implemented + # TODO(sh8121att) Network driver not yet implemented cfg.StrOpt( 'network_driver', default=None, @@ -149,6 +160,8 @@ class DrydockConfig(object): self.conf.register_opts(DrydockConfig.options) self.conf.register_opts(DrydockConfig.logging_options, group='logging') self.conf.register_opts(DrydockConfig.plugin_options, group='plugins') + self.conf.register_opts( + DrydockConfig.database_options, group='database') self.conf.register_opts( DrydockConfig.timeout_options, group='timeouts') self.conf.register_opts( @@ -164,7 +177,8 @@ def list_opts(): 'DEFAULT': DrydockConfig.options, 'logging': DrydockConfig.logging_options, 'plugins': DrydockConfig.plugin_options, - 'timeouts': DrydockConfig.timeout_options + 'timeouts': DrydockConfig.timeout_options, + 'database': DrydockConfig.database_options, } package_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/drydock_provisioner/control/base.py b/drydock_provisioner/control/base.py index 678569e8..dc84d635 100644 --- a/drydock_provisioner/control/base.py +++ b/drydock_provisioner/control/base.py @@ -125,6 +125,37 @@ class DrydockRequestContext(object): self.external_marker = '' self.policy_engine = None + @classmethod + def from_dict(cls, d): + """Instantiate a context from a dictionary of values. + + This is only used to deserialize persisted instances, so we + will trust the dictionary keys are exactly the correct fields + + :param d: Dictionary of instance values + """ + i = DrydockRequestContext() + + for k, v in d.items(): + setattr(i, k, v) + + return i + + def to_dict(self): + return { + 'log_level': self.log_level, + 'user': self.user, + 'user_id': self.user_id, + 'user_domain_id': self.user_domain_id, + 'roles': self.roles, + 'project_id': self.project_id, + 'project_domain_id': self.project_domain_id, + 'is_admin_project': self.is_admin_project, + 'authenticated': self.authenticated, + 'request_id': self.request_id, + 'external_marker': self.external_marker, + } + def set_log_level(self, level): if level in ['error', 'info', 'debug']: self.log_level = level diff --git a/drydock_provisioner/drivers/node/maasdriver/driver.py b/drydock_provisioner/drivers/node/maasdriver/driver.py index 0371b5ff..bb1af0a0 100644 --- a/drydock_provisioner/drivers/node/maasdriver/driver.py +++ b/drydock_provisioner/drivers/node/maasdriver/driver.py @@ -1955,8 +1955,8 @@ class MaasTaskRunner(drivers.DriverTaskRunner): 'mount_options': p.mount_options, } self.logger.debug( - "Mounting partition %s on %s" % (p.name, - p.mountpoint)) + "Mounting partition %s on %s" % + (p.name, p.mountpoint)) part.mount(**mount_opts) self.logger.debug( diff --git a/drydock_provisioner/error.py b/drydock_provisioner/error.py index dfd811a3..a534f414 100644 --- a/drydock_provisioner/error.py +++ b/drydock_provisioner/error.py @@ -22,6 +22,10 @@ class StateError(Exception): pass +class TaskNotFoundError(StateError): + pass + + class OrchestratorError(Exception): pass diff --git a/drydock_provisioner/objects/__init__.py b/drydock_provisioner/objects/__init__.py index f95a7324..dc9ad1d1 100644 --- a/drydock_provisioner/objects/__init__.py +++ b/drydock_provisioner/objects/__init__.py @@ -31,6 +31,7 @@ def register_all(): importlib.import_module('drydock_provisioner.objects.site') importlib.import_module('drydock_provisioner.objects.promenade') importlib.import_module('drydock_provisioner.objects.rack') + importlib.import_module('drydock_provisioner.objects.task') # Utility class for calculating inheritance diff --git a/drydock_provisioner/objects/hostprofile.py b/drydock_provisioner/objects/hostprofile.py index 19956a8c..8a1e52b1 100644 --- a/drydock_provisioner/objects/hostprofile.py +++ b/drydock_provisioner/objects/hostprofile.py @@ -252,30 +252,14 @@ class HostInterface(base.DrydockObject): getattr(j, 'network_link', None), getattr(i, 'network_link', None)) - s = [ - x for x in getattr(i, 'hardware_slaves', []) - if ("!" + x - ) not in getattr(j, 'hardware_slaves', []) - ] + m.hardware_slaves = objects.Utils.merge_lists( + getattr(j, 'hardware_slaves', []), + getattr(i, 'hardware_slaves', [])) - s.extend([ - x for x in getattr(j, 'hardware_slaves', []) - if not x.startswith("!") - ]) + m.networks = objects.Utils.merge_lists( + getattr(j, 'networks', []), + getattr(i, 'networks', [])) - m.hardware_slaves = s - - n = [ - x for x in getattr(i, 'networks', []) - if ("!" + x) not in getattr(j, 'networks', []) - ] - - n.extend([ - x for x in getattr(j, 'networks', []) - if not x.startswith("!") - ]) - - m.networks = n m.source = hd_fields.ModelSource.Compiled effective_list.append(m) @@ -332,7 +316,7 @@ class HostVolumeGroup(base.DrydockObject): self.physical_devices.append(pv) def is_sys(self): - """Is this the VG for root and/or boot?""" + """Check if this is the VG for root and/or boot.""" for lv in getattr(self, 'logical_volumes', []): if lv.is_sys(): return True @@ -577,7 +561,7 @@ class HostPartition(base.DrydockObject): return self.name def is_sys(self): - """Is this partition for root and/or boot?""" + """Check if this is the partition for root and/or boot.""" if self.mountpoint is not None and self.mountpoint in ['/', '/boot']: return True return False @@ -707,7 +691,7 @@ class HostVolume(base.DrydockObject): return self.name def is_sys(self): - """Is this LV for root and/or boot?""" + """Check if this is the LV for root and/or boot.""" if self.mountpoint is not None and self.mountpoint in ['/', '/boot']: return True return False diff --git a/drydock_provisioner/objects/task.py b/drydock_provisioner/objects/task.py index 769fe540..deeedb3a 100644 --- a/drydock_provisioner/objects/task.py +++ b/drydock_provisioner/objects/task.py @@ -14,12 +14,17 @@ """Models for representing asynchronous tasks.""" import uuid -import datetime +import json + +from datetime import datetime + +from drydock_provisioner import objects import drydock_provisioner.error as errors - import drydock_provisioner.objects.fields as hd_fields +from drydock_provisioner.control.base import DrydockRequestContext + class Task(object): """Asynchronous Task. @@ -30,29 +35,42 @@ class Task(object): :param parent_task_id: Optional UUID4 ID of the parent task to this task :param node_filter: Optional instance of TaskNodeFilter limiting the set of nodes this task will impact + :param context: instance of DrydockRequestContext representing the request context the + task is executing under + :param statemgr: instance of AppState used to access the database for state management """ - def __init__(self, **kwargs): - context = kwargs.get('context', None) + def __init__(self, + action=None, + design_ref=None, + parent_task_id=None, + node_filter=None, + context=None, + statemgr=None): + self.statemgr = statemgr self.task_id = uuid.uuid4() self.status = hd_fields.TaskStatus.Requested self.subtask_id_list = [] self.result = TaskStatus() - self.action = kwargs.get('action', hd_fields.OrchestratorAction.Noop) - self.design_ref = kwargs.get('design_ref', None) - self.parent_task_id = kwargs.get('parent_task_id', None) + self.action = action or hd_fields.OrchestratorAction.Noop + self.design_ref = design_ref + self.parent_task_id = parent_task_id self.created = datetime.utcnow() - self.node_filter = kwargs.get('node_filter', None) + self.node_filter = node_filter self.created_by = None self.updated = None self.terminated = None self.terminated_by = None - self.context = context + self.request_context = context if context is not None: self.created_by = context.user + @classmethod + def obj_name(cls): + return cls.__name__ + def get_id(self): return self.task_id @@ -68,29 +86,107 @@ class Task(object): def get_result(self): return self.result - def add_result_message(self, **kwargs): - """Add a message to result details.""" - self.result.add_message(**kwargs) + def success(self): + """Encounter a result that causes at least partial success.""" + if self.result.status in [hd_fields.TaskResult.Failure, + hd_fields.TaskResult.PartialSuccess]: + self.result.status = hd_fields.TaskResult.PartialSuccess + else: + self.result.status = hd_fields.TaskResult.Success - def register_subtask(self, subtask_id): + def failure(self): + """Encounter a result that causes at least partial failure.""" + if self.result.status in [hd_fields.TaskResult.Success, + hd_fields.TaskResult.PartialSuccess]: + self.result.status = hd_fields.TaskResult.PartialSuccess + else: + self.result.status = hd_fields.TaskResult.Failure + + def register_subtask(self, subtask): + """Register a task as a subtask to this task. + + :param subtask: objects.Task instance + """ if self.status in [hd_fields.TaskStatus.Terminating]: raise errors.OrchestratorError("Cannot add subtask for parent" " marked for termination") - self.subtask_id_list.append(subtask_id) + if self.statemgr.add_subtask(self.task_id, subtask.task_id): + self.subtask_id_list.append(subtask.task_id) + subtask.parent_task_id = self.task_id + subtask.save() + else: + raise errors.OrchestratorError("Error adding subtask.") + + def save(self): + """Save this task's current state to the database.""" + if not self.statemgr.put_task(self): + raise errors.OrchestratorError("Error saving task.") def get_subtasks(self): return self.subtask_id_list def add_status_msg(self, **kwargs): - self.result.add_status_msg(**kwargs) + msg = self.result.add_status_msg(**kwargs) + self.statemgr.post_result_message(self.task_id, msg) + + def to_db(self, include_id=True): + """Convert this instance to a dictionary for use persisting to a db. + + include_id=False can be used for doing an update where the primary key + of the table shouldn't included in the values set + + :param include_id: Whether to include task_id in the dictionary + """ + _dict = { + 'parent_task_id': + self.parent_task_id.bytes + if self.parent_task_id is not None else None, + 'subtask_id_list': [x.bytes for x in self.subtask_id_list], + 'result_status': + self.result.status, + 'result_message': + self.result.message, + 'result_reason': + self.result.reason, + 'result_error_count': + self.result.error_count, + 'status': + self.status, + 'created': + self.created, + 'created_by': + self.created_by, + 'updated': + self.updated, + 'design_ref': + self.design_ref, + 'request_context': + json.dumps(self.request_context.to_dict()) + if self.request_context is not None else None, + 'action': + self.action, + 'terminated': + self.terminated, + 'terminated_by': + self.terminated_by, + } + + if include_id: + _dict['task_id'] = self.task_id.bytes + + return _dict def to_dict(self): + """Convert this instance to a dictionary. + + Intended for use in JSON serialization + """ return { 'Kind': 'Task', 'apiVersion': 'v1', 'task_id': str(self.task_id), 'action': self.action, - 'parent_task': str(self.parent_task_id), + 'parent_task_id': str(self.parent_task_id), 'design_ref': self.design_ref, 'status': self.status, 'result': self.result.to_dict(), @@ -103,20 +199,54 @@ class Task(object): 'terminated_by': self.terminated_by, } + @classmethod + def from_db(cls, d): + """Create an instance from a DB-based dictionary. + + :param d: Dictionary of instance data + """ + i = Task() + + i.task_id = uuid.UUID(bytes=d.get('task_id')) + + if d.get('parent_task_id', None) is not None: + i.parent_task_id = uuid.UUID(bytes=d.get('parent_task_id')) + + if d.get('subtask_id_list', None) is not None: + for t in d.get('subtask_id_list'): + i.subtask_id_list.append(uuid.UUID(bytes=t)) + + simple_fields = [ + 'status', 'created', 'created_by', 'design_ref', 'action', + 'terminated', 'terminated_by' + ] + + for f in simple_fields: + setattr(i, f, d.get(f, None)) + + # Deserialize the request context for this task + if i.request_context is not None: + i.request_context = DrydockRequestContext.from_dict( + i.request_context) + + return i + class TaskStatus(object): """Status/Result of this task's execution.""" def __init__(self): - self.details = { - 'errorCount': 0, - 'messageList': [] - } + self.error_count = 0 + self.message_list = [] self.message = None self.reason = None self.status = hd_fields.ActionResult.Incomplete + @classmethod + def obj_name(cls): + return cls.__name__ + def set_message(self, msg): self.message = msg @@ -126,16 +256,24 @@ class TaskStatus(object): def set_status(self, status): self.status = status - def add_status_msg(self, msg=None, error=None, ctx_type=None, ctx=None, **kwargs): + def add_status_msg(self, + msg=None, + error=None, + ctx_type=None, + ctx=None, + **kwargs): if msg is None or error is None or ctx_type is None or ctx is None: - raise ValueError('Status message requires fields: msg, error, ctx_type, ctx') + raise ValueError( + 'Status message requires fields: msg, error, ctx_type, ctx') new_msg = TaskStatusMessage(msg, error, ctx_type, ctx, **kwargs) - self.details.messageList.append(new_msg) + self.message_list.append(new_msg) if error: - self.details.errorCount = self.details.errorCount + 1 + self.error_count = self.error_count + 1 + + return new_msg def to_dict(self): return { @@ -146,8 +284,8 @@ class TaskStatus(object): 'reason': self.reason, 'status': self.status, 'details': { - 'errorCount': self.details.errorCount, - 'messageList': [x.to_dict() for x in self.details.messageList], + 'errorCount': self.error_count, + 'messageList': [x.to_dict() for x in self.message_list], } } @@ -163,6 +301,10 @@ class TaskStatusMessage(object): self.ts = datetime.utcnow() self.extra = kwargs + @classmethod + def obj_name(cls): + return cls.__name__ + def to_dict(self): _dict = { 'message': self.message, @@ -175,3 +317,36 @@ class TaskStatusMessage(object): _dict.update(self.extra) return _dict + + def to_db(self): + """Convert this instance to a dictionary appropriate for the DB.""" + return { + 'message': self.message, + 'error': self.error, + 'context': self.ctx, + 'context_type': self.ctx_type, + 'ts': self.ts, + 'extra': json.dumps(self.extra), + } + + @classmethod + def from_db(cls, d): + """Create instance from DB-based dictionary. + + :param d: dictionary of values + """ + i = TaskStatusMessage( + d.get('message', None), + d.get('error'), + d.get('context_type'), d.get('context')) + if 'extra' in d: + i.extra = d.get('extra') + i.ts = d.get('ts', None) + + return i + + +# Emulate OVO object registration +setattr(objects, Task.obj_name(), Task) +setattr(objects, TaskStatus.obj_name(), TaskStatus) +setattr(objects, TaskStatusMessage.obj_name(), TaskStatusMessage) diff --git a/drydock_provisioner/statemgmt/__init__.py b/drydock_provisioner/statemgmt/__init__.py index b810a72a..a96a99ac 100644 --- a/drydock_provisioner/statemgmt/__init__.py +++ b/drydock_provisioner/statemgmt/__init__.py @@ -11,239 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from datetime import datetime -from datetime import timezone -from threading import Lock +"""Module for managing external data access. -import uuid - -import drydock_provisioner.objects as objects -import drydock_provisioner.objects.task as tasks - -from drydock_provisioner.error import DesignError, StateError - - -class DesignState(object): - def __init__(self): - self.designs = {} - self.designs_lock = Lock() - - self.promenade = {} - self.promenade_lock = Lock() - - self.builds = [] - self.builds_lock = Lock() - - self.tasks = [] - self.tasks_lock = Lock() - - self.bootdata = {} - self.bootdata_lock = Lock() - - return - - # TODO Need to lock a design base or change once implementation - # has started - def get_design(self, design_id): - if design_id not in self.designs.keys(): - - raise DesignError("Design ID %s not found" % (design_id)) - - return objects.SiteDesign.obj_from_primitive(self.designs[design_id]) - - def post_design(self, site_design): - if site_design is not None: - my_lock = self.designs_lock.acquire(blocking=True, timeout=10) - if my_lock: - design_id = site_design.id - if design_id not in self.designs.keys(): - self.designs[design_id] = site_design.obj_to_primitive() - else: - self.designs_lock.release() - raise StateError("Design ID %s already exists" % design_id) - self.designs_lock.release() - return True - raise StateError("Could not acquire lock") - else: - raise DesignError("Design change must be a SiteDesign instance") - - def put_design(self, site_design): - if site_design is not None: - my_lock = self.designs_lock.acquire(blocking=True, timeout=10) - if my_lock: - design_id = site_design.id - if design_id not in self.designs.keys(): - self.designs_lock.release() - raise StateError("Design ID %s does not exist" % design_id) - else: - self.designs[design_id] = site_design.obj_to_primitive() - self.designs_lock.release() - return True - raise StateError("Could not acquire lock") - else: - raise DesignError("Design base must be a SiteDesign instance") - - def get_current_build(self): - latest_stamp = 0 - current_build = None - - for b in self.builds: - if b.build_id > latest_stamp: - latest_stamp = b.build_id - current_build = b - - return deepcopy(current_build) - - def get_build(self, build_id): - for b in self.builds: - if b.build_id == build_id: - return b - - return None - - def post_build(self, site_build): - if site_build is not None and isinstance(site_build, SiteBuild): - my_lock = self.builds_lock.acquire(block=True, timeout=10) - if my_lock: - exists = [ - b for b in self.builds if b.build_id == site_build.build_id - ] - - if len(exists) > 0: - self.builds_lock.release() - raise DesignError("Already a site build with ID %s" % - (str(site_build.build_id))) - self.builds.append(deepcopy(site_build)) - self.builds_lock.release() - return True - raise StateError("Could not acquire lock") - else: - raise DesignError("Design change must be a SiteDesign instance") - - def put_build(self, site_build): - if site_build is not None and isinstance(site_build, SiteBuild): - my_lock = self.builds_lock.acquire(block=True, timeout=10) - if my_lock: - buildid = site_build.buildid - for b in self.builds: - if b.buildid == buildid: - b.merge_updates(site_build) - self.builds_lock.release() - return True - self.builds_lock.release() - return False - raise StateError("Could not acquire lock") - else: - raise DesignError("Design change must be a SiteDesign instance") - - def get_task(self, task_id): - for t in self.tasks: - if t.get_id() == task_id or str(t.get_id()) == task_id: - return deepcopy(t) - return None - - def post_task(self, task): - if task is not None and isinstance(task, tasks.Task): - my_lock = self.tasks_lock.acquire(blocking=True, timeout=10) - if my_lock: - task_id = task.get_id() - matching_tasks = [ - t for t in self.tasks if t.get_id() == task_id - ] - if len(matching_tasks) > 0: - self.tasks_lock.release() - raise StateError("Task %s already created" % task_id) - - self.tasks.append(deepcopy(task)) - self.tasks_lock.release() - return True - else: - raise StateError("Could not acquire lock") - else: - raise StateError("Task is not the correct type") - - def put_task(self, task, lock_id=None): - if task is not None and isinstance(task, tasks.Task): - my_lock = self.tasks_lock.acquire(blocking=True, timeout=10) - if my_lock: - task_id = task.get_id() - t = self.get_task(task_id) - if t.lock_id is not None and t.lock_id != lock_id: - self.tasks_lock.release() - raise StateError("Task locked for updates") - - task.lock_id = lock_id - self.tasks = [ - i if i.get_id() != task_id else deepcopy(task) - for i in self.tasks - ] - - self.tasks_lock.release() - return True - else: - raise StateError("Could not acquire lock") - else: - raise StateError("Task is not the correct type") - - def lock_task(self, task_id): - my_lock = self.tasks_lock.acquire(blocking=True, timeout=10) - if my_lock: - lock_id = uuid.uuid4() - for t in self.tasks: - if t.get_id() == task_id and t.lock_id is None: - t.lock_id = lock_id - self.tasks_lock.release() - return lock_id - self.tasks_lock.release() - return None - else: - raise StateError("Could not acquire lock") - - def unlock_task(self, task_id, lock_id): - my_lock = self.tasks_lock.acquire(blocking=True, timeout=10) - if my_lock: - for t in self.tasks: - if t.get_id() == task_id and t.lock_id == lock_id: - t.lock_id = None - self.tasks_lock.release() - return True - self.tasks_lock.release() - return False - else: - raise StateError("Could not acquire lock") - - def post_promenade_part(self, part): - my_lock = self.promenade_lock.acquire(blocking=True, timeout=10) - if my_lock: - if self.promenade.get(part.target, None) is not None: - self.promenade[part.target].append(part.obj_to_primitive()) - else: - self.promenade[part.target] = [part.obj_to_primitive()] - self.promenade_lock.release() - return None - else: - raise StateError("Could not acquire lock") - - def get_promenade_parts(self, target): - parts = self.promenade.get(target, None) - - if parts is not None: - return [ - objects.PromenadeConfig.obj_from_primitive(p) for p in parts - ] - else: - # Return an empty list just to play nice with extend - return [] - - def set_bootdata_key(self, hostname, design_id, data_key): - my_lock = self.bootdata_lock.acquire(blocking=True, timeout=10) - if my_lock: - self.bootdata[hostname] = {'design_id': design_id, 'key': data_key} - self.bootdata_lock.release() - return None - else: - raise StateError("Could not acquire lock") - - def get_bootdata_key(self, hostname): - return self.bootdata.get(hostname, None) +Includes database access for persisting Drydock data as +well as functionality resolve design references. +""" diff --git a/drydock_provisioner/statemgmt/db/tables.py b/drydock_provisioner/statemgmt/db/tables.py index 4527757e..88b99b09 100644 --- a/drydock_provisioner/statemgmt/db/tables.py +++ b/drydock_provisioner/statemgmt/db/tables.py @@ -1,12 +1,18 @@ """Definitions for Drydock database tables.""" -from sqlalchemy import Table, Column, MetaData -from sqlalchemy.types import Boolean, DateTime, String, Integer, JSON, BLOB +from sqlalchemy.schema import Table, Column, Sequence +from sqlalchemy.types import Boolean, DateTime, String, Integer from sqlalchemy.dialects import postgresql as pg -metadata = MetaData() -class Tasks(Table): +class ExtendTable(Table): + def __new__(cls, metadata): + self = super().__new__(cls, cls.__tablename__, metadata, + *cls.__schema__) + return self + + +class Tasks(ExtendTable): """Table for persisting Tasks.""" __tablename__ = 'tasks' @@ -16,6 +22,8 @@ class Tasks(Table): Column('parent_task_id', pg.BYTEA(16)), Column('subtask_id_list', pg.ARRAY(pg.BYTEA(16))), Column('result_status', String(32)), + Column('result_message', String(128)), + Column('result_reason', String(128)), Column('result_error_count', Integer), Column('status', String(32)), Column('created', DateTime), @@ -28,51 +36,37 @@ class Tasks(Table): Column('terminated_by', String(16)) ] - def __init__(self): - super().__init__( - Tasks.__tablename__, - metadata, - *Tasks.__schema__) - -class ResultMessage(Table): +class ResultMessage(ExtendTable): """Table for tracking result/status messages.""" __tablename__ = 'result_message' __schema__ = [ - Column('task_id', pg.BYTEA(16), primary_key=True), - Column('sequence', Integer, autoincrement='auto', primary_key=True), + Column('sequence', Integer, primary_key=True), + Column('task_id', pg.BYTEA(16)), Column('message', String(128)), Column('error', Boolean), + Column('context', String(32)), + Column('context_type', String(16)), + Column('ts', DateTime), Column('extra', pg.JSON) ] - def __init__(self): - super().__init__( - ResultMessage.__tablename__, - metadata, - *ResultMessage.__schema__) - -class ActiveInstance(Table): +class ActiveInstance(ExtendTable): """Table to organize multiple orchestrator instances.""" __tablename__ = 'active_instance' __schema__ = [ - Column('identity', pg.BYTEA(16), primary_key=True), - Column('last_ping', DateTime) + Column('dummy_key', Integer, primary_key=True), + Column('identity', pg.BYTEA(16)), + Column('last_ping', DateTime), ] - def __init__(self): - super().__init__( - ActiveInstance.__tablename__, - metadata, - *ActiveInstance.__schema__) - -class BuildData(Table): +class BuildData(ExtendTable): """Table persisting node build data.""" __tablename__ = 'build_data' @@ -82,9 +76,3 @@ class BuildData(Table): Column('task_id', pg.BYTEA(16)), Column('message', String(128)), ] - - def __init__(self): - super().__init__( - BuildData.__tablename__, - metadata, - *BuildData.__schema__) diff --git a/drydock_provisioner/statemgmt/state.py b/drydock_provisioner/statemgmt/state.py new file mode 100644 index 00000000..f83d1394 --- /dev/null +++ b/drydock_provisioner/statemgmt/state.py @@ -0,0 +1,359 @@ +# Copyright 2017 AT&T Intellectual Property. All other rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Access methods for managing external data access and persistence.""" + +import logging + +from sqlalchemy import create_engine +from sqlalchemy import MetaData +from sqlalchemy import sql + +import drydock_provisioner.objects as objects + +from .db import tables + +from drydock_provisioner import config + +from drydock_provisioner.error import DesignError +from drydock_provisioner.error import StateError + + +class DrydockState(object): + def __init__(self): + self.logger = logging.getLogger( + config.config_mgr.conf.logging.global_logger_name) + + self.db_engine = create_engine( + config.config_mgr.conf.database.database_connect_string) + self.db_metadata = MetaData() + + self.tasks_tbl = tables.Tasks(self.db_metadata) + self.result_message_tbl = tables.ResultMessage(self.db_metadata) + self.active_instance_tbl = tables.ActiveInstance(self.db_metadata) + self.build_data_tbl = tables.BuildData(self.db_metadata) + + return + + # TODO(sh8121att) Need to lock a design base or change once implementation + # has started + def get_design(self, design_id): + if design_id not in self.designs.keys(): + + raise DesignError("Design ID %s not found" % (design_id)) + + return objects.SiteDesign.obj_from_primitive(self.designs[design_id]) + + def post_design(self, site_design): + if site_design is not None: + my_lock = self.designs_lock.acquire(blocking=True, timeout=10) + if my_lock: + design_id = site_design.id + if design_id not in self.designs.keys(): + self.designs[design_id] = site_design.obj_to_primitive() + else: + self.designs_lock.release() + raise StateError("Design ID %s already exists" % design_id) + self.designs_lock.release() + return True + raise StateError("Could not acquire lock") + else: + raise DesignError("Design change must be a SiteDesign instance") + + def put_design(self, site_design): + if site_design is not None: + my_lock = self.designs_lock.acquire(blocking=True, timeout=10) + if my_lock: + design_id = site_design.id + if design_id not in self.designs.keys(): + self.designs_lock.release() + raise StateError("Design ID %s does not exist" % design_id) + else: + self.designs[design_id] = site_design.obj_to_primitive() + self.designs_lock.release() + return True + raise StateError("Could not acquire lock") + else: + raise DesignError("Design base must be a SiteDesign instance") + + def get_current_build(self): + latest_stamp = 0 + current_build = None + + for b in self.builds: + if b.build_id > latest_stamp: + latest_stamp = b.build_id + current_build = b + + return deepcopy(current_build) + + def get_build(self, build_id): + for b in self.builds: + if b.build_id == build_id: + return b + + return None + + def post_build(self, site_build): + if site_build is not None and isinstance(site_build, SiteBuild): + my_lock = self.builds_lock.acquire(block=True, timeout=10) + if my_lock: + exists = [ + b for b in self.builds if b.build_id == site_build.build_id + ] + + if len(exists) > 0: + self.builds_lock.release() + raise DesignError("Already a site build with ID %s" % + (str(site_build.build_id))) + self.builds.append(deepcopy(site_build)) + self.builds_lock.release() + return True + raise StateError("Could not acquire lock") + else: + raise DesignError("Design change must be a SiteDesign instance") + + def put_build(self, site_build): + if site_build is not None and isinstance(site_build, SiteBuild): + my_lock = self.builds_lock.acquire(block=True, timeout=10) + if my_lock: + buildid = site_build.buildid + for b in self.builds: + if b.buildid == buildid: + b.merge_updates(site_build) + self.builds_lock.release() + return True + self.builds_lock.release() + return False + raise StateError("Could not acquire lock") + else: + raise DesignError("Design change must be a SiteDesign instance") + + def get_tasks(self): + """Get all tasks in the database.""" + try: + conn = self.db_engine.connect() + query = sql.select([self.tasks_tbl]) + rs = conn.execute(query) + + task_list = [objects.Task.from_db(dict(r)) for r in rs] + + self._assemble_tasks(task_list=task_list) + + conn.close() + + return task_list + except Exception as ex: + self.logger.error("Error querying task list: %s" % str(ex)) + return [] + + def get_task(self, task_id): + """Query database for task matching task_id. + + :param task_id: uuid.UUID of a task_id to query against + """ + try: + conn = self.db_engine.connect() + query = self.tasks_tbl.select().where( + self.tasks_tbl.c.task_id == task_id.bytes) + rs = conn.execute(query) + + r = rs.fetchone() + + task = objects.Task.from_db(dict(r)) + + self.logger.debug("Assembling result messages for task %s." % str(task.task_id)) + self._assemble_tasks(task_list=[task]) + + conn.close() + + return task + + except Exception as ex: + self.logger.error("Error querying task %s: %s" % (str(task_id), + str(ex)), exc_info=True) + return None + + def post_result_message(self, task_id, msg): + """Add a result message to database attached to task task_id. + + :param task_id: uuid.UUID ID of the task the msg belongs to + :param msg: instance of objects.TaskStatusMessage + """ + try: + conn = self.db_engine.connect() + query = self.result_message_tbl.insert().values(task_id=task_id.bytes, **(msg.to_db())) + conn.execute(query) + conn.close() + return True + except Exception as ex: + self.logger.error("Error inserting result message for task %s: %s" % (str(task_id), str(ex))) + return False + + def _assemble_tasks(self, task_list=None): + """Attach all the appropriate result messages to the tasks in the list. + + :param task_list: a list of objects.Task instances to attach result messages to + """ + if task_list is None: + return None + + conn = self.db_engine.connect() + query = sql.select([self.result_message_tbl]).where( + self.result_message_tbl.c.task_id == sql.bindparam( + 'task_id')).order_by(self.result_message_tbl.c.sequence.asc()) + query.compile(self.db_engine) + + for t in task_list: + rs = conn.execute(query, task_id=t.task_id.bytes) + error_count = 0 + for r in rs: + msg = objects.TaskStatusMessage.from_db(dict(r)) + if msg.error: + error_count = error_count + 1 + t.result.message_list.append(msg) + t.result.error_count = error_count + + conn.close() + + def post_task(self, task): + """Insert a task into the database. + + Does not insert attached result messages + + :param task: instance of objects.Task to insert into the database. + """ + try: + conn = self.db_engine.connect() + query = self.tasks_tbl.insert().values(**(task.to_db(include_id=True))) + conn.execute(query) + conn.close() + return True + except Exception as ex: + self.logger.error("Error inserting task %s: %s" % + (str(task.task_id), str(ex))) + return False + + def put_task(self, task): + """Update a task in the database. + + :param task: objects.Task instance to reference for update values + """ + try: + conn = self.db_engine.connect() + query = self.tasks_tbl.update( + **(task.to_db(include_id=False))).where( + self.tasks_tbl.c.task_id == task.task_id.bytes) + rs = conn.execute(query) + if rs.rowcount == 1: + conn.close() + return True + else: + conn.close() + return False + except Exception as ex: + self.logger.error("Error updating task %s: %s" % + (str(task.task_id), str(ex))) + return False + + def add_subtask(self, task_id, subtask_id): + """Add new task to subtask list. + + :param task_id: uuid.UUID parent task ID + :param subtask_id: uuid.UUID new subtask ID + """ + query_string = sql.text("UPDATE tasks " + "SET subtask_id_list = array_append(subtask_id_list, :new_subtask) " + "WHERE task_id = :task_id").execution_options(autocommit=True) + + try: + conn = self.db_engine.connect() + rs = conn.execute(query_string, new_subtask=subtask_id.bytes, task_id=task_id.bytes) + rc = rs.rowcount + conn.close() + if rc == 1: + return True + else: + return False + except Exception as ex: + self.logger.error("Error appending subtask %s to task %s: %s" + % (str(subtask_id), str(task_id), str(ex))) + return False + + def claim_leadership(self, leader_id): + """Claim active instance status for leader_id. + + Attempt to claim leadership for leader_id. If another leader_id already has leadership + and has checked-in within the configured interval, this claim fails. If the last check-in + of an active instance is outside the configured interval, this claim will overwrite the + current active instance and succeed. If leadership has not been claimed, this call will + succeed. + + All leadership claims by an instance should use the same leader_id + + :param leader_id: a uuid.UUID instance identifying the instance to be considered active + """ + query_string = sql.text("INSERT INTO active_instance (dummy_key, identity, last_ping) " + "VALUES (1, :instance_id, timezone('UTC', now())) " + "ON CONFLICT (dummy_key) DO UPDATE SET " + "identity = :instance_id " + "WHERE active_instance.last_ping < (now() - interval '%d seconds')" + % (config.config_mgr.conf.default.leader_grace_period)).execution_options(autocommit=True) + + try: + conn = self.db_engine.connect() + rs = conn.execute(query_string, instance_id=leader_id.bytes) + rc = rs.rowcount + conn.close() + if rc == 1: + return True + else: + return False + except Exception as ex: + self.logger.error("Error executing leadership claim: %s" % str(ex)) + return False + + def post_promenade_part(self, part): + my_lock = self.promenade_lock.acquire(blocking=True, timeout=10) + if my_lock: + if self.promenade.get(part.target, None) is not None: + self.promenade[part.target].append(part.obj_to_primitive()) + else: + self.promenade[part.target] = [part.obj_to_primitive()] + self.promenade_lock.release() + return None + else: + raise StateError("Could not acquire lock") + + def get_promenade_parts(self, target): + parts = self.promenade.get(target, None) + + if parts is not None: + return [ + objects.PromenadeConfig.obj_from_primitive(p) for p in parts + ] + else: + # Return an empty list just to play nice with extend + return [] + + def set_bootdata_key(self, hostname, design_id, data_key): + my_lock = self.bootdata_lock.acquire(blocking=True, timeout=10) + if my_lock: + self.bootdata[hostname] = {'design_id': design_id, 'key': data_key} + self.bootdata_lock.release() + return None + else: + raise StateError("Could not acquire lock") + + def get_bootdata_key(self, hostname): + return self.bootdata.get(hostname, None) diff --git a/setup.py b/setup.py index 6db3447a..0917d50d 100644 --- a/setup.py +++ b/setup.py @@ -29,20 +29,28 @@ setup( author_email='sh8121@att.com', license='Apache 2.0', packages=[ - 'drydock_provisioner', 'drydock_provisioner.objects', - 'drydock_provisioner.ingester', 'drydock_provisioner.ingester.plugins', - 'drydock_provisioner.statemgmt', 'drydock_provisioner.orchestrator', - 'drydock_provisioner.control', 'drydock_provisioner.drivers', + 'drydock_provisioner', + 'drydock_provisioner.objects', + 'drydock_provisioner.ingester', + 'drydock_provisioner.ingester.plugins', + 'drydock_provisioner.statemgmt', + 'drydock_provisioner.orchestrator', + 'drydock_provisioner.control', + 'drydock_provisioner.drivers', 'drydock_provisioner.drivers.oob', 'drydock_provisioner.drivers.oob.pyghmi_driver', 'drydock_provisioner.drivers.oob.manual_driver', 'drydock_provisioner.drivers.node', 'drydock_provisioner.drivers.node.maasdriver', 'drydock_provisioner.drivers.node.maasdriver.models', - 'drydock_provisioner.control', 'drydock_provisioner.cli', - 'drydock_provisioner.cli.design', 'drydock_provisioner.cli.part', - 'drydock_provisioner.cli.task', 'drydock_provisioner.drydock_client', - 'drydock_provisioner.statemgmt.db','drydock_provisioner.cli.node' + 'drydock_provisioner.control', + 'drydock_provisioner.cli', + 'drydock_provisioner.cli.design', + 'drydock_provisioner.cli.part', + 'drydock_provisioner.cli.task', + 'drydock_provisioner.drydock_client', + 'drydock_provisioner.statemgmt.db', + 'drydock_provisioner.cli.node', ], entry_points={ 'oslo.config.opts': @@ -52,4 +60,4 @@ setup( 'console_scripts': 'drydock = drydock_provisioner.cli.commands:drydock' }, - ) +) diff --git a/tests/integration/bs_psql.sh b/tests/integration/bs_psql.sh new file mode 100755 index 00000000..1fa4b723 --- /dev/null +++ b/tests/integration/bs_psql.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +sudo docker run --rm -dp 5432:5432 --name 'psql_integration' postgres:9.5 +sleep 15 + +psql -h localhost -c "create user drydock with password 'drydock';" postgres postgres +psql -h localhost -c "create database drydock;" postgres postgres + +export DRYDOCK_DB_URL="postgresql+psycopg2://drydock:drydock@localhost:5432/drydock" +alembic upgrade head + diff --git a/tests/integration/test_postgres_leadership.py b/tests/integration/test_postgres_leadership.py new file mode 100644 index 00000000..af049fd9 --- /dev/null +++ b/tests/integration/test_postgres_leadership.py @@ -0,0 +1,71 @@ +import pytest + +import logging +import uuid +import time + +from oslo_config import cfg + +import drydock_provisioner.objects as objects +import drydock_provisioner.config as config + +from drydock_provisioner.statemgmt.state import DrydockState + + +class TestPostgres(object): + def test_claim_leadership(self, setup): + """Test that a node can claim leadership. + + First test claiming leadership with an empty table, simulating startup + Second test that an immediate follow-up claim is denied + Third test that a usurping claim after the grace period succeeds + """ + ds = DrydockState() + + first_leader = uuid.uuid4() + second_leader = uuid.uuid4() + + print("Claiming leadership for %s" % str(first_leader.bytes)) + crown = ds.claim_leadership(first_leader) + + assert crown == True + + print("Claiming leadership for %s" % str(second_leader.bytes)) + crown = ds.claim_leadership(second_leader) + + assert crown == False + + time.sleep(20) + + print( + "Claiming leadership for %s after 20s" % str(second_leader.bytes)) + crown = ds.claim_leadership(second_leader) + + assert crown == True + + @pytest.fixture(scope='module') + def setup(self): + objects.register_all() + logging.basicConfig() + + req_opts = { + 'default': [cfg.IntOpt('leader_grace_period')], + 'database': [cfg.StrOpt('database_connect_string')], + 'logging': [ + cfg.StrOpt('global_logger_name', default='drydock'), + ] + } + + for k, v in req_opts.items(): + config.config_mgr.conf.register_opts(v, group=k) + + config.config_mgr.conf([]) + config.config_mgr.conf.set_override( + name="database_connect_string", + group="database", + override= + "postgresql+psycopg2://drydock:drydock@localhost:5432/drydock") + config.config_mgr.conf.set_override( + name="leader_grace_period", group="default", override=15) + + return diff --git a/tests/integration/test_postgres_results.py b/tests/integration/test_postgres_results.py new file mode 100644 index 00000000..dd6f4f66 --- /dev/null +++ b/tests/integration/test_postgres_results.py @@ -0,0 +1,114 @@ +import pytest + +import logging +import uuid +import time + +from oslo_config import cfg + +from sqlalchemy import sql +from sqlalchemy import create_engine + +from drydock_provisioner import objects +import drydock_provisioner.config as config + +from drydock_provisioner.control.base import DrydockRequestContext +from drydock_provisioner.statemgmt.state import DrydockState + + +class TestPostgres(object): + + def test_result_message_insert(self, populateddb, drydockstate): + """Test that a result message for a task can be added.""" + msg1 = objects.TaskStatusMessage('Error 1', True, 'node', 'node1') + msg2 = objects.TaskStatusMessage('Status 1', False, 'node', 'node1') + + result = drydockstate.post_result_message(populateddb.task_id, msg1) + assert result + result = drydockstate.post_result_message(populateddb.task_id, msg2) + assert result + + task = drydockstate.get_task(populateddb.task_id) + + assert task.result.error_count == 1 + + assert len(task.result.message_list) == 2 + + @pytest.fixture(scope='function') + def populateddb(self, cleandb): + """Add dummy task to test against.""" + task = objects.Task( + action='prepare_site', design_ref='http://test.com/design') + + q1 = sql.text('INSERT INTO tasks ' \ + '(task_id, created, action, design_ref) ' \ + 'VALUES (:task_id, :created, :action, :design_ref)').execution_options(autocommit=True) + + engine = create_engine( + config.config_mgr.conf.database.database_connect_string) + conn = engine.connect() + + conn.execute( + q1, + task_id=task.task_id.bytes, + created=task.created, + action=task.action, + design_ref=task.design_ref) + + conn.close() + + return task + + @pytest.fixture(scope='session') + def drydockstate(self): + return DrydockState() + + @pytest.fixture(scope='function') + def cleandb(self, setup): + q1 = sql.text('TRUNCATE TABLE tasks').execution_options( + autocommit=True) + q2 = sql.text('TRUNCATE TABLE result_message').execution_options( + autocommit=True) + q3 = sql.text('TRUNCATE TABLE active_instance').execution_options( + autocommit=True) + q4 = sql.text('TRUNCATE TABLE build_data').execution_options( + autocommit=True) + + engine = create_engine( + config.config_mgr.conf.database.database_connect_string) + conn = engine.connect() + + conn.execute(q1) + conn.execute(q2) + conn.execute(q3) + conn.execute(q4) + + conn.close() + return + + @pytest.fixture(scope='module') + def setup(self): + objects.register_all() + logging.basicConfig() + + req_opts = { + 'default': [cfg.IntOpt('leader_grace_period')], + 'database': [cfg.StrOpt('database_connect_string')], + 'logging': [ + cfg.StrOpt('global_logger_name', default='drydock'), + ] + } + + for k, v in req_opts.items(): + config.config_mgr.conf.register_opts(v, group=k) + + config.config_mgr.conf([]) + config.config_mgr.conf.set_override( + name="database_connect_string", + group="database", + override= + "postgresql+psycopg2://drydock:drydock@localhost:5432/drydock") + config.config_mgr.conf.set_override( + name="leader_grace_period", group="default", override=15) + + return diff --git a/tests/integration/test_postgres_tasks.py b/tests/integration/test_postgres_tasks.py new file mode 100644 index 00000000..80921340 --- /dev/null +++ b/tests/integration/test_postgres_tasks.py @@ -0,0 +1,140 @@ +import pytest + +import logging +import uuid +import time + +from oslo_config import cfg + +from sqlalchemy import sql +from sqlalchemy import create_engine + +from drydock_provisioner import objects +import drydock_provisioner.config as config + +from drydock_provisioner.control.base import DrydockRequestContext +from drydock_provisioner.statemgmt.state import DrydockState + + +class TestPostgres(object): + def test_task_insert(self, cleandb, drydockstate): + """Test that a task can be inserted into the database.""" + ctx = DrydockRequestContext() + ctx.user = 'sh8121' + ctx.external_marker = str(uuid.uuid4()) + + task = objects.Task( + action='deploy_node', + design_ref='http://foo.bar/design', + context=ctx) + + result = drydockstate.post_task(task) + + assert result == True + + def test_subtask_append(self, cleandb, drydockstate): + """Test that the atomic subtask append method works.""" + + task = objects.Task(action='deploy_node', design_ref='http://foobar/design') + subtask = objects.Task(action='deploy_node', design_ref='http://foobar/design', parent_task_id=task.task_id) + + drydockstate.post_task(task) + drydockstate.post_task(subtask) + drydockstate.add_subtask(task.task_id, subtask.task_id) + + test_task = drydockstate.get_task(task.task_id) + + assert subtask.task_id in test_task.subtask_id_list + + def test_task_select(self, populateddb, drydockstate): + """Test that a task can be selected.""" + result = drydockstate.get_task(populateddb.task_id) + + assert result is not None + assert result.design_ref == populateddb.design_ref + + def test_task_list(self, populateddb, drydockstate): + """Test getting a list of all tasks.""" + + result = drydockstate.get_tasks() + + assert len(result) == 1 + + @pytest.fixture(scope='function') + def populateddb(self, cleandb): + """Add dummy task to test against.""" + task = objects.Task( + action='prepare_site', design_ref='http://test.com/design') + + q1 = sql.text('INSERT INTO tasks ' \ + '(task_id, created, action, design_ref) ' \ + 'VALUES (:task_id, :created, :action, :design_ref)').execution_options(autocommit=True) + + engine = create_engine( + config.config_mgr.conf.database.database_connect_string) + conn = engine.connect() + + conn.execute( + q1, + task_id=task.task_id.bytes, + created=task.created, + action=task.action, + design_ref=task.design_ref) + + conn.close() + + return task + + @pytest.fixture(scope='session') + def drydockstate(self): + return DrydockState() + + @pytest.fixture(scope='function') + def cleandb(self, setup): + q1 = sql.text('TRUNCATE TABLE tasks').execution_options( + autocommit=True) + q2 = sql.text('TRUNCATE TABLE result_message').execution_options( + autocommit=True) + q3 = sql.text('TRUNCATE TABLE active_instance').execution_options( + autocommit=True) + q4 = sql.text('TRUNCATE TABLE build_data').execution_options( + autocommit=True) + + engine = create_engine( + config.config_mgr.conf.database.database_connect_string) + conn = engine.connect() + + conn.execute(q1) + conn.execute(q2) + conn.execute(q3) + conn.execute(q4) + + conn.close() + return + + @pytest.fixture(scope='module') + def setup(self): + objects.register_all() + logging.basicConfig() + + req_opts = { + 'default': [cfg.IntOpt('leader_grace_period')], + 'database': [cfg.StrOpt('database_connect_string')], + 'logging': [ + cfg.StrOpt('global_logger_name', default='drydock'), + ] + } + + for k, v in req_opts.items(): + config.config_mgr.conf.register_opts(v, group=k) + + config.config_mgr.conf([]) + config.config_mgr.conf.set_override( + name="database_connect_string", + group="database", + override= + "postgresql+psycopg2://drydock:drydock@localhost:5432/drydock") + config.config_mgr.conf.set_override( + name="leader_grace_period", group="default", override=15) + + return diff --git a/tox.ini b/tox.ini index 415e38c8..7c4cd1dc 100644 --- a/tox.ini +++ b/tox.ini @@ -30,6 +30,13 @@ commands= py.test \ tests/unit/{posargs} +[testenv:integration] +setenv= + PYTHONWARNING=all +commands= + py.test \ + tests/integration/{posargs} + [testenv:genconfig] commands = oslo-config-generator --config-file=etc/drydock/drydock-config-generator.conf