Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8c9f1c9

Browse files
committedJun 1, 2022
Improved handling of stdout objects that don't have a 'buffer' attribute.
For instance, when using `renderer_print_formatted_text` in a Jupyter Notebook.
1 parent cdbaef7 commit 8c9f1c9

File tree

6 files changed

+24
-37
lines changed

6 files changed

+24
-37
lines changed
 

‎src/prompt_toolkit/contrib/ssh/server.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,8 @@ async def _interact(self) -> None:
8585

8686
term = self._chan.get_terminal_type()
8787

88-
self._output = Vt100_Output(
89-
self.stdout, self._get_size, term=term, write_binary=False
90-
)
88+
self._output = Vt100_Output(self.stdout, self._get_size, term=term)
89+
9190
with create_pipe_input() as self._input:
9291
with create_app_session(input=self._input, output=self._output) as session:
9392
self.app_session = session

‎src/prompt_toolkit/contrib/telnet/server.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,7 @@ def size_received(rows: int, columns: int) -> None:
169169

170170
def ttype_received(ttype: str) -> None:
171171
"""TelnetProtocolParser 'ttype_received' callback"""
172-
self.vt100_output = Vt100_Output(
173-
self.stdout, get_size, term=ttype, write_binary=False
174-
)
172+
self.vt100_output = Vt100_Output(self.stdout, get_size, term=ttype)
175173
self._ready.set()
176174

177175
self.parser = TelnetProtocolParser(data_received, size_received, ttype_received)

‎src/prompt_toolkit/output/flush_stdout.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
__all__ = ["flush_stdout"]
88

99

10-
def flush_stdout(stdout: TextIO, data: str, write_binary: bool) -> None:
10+
def flush_stdout(stdout: TextIO, data: str) -> None:
11+
# If the IO object has an `encoding` and `buffer` attribute, it means that
12+
# we can access the underlying BinaryIO object and write into it in binary
13+
# mode. This is preferred if possible.
14+
# NOTE: When used in a Jupyter notebook, don't write binary.
15+
# `ipykernel.iostream.OutStream` has an `encoding` attribute, but not
16+
# a `buffer` attribute, so we can't write binary in it.
17+
has_binary_io = hasattr(stdout, "encoding") and hasattr(stdout, "buffer")
18+
1119
try:
1220
# Ensure that `stdout` is made blocking when writing into it.
1321
# Otherwise, when uvloop is activated (which makes stdout
@@ -20,14 +28,8 @@ def flush_stdout(stdout: TextIO, data: str, write_binary: bool) -> None:
2028
# My Arch Linux installation of july 2015 reported 'ANSI_X3.4-1968'
2129
# for sys.stdout.encoding in xterm.
2230
out: IO[bytes]
23-
if write_binary:
24-
if hasattr(stdout, "buffer"):
25-
out = stdout.buffer
26-
else:
27-
# IO[bytes] was given to begin with.
28-
# (Used in the unit tests, for instance.)
29-
out = cast(IO[bytes], stdout)
30-
out.write(data.encode(stdout.encoding or "utf-8", "replace"))
31+
if has_binary_io:
32+
stdout.buffer.write(data.encode(stdout.encoding or "utf-8", "replace"))
3133
else:
3234
stdout.write(data)
3335

‎src/prompt_toolkit/output/plain_text.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,10 @@ class PlainTextOutput(Output):
2323
formatting.)
2424
"""
2525

26-
def __init__(self, stdout: TextIO, write_binary: bool = True) -> None:
26+
def __init__(self, stdout: TextIO) -> None:
2727
assert all(hasattr(stdout, a) for a in ("write", "flush"))
2828

29-
if write_binary:
30-
assert hasattr(stdout, "encoding")
31-
3229
self.stdout: TextIO = stdout
33-
self.write_binary = write_binary
3430
self._buffer: List[str] = []
3531

3632
def fileno(self) -> int:
@@ -58,7 +54,7 @@ def flush(self) -> None:
5854

5955
data = "".join(self._buffer)
6056
self._buffer = []
61-
flush_stdout(self.stdout, data, write_binary=self.write_binary)
57+
flush_stdout(self.stdout, data)
6258

6359
def erase_screen(self) -> None:
6460
pass

‎src/prompt_toolkit/output/vt100.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,6 @@ class Vt100_Output(Output):
400400
:param get_size: A callable which returns the `Size` of the output terminal.
401401
:param stdout: Any object with has a `write` and `flush` method + an 'encoding' property.
402402
:param term: The terminal environment variable. (xterm, xterm-256color, linux, ...)
403-
:param write_binary: Encode the output before writing it. If `True` (the
404-
default), the `stdout` object is supposed to expose an `encoding` attribute.
405403
"""
406404

407405
# For the error messages. Only display "Output is not a terminal" once per
@@ -413,19 +411,14 @@ def __init__(
413411
stdout: TextIO,
414412
get_size: Callable[[], Size],
415413
term: Optional[str] = None,
416-
write_binary: bool = True,
417414
default_color_depth: Optional[ColorDepth] = None,
418415
enable_bell: bool = True,
419416
) -> None:
420417

421418
assert all(hasattr(stdout, a) for a in ("write", "flush"))
422419

423-
if write_binary:
424-
assert hasattr(stdout, "encoding")
425-
426420
self._buffer: List[str] = []
427421
self.stdout: TextIO = stdout
428-
self.write_binary = write_binary
429422
self.default_color_depth = default_color_depth
430423
self._get_size = get_size
431424
self.term = term
@@ -699,7 +692,7 @@ def flush(self) -> None:
699692
data = "".join(self._buffer)
700693
self._buffer = []
701694

702-
flush_stdout(self.stdout, data, write_binary=self.write_binary)
695+
flush_stdout(self.stdout, data)
703696

704697
def ask_for_cpr(self) -> None:
705698
"""

‎tests/test_print_formatted_text.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
class _Capture:
1414
"Emulate an stdout object."
15-
encoding = "utf-8"
1615

1716
def __init__(self):
1817
self._data = []
@@ -22,7 +21,7 @@ def write(self, data):
2221

2322
@property
2423
def data(self):
25-
return b"".join(self._data)
24+
return "".join(self._data)
2625

2726
def flush(self):
2827
pass
@@ -40,15 +39,15 @@ def fileno(self):
4039
def test_print_formatted_text():
4140
f = _Capture()
4241
pt_print([("", "hello"), ("", "world")], file=f)
43-
assert b"hello" in f.data
44-
assert b"world" in f.data
42+
assert "hello" in f.data
43+
assert "world" in f.data
4544

4645

4746
@pytest.mark.skipif(is_windows(), reason="Doesn't run on Windows yet.")
4847
def test_print_formatted_text_backslash_r():
4948
f = _Capture()
5049
pt_print("hello\r\n", file=f)
51-
assert b"hello" in f.data
50+
assert "hello" in f.data
5251

5352

5453
@pytest.mark.skipif(is_windows(), reason="Doesn't run on Windows yet.")
@@ -70,8 +69,8 @@ def test_formatted_text_with_style():
7069
# NOTE: We pass the default (8bit) color depth, so that the unit tests
7170
# don't start failing when environment variables change.
7271
pt_print(tokens, style=style, file=f, color_depth=ColorDepth.DEFAULT)
73-
assert b"\x1b[0;38;5;197mHello" in f.data
74-
assert b"\x1b[0;38;5;83;3mworld" in f.data
72+
assert "\x1b[0;38;5;197mHello" in f.data
73+
assert "\x1b[0;38;5;83;3mworld" in f.data
7574

7675

7776
@pytest.mark.skipif(is_windows(), reason="Doesn't run on Windows yet.")
@@ -87,5 +86,5 @@ def test_html_with_style():
8786

8887
assert (
8988
f.data
90-
== b"\x1b[0m\x1b[?7h\x1b[0;32mhello\x1b[0m \x1b[0;1mworld\x1b[0m\r\n\x1b[0m"
89+
== "\x1b[0m\x1b[?7h\x1b[0;32mhello\x1b[0m \x1b[0;1mworld\x1b[0m\r\n\x1b[0m"
9190
)

0 commit comments

Comments
 (0)
Please sign in to comment.