Coverage for src/prisma/cli/cli.py: 98%

58 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-04-28 15:17 +0000

1import os 

2import sys 

3import logging 

4import contextlib 

5from typing import List, Iterator, NoReturn, Optional 

6 

7import click 

8 

9from . import prisma 

10from .. import _sync_http as http 

11from .utils import error 

12from ..utils import DEBUG 

13from .custom import cli 

14from ..generator import Generator 

15 

16__all__ = ('main', 'setup_logging') 

17 

18log: logging.Logger = logging.getLogger(__name__) 

19 

20 

21# TODO: switch base cli to click as well to support autocomplete 

22 

23 

24def main( 

25 args: Optional[List[str]] = None, 

26 use_handler: bool = True, 

27 do_cleanup: bool = True, 

28) -> NoReturn: 

29 if args is None: 

30 args = sys.argv 

31 

32 with setup_logging(use_handler), cleanup(do_cleanup): 

33 if len(args) > 1: 

34 if args[1] == 'py': 

35 cli.main(args[2:], prog_name='prisma py') 

36 else: 

37 sys.exit(prisma.run(args[1:])) 

38 else: 

39 if not os.environ.get('PRISMA_GENERATOR_INVOCATION'): 

40 error( 

41 'This command is only intended to be invoked internally. ' 'Please run the following instead:', 

42 exit_=False, 

43 ) 

44 click.echo('prisma <command>') 

45 click.echo('e.g.') 

46 click.echo('prisma generate') 

47 sys.exit(1) 

48 Generator.invoke() 

49 

50 # mypy does not recognise sys.exit as a NoReturn for some reason 

51 raise SystemExit(0) 

52 

53 

54@contextlib.contextmanager 

55def setup_logging(use_handler: bool = True) -> Iterator[None]: 

56 handler = None 

57 logger = logging.getLogger() 

58 

59 try: 

60 if DEBUG: 

61 logger.setLevel(logging.DEBUG) 

62 

63 # the prisma CLI binary uses the DEBUG environment variable 

64 if os.environ.get('DEBUG') is None: 

65 os.environ['DEBUG'] = 'prisma:GeneratorProcess' 

66 else: 

67 log.debug('Not overriding the DEBUG environment variable.') 

68 else: 

69 logger.setLevel(logging.INFO) 

70 

71 if use_handler: 

72 fmt = logging.Formatter( 

73 '[{levelname:<7}] {name}: {message}', 

74 style='{', 

75 ) 

76 handler = logging.StreamHandler() 

77 handler.setFormatter(fmt) 

78 logger.addHandler(handler) 

79 

80 yield 

81 finally: 

82 if use_handler and handler is not None: 

83 handler.close() 

84 logger.removeHandler(handler) 

85 

86 

87@contextlib.contextmanager 

88def cleanup(do_cleanup: bool = True) -> Iterator[None]: 

89 try: 

90 yield 

91 finally: 

92 if do_cleanup: 

93 http.client.close() 

94 

95 

96if __name__ == '__main__': 96 ↛ 97line 96 didn't jump to line 97, because the condition on line 96 was never true

97 main()