Commit 5b45cd2c authored by jpic ∞'s avatar jpic ∞ 💾
Browse files

Support async callable objects

parent 7ea9b050
......@@ -193,12 +193,19 @@ class Callable(Importable):
def for_callback(cls, cb):
return getattr(cb, 'cli2', cls(cb.__name__, cb))
@property
def is_async(self):
__call__ = getattr(self.target, '__call__', None)
if __call__ and inspect.iscoroutinefunction(__call__):
return True
return inspect.iscoroutinefunction(self.target)
def __call__(self, *args, **kwargs):
req_args = self.required_args
if len(args) < len(req_args):
raise Cli2ArgsException(self, args)
if inspect.iscoroutinefunction(self.target):
if self.is_async:
target = sync.async_to_sync(self.target)
else:
target = self.target
......@@ -216,6 +223,7 @@ class Callable(Importable):
def required_args(self):
if self.is_module:
return []
try:
argspec = inspect.getfullargspec(self.target)
"""
......@@ -228,9 +236,13 @@ class Callable(Importable):
del argspec.args[0]
if argspec.defaults:
return argspec.args[:-len(argspec.defaults)]
args = argspec.args[:-len(argspec.defaults)]
else:
return argspec.args
args = argspec.args
if args and args[0] in ('self', 'cls'):
return args[1:]
return args
except TypeError:
# catch builtins that don't provide a signature
# TODO: parse first line of inspect.getdoc() for builtin signature?
......
......@@ -18,3 +18,41 @@ def test_command_color_default():
def foo():
pass
assert cli2.Callable('foo', foo).color == cli2.YELLOW
def test_callable_object():
result = cli2.Callable.factory(
'cli2.console_script.ConsoleScript.singleton')
assert result.target == cli2.ConsoleScript.singleton
assert not result.required_args
assert [*result.get_callables()]
def test_boundmethod():
result = cli2.Callable.factory(
'cli2.console_script.ConsoleScript.singleton.result_handler')
assert result.target == cli2.ConsoleScript.singleton.result_handler
assert result.required_args == ['result']
assert not [*result.get_callables()]
def test_callable_function():
from cli2.cli import run
result = cli2.Callable.factory('cli2.cli.run')
assert result.target == run
assert result.required_args
assert not [*result.get_callables()]
def test_callable_function_async():
async def foo():
pass
assert cli2.Callable('foo', foo).is_async
def test_callable_object_async():
class Foo:
async def __call__(self):
pass
assert cli2.Callable('foo', Foo()).is_async
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment