[utils] Align traverse_obj() with yt-dlp
Thanks Grub4k for these:
* traverse `Iterable`s, from https://github.com/yt-dlp/yt-dlp/pull/6902, etc
* traverse `set` key for transformations/filters, `re.Match` group names, from
776995bc10
, etc
* traverse `re.Match`es, from https://github.com/yt-dlp/yt-dlp/pull/5174
* always return list when branching, from https://github.com/yt-dlp/yt-dlp/pull/5170
This commit is contained in:
parent
47214e46d8
commit
825a40744b
2 changed files with 23 additions and 23 deletions
|
@ -20,7 +20,7 @@ import xml.etree.ElementTree
|
||||||
from youtube_dl.utils import (
|
from youtube_dl.utils import (
|
||||||
age_restricted,
|
age_restricted,
|
||||||
args_to_str,
|
args_to_str,
|
||||||
encode_base_n,
|
base_url,
|
||||||
caesar,
|
caesar,
|
||||||
clean_html,
|
clean_html,
|
||||||
clean_podcast_url,
|
clean_podcast_url,
|
||||||
|
@ -29,10 +29,12 @@ from youtube_dl.utils import (
|
||||||
detect_exe_version,
|
detect_exe_version,
|
||||||
determine_ext,
|
determine_ext,
|
||||||
dict_get,
|
dict_get,
|
||||||
|
encode_base_n,
|
||||||
encode_compat_str,
|
encode_compat_str,
|
||||||
encodeFilename,
|
encodeFilename,
|
||||||
escape_rfc3986,
|
escape_rfc3986,
|
||||||
escape_url,
|
escape_url,
|
||||||
|
expand_path,
|
||||||
extract_attributes,
|
extract_attributes,
|
||||||
ExtractorError,
|
ExtractorError,
|
||||||
find_xpath_attr,
|
find_xpath_attr,
|
||||||
|
@ -51,6 +53,7 @@ from youtube_dl.utils import (
|
||||||
js_to_json,
|
js_to_json,
|
||||||
LazyList,
|
LazyList,
|
||||||
limit_length,
|
limit_length,
|
||||||
|
lowercase_escape,
|
||||||
merge_dicts,
|
merge_dicts,
|
||||||
mimetype2ext,
|
mimetype2ext,
|
||||||
month_by_name,
|
month_by_name,
|
||||||
|
@ -66,17 +69,16 @@ from youtube_dl.utils import (
|
||||||
parse_resolution,
|
parse_resolution,
|
||||||
parse_bitrate,
|
parse_bitrate,
|
||||||
pkcs1pad,
|
pkcs1pad,
|
||||||
read_batch_urls,
|
|
||||||
sanitize_filename,
|
|
||||||
sanitize_path,
|
|
||||||
sanitize_url,
|
|
||||||
expand_path,
|
|
||||||
prepend_extension,
|
prepend_extension,
|
||||||
replace_extension,
|
read_batch_urls,
|
||||||
remove_start,
|
remove_start,
|
||||||
remove_end,
|
remove_end,
|
||||||
remove_quotes,
|
remove_quotes,
|
||||||
|
replace_extension,
|
||||||
rot47,
|
rot47,
|
||||||
|
sanitize_filename,
|
||||||
|
sanitize_path,
|
||||||
|
sanitize_url,
|
||||||
shell_quote,
|
shell_quote,
|
||||||
smuggle_url,
|
smuggle_url,
|
||||||
str_or_none,
|
str_or_none,
|
||||||
|
@ -93,10 +95,8 @@ from youtube_dl.utils import (
|
||||||
unified_timestamp,
|
unified_timestamp,
|
||||||
unsmuggle_url,
|
unsmuggle_url,
|
||||||
uppercase_escape,
|
uppercase_escape,
|
||||||
lowercase_escape,
|
|
||||||
url_basename,
|
url_basename,
|
||||||
url_or_none,
|
url_or_none,
|
||||||
base_url,
|
|
||||||
urljoin,
|
urljoin,
|
||||||
urlencode_postdata,
|
urlencode_postdata,
|
||||||
urshift,
|
urshift,
|
||||||
|
@ -1586,6 +1586,11 @@ Line 1
|
||||||
'dict': {},
|
'dict': {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# define a pukka Iterable
|
||||||
|
def iter_range(stop):
|
||||||
|
for from_ in range(stop):
|
||||||
|
yield from_
|
||||||
|
|
||||||
# Test base functionality
|
# Test base functionality
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',
|
self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',
|
||||||
msg='allow tuple path')
|
msg='allow tuple path')
|
||||||
|
@ -1602,13 +1607,13 @@ Line 1
|
||||||
# Test Ellipsis behavior
|
# Test Ellipsis behavior
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
|
self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
|
||||||
(item for item in _TEST_DATA.values() if item not in (None, {})),
|
(item for item in _TEST_DATA.values() if item not in (None, {})),
|
||||||
msg='`...` should give all non discarded values')
|
msg='`...` should give all non-discarded values')
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
|
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
|
||||||
msg='`...` selection for dicts should select all values')
|
msg='`...` selection for dicts should select all values')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
|
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
|
||||||
['https://www.example.com/0', 'https://www.example.com/1'],
|
['https://www.example.com/0', 'https://www.example.com/1'],
|
||||||
msg='nested `...` queries should work')
|
msg='nested `...` queries should work')
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4),
|
self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), iter_range(4),
|
||||||
msg='`...` query result should be flattened')
|
msg='`...` query result should be flattened')
|
||||||
self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)),
|
self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)),
|
||||||
msg='`...` should accept iterables')
|
msg='`...` should accept iterables')
|
||||||
|
@ -1618,7 +1623,7 @@ Line 1
|
||||||
[_TEST_DATA['urls']],
|
[_TEST_DATA['urls']],
|
||||||
msg='function as query key should perform a filter based on (key, value)')
|
msg='function as query key should perform a filter based on (key, value)')
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)),
|
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)),
|
||||||
msg='exceptions in the query function should be catched')
|
msg='exceptions in the query function should be caught')
|
||||||
self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
|
self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
|
||||||
msg='function key should accept iterables')
|
msg='function key should accept iterables')
|
||||||
if __debug__:
|
if __debug__:
|
||||||
|
@ -1706,7 +1711,7 @@ Line 1
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
|
||||||
msg='remove empty values when dict key')
|
msg='remove empty values when dict key')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis},
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis},
|
||||||
msg='use `default` when dict key and `default`')
|
msg='use `default` when dict key and a default')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
|
||||||
msg='remove empty values when nested dict key fails')
|
msg='remove empty values when nested dict key fails')
|
||||||
self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
|
self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
|
||||||
|
@ -1768,7 +1773,7 @@ Line 1
|
||||||
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
|
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
|
||||||
'str', msg='accept matching `expected_type` type')
|
'str', msg='accept matching `expected_type` type')
|
||||||
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
|
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
|
||||||
None, msg='reject non matching `expected_type` type')
|
None, msg='reject non-matching `expected_type` type')
|
||||||
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
|
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
|
||||||
'0', msg='transform type using type function')
|
'0', msg='transform type using type function')
|
||||||
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
|
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
|
||||||
|
@ -1780,7 +1785,7 @@ Line 1
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
|
||||||
{0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
|
{0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int),
|
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int),
|
||||||
1, msg='expected_type should not filter non final dict values')
|
1, msg='expected_type should not filter non-final dict values')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
|
||||||
{0: {0: 100}}, msg='expected_type should transform deep dict values')
|
{0: {0: 100}}, msg='expected_type should transform deep dict values')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)),
|
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)),
|
||||||
|
@ -1838,7 +1843,7 @@ Line 1
|
||||||
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
|
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
|
||||||
_traverse_string=True), 'sr',
|
_traverse_string=True), 'sr',
|
||||||
msg='`slice` should result in string if `traverse_string`')
|
msg='`slice` should result in string if `traverse_string`')
|
||||||
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
|
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == 's'),
|
||||||
_traverse_string=True), 'str',
|
_traverse_string=True), 'str',
|
||||||
msg='function should result in string if `traverse_string`')
|
msg='function should result in string if `traverse_string`')
|
||||||
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
|
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
|
||||||
|
|
|
@ -4268,13 +4268,8 @@ def variadic(x, allowed_types=NO_DEFAULT):
|
||||||
|
|
||||||
|
|
||||||
def dict_get(d, key_or_keys, default=None, skip_false_values=True):
|
def dict_get(d, key_or_keys, default=None, skip_false_values=True):
|
||||||
if isinstance(key_or_keys, (list, tuple)):
|
exp = (lambda x: x or None) if skip_false_values else IDENTITY
|
||||||
for key in key_or_keys:
|
return traverse_obj(d, *variadic(key_or_keys), expected_type=exp, default=default)
|
||||||
if key not in d or d[key] is None or skip_false_values and not d[key]:
|
|
||||||
continue
|
|
||||||
return d[key]
|
|
||||||
return default
|
|
||||||
return d.get(key_or_keys, default)
|
|
||||||
|
|
||||||
|
|
||||||
def try_call(*funcs, **kwargs):
|
def try_call(*funcs, **kwargs):
|
||||||
|
|
Loading…
Reference in a new issue