Initial commit

This commit is contained in:
2020-05-08 14:39:22 +01:00
commit 57828567af
1662 changed files with 248701 additions and 0 deletions

View File

@@ -0,0 +1,63 @@
# Copyright (c) 2014 Google, Inc.
# Copyright (c) 2015-2016, 2018-2019 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2018 Nick Drozd <nicholasdrozd@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import os
import sys
from astroid import builder
from astroid import MANAGER
from astroid.bases import BUILTINS
DATA_DIR = os.path.join("testdata", "python{}".format(sys.version_info[0]))
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), DATA_DIR, "data")
def find(name):
return os.path.normpath(os.path.join(os.path.dirname(__file__), DATA_DIR, name))
def build_file(path, modname=None):
return builder.AstroidBuilder().file_build(find(path), modname)
class SysPathSetup:
def setUp(self):
sys.path.insert(0, find(""))
def tearDown(self):
del sys.path[0]
datadir = find("")
for key in list(sys.path_importer_cache):
if key.startswith(datadir):
del sys.path_importer_cache[key]
class AstroidCacheSetupMixin:
"""Mixin for handling the astroid cache problems.
When clearing the astroid cache, some tests fails due to
cache inconsistencies, where some objects had a different
builtins object referenced.
This saves the builtins module and makes sure to add it
back to the astroid_cache after the tests finishes.
The builtins module is special, since some of the
transforms for a couple of its objects (str, bytes etc)
are executed only once, so astroid_bootstrapping will be
useless for retrieving the original builtins module.
"""
@classmethod
def setup_class(cls):
cls._builtins = MANAGER.astroid_cache.get(BUILTINS)
@classmethod
def teardown_class(cls):
if cls._builtins:
MANAGER.astroid_cache[BUILTINS] = cls._builtins

View File

@@ -0,0 +1,96 @@
from __future__ import unicode_literals
from pyhashxx import Hashxx
import unittest
class TestHashBytes(unittest.TestCase):
def test_empty_string(self):
h = Hashxx()
h.update(b'')
self.assertEqual(h.digest(), 46947589)
def test_one_string(self):
h = Hashxx()
h.update(b'hello')
self.assertEqual(h.digest(), 4211111929)
h = Hashxx()
h.update(b'goodbye')
self.assertEqual(h.digest(), 2269043192)
def test_multiple_strings(self):
h = Hashxx()
h.update(b'hello')
h.update(b'goodbye')
self.assertEqual(h.digest(), 4110974955)
def test_tuple(self):
# Tuples shouldn't affect the hash, they should be equivalent to hashing
# each part in a separate update
h = Hashxx()
h.update((b'hello',b'goodbye'))
self.assertEqual(h.digest(), 4110974955)
def test_seeds(self):
h = Hashxx(seed=0)
h.update(b'hello')
self.assertEqual(h.digest(), 4211111929)
h = Hashxx(seed=1)
h.update(b'hello')
self.assertEqual(h.digest(), 4244634537)
h = Hashxx(seed=2)
h.update(b'hello')
self.assertEqual(h.digest(), 4191738725)
def hash_value(self, val, seed=0):
h = Hashxx(seed=seed)
h.update(val)
return h.digest()
def test_incremental(self):
# Make sure incrementally computed results match those
# computed all at once
hello_hash = self.hash_value(b'hello')
hello_world_hash = self.hash_value(b'helloworld')
h = Hashxx()
h.update(b'hello')
self.assertEqual(h.digest(), hello_hash)
h.update(b'world')
self.assertEqual(h.digest(), hello_world_hash)
def test_simultaneous(self):
# Ensure that interleaved updates still give same results as
# independent
h1 = Hashxx()
h2 = Hashxx()
h1.update(b'he')
h2.update(b'goo')
h1.update(b'll')
h2.update(b'db')
h1.update(b'o')
h2.update(b'ye')
self.assertEqual(h1.digest(), self.hash_value(b'hello'))
self.assertEqual(h2.digest(), self.hash_value(b'goodbye'))
def test_bad_seed(self):
self.assertRaises(TypeError, Hashxx, seed="badseed")
def test_bad_arg(self):
h = Hashxx()
self.assertRaises(TypeError, h.update, [1,2,3])
def test_no_args(self):
h = Hashxx()
self.assertRaises(TypeError, h.update)
def test_no_unicode(self):
h = Hashxx()
self.assertRaises(TypeError, h.update, 'hello')

View File

@@ -0,0 +1,35 @@
from __future__ import unicode_literals
from pyhashxx import hashxx, Hashxx
import unittest
class TestOneShot(unittest.TestCase):
# The shorthand should be equivalent to this simple function:
def hash_value(self, val, seed=0):
h = Hashxx(seed=seed)
h.update(val)
return h.digest()
def test_empty_string(self):
self.assertEqual(hashxx(b''), self.hash_value(b''))
def test_string(self):
self.assertEqual(hashxx(b'hello'), self.hash_value(b'hello'))
def test_seeds(self):
self.assertNotEqual(hashxx(b'hello', seed=0), hashxx(b'hello', seed=1))
self.assertEqual(hashxx(b'hello', seed=0), self.hash_value(b'hello', seed=0))
self.assertEqual(hashxx(b'hello', seed=1), self.hash_value(b'hello', seed=1))
self.assertEqual(hashxx(b'hello', seed=2), self.hash_value(b'hello', seed=2))
def test_bad_arg(self):
self.assertRaises(TypeError, hashxx, [1, 2, 3])
def test_bad_seed(self):
self.assertRaises(TypeError, hashxx, seed="badseed")
def test_no_args(self):
self.assertRaises(TypeError, hashxx)
def test_no_unicode(self):
self.assertRaises(TypeError, hashxx, 'hello')

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,60 @@
# -*- encoding=utf-8 -*-
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2019 hippo91 <guillaume.peillex@gmail.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class BrainNumpyCoreFromNumericTest(unittest.TestCase):
"""
Test the numpy core fromnumeric brain module
"""
numpy_functions = (("sum", "[1, 2]"),)
def _inferred_numpy_func_call(self, func_name, *func_args):
node = builder.extract_node(
"""
import numpy as np
func = np.{:s}
func({:s})
""".format(
func_name, ",".join(func_args)
)
)
return node.infer()
def test_numpy_function_calls_inferred_as_ndarray(self):
"""
Test that calls to numpy functions are inferred as numpy.ndarray
"""
licit_array_types = (".ndarray",)
for func_ in self.numpy_functions:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred value for {:s}".format(func_[0]),
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,64 @@
# -*- encoding=utf-8 -*-
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2019 hippo91 <guillaume.peillex@gmail.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class BrainNumpyCoreFunctionBaseTest(unittest.TestCase):
"""
Test the numpy core numeric brain module
"""
numpy_functions = (
("linspace", "1, 100"),
("logspace", "1, 100"),
("geomspace", "1, 100"),
)
def _inferred_numpy_func_call(self, func_name, *func_args):
node = builder.extract_node(
"""
import numpy as np
func = np.{:s}
func({:s})
""".format(
func_name, ",".join(func_args)
)
)
return node.infer()
def test_numpy_function_calls_inferred_as_ndarray(self):
"""
Test that calls to numpy functions are inferred as numpy.ndarray
"""
licit_array_types = (".ndarray",)
for func_ in self.numpy_functions:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred value for {:s}".format(func_[0]),
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,181 @@
# -*- encoding=utf-8 -*-
# Copyright (c) 2019 hippo91 <guillaume.peillex@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class BrainNumpyCoreMultiarrayTest(unittest.TestCase):
"""
Test the numpy core multiarray brain module
"""
numpy_functions_returning_array = (
("array", "[1, 2]"),
("bincount", "[1, 2]"),
("busday_count", "('2011-01', '2011-02')"),
("busday_offset", "'2012-03', -1, roll='forward'"),
("concatenate", "([1, 2], [1, 2])"),
("datetime_as_string", "['2012-02', '2012-03']"),
("dot", "[1, 2]", "[1, 2]"),
("empty_like", "[1, 2]"),
("inner", "[1, 2]", "[1, 2]"),
("is_busday", "['2011-07-01', '2011-07-02', '2011-07-18']"),
("lexsort", "(('toto', 'tutu'), ('riri', 'fifi'))"),
("packbits", "np.array([1, 2])"),
("ravel_multi_index", "np.array([[1, 2], [2, 1]])", "(3, 4)"),
("unpackbits", "np.array([[1], [2], [3]], dtype=np.uint8)"),
("vdot", "[1, 2]", "[1, 2]"),
("where", "[True, False]", "[1, 2]", "[2, 1]"),
("empty", "[1, 2]"),
("zeros", "[1, 2]"),
)
numpy_functions_returning_bool = (
("can_cast", "np.int32, np.int64"),
("may_share_memory", "np.array([1, 2])", "np.array([3, 4])"),
("shares_memory", "np.array([1, 2])", "np.array([3, 4])"),
)
numpy_functions_returning_dtype = (
# ("min_scalar_type", "10"), # Not yet tested as it returns np.dtype
# ("result_type", "'i4'", "'c8'"), # Not yet tested as it returns np.dtype
)
numpy_functions_returning_none = (("copyto", "([1, 2], [1, 3])"),)
numpy_functions_returning_tuple = (
(
"unravel_index",
"[22, 33, 44]",
"(6, 7)",
), # Not yet tested as is returns a tuple
)
def _inferred_numpy_func_call(self, func_name, *func_args):
node = builder.extract_node(
"""
import numpy as np
func = np.{:s}
func({:s})
""".format(
func_name, ",".join(func_args)
)
)
return node.infer()
def test_numpy_function_calls_inferred_as_ndarray(self):
"""
Test that calls to numpy functions are inferred as numpy.ndarray
"""
for func_ in self.numpy_functions_returning_array:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == ".ndarray",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
def test_numpy_function_calls_inferred_as_bool(self):
"""
Test that calls to numpy functions are inferred as bool
"""
for func_ in self.numpy_functions_returning_bool:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == "builtins.bool",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
def test_numpy_function_calls_inferred_as_dtype(self):
"""
Test that calls to numpy functions are inferred as numpy.dtype
"""
for func_ in self.numpy_functions_returning_dtype:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == "numpy.dtype",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
def test_numpy_function_calls_inferred_as_none(self):
"""
Test that calls to numpy functions are inferred as None
"""
for func_ in self.numpy_functions_returning_none:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == "builtins.NoneType",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
def test_numpy_function_calls_inferred_as_tuple(self):
"""
Test that calls to numpy functions are inferred as tuple
"""
for func_ in self.numpy_functions_returning_tuple:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_[0]
),
)
self.assertTrue(
inferred_values[-1].pytype() == "builtins.tuple",
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,65 @@
# -*- encoding=utf-8 -*-
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2019 hippo91 <guillaume.peillex@gmail.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class BrainNumpyCoreNumericTest(unittest.TestCase):
"""
Test the numpy core numeric brain module
"""
numpy_functions = (
("zeros_like", "[1, 2]"),
("full_like", "[1, 2]", "4"),
("ones_like", "[1, 2]"),
("ones", "[1, 2]"),
)
def _inferred_numpy_func_call(self, func_name, *func_args):
node = builder.extract_node(
"""
import numpy as np
func = np.{:s}
func({:s})
""".format(
func_name, ",".join(func_args)
)
)
return node.infer()
def test_numpy_function_calls_inferred_as_ndarray(self):
"""
Test that calls to numpy functions are inferred as numpy.ndarray
"""
licit_array_types = (".ndarray",)
for func_ in self.numpy_functions:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(*func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred value for {:s}".format(func_[0]),
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg="Illicit type for {:s} ({})".format(
func_[0], inferred_values[-1].pytype()
),
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,348 @@
# -*- encoding=utf-8 -*-
# Copyright (c) 2017-2020 hippo91 <guillaume.peillex@gmail.com>
# Copyright (c) 2017-2018 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
from astroid import nodes
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class NumpyBrainCoreNumericTypesTest(unittest.TestCase):
"""
Test of all the missing types defined in numerictypes module.
"""
all_types = [
"uint16",
"uint32",
"uint64",
"float16",
"float32",
"float64",
"float96",
"complex64",
"complex128",
"complex192",
"timedelta64",
"datetime64",
"unicode_",
"str_",
"bool_",
"bool8",
"byte",
"int8",
"bytes0",
"bytes_",
"cdouble",
"cfloat",
"character",
"clongdouble",
"clongfloat",
"complexfloating",
"csingle",
"double",
"flexible",
"floating",
"half",
"inexact",
"int0",
"longcomplex",
"longdouble",
"longfloat",
"short",
"signedinteger",
"single",
"singlecomplex",
"str0",
"ubyte",
"uint",
"uint0",
"uintc",
"uintp",
"ulonglong",
"unsignedinteger",
"ushort",
"void0",
]
def _inferred_numpy_attribute(self, attrib):
node = builder.extract_node(
"""
import numpy.core.numerictypes as tested_module
missing_type = tested_module.{:s}""".format(
attrib
)
)
return next(node.value.infer())
def test_numpy_core_types(self):
"""
Test that all defined types have ClassDef type.
"""
for typ in self.all_types:
with self.subTest(typ=typ):
inferred = self._inferred_numpy_attribute(typ)
self.assertIsInstance(inferred, nodes.ClassDef)
def test_generic_types_have_methods(self):
"""
Test that all generic derived types have specified methods
"""
generic_methods = [
"all",
"any",
"argmax",
"argmin",
"argsort",
"astype",
"base",
"byteswap",
"choose",
"clip",
"compress",
"conj",
"conjugate",
"copy",
"cumprod",
"cumsum",
"data",
"diagonal",
"dtype",
"dump",
"dumps",
"fill",
"flags",
"flat",
"flatten",
"getfield",
"imag",
"item",
"itemset",
"itemsize",
"max",
"mean",
"min",
"nbytes",
"ndim",
"newbyteorder",
"nonzero",
"prod",
"ptp",
"put",
"ravel",
"real",
"repeat",
"reshape",
"resize",
"round",
"searchsorted",
"setfield",
"setflags",
"shape",
"size",
"sort",
"squeeze",
"std",
"strides",
"sum",
"swapaxes",
"take",
"tobytes",
"tofile",
"tolist",
"tostring",
"trace",
"transpose",
"var",
"view",
]
for type_ in (
"bool_",
"bytes_",
"character",
"complex128",
"complex192",
"complex64",
"complexfloating",
"datetime64",
"flexible",
"float16",
"float32",
"float64",
"float96",
"floating",
"generic",
"inexact",
"int16",
"int32",
"int32",
"int64",
"int8",
"integer",
"number",
"signedinteger",
"str_",
"timedelta64",
"uint16",
"uint32",
"uint32",
"uint64",
"uint8",
"unsignedinteger",
"void",
):
with self.subTest(typ=type_):
inferred = self._inferred_numpy_attribute(type_)
for meth in generic_methods:
with self.subTest(meth=meth):
self.assertTrue(meth in {m.name for m in inferred.methods()})
def test_generic_types_have_attributes(self):
"""
Test that all generic derived types have specified attributes
"""
generic_attr = [
"base",
"data",
"dtype",
"flags",
"flat",
"imag",
"itemsize",
"nbytes",
"ndim",
"real",
"size",
"strides",
]
for type_ in (
"bool_",
"bytes_",
"character",
"complex128",
"complex192",
"complex64",
"complexfloating",
"datetime64",
"flexible",
"float16",
"float32",
"float64",
"float96",
"floating",
"generic",
"inexact",
"int16",
"int32",
"int32",
"int64",
"int8",
"integer",
"number",
"signedinteger",
"str_",
"timedelta64",
"uint16",
"uint32",
"uint32",
"uint64",
"uint8",
"unsignedinteger",
"void",
):
with self.subTest(typ=type_):
inferred = self._inferred_numpy_attribute(type_)
for attr in generic_attr:
with self.subTest(attr=attr):
self.assertNotEqual(len(inferred.getattr(attr)), 0)
def test_number_types_have_unary_operators(self):
"""
Test that number types have unary operators
"""
unary_ops = ("__neg__",)
for type_ in (
"float64",
"float96",
"floating",
"int16",
"int32",
"int32",
"int64",
"int8",
"integer",
"number",
"signedinteger",
"uint16",
"uint32",
"uint32",
"uint64",
"uint8",
"unsignedinteger",
):
with self.subTest(typ=type_):
inferred = self._inferred_numpy_attribute(type_)
for attr in unary_ops:
with self.subTest(attr=attr):
self.assertNotEqual(len(inferred.getattr(attr)), 0)
def test_array_types_have_unary_operators(self):
"""
Test that array types have unary operators
"""
unary_ops = ("__neg__", "__invert__")
for type_ in ("ndarray",):
with self.subTest(typ=type_):
inferred = self._inferred_numpy_attribute(type_)
for attr in unary_ops:
with self.subTest(attr=attr):
self.assertNotEqual(len(inferred.getattr(attr)), 0)
def test_datetime_astype_return(self):
"""
Test that the return of astype method of the datetime object
is inferred as a ndarray.
PyCQA/pylint#3332
"""
node = builder.extract_node(
"""
import numpy as np
import datetime
test_array = np.datetime64(1, 'us')
test_array.astype(datetime.datetime)
"""
)
licit_array_types = ".ndarray"
inferred_values = list(node.infer())
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred value for {:s}".format("datetime64.astype"),
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg="Illicit type for {:s} ({})".format(
"datetime64.astype", inferred_values[-1].pytype()
),
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,279 @@
# -*- encoding=utf-8 -*-
# Copyright (c) 2019 hippo91 <guillaume.peillex@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
from astroid import nodes, bases
from astroid import util
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class NumpyBrainCoreUmathTest(unittest.TestCase):
"""
Test of all members of numpy.core.umath module
"""
one_arg_ufunc = (
"arccos",
"arccosh",
"arcsin",
"arcsinh",
"arctan",
"arctanh",
"cbrt",
"conj",
"conjugate",
"cosh",
"deg2rad",
"exp2",
"expm1",
"fabs",
"frexp",
"isfinite",
"isinf",
"log",
"log1p",
"log2",
"logical_not",
"modf",
"negative",
"positive",
"rad2deg",
"reciprocal",
"rint",
"sign",
"signbit",
"spacing",
"square",
"tan",
"tanh",
"trunc",
)
two_args_ufunc = (
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"copysign",
"divide",
"divmod",
"equal",
"float_power",
"floor_divide",
"fmax",
"fmin",
"fmod",
"gcd",
"greater",
"heaviside",
"hypot",
"lcm",
"ldexp",
"left_shift",
"less",
"logaddexp",
"logaddexp2",
"logical_and",
"logical_or",
"logical_xor",
"maximum",
"minimum",
"nextafter",
"not_equal",
"power",
"remainder",
"right_shift",
"subtract",
"true_divide",
)
all_ufunc = one_arg_ufunc + two_args_ufunc
constants = ("e", "euler_gamma")
def _inferred_numpy_attribute(self, func_name):
node = builder.extract_node(
"""
import numpy.core.umath as tested_module
func = tested_module.{:s}
func""".format(
func_name
)
)
return next(node.infer())
def test_numpy_core_umath_constants(self):
"""
Test that constants have Const type.
"""
for const in self.constants:
with self.subTest(const=const):
inferred = self._inferred_numpy_attribute(const)
self.assertIsInstance(inferred, nodes.Const)
def test_numpy_core_umath_constants_values(self):
"""
Test the values of the constants.
"""
exact_values = {"e": 2.718281828459045, "euler_gamma": 0.5772156649015329}
for const in self.constants:
with self.subTest(const=const):
inferred = self._inferred_numpy_attribute(const)
self.assertEqual(inferred.value, exact_values[const])
def test_numpy_core_umath_functions(self):
"""
Test that functions have FunctionDef type.
"""
for func in self.all_ufunc:
with self.subTest(func=func):
inferred = self._inferred_numpy_attribute(func)
self.assertIsInstance(inferred, bases.Instance)
def test_numpy_core_umath_functions_one_arg(self):
"""
Test the arguments names of functions.
"""
exact_arg_names = [
"self",
"x",
"out",
"where",
"casting",
"order",
"dtype",
"subok",
]
for func in self.one_arg_ufunc:
with self.subTest(func=func):
inferred = self._inferred_numpy_attribute(func)
self.assertEqual(
inferred.getattr("__call__")[0].argnames(), exact_arg_names
)
def test_numpy_core_umath_functions_two_args(self):
"""
Test the arguments names of functions.
"""
exact_arg_names = [
"self",
"x1",
"x2",
"out",
"where",
"casting",
"order",
"dtype",
"subok",
]
for func in self.two_args_ufunc:
with self.subTest(func=func):
inferred = self._inferred_numpy_attribute(func)
self.assertEqual(
inferred.getattr("__call__")[0].argnames(), exact_arg_names
)
def test_numpy_core_umath_functions_kwargs_default_values(self):
"""
Test the default values for keyword arguments.
"""
exact_kwargs_default_values = [None, True, "same_kind", "K", None, True]
for func in self.one_arg_ufunc + self.two_args_ufunc:
with self.subTest(func=func):
inferred = self._inferred_numpy_attribute(func)
default_args_values = [
default.value
for default in inferred.getattr("__call__")[0].args.defaults
]
self.assertEqual(default_args_values, exact_kwargs_default_values)
def _inferred_numpy_func_call(self, func_name, *func_args):
node = builder.extract_node(
"""
import numpy as np
func = np.{:s}
func()
""".format(
func_name
)
)
return node.infer()
def test_numpy_core_umath_functions_return_type(self):
"""
Test that functions which should return a ndarray do return it
"""
ndarray_returning_func = [
f for f in self.all_ufunc if f not in ("frexp", "modf")
]
for func_ in ndarray_returning_func:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(func_))
self.assertTrue(
len(inferred_values) == 1
or len(inferred_values) == 2
and inferred_values[-1].pytype() is util.Uninferable,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values[-1].pytype(), func_
),
)
self.assertTrue(
inferred_values[0].pytype() == ".ndarray",
msg="Illicit type for {:s} ({})".format(
func_, inferred_values[-1].pytype()
),
)
def test_numpy_core_umath_functions_return_type_tuple(self):
"""
Test that functions which should return a pair of ndarray do return it
"""
ndarray_returning_func = ("frexp", "modf")
for func_ in ndarray_returning_func:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_numpy_func_call(func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred values ({}) for {:s}".format(
inferred_values, func_
),
)
self.assertTrue(
inferred_values[-1].pytype() == "builtins.tuple",
msg="Illicit type for {:s} ({})".format(
func_, inferred_values[-1].pytype()
),
)
self.assertTrue(
len(inferred_values[0].elts) == 2,
msg="{} should return a pair of values. That's not the case.".format(
func_
),
)
for array in inferred_values[-1].elts:
effective_infer = [m.pytype() for m in array.inferred()]
self.assertTrue(
".ndarray" in effective_infer,
msg=(
"Each item in the return of {} "
"should be inferred as a ndarray and not as {}".format(
func_, effective_infer
)
),
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,173 @@
# -*- encoding=utf-8 -*-
# Copyright (c) 2017-2020 hippo91 <guillaume.peillex@gmail.com>
# Copyright (c) 2017-2018 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class NumpyBrainNdarrayTest(unittest.TestCase):
"""
Test that calls to numpy functions returning arrays are correctly inferred
"""
ndarray_returning_ndarray_methods = (
"__abs__",
"__add__",
"__and__",
"__array__",
"__array_wrap__",
"__copy__",
"__deepcopy__",
"__eq__",
"__floordiv__",
"__ge__",
"__gt__",
"__iadd__",
"__iand__",
"__ifloordiv__",
"__ilshift__",
"__imod__",
"__imul__",
"__invert__",
"__ior__",
"__ipow__",
"__irshift__",
"__isub__",
"__itruediv__",
"__ixor__",
"__le__",
"__lshift__",
"__lt__",
"__matmul__",
"__mod__",
"__mul__",
"__ne__",
"__neg__",
"__or__",
"__pos__",
"__pow__",
"__rshift__",
"__sub__",
"__truediv__",
"__xor__",
"all",
"any",
"argmax",
"argmin",
"argpartition",
"argsort",
"astype",
"byteswap",
"choose",
"clip",
"compress",
"conj",
"conjugate",
"copy",
"cumprod",
"cumsum",
"diagonal",
"dot",
"flatten",
"getfield",
"max",
"mean",
"min",
"newbyteorder",
"prod",
"ptp",
"ravel",
"repeat",
"reshape",
"round",
"searchsorted",
"squeeze",
"std",
"sum",
"swapaxes",
"take",
"trace",
"transpose",
"var",
"view",
)
def _inferred_ndarray_method_call(self, func_name):
node = builder.extract_node(
"""
import numpy as np
test_array = np.ndarray((2, 2))
test_array.{:s}()
""".format(
func_name
)
)
return node.infer()
def _inferred_ndarray_attribute(self, attr_name):
node = builder.extract_node(
"""
import numpy as np
test_array = np.ndarray((2, 2))
test_array.{:s}
""".format(
attr_name
)
)
return node.infer()
def test_numpy_function_calls_inferred_as_ndarray(self):
"""
Test that some calls to numpy functions are inferred as numpy.ndarray
"""
licit_array_types = ".ndarray"
for func_ in self.ndarray_returning_ndarray_methods:
with self.subTest(typ=func_):
inferred_values = list(self._inferred_ndarray_method_call(func_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred value for {:s}".format(func_),
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg="Illicit type for {:s} ({})".format(
func_, inferred_values[-1].pytype()
),
)
def test_numpy_ndarray_attribute_inferred_as_ndarray(self):
"""
Test that some numpy ndarray attributes are inferred as numpy.ndarray
"""
licit_array_types = ".ndarray"
for attr_ in ("real", "imag"):
with self.subTest(typ=attr_):
inferred_values = list(self._inferred_ndarray_attribute(attr_))
self.assertTrue(
len(inferred_values) == 1,
msg="Too much inferred value for {:s}".format(attr_),
)
self.assertTrue(
inferred_values[-1].pytype() in licit_array_types,
msg="Illicit type for {:s} ({})".format(
attr_, inferred_values[-1].pytype()
),
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,115 @@
# -*- encoding=utf-8 -*-
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2019 hippo91 <guillaume.peillex@gmail.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
try:
import numpy # pylint: disable=unused-import
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
from astroid import builder
from astroid import nodes
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
class NumpyBrainRandomMtrandTest(unittest.TestCase):
"""
Test of all the functions of numpy.random.mtrand module.
"""
#  Map between functions names and arguments names and default values
all_mtrand = {
"beta": (["a", "b", "size"], [None]),
"binomial": (["n", "p", "size"], [None]),
"bytes": (["length"], []),
"chisquare": (["df", "size"], [None]),
"choice": (["a", "size", "replace", "p"], [None, True, None]),
"dirichlet": (["alpha", "size"], [None]),
"exponential": (["scale", "size"], [1.0, None]),
"f": (["dfnum", "dfden", "size"], [None]),
"gamma": (["shape", "scale", "size"], [1.0, None]),
"geometric": (["p", "size"], [None]),
"get_state": ([], []),
"gumbel": (["loc", "scale", "size"], [0.0, 1.0, None]),
"hypergeometric": (["ngood", "nbad", "nsample", "size"], [None]),
"laplace": (["loc", "scale", "size"], [0.0, 1.0, None]),
"logistic": (["loc", "scale", "size"], [0.0, 1.0, None]),
"lognormal": (["mean", "sigma", "size"], [0.0, 1.0, None]),
"logseries": (["p", "size"], [None]),
"multinomial": (["n", "pvals", "size"], [None]),
"multivariate_normal": (["mean", "cov", "size"], [None]),
"negative_binomial": (["n", "p", "size"], [None]),
"noncentral_chisquare": (["df", "nonc", "size"], [None]),
"noncentral_f": (["dfnum", "dfden", "nonc", "size"], [None]),
"normal": (["loc", "scale", "size"], [0.0, 1.0, None]),
"pareto": (["a", "size"], [None]),
"permutation": (["x"], []),
"poisson": (["lam", "size"], [1.0, None]),
"power": (["a", "size"], [None]),
"rand": (["args"], []),
"randint": (["low", "high", "size", "dtype"], [None, None, "l"]),
"randn": (["args"], []),
"random_integers": (["low", "high", "size"], [None, None]),
"random_sample": (["size"], [None]),
"rayleigh": (["scale", "size"], [1.0, None]),
"seed": (["seed"], [None]),
"set_state": (["state"], []),
"shuffle": (["x"], []),
"standard_cauchy": (["size"], [None]),
"standard_exponential": (["size"], [None]),
"standard_gamma": (["shape", "size"], [None]),
"standard_normal": (["size"], [None]),
"standard_t": (["df", "size"], [None]),
"triangular": (["left", "mode", "right", "size"], [None]),
"uniform": (["low", "high", "size"], [0.0, 1.0, None]),
"vonmises": (["mu", "kappa", "size"], [None]),
"wald": (["mean", "scale", "size"], [None]),
"weibull": (["a", "size"], [None]),
"zipf": (["a", "size"], [None]),
}
def _inferred_numpy_attribute(self, func_name):
node = builder.extract_node(
"""
import numpy.random.mtrand as tested_module
func = tested_module.{:s}
func""".format(
func_name
)
)
return next(node.infer())
def test_numpy_random_mtrand_functions(self):
"""
Test that all functions have FunctionDef type.
"""
for func in self.all_mtrand:
with self.subTest(func=func):
inferred = self._inferred_numpy_attribute(func)
self.assertIsInstance(inferred, nodes.FunctionDef)
def test_numpy_random_mtrand_functions_signature(self):
"""
Test the arguments names and default values.
"""
for (
func,
(exact_arg_names, exact_kwargs_default_values),
) in self.all_mtrand.items():
with self.subTest(func=func):
inferred = self._inferred_numpy_attribute(func)
self.assertEqual(inferred.argnames(), exact_arg_names)
default_args_values = [
default.value for default in inferred.args.defaults
]
self.assertEqual(default_args_values, exact_kwargs_default_values)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,732 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2006-2014 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
# Copyright (c) 2014-2019 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2014-2015 Google, Inc.
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2015 Florian Bruhin <me@the-compiler.org>
# Copyright (c) 2016 Jakub Wilk <jwilk@jwilk.net>
# Copyright (c) 2017 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2017 Łukasz Rogalski <rogalski.91@gmail.com>
# Copyright (c) 2018 Ville Skyttä <ville.skytta@iki.fi>
# Copyright (c) 2018 brendanator <brendan.maginnis@gmail.com>
# Copyright (c) 2018 Anthony Sottile <asottile@umich.edu>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2019 Hugo van Kemenade <hugovk@users.noreply.github.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
"""tests for the astroid builder and rebuilder module"""
import builtins
import collections
import os
import socket
import sys
import unittest
import pytest
from astroid import builder
from astroid import exceptions
from astroid import manager
from astroid import nodes
from astroid import test_utils
from astroid import util
from . import resources
MANAGER = manager.AstroidManager()
BUILTINS = builtins.__name__
PY38 = sys.version_info[:2] >= (3, 8)
class FromToLineNoTest(unittest.TestCase):
def setUp(self):
self.astroid = resources.build_file("data/format.py")
def test_callfunc_lineno(self):
stmts = self.astroid.body
# on line 4:
# function('aeozrijz\
# earzer', hop)
discard = stmts[0]
self.assertIsInstance(discard, nodes.Expr)
self.assertEqual(discard.fromlineno, 4)
self.assertEqual(discard.tolineno, 5)
callfunc = discard.value
self.assertIsInstance(callfunc, nodes.Call)
self.assertEqual(callfunc.fromlineno, 4)
self.assertEqual(callfunc.tolineno, 5)
name = callfunc.func
self.assertIsInstance(name, nodes.Name)
self.assertEqual(name.fromlineno, 4)
self.assertEqual(name.tolineno, 4)
strarg = callfunc.args[0]
self.assertIsInstance(strarg, nodes.Const)
if hasattr(sys, "pypy_version_info"):
lineno = 4
else:
lineno = 5 if not PY38 else 4
self.assertEqual(strarg.fromlineno, lineno)
self.assertEqual(strarg.tolineno, lineno)
namearg = callfunc.args[1]
self.assertIsInstance(namearg, nodes.Name)
self.assertEqual(namearg.fromlineno, 5)
self.assertEqual(namearg.tolineno, 5)
# on line 10:
# fonction(1,
# 2,
# 3,
# 4)
discard = stmts[2]
self.assertIsInstance(discard, nodes.Expr)
self.assertEqual(discard.fromlineno, 10)
self.assertEqual(discard.tolineno, 13)
callfunc = discard.value
self.assertIsInstance(callfunc, nodes.Call)
self.assertEqual(callfunc.fromlineno, 10)
self.assertEqual(callfunc.tolineno, 13)
name = callfunc.func
self.assertIsInstance(name, nodes.Name)
self.assertEqual(name.fromlineno, 10)
self.assertEqual(name.tolineno, 10)
for i, arg in enumerate(callfunc.args):
self.assertIsInstance(arg, nodes.Const)
self.assertEqual(arg.fromlineno, 10 + i)
self.assertEqual(arg.tolineno, 10 + i)
def test_function_lineno(self):
stmts = self.astroid.body
# on line 15:
# def definition(a,
# b,
# c):
# return a + b + c
function = stmts[3]
self.assertIsInstance(function, nodes.FunctionDef)
self.assertEqual(function.fromlineno, 15)
self.assertEqual(function.tolineno, 18)
return_ = function.body[0]
self.assertIsInstance(return_, nodes.Return)
self.assertEqual(return_.fromlineno, 18)
self.assertEqual(return_.tolineno, 18)
self.skipTest(
"FIXME http://bugs.python.org/issue10445 "
"(no line number on function args)"
)
def test_decorated_function_lineno(self):
astroid = builder.parse(
"""
@decorator
def function(
arg):
print (arg)
""",
__name__,
)
function = astroid["function"]
# XXX discussable, but that's what is expected by pylint right now
self.assertEqual(function.fromlineno, 3)
self.assertEqual(function.tolineno, 5)
self.assertEqual(function.decorators.fromlineno, 2)
self.assertEqual(function.decorators.tolineno, 2)
def test_class_lineno(self):
stmts = self.astroid.body
# on line 20:
# class debile(dict,
# object):
# pass
class_ = stmts[4]
self.assertIsInstance(class_, nodes.ClassDef)
self.assertEqual(class_.fromlineno, 20)
self.assertEqual(class_.tolineno, 22)
self.assertEqual(class_.blockstart_tolineno, 21)
pass_ = class_.body[0]
self.assertIsInstance(pass_, nodes.Pass)
self.assertEqual(pass_.fromlineno, 22)
self.assertEqual(pass_.tolineno, 22)
def test_if_lineno(self):
stmts = self.astroid.body
# on line 20:
# if aaaa: pass
# else:
# aaaa,bbbb = 1,2
# aaaa,bbbb = bbbb,aaaa
if_ = stmts[5]
self.assertIsInstance(if_, nodes.If)
self.assertEqual(if_.fromlineno, 24)
self.assertEqual(if_.tolineno, 27)
self.assertEqual(if_.blockstart_tolineno, 24)
self.assertEqual(if_.orelse[0].fromlineno, 26)
self.assertEqual(if_.orelse[1].tolineno, 27)
def test_for_while_lineno(self):
for code in (
"""
for a in range(4):
print (a)
break
else:
print ("bouh")
""",
"""
while a:
print (a)
break
else:
print ("bouh")
""",
):
astroid = builder.parse(code, __name__)
stmt = astroid.body[0]
self.assertEqual(stmt.fromlineno, 2)
self.assertEqual(stmt.tolineno, 6)
self.assertEqual(stmt.blockstart_tolineno, 2)
self.assertEqual(stmt.orelse[0].fromlineno, 6) # XXX
self.assertEqual(stmt.orelse[0].tolineno, 6)
def test_try_except_lineno(self):
astroid = builder.parse(
"""
try:
print (a)
except:
pass
else:
print ("bouh")
""",
__name__,
)
try_ = astroid.body[0]
self.assertEqual(try_.fromlineno, 2)
self.assertEqual(try_.tolineno, 7)
self.assertEqual(try_.blockstart_tolineno, 2)
self.assertEqual(try_.orelse[0].fromlineno, 7) # XXX
self.assertEqual(try_.orelse[0].tolineno, 7)
hdlr = try_.handlers[0]
self.assertEqual(hdlr.fromlineno, 4)
self.assertEqual(hdlr.tolineno, 5)
self.assertEqual(hdlr.blockstart_tolineno, 4)
def test_try_finally_lineno(self):
astroid = builder.parse(
"""
try:
print (a)
finally:
print ("bouh")
""",
__name__,
)
try_ = astroid.body[0]
self.assertEqual(try_.fromlineno, 2)
self.assertEqual(try_.tolineno, 5)
self.assertEqual(try_.blockstart_tolineno, 2)
self.assertEqual(try_.finalbody[0].fromlineno, 5) # XXX
self.assertEqual(try_.finalbody[0].tolineno, 5)
def test_try_finally_25_lineno(self):
astroid = builder.parse(
"""
try:
print (a)
except:
pass
finally:
print ("bouh")
""",
__name__,
)
try_ = astroid.body[0]
self.assertEqual(try_.fromlineno, 2)
self.assertEqual(try_.tolineno, 7)
self.assertEqual(try_.blockstart_tolineno, 2)
self.assertEqual(try_.finalbody[0].fromlineno, 7) # XXX
self.assertEqual(try_.finalbody[0].tolineno, 7)
def test_with_lineno(self):
astroid = builder.parse(
"""
from __future__ import with_statement
with file("/tmp/pouet") as f:
print (f)
""",
__name__,
)
with_ = astroid.body[1]
self.assertEqual(with_.fromlineno, 3)
self.assertEqual(with_.tolineno, 4)
self.assertEqual(with_.blockstart_tolineno, 3)
class BuilderTest(unittest.TestCase):
def setUp(self):
self.builder = builder.AstroidBuilder()
def test_data_build_null_bytes(self):
with self.assertRaises(exceptions.AstroidSyntaxError):
self.builder.string_build("\x00")
def test_data_build_invalid_x_escape(self):
with self.assertRaises(exceptions.AstroidSyntaxError):
self.builder.string_build('"\\x1"')
def test_missing_newline(self):
"""check that a file with no trailing new line is parseable"""
resources.build_file("data/noendingnewline.py")
def test_missing_file(self):
with self.assertRaises(exceptions.AstroidBuildingError):
resources.build_file("data/inexistant.py")
def test_inspect_build0(self):
"""test astroid tree build from a living object"""
builtin_ast = MANAGER.ast_from_module_name(BUILTINS)
# just check type and object are there
builtin_ast.getattr("type")
objectastroid = builtin_ast.getattr("object")[0]
self.assertIsInstance(objectastroid.getattr("__new__")[0], nodes.FunctionDef)
# check open file alias
builtin_ast.getattr("open")
# check 'help' is there (defined dynamically by site.py)
builtin_ast.getattr("help")
# check property has __init__
pclass = builtin_ast["property"]
self.assertIn("__init__", pclass)
self.assertIsInstance(builtin_ast["None"], nodes.Const)
self.assertIsInstance(builtin_ast["True"], nodes.Const)
self.assertIsInstance(builtin_ast["False"], nodes.Const)
self.assertIsInstance(builtin_ast["Exception"], nodes.ClassDef)
self.assertIsInstance(builtin_ast["NotImplementedError"], nodes.ClassDef)
def test_inspect_build1(self):
time_ast = MANAGER.ast_from_module_name("time")
self.assertTrue(time_ast)
self.assertEqual(time_ast["time"].args.defaults, [])
def test_inspect_build3(self):
self.builder.inspect_build(unittest)
def test_inspect_build_type_object(self):
builtin_ast = MANAGER.ast_from_module_name(BUILTINS)
inferred = list(builtin_ast.igetattr("object"))
self.assertEqual(len(inferred), 1)
inferred = inferred[0]
self.assertEqual(inferred.name, "object")
inferred.as_string() # no crash test
inferred = list(builtin_ast.igetattr("type"))
self.assertEqual(len(inferred), 1)
inferred = inferred[0]
self.assertEqual(inferred.name, "type")
inferred.as_string() # no crash test
def test_inspect_transform_module(self):
# ensure no cached version of the time module
MANAGER._mod_file_cache.pop(("time", None), None)
MANAGER.astroid_cache.pop("time", None)
def transform_time(node):
if node.name == "time":
node.transformed = True
MANAGER.register_transform(nodes.Module, transform_time)
try:
time_ast = MANAGER.ast_from_module_name("time")
self.assertTrue(getattr(time_ast, "transformed", False))
finally:
MANAGER.unregister_transform(nodes.Module, transform_time)
def test_package_name(self):
"""test base properties and method of an astroid module"""
datap = resources.build_file("data/__init__.py", "data")
self.assertEqual(datap.name, "data")
self.assertEqual(datap.package, 1)
datap = resources.build_file("data/__init__.py", "data.__init__")
self.assertEqual(datap.name, "data")
self.assertEqual(datap.package, 1)
datap = resources.build_file("data/tmp__init__.py", "data.tmp__init__")
self.assertEqual(datap.name, "data.tmp__init__")
self.assertEqual(datap.package, 0)
def test_yield_parent(self):
"""check if we added discard nodes as yield parent (w/ compiler)"""
code = """
def yiell(): #@
yield 0
if noe:
yield more
"""
func = builder.extract_node(code)
self.assertIsInstance(func, nodes.FunctionDef)
stmt = func.body[0]
self.assertIsInstance(stmt, nodes.Expr)
self.assertIsInstance(stmt.value, nodes.Yield)
self.assertIsInstance(func.body[1].body[0], nodes.Expr)
self.assertIsInstance(func.body[1].body[0].value, nodes.Yield)
def test_object(self):
obj_ast = self.builder.inspect_build(object)
self.assertIn("__setattr__", obj_ast)
def test_newstyle_detection(self):
data = """
class A:
"old style"
class B(A):
"old style"
class C(object):
"new style"
class D(C):
"new style"
__metaclass__ = type
class E(A):
"old style"
class F:
"new style"
"""
mod_ast = builder.parse(data, __name__)
self.assertTrue(mod_ast["A"].newstyle)
self.assertTrue(mod_ast["B"].newstyle)
self.assertTrue(mod_ast["E"].newstyle)
self.assertTrue(mod_ast["C"].newstyle)
self.assertTrue(mod_ast["D"].newstyle)
self.assertTrue(mod_ast["F"].newstyle)
def test_globals(self):
data = """
CSTE = 1
def update_global():
global CSTE
CSTE += 1
def global_no_effect():
global CSTE2
print (CSTE)
"""
astroid = builder.parse(data, __name__)
self.assertEqual(len(astroid.getattr("CSTE")), 2)
self.assertIsInstance(astroid.getattr("CSTE")[0], nodes.AssignName)
self.assertEqual(astroid.getattr("CSTE")[0].fromlineno, 2)
self.assertEqual(astroid.getattr("CSTE")[1].fromlineno, 6)
with self.assertRaises(exceptions.AttributeInferenceError):
astroid.getattr("CSTE2")
with self.assertRaises(exceptions.InferenceError):
next(astroid["global_no_effect"].ilookup("CSTE2"))
def test_socket_build(self):
astroid = self.builder.module_build(socket)
# XXX just check the first one. Actually 3 objects are inferred (look at
# the socket module) but the last one as those attributes dynamically
# set and astroid is missing this.
for fclass in astroid.igetattr("socket"):
self.assertIn("connect", fclass)
self.assertIn("send", fclass)
self.assertIn("close", fclass)
break
def test_gen_expr_var_scope(self):
data = "l = list(n for n in range(10))\n"
astroid = builder.parse(data, __name__)
# n unavailable outside gen expr scope
self.assertNotIn("n", astroid)
# test n is inferable anyway
n = test_utils.get_name_node(astroid, "n")
self.assertIsNot(n.scope(), astroid)
self.assertEqual([i.__class__ for i in n.infer()], [util.Uninferable.__class__])
def test_no_future_imports(self):
mod = builder.parse("import sys")
self.assertEqual(set(), mod.future_imports)
def test_future_imports(self):
mod = builder.parse("from __future__ import print_function")
self.assertEqual({"print_function"}, mod.future_imports)
def test_two_future_imports(self):
mod = builder.parse(
"""
from __future__ import print_function
from __future__ import absolute_import
"""
)
self.assertEqual({"print_function", "absolute_import"}, mod.future_imports)
def test_inferred_build(self):
code = """
class A: pass
A.type = "class"
def A_assign_type(self):
print (self)
A.assign_type = A_assign_type
"""
astroid = builder.parse(code)
lclass = list(astroid.igetattr("A"))
self.assertEqual(len(lclass), 1)
lclass = lclass[0]
self.assertIn("assign_type", lclass.locals)
self.assertIn("type", lclass.locals)
def test_augassign_attr(self):
builder.parse(
"""
class Counter:
v = 0
def inc(self):
self.v += 1
""",
__name__,
)
# TODO: Check self.v += 1 generate AugAssign(AssAttr(...)),
# not AugAssign(GetAttr(AssName...))
def test_inferred_dont_pollute(self):
code = """
def func(a=None):
a.custom_attr = 0
def func2(a={}):
a.custom_attr = 0
"""
builder.parse(code)
nonetype = nodes.const_factory(None)
# pylint: disable=no-member; Infers two potential values
self.assertNotIn("custom_attr", nonetype.locals)
self.assertNotIn("custom_attr", nonetype.instance_attrs)
nonetype = nodes.const_factory({})
self.assertNotIn("custom_attr", nonetype.locals)
self.assertNotIn("custom_attr", nonetype.instance_attrs)
def test_asstuple(self):
code = "a, b = range(2)"
astroid = builder.parse(code)
self.assertIn("b", astroid.locals)
code = """
def visit_if(self, node):
node.test, body = node.tests[0]
"""
astroid = builder.parse(code)
self.assertIn("body", astroid["visit_if"].locals)
def test_build_constants(self):
"""test expected values of constants after rebuilding"""
code = """
def func():
return None
return
return 'None'
"""
astroid = builder.parse(code)
none, nothing, chain = [ret.value for ret in astroid.body[0].body]
self.assertIsInstance(none, nodes.Const)
self.assertIsNone(none.value)
self.assertIsNone(nothing)
self.assertIsInstance(chain, nodes.Const)
self.assertEqual(chain.value, "None")
def test_not_implemented(self):
node = builder.extract_node(
"""
NotImplemented #@
"""
)
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.Const)
self.assertEqual(inferred.value, NotImplemented)
class FileBuildTest(unittest.TestCase):
def setUp(self):
self.module = resources.build_file("data/module.py", "data.module")
def test_module_base_props(self):
"""test base properties and method of an astroid module"""
module = self.module
self.assertEqual(module.name, "data.module")
self.assertEqual(module.doc, "test module for astroid\n")
self.assertEqual(module.fromlineno, 0)
self.assertIsNone(module.parent)
self.assertEqual(module.frame(), module)
self.assertEqual(module.root(), module)
self.assertEqual(module.file, os.path.abspath(resources.find("data/module.py")))
self.assertEqual(module.pure_python, 1)
self.assertEqual(module.package, 0)
self.assertFalse(module.is_statement)
self.assertEqual(module.statement(), module)
self.assertEqual(module.statement(), module)
def test_module_locals(self):
"""test the 'locals' dictionary of an astroid module"""
module = self.module
_locals = module.locals
self.assertIs(_locals, module.globals)
keys = sorted(_locals.keys())
should = [
"MY_DICT",
"NameNode",
"YO",
"YOUPI",
"__revision__",
"global_access",
"modutils",
"four_args",
"os",
"redirect",
]
should.sort()
self.assertEqual(keys, sorted(should))
def test_function_base_props(self):
"""test base properties and method of an astroid function"""
module = self.module
function = module["global_access"]
self.assertEqual(function.name, "global_access")
self.assertEqual(function.doc, "function test")
self.assertEqual(function.fromlineno, 11)
self.assertTrue(function.parent)
self.assertEqual(function.frame(), function)
self.assertEqual(function.parent.frame(), module)
self.assertEqual(function.root(), module)
self.assertEqual([n.name for n in function.args.args], ["key", "val"])
self.assertEqual(function.type, "function")
def test_function_locals(self):
"""test the 'locals' dictionary of an astroid function"""
_locals = self.module["global_access"].locals
self.assertEqual(len(_locals), 4)
keys = sorted(_locals.keys())
self.assertEqual(keys, ["i", "key", "local", "val"])
def test_class_base_props(self):
"""test base properties and method of an astroid class"""
module = self.module
klass = module["YO"]
self.assertEqual(klass.name, "YO")
self.assertEqual(klass.doc, "hehe\n haha")
self.assertEqual(klass.fromlineno, 25)
self.assertTrue(klass.parent)
self.assertEqual(klass.frame(), klass)
self.assertEqual(klass.parent.frame(), module)
self.assertEqual(klass.root(), module)
self.assertEqual(klass.basenames, [])
self.assertTrue(klass.newstyle)
def test_class_locals(self):
"""test the 'locals' dictionary of an astroid class"""
module = self.module
klass1 = module["YO"]
locals1 = klass1.locals
keys = sorted(locals1.keys())
assert_keys = ["__init__", "__module__", "__qualname__", "a"]
self.assertEqual(keys, assert_keys)
klass2 = module["YOUPI"]
locals2 = klass2.locals
keys = locals2.keys()
assert_keys = [
"__init__",
"__module__",
"__qualname__",
"class_attr",
"class_method",
"method",
"static_method",
]
self.assertEqual(sorted(keys), assert_keys)
def test_class_instance_attrs(self):
module = self.module
klass1 = module["YO"]
klass2 = module["YOUPI"]
self.assertEqual(list(klass1.instance_attrs.keys()), ["yo"])
self.assertEqual(list(klass2.instance_attrs.keys()), ["member"])
def test_class_basenames(self):
module = self.module
klass1 = module["YO"]
klass2 = module["YOUPI"]
self.assertEqual(klass1.basenames, [])
self.assertEqual(klass2.basenames, ["YO"])
def test_method_base_props(self):
"""test base properties and method of an astroid method"""
klass2 = self.module["YOUPI"]
# "normal" method
method = klass2["method"]
self.assertEqual(method.name, "method")
self.assertEqual([n.name for n in method.args.args], ["self"])
self.assertEqual(method.doc, "method\n test")
self.assertEqual(method.fromlineno, 48)
self.assertEqual(method.type, "method")
# class method
method = klass2["class_method"]
self.assertEqual([n.name for n in method.args.args], ["cls"])
self.assertEqual(method.type, "classmethod")
# static method
method = klass2["static_method"]
self.assertEqual(method.args.args, [])
self.assertEqual(method.type, "staticmethod")
def test_method_locals(self):
"""test the 'locals' dictionary of an astroid method"""
method = self.module["YOUPI"]["method"]
_locals = method.locals
keys = sorted(_locals)
# ListComp variables are not accessible outside
self.assertEqual(len(_locals), 3)
self.assertEqual(keys, ["autre", "local", "self"])
def test_unknown_encoding(self):
with self.assertRaises(exceptions.AstroidSyntaxError):
resources.build_file("data/invalid_encoding.py")
def test_module_build_dunder_file():
"""Test that module_build() can work with modules that have the *__file__* attribute"""
module = builder.AstroidBuilder().module_build(collections)
assert module.path[0] == collections.__file__
@pytest.mark.skipif(
sys.version_info[:2] >= (3, 8),
reason=(
"The builtin ast module does not fail with a specific error "
"for syntax error caused by invalid type comments."
),
)
def test_parse_module_with_invalid_type_comments_does_not_crash():
node = builder.parse(
"""
# op {
# name: "AssignAddVariableOp"
# input_arg {
# name: "resource"
# type: DT_RESOURCE
# }
# input_arg {
# name: "value"
# type_attr: "dtype"
# }
# attr {
# name: "dtype"
# type: "type"
# }
# is_stateful: true
# }
a, b = 2
"""
)
assert isinstance(node, nodes.Module)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,263 @@
# Copyright (c) 2015-2016, 2018, 2020 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
import builtins
from astroid import builder
from astroid import exceptions
from astroid import helpers
from astroid import manager
from astroid import raw_building
from astroid import test_utils
from astroid import util
class TestHelpers(unittest.TestCase):
def setUp(self):
builtins_name = builtins.__name__
astroid_manager = manager.AstroidManager()
self.builtins = astroid_manager.astroid_cache[builtins_name]
self.manager = manager.AstroidManager()
def _extract(self, obj_name):
return self.builtins.getattr(obj_name)[0]
def _build_custom_builtin(self, obj_name):
proxy = raw_building.build_class(obj_name)
proxy.parent = self.builtins
return proxy
def assert_classes_equal(self, cls, other):
self.assertEqual(cls.name, other.name)
self.assertEqual(cls.parent, other.parent)
self.assertEqual(cls.qname(), other.qname())
def test_object_type(self):
pairs = [
("1", self._extract("int")),
("[]", self._extract("list")),
("{1, 2, 3}", self._extract("set")),
("{1:2, 4:3}", self._extract("dict")),
("type", self._extract("type")),
("object", self._extract("type")),
("object()", self._extract("object")),
("lambda: None", self._build_custom_builtin("function")),
("len", self._build_custom_builtin("builtin_function_or_method")),
("None", self._build_custom_builtin("NoneType")),
("import sys\nsys#@", self._build_custom_builtin("module")),
]
for code, expected in pairs:
node = builder.extract_node(code)
objtype = helpers.object_type(node)
self.assert_classes_equal(objtype, expected)
def test_object_type_classes_and_functions(self):
ast_nodes = builder.extract_node(
"""
def generator():
yield
class A(object):
def test(self):
self #@
@classmethod
def cls_method(cls): pass
@staticmethod
def static_method(): pass
A #@
A() #@
A.test #@
A().test #@
A.cls_method #@
A().cls_method #@
A.static_method #@
A().static_method #@
generator() #@
"""
)
from_self = helpers.object_type(ast_nodes[0])
cls = next(ast_nodes[1].infer())
self.assert_classes_equal(from_self, cls)
cls_type = helpers.object_type(ast_nodes[1])
self.assert_classes_equal(cls_type, self._extract("type"))
instance_type = helpers.object_type(ast_nodes[2])
cls = next(ast_nodes[2].infer())._proxied
self.assert_classes_equal(instance_type, cls)
expected_method_types = [
(ast_nodes[3], "function"),
(ast_nodes[4], "method"),
(ast_nodes[5], "method"),
(ast_nodes[6], "method"),
(ast_nodes[7], "function"),
(ast_nodes[8], "function"),
(ast_nodes[9], "generator"),
]
for node, expected in expected_method_types:
node_type = helpers.object_type(node)
expected_type = self._build_custom_builtin(expected)
self.assert_classes_equal(node_type, expected_type)
@test_utils.require_version(minver="3.0")
def test_object_type_metaclasses(self):
module = builder.parse(
"""
import abc
class Meta(metaclass=abc.ABCMeta):
pass
meta_instance = Meta()
"""
)
meta_type = helpers.object_type(module["Meta"])
self.assert_classes_equal(meta_type, module["Meta"].metaclass())
meta_instance = next(module["meta_instance"].infer())
instance_type = helpers.object_type(meta_instance)
self.assert_classes_equal(instance_type, module["Meta"])
@test_utils.require_version(minver="3.0")
def test_object_type_most_derived(self):
node = builder.extract_node(
"""
class A(type):
def __new__(*args, **kwargs):
return type.__new__(*args, **kwargs)
class B(object): pass
class C(object, metaclass=A): pass
# The most derived metaclass of D is A rather than type.
class D(B , C): #@
pass
"""
)
metaclass = node.metaclass()
self.assertEqual(metaclass.name, "A")
obj_type = helpers.object_type(node)
self.assertEqual(metaclass, obj_type)
def test_inference_errors(self):
node = builder.extract_node(
"""
from unknown import Unknown
u = Unknown #@
"""
)
self.assertEqual(helpers.object_type(node), util.Uninferable)
def test_object_type_too_many_types(self):
node = builder.extract_node(
"""
from unknown import Unknown
def test(x):
if x:
return lambda: None
else:
return 1
test(Unknown) #@
"""
)
self.assertEqual(helpers.object_type(node), util.Uninferable)
def test_is_subtype(self):
ast_nodes = builder.extract_node(
"""
class int_subclass(int):
pass
class A(object): pass #@
class B(A): pass #@
class C(A): pass #@
int_subclass() #@
"""
)
cls_a = ast_nodes[0]
cls_b = ast_nodes[1]
cls_c = ast_nodes[2]
int_subclass = ast_nodes[3]
int_subclass = helpers.object_type(next(int_subclass.infer()))
base_int = self._extract("int")
self.assertTrue(helpers.is_subtype(int_subclass, base_int))
self.assertTrue(helpers.is_supertype(base_int, int_subclass))
self.assertTrue(helpers.is_supertype(cls_a, cls_b))
self.assertTrue(helpers.is_supertype(cls_a, cls_c))
self.assertTrue(helpers.is_subtype(cls_b, cls_a))
self.assertTrue(helpers.is_subtype(cls_c, cls_a))
self.assertFalse(helpers.is_subtype(cls_a, cls_b))
self.assertFalse(helpers.is_subtype(cls_a, cls_b))
def test_is_subtype_supertype_mro_error(self):
cls_e, cls_f = builder.extract_node(
"""
class A(object): pass
class B(A): pass
class C(A): pass
class D(B, C): pass
class E(C, B): pass #@
class F(D, E): pass #@
"""
)
self.assertFalse(helpers.is_subtype(cls_e, cls_f))
self.assertFalse(helpers.is_subtype(cls_e, cls_f))
with self.assertRaises(exceptions._NonDeducibleTypeHierarchy):
helpers.is_subtype(cls_f, cls_e)
self.assertFalse(helpers.is_supertype(cls_f, cls_e))
def test_is_subtype_supertype_unknown_bases(self):
cls_a, cls_b = builder.extract_node(
"""
from unknown import Unknown
class A(Unknown): pass #@
class B(A): pass #@
"""
)
with self.assertRaises(exceptions._NonDeducibleTypeHierarchy):
helpers.is_subtype(cls_a, cls_b)
with self.assertRaises(exceptions._NonDeducibleTypeHierarchy):
helpers.is_supertype(cls_a, cls_b)
def test_is_subtype_supertype_unrelated_classes(self):
cls_a, cls_b = builder.extract_node(
"""
class A(object): pass #@
class B(object): pass #@
"""
)
self.assertFalse(helpers.is_subtype(cls_a, cls_b))
self.assertFalse(helpers.is_subtype(cls_b, cls_a))
self.assertFalse(helpers.is_supertype(cls_a, cls_b))
self.assertFalse(helpers.is_supertype(cls_b, cls_a))
def test_is_subtype_supertype_classes_no_type_ancestor(self):
cls_a = builder.extract_node(
"""
class A(object): #@
pass
"""
)
builtin_type = self._extract("type")
self.assertFalse(helpers.is_supertype(builtin_type, cls_a))
self.assertFalse(helpers.is_subtype(cls_a, builtin_type))
def test_is_subtype_supertype_classes_metaclasses(self):
cls_a = builder.extract_node(
"""
class A(type): #@
pass
"""
)
builtin_type = self._extract("type")
self.assertTrue(helpers.is_supertype(builtin_type, cls_a))
self.assertTrue(helpers.is_subtype(cls_a, builtin_type))
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,357 @@
# Copyright (c) 2007-2013 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
# Copyright (c) 2010 Daniel Harding <dharding@gmail.com>
# Copyright (c) 2014-2016, 2018-2019 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2014 Google, Inc.
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2019 Hugo van Kemenade <hugovk@users.noreply.github.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
"""tests for the astroid variable lookup capabilities
"""
import functools
import unittest
from astroid import builder
from astroid import exceptions
from astroid import nodes
from astroid import scoped_nodes
from . import resources
class LookupTest(resources.SysPathSetup, unittest.TestCase):
def setUp(self):
super(LookupTest, self).setUp()
self.module = resources.build_file("data/module.py", "data.module")
self.module2 = resources.build_file("data/module2.py", "data.module2")
self.nonregr = resources.build_file("data/nonregr.py", "data.nonregr")
def test_limit(self):
code = """
l = [a
for a,b in list]
a = 1
b = a
a = None
def func():
c = 1
"""
astroid = builder.parse(code, __name__)
# a & b
a = next(astroid.nodes_of_class(nodes.Name))
self.assertEqual(a.lineno, 2)
self.assertEqual(len(astroid.lookup("b")[1]), 1)
self.assertEqual(len(astroid.lookup("a")[1]), 1)
b = astroid.locals["b"][0]
stmts = a.lookup("a")[1]
self.assertEqual(len(stmts), 1)
self.assertEqual(b.lineno, 6)
b_infer = b.infer()
b_value = next(b_infer)
self.assertEqual(b_value.value, 1)
# c
self.assertRaises(StopIteration, functools.partial(next, b_infer))
func = astroid.locals["func"][0]
self.assertEqual(len(func.lookup("c")[1]), 1)
def test_module(self):
astroid = builder.parse("pass", __name__)
# built-in objects
none = next(astroid.ilookup("None"))
self.assertIsNone(none.value)
obj = next(astroid.ilookup("object"))
self.assertIsInstance(obj, nodes.ClassDef)
self.assertEqual(obj.name, "object")
self.assertRaises(
exceptions.InferenceError, functools.partial(next, astroid.ilookup("YOAA"))
)
# XXX
self.assertEqual(len(list(self.nonregr.ilookup("enumerate"))), 2)
def test_class_ancestor_name(self):
code = """
class A:
pass
class A(A):
pass
"""
astroid = builder.parse(code, __name__)
cls1 = astroid.locals["A"][0]
cls2 = astroid.locals["A"][1]
name = next(cls2.nodes_of_class(nodes.Name))
self.assertEqual(next(name.infer()), cls1)
### backport those test to inline code
def test_method(self):
method = self.module["YOUPI"]["method"]
my_dict = next(method.ilookup("MY_DICT"))
self.assertTrue(isinstance(my_dict, nodes.Dict), my_dict)
none = next(method.ilookup("None"))
self.assertIsNone(none.value)
self.assertRaises(
exceptions.InferenceError, functools.partial(next, method.ilookup("YOAA"))
)
def test_function_argument_with_default(self):
make_class = self.module2["make_class"]
base = next(make_class.ilookup("base"))
self.assertTrue(isinstance(base, nodes.ClassDef), base.__class__)
self.assertEqual(base.name, "YO")
self.assertEqual(base.root().name, "data.module")
def test_class(self):
klass = self.module["YOUPI"]
my_dict = next(klass.ilookup("MY_DICT"))
self.assertIsInstance(my_dict, nodes.Dict)
none = next(klass.ilookup("None"))
self.assertIsNone(none.value)
obj = next(klass.ilookup("object"))
self.assertIsInstance(obj, nodes.ClassDef)
self.assertEqual(obj.name, "object")
self.assertRaises(
exceptions.InferenceError, functools.partial(next, klass.ilookup("YOAA"))
)
def test_inner_classes(self):
ddd = list(self.nonregr["Ccc"].ilookup("Ddd"))
self.assertEqual(ddd[0].name, "Ddd")
def test_loopvar_hiding(self):
astroid = builder.parse(
"""
x = 10
for x in range(5):
print (x)
if x > 0:
print ('#' * x)
""",
__name__,
)
xnames = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"]
# inside the loop, only one possible assignment
self.assertEqual(len(xnames[0].lookup("x")[1]), 1)
# outside the loop, two possible assignments
self.assertEqual(len(xnames[1].lookup("x")[1]), 2)
self.assertEqual(len(xnames[2].lookup("x")[1]), 2)
def test_list_comps(self):
astroid = builder.parse(
"""
print ([ i for i in range(10) ])
print ([ i for i in range(10) ])
print ( list( i for i in range(10) ) )
""",
__name__,
)
xnames = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "i"]
self.assertEqual(len(xnames[0].lookup("i")[1]), 1)
self.assertEqual(xnames[0].lookup("i")[1][0].lineno, 2)
self.assertEqual(len(xnames[1].lookup("i")[1]), 1)
self.assertEqual(xnames[1].lookup("i")[1][0].lineno, 3)
self.assertEqual(len(xnames[2].lookup("i")[1]), 1)
self.assertEqual(xnames[2].lookup("i")[1][0].lineno, 4)
def test_list_comp_target(self):
"""test the list comprehension target"""
astroid = builder.parse(
"""
ten = [ var for var in range(10) ]
var
"""
)
var = astroid.body[1].value
self.assertRaises(exceptions.NameInferenceError, var.inferred)
def test_dict_comps(self):
astroid = builder.parse(
"""
print ({ i: j for i in range(10) for j in range(10) })
print ({ i: j for i in range(10) for j in range(10) })
""",
__name__,
)
xnames = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "i"]
self.assertEqual(len(xnames[0].lookup("i")[1]), 1)
self.assertEqual(xnames[0].lookup("i")[1][0].lineno, 2)
self.assertEqual(len(xnames[1].lookup("i")[1]), 1)
self.assertEqual(xnames[1].lookup("i")[1][0].lineno, 3)
xnames = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "j"]
self.assertEqual(len(xnames[0].lookup("i")[1]), 1)
self.assertEqual(xnames[0].lookup("i")[1][0].lineno, 2)
self.assertEqual(len(xnames[1].lookup("i")[1]), 1)
self.assertEqual(xnames[1].lookup("i")[1][0].lineno, 3)
def test_set_comps(self):
astroid = builder.parse(
"""
print ({ i for i in range(10) })
print ({ i for i in range(10) })
""",
__name__,
)
xnames = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "i"]
self.assertEqual(len(xnames[0].lookup("i")[1]), 1)
self.assertEqual(xnames[0].lookup("i")[1][0].lineno, 2)
self.assertEqual(len(xnames[1].lookup("i")[1]), 1)
self.assertEqual(xnames[1].lookup("i")[1][0].lineno, 3)
def test_set_comp_closure(self):
astroid = builder.parse(
"""
ten = { var for var in range(10) }
var
"""
)
var = astroid.body[1].value
self.assertRaises(exceptions.NameInferenceError, var.inferred)
def test_generator_attributes(self):
tree = builder.parse(
"""
def count():
"test"
yield 0
iterer = count()
num = iterer.next()
"""
)
next_node = tree.body[2].value.func
gener = next_node.expr.inferred()[0]
self.assertIsInstance(gener.getattr("__next__")[0], nodes.FunctionDef)
self.assertIsInstance(gener.getattr("send")[0], nodes.FunctionDef)
self.assertIsInstance(gener.getattr("throw")[0], nodes.FunctionDef)
self.assertIsInstance(gener.getattr("close")[0], nodes.FunctionDef)
def test_explicit___name__(self):
code = """
class Pouet:
__name__ = "pouet"
p1 = Pouet()
class PouetPouet(Pouet): pass
p2 = Pouet()
class NoName: pass
p3 = NoName()
"""
astroid = builder.parse(code, __name__)
p1 = next(astroid["p1"].infer())
self.assertTrue(p1.getattr("__name__"))
p2 = next(astroid["p2"].infer())
self.assertTrue(p2.getattr("__name__"))
self.assertTrue(astroid["NoName"].getattr("__name__"))
p3 = next(astroid["p3"].infer())
self.assertRaises(exceptions.AttributeInferenceError, p3.getattr, "__name__")
def test_function_module_special(self):
astroid = builder.parse(
'''
def initialize(linter):
"""initialize linter with checkers in this package """
package_load(linter, __path__[0])
''',
"data.__init__",
)
path = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "__path__"][
0
]
self.assertEqual(len(path.lookup("__path__")[1]), 1)
def test_builtin_lookup(self):
self.assertEqual(scoped_nodes.builtin_lookup("__dict__")[1], ())
intstmts = scoped_nodes.builtin_lookup("int")[1]
self.assertEqual(len(intstmts), 1)
self.assertIsInstance(intstmts[0], nodes.ClassDef)
self.assertEqual(intstmts[0].name, "int")
# pylint: disable=no-member; Infers two potential values
self.assertIs(intstmts[0], nodes.const_factory(1)._proxied)
def test_decorator_arguments_lookup(self):
code = """
def decorator(value):
def wrapper(function):
return function
return wrapper
class foo:
member = 10 #@
@decorator(member) #This will cause pylint to complain
def test(self):
pass
"""
member = builder.extract_node(code, __name__).targets[0]
it = member.infer()
obj = next(it)
self.assertIsInstance(obj, nodes.Const)
self.assertEqual(obj.value, 10)
self.assertRaises(StopIteration, functools.partial(next, it))
def test_inner_decorator_member_lookup(self):
code = """
class FileA:
def decorator(bla):
return bla
@__(decorator)
def funcA():
return 4
"""
decname = builder.extract_node(code, __name__)
it = decname.infer()
obj = next(it)
self.assertIsInstance(obj, nodes.FunctionDef)
self.assertRaises(StopIteration, functools.partial(next, it))
def test_static_method_lookup(self):
code = """
class FileA:
@staticmethod
def funcA():
return 4
class Test:
FileA = [1,2,3]
def __init__(self):
print (FileA.funcA())
"""
astroid = builder.parse(code, __name__)
it = astroid["Test"]["__init__"].ilookup("FileA")
obj = next(it)
self.assertIsInstance(obj, nodes.ClassDef)
self.assertRaises(StopIteration, functools.partial(next, it))
def test_global_delete(self):
code = """
def run2():
f = Frobble()
class Frobble:
pass
Frobble.mumble = True
del Frobble
def run1():
f = Frobble()
"""
astroid = builder.parse(code, __name__)
stmts = astroid["run2"].lookup("Frobbel")[1]
self.assertEqual(len(stmts), 0)
stmts = astroid["run1"].lookup("Frobbel")[1]
self.assertEqual(len(stmts), 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,312 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2006, 2009-2014 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
# Copyright (c) 2013 AndroWiiid <androwiiid@gmail.com>
# Copyright (c) 2014-2019 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2014 Google, Inc.
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2017 Chris Philip <chrisp533@gmail.com>
# Copyright (c) 2017 Hugo <hugovk@users.noreply.github.com>
# Copyright (c) 2017 ioanatia <ioanatia@users.noreply.github.com>
# Copyright (c) 2018 Ville Skyttä <ville.skytta@iki.fi>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2019 Hugo van Kemenade <hugovk@users.noreply.github.com>
# Copyright (c) 2020 Anubhav <35621759+anubh-v@users.noreply.github.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import os
import platform
import site
import sys
import unittest
import pkg_resources
import six
import time
import astroid
from astroid import exceptions
from astroid import manager
from . import resources
BUILTINS = six.moves.builtins.__name__
def _get_file_from_object(obj):
if platform.python_implementation() == "Jython":
return obj.__file__.split("$py.class")[0] + ".py"
return obj.__file__
class AstroidManagerTest(
resources.SysPathSetup, resources.AstroidCacheSetupMixin, unittest.TestCase
):
def setUp(self):
super(AstroidManagerTest, self).setUp()
self.manager = manager.AstroidManager()
def test_ast_from_file(self):
filepath = unittest.__file__
ast = self.manager.ast_from_file(filepath)
self.assertEqual(ast.name, "unittest")
self.assertIn("unittest", self.manager.astroid_cache)
def test_ast_from_file_cache(self):
filepath = unittest.__file__
self.manager.ast_from_file(filepath)
ast = self.manager.ast_from_file("unhandledName", "unittest")
self.assertEqual(ast.name, "unittest")
self.assertIn("unittest", self.manager.astroid_cache)
def test_ast_from_file_astro_builder(self):
filepath = unittest.__file__
ast = self.manager.ast_from_file(filepath, None, True, True)
self.assertEqual(ast.name, "unittest")
self.assertIn("unittest", self.manager.astroid_cache)
def test_ast_from_file_name_astro_builder_exception(self):
self.assertRaises(
exceptions.AstroidBuildingError, self.manager.ast_from_file, "unhandledName"
)
def test_ast_from_string(self):
filepath = unittest.__file__
dirname = os.path.dirname(filepath)
modname = os.path.basename(dirname)
with open(filepath, "r") as file:
data = file.read()
ast = self.manager.ast_from_string(data, modname, filepath)
self.assertEqual(ast.name, "unittest")
self.assertEqual(ast.file, filepath)
self.assertIn("unittest", self.manager.astroid_cache)
def test_do_not_expose_main(self):
obj = self.manager.ast_from_module_name("__main__")
self.assertEqual(obj.name, "__main__")
self.assertEqual(obj.items(), [])
def test_ast_from_module_name(self):
ast = self.manager.ast_from_module_name("unittest")
self.assertEqual(ast.name, "unittest")
self.assertIn("unittest", self.manager.astroid_cache)
def test_ast_from_module_name_not_python_source(self):
ast = self.manager.ast_from_module_name("time")
self.assertEqual(ast.name, "time")
self.assertIn("time", self.manager.astroid_cache)
self.assertEqual(ast.pure_python, False)
def test_ast_from_module_name_astro_builder_exception(self):
self.assertRaises(
exceptions.AstroidBuildingError,
self.manager.ast_from_module_name,
"unhandledModule",
)
def _test_ast_from_old_namespace_package_protocol(self, root):
origpath = sys.path[:]
paths = [
resources.find("data/path_{}_{}".format(root, index))
for index in range(1, 4)
]
sys.path.extend(paths)
try:
for name in ("foo", "bar", "baz"):
module = self.manager.ast_from_module_name("package." + name)
self.assertIsInstance(module, astroid.Module)
finally:
sys.path = origpath
def test_ast_from_namespace_pkgutil(self):
self._test_ast_from_old_namespace_package_protocol("pkgutil")
def test_ast_from_namespace_pkg_resources(self):
self._test_ast_from_old_namespace_package_protocol("pkg_resources")
def test_implicit_namespace_package(self):
data_dir = os.path.dirname(resources.find("data/namespace_pep_420"))
contribute = os.path.join(data_dir, "contribute_to_namespace")
for value in (data_dir, contribute):
sys.path.insert(0, value)
try:
module = self.manager.ast_from_module_name("namespace_pep_420.module")
self.assertIsInstance(module, astroid.Module)
self.assertEqual(module.name, "namespace_pep_420.module")
var = next(module.igetattr("var"))
self.assertIsInstance(var, astroid.Const)
self.assertEqual(var.value, 42)
finally:
for _ in range(2):
sys.path.pop(0)
def test_namespace_package_pth_support(self):
pth = "foogle_fax-0.12.5-py2.7-nspkg.pth"
site.addpackage(resources.RESOURCE_PATH, pth, [])
pkg_resources._namespace_packages["foogle"] = []
try:
module = self.manager.ast_from_module_name("foogle.fax")
submodule = next(module.igetattr("a"))
value = next(submodule.igetattr("x"))
self.assertIsInstance(value, astroid.Const)
with self.assertRaises(exceptions.AstroidImportError):
self.manager.ast_from_module_name("foogle.moogle")
finally:
del pkg_resources._namespace_packages["foogle"]
sys.modules.pop("foogle")
def test_nested_namespace_import(self):
pth = "foogle_fax-0.12.5-py2.7-nspkg.pth"
site.addpackage(resources.RESOURCE_PATH, pth, [])
pkg_resources._namespace_packages["foogle"] = ["foogle.crank"]
pkg_resources._namespace_packages["foogle.crank"] = []
try:
self.manager.ast_from_module_name("foogle.crank")
finally:
del pkg_resources._namespace_packages["foogle"]
sys.modules.pop("foogle")
def test_namespace_and_file_mismatch(self):
filepath = unittest.__file__
ast = self.manager.ast_from_file(filepath)
self.assertEqual(ast.name, "unittest")
pth = "foogle_fax-0.12.5-py2.7-nspkg.pth"
site.addpackage(resources.RESOURCE_PATH, pth, [])
pkg_resources._namespace_packages["foogle"] = []
try:
with self.assertRaises(exceptions.AstroidImportError):
self.manager.ast_from_module_name("unittest.foogle.fax")
finally:
del pkg_resources._namespace_packages["foogle"]
sys.modules.pop("foogle")
def _test_ast_from_zip(self, archive):
origpath = sys.path[:]
sys.modules.pop("mypypa", None)
archive_path = resources.find(archive)
sys.path.insert(0, archive_path)
try:
module = self.manager.ast_from_module_name("mypypa")
self.assertEqual(module.name, "mypypa")
end = os.path.join(archive, "mypypa")
self.assertTrue(
module.file.endswith(end), "%s doesn't endswith %s" % (module.file, end)
)
finally:
# remove the module, else after importing egg, we don't get the zip
if "mypypa" in self.manager.astroid_cache:
del self.manager.astroid_cache["mypypa"]
del self.manager._mod_file_cache[("mypypa", None)]
if archive_path in sys.path_importer_cache:
del sys.path_importer_cache[archive_path]
sys.path = origpath
def test_ast_from_module_name_egg(self):
self._test_ast_from_zip(
os.path.sep.join(["data", os.path.normcase("MyPyPa-0.1.0-py2.5.egg")])
)
def test_ast_from_module_name_zip(self):
self._test_ast_from_zip(
os.path.sep.join(["data", os.path.normcase("MyPyPa-0.1.0-py2.5.zip")])
)
def test_zip_import_data(self):
"""check if zip_import_data works"""
filepath = resources.find("data/MyPyPa-0.1.0-py2.5.zip/mypypa")
ast = self.manager.zip_import_data(filepath)
self.assertEqual(ast.name, "mypypa")
def test_zip_import_data_without_zipimport(self):
"""check if zip_import_data return None without zipimport"""
self.assertEqual(self.manager.zip_import_data("path"), None)
def test_file_from_module(self):
"""check if the unittest filepath is equals to the result of the method"""
self.assertEqual(
_get_file_from_object(unittest),
self.manager.file_from_module_name("unittest", None).location,
)
def test_file_from_module_name_astro_building_exception(self):
"""check if the method raises an exception with a wrong module name"""
self.assertRaises(
exceptions.AstroidBuildingError,
self.manager.file_from_module_name,
"unhandledModule",
None,
)
def test_ast_from_module(self):
ast = self.manager.ast_from_module(unittest)
self.assertEqual(ast.pure_python, True)
ast = self.manager.ast_from_module(time)
self.assertEqual(ast.pure_python, False)
def test_ast_from_module_cache(self):
"""check if the module is in the cache manager"""
ast = self.manager.ast_from_module(unittest)
self.assertEqual(ast.name, "unittest")
self.assertIn("unittest", self.manager.astroid_cache)
def test_ast_from_class(self):
ast = self.manager.ast_from_class(int)
self.assertEqual(ast.name, "int")
self.assertEqual(ast.parent.frame().name, BUILTINS)
ast = self.manager.ast_from_class(object)
self.assertEqual(ast.name, "object")
self.assertEqual(ast.parent.frame().name, BUILTINS)
self.assertIn("__setattr__", ast)
def test_ast_from_class_with_module(self):
"""check if the method works with the module name"""
ast = self.manager.ast_from_class(int, int.__module__)
self.assertEqual(ast.name, "int")
self.assertEqual(ast.parent.frame().name, BUILTINS)
ast = self.manager.ast_from_class(object, object.__module__)
self.assertEqual(ast.name, "object")
self.assertEqual(ast.parent.frame().name, BUILTINS)
self.assertIn("__setattr__", ast)
def test_ast_from_class_attr_error(self):
"""give a wrong class at the ast_from_class method"""
self.assertRaises(
exceptions.AstroidBuildingError, self.manager.ast_from_class, None
)
def testFailedImportHooks(self):
def hook(modname):
if modname == "foo.bar":
return unittest
raise exceptions.AstroidBuildingError()
with self.assertRaises(exceptions.AstroidBuildingError):
self.manager.ast_from_module_name("foo.bar")
self.manager.register_failed_import_hook(hook)
self.assertEqual(unittest, self.manager.ast_from_module_name("foo.bar"))
with self.assertRaises(exceptions.AstroidBuildingError):
self.manager.ast_from_module_name("foo.bar.baz")
del self.manager._failed_import_hooks[0]
class BorgAstroidManagerTC(unittest.TestCase):
def test_borg(self):
"""test that the AstroidManager is really a borg, i.e. that two different
instances has same cache"""
first_manager = manager.AstroidManager()
built = first_manager.ast_from_module_name(BUILTINS)
second_manager = manager.AstroidManager()
second_built = second_manager.ast_from_module_name(BUILTINS)
self.assertIs(built, second_built)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,331 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2014-2016, 2018-2020 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2014 Google, Inc.
# Copyright (c) 2014 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
# Copyright (c) 2015 Florian Bruhin <me@the-compiler.org>
# Copyright (c) 2015 Radosław Ganczarek <radoslaw@ganczarek.in>
# Copyright (c) 2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2018 Mario Corchero <mcorcherojim@bloomberg.net>
# Copyright (c) 2018 Mario Corchero <mariocj89@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2019 Hugo van Kemenade <hugovk@users.noreply.github.com>
# Copyright (c) 2019 markmcclain <markmcclain@users.noreply.github.com>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
"""
unit tests for module modutils (module manipulation utilities)
"""
import distutils.version
import email
import os
import sys
import unittest
import xml
from xml import etree
from xml.etree import ElementTree
import tempfile
import shutil
import astroid
from astroid.interpreter._import import spec
from astroid import modutils
from . import resources
def _get_file_from_object(obj):
return modutils._path_from_filename(obj.__file__)
class ModuleFileTest(unittest.TestCase):
package = "mypypa"
def tearDown(self):
for k in list(sys.path_importer_cache):
if "MyPyPa" in k:
del sys.path_importer_cache[k]
def test_find_zipped_module(self):
found_spec = spec.find_spec(
[self.package], [resources.find("data/MyPyPa-0.1.0-py2.5.zip")]
)
self.assertEqual(found_spec.type, spec.ModuleType.PY_ZIPMODULE)
self.assertEqual(
found_spec.location.split(os.sep)[-3:],
["data", "MyPyPa-0.1.0-py2.5.zip", self.package],
)
def test_find_egg_module(self):
found_spec = spec.find_spec(
[self.package], [resources.find("data/MyPyPa-0.1.0-py2.5.egg")]
)
self.assertEqual(found_spec.type, spec.ModuleType.PY_ZIPMODULE)
self.assertEqual(
found_spec.location.split(os.sep)[-3:],
["data", "MyPyPa-0.1.0-py2.5.egg", self.package],
)
def test_find_distutils_submodules_in_virtualenv(self):
found_spec = spec.find_spec(["distutils", "version"])
self.assertEqual(found_spec.location, distutils.version.__file__)
class LoadModuleFromNameTest(unittest.TestCase):
""" load a python module from it's name """
def test_knownValues_load_module_from_name_1(self):
self.assertEqual(modutils.load_module_from_name("sys"), sys)
def test_knownValues_load_module_from_name_2(self):
self.assertEqual(modutils.load_module_from_name("os.path"), os.path)
def test_raise_load_module_from_name_1(self):
self.assertRaises(
ImportError, modutils.load_module_from_name, "os.path", use_sys=0
)
class GetModulePartTest(unittest.TestCase):
"""given a dotted name return the module part of the name"""
def test_knownValues_get_module_part_1(self):
self.assertEqual(
modutils.get_module_part("astroid.modutils"), "astroid.modutils"
)
def test_knownValues_get_module_part_2(self):
self.assertEqual(
modutils.get_module_part("astroid.modutils.get_module_part"),
"astroid.modutils",
)
def test_knownValues_get_module_part_3(self):
"""relative import from given file"""
self.assertEqual(
modutils.get_module_part("node_classes.AssName", modutils.__file__),
"node_classes",
)
def test_knownValues_get_compiled_module_part(self):
self.assertEqual(modutils.get_module_part("math.log10"), "math")
self.assertEqual(modutils.get_module_part("math.log10", __file__), "math")
def test_knownValues_get_builtin_module_part(self):
self.assertEqual(modutils.get_module_part("sys.path"), "sys")
self.assertEqual(modutils.get_module_part("sys.path", "__file__"), "sys")
def test_get_module_part_exception(self):
self.assertRaises(
ImportError, modutils.get_module_part, "unknown.module", modutils.__file__
)
class ModPathFromFileTest(unittest.TestCase):
""" given an absolute file path return the python module's path as a list """
def test_knownValues_modpath_from_file_1(self):
self.assertEqual(
modutils.modpath_from_file(ElementTree.__file__),
["xml", "etree", "ElementTree"],
)
def test_raise_modpath_from_file_Exception(self):
self.assertRaises(Exception, modutils.modpath_from_file, "/turlututu")
def test_import_symlink_with_source_outside_of_path(self):
with tempfile.NamedTemporaryFile() as tmpfile:
linked_file_name = "symlinked_file.py"
try:
os.symlink(tmpfile.name, linked_file_name)
self.assertEqual(
modutils.modpath_from_file(linked_file_name), ["symlinked_file"]
)
finally:
os.remove(linked_file_name)
def test_import_symlink_both_outside_of_path(self):
with tempfile.NamedTemporaryFile() as tmpfile:
linked_file_name = os.path.join(tempfile.gettempdir(), "symlinked_file.py")
try:
os.symlink(tmpfile.name, linked_file_name)
self.assertRaises(
ImportError, modutils.modpath_from_file, linked_file_name
)
finally:
os.remove(linked_file_name)
def test_load_from_module_symlink_on_symlinked_paths_in_syspath(self):
# constants
tmp = tempfile.gettempdir()
deployment_path = os.path.join(tmp, "deployment")
path_to_include = os.path.join(tmp, "path_to_include")
real_secret_path = os.path.join(tmp, "secret.py")
symlink_secret_path = os.path.join(path_to_include, "secret.py")
# setup double symlink
# /tmp/deployment
# /tmp/path_to_include (symlink to /tmp/deployment)
# /tmp/secret.py
# /tmp/deployment/secret.py (points to /tmp/secret.py)
try:
os.mkdir(deployment_path)
self.addCleanup(shutil.rmtree, deployment_path)
os.symlink(deployment_path, path_to_include)
self.addCleanup(os.remove, path_to_include)
except OSError:
pass
with open(real_secret_path, "w"):
pass
os.symlink(real_secret_path, symlink_secret_path)
self.addCleanup(os.remove, real_secret_path)
# add the symlinked path to sys.path
sys.path.append(path_to_include)
self.addCleanup(sys.path.pop)
# this should be equivalent to: import secret
self.assertEqual(modutils.modpath_from_file(symlink_secret_path), ["secret"])
class LoadModuleFromPathTest(resources.SysPathSetup, unittest.TestCase):
def test_do_not_load_twice(self):
modutils.load_module_from_modpath(["data", "lmfp", "foo"])
modutils.load_module_from_modpath(["data", "lmfp"])
# pylint: disable=no-member; just-once is added by a test file dynamically.
self.assertEqual(len(sys.just_once), 1)
del sys.just_once
class FileFromModPathTest(resources.SysPathSetup, unittest.TestCase):
"""given a mod path (i.e. splited module / package name), return the
corresponding file, giving priority to source file over precompiled file
if it exists"""
def test_site_packages(self):
filename = _get_file_from_object(modutils)
result = modutils.file_from_modpath(["astroid", "modutils"])
self.assertEqual(os.path.realpath(result), os.path.realpath(filename))
def test_std_lib(self):
path = modutils.file_from_modpath(["os", "path"]).replace(".pyc", ".py")
self.assertEqual(
os.path.realpath(path),
os.path.realpath(os.path.__file__.replace(".pyc", ".py")),
)
def test_builtin(self):
self.assertIsNone(modutils.file_from_modpath(["sys"]))
def test_unexisting(self):
self.assertRaises(ImportError, modutils.file_from_modpath, ["turlututu"])
def test_unicode_in_package_init(self):
# file_from_modpath should not crash when reading an __init__
# file with unicode characters.
modutils.file_from_modpath(["data", "unicode_package", "core"])
class GetSourceFileTest(unittest.TestCase):
def test(self):
filename = _get_file_from_object(os.path)
self.assertEqual(
modutils.get_source_file(os.path.__file__), os.path.normpath(filename)
)
def test_raise(self):
self.assertRaises(modutils.NoSourceFile, modutils.get_source_file, "whatever")
class StandardLibModuleTest(resources.SysPathSetup, unittest.TestCase):
"""
return true if the module may be considered as a module from the standard
library
"""
def test_datetime(self):
# This is an interesting example, since datetime, on pypy,
# is under lib_pypy, rather than the usual Lib directory.
self.assertTrue(modutils.is_standard_module("datetime"))
def test_builtins(self):
self.assertFalse(modutils.is_standard_module("__builtin__"))
self.assertTrue(modutils.is_standard_module("builtins"))
def test_builtin(self):
self.assertTrue(modutils.is_standard_module("sys"))
self.assertTrue(modutils.is_standard_module("marshal"))
def test_nonstandard(self):
self.assertFalse(modutils.is_standard_module("astroid"))
def test_unknown(self):
self.assertFalse(modutils.is_standard_module("unknown"))
def test_4(self):
self.assertTrue(modutils.is_standard_module("hashlib"))
self.assertTrue(modutils.is_standard_module("pickle"))
self.assertTrue(modutils.is_standard_module("email"))
self.assertTrue(modutils.is_standard_module("io"))
self.assertFalse(modutils.is_standard_module("StringIO"))
self.assertTrue(modutils.is_standard_module("unicodedata"))
def test_custom_path(self):
datadir = resources.find("")
if any(datadir.startswith(p) for p in modutils.EXT_LIB_DIRS):
self.skipTest("known breakage of is_standard_module on installed package")
self.assertTrue(modutils.is_standard_module("data.module", (datadir,)))
self.assertTrue(
modutils.is_standard_module("data.module", (os.path.abspath(datadir),))
)
def test_failing_edge_cases(self):
# using a subpackage/submodule path as std_path argument
self.assertFalse(modutils.is_standard_module("xml.etree", etree.__path__))
# using a module + object name as modname argument
self.assertTrue(modutils.is_standard_module("sys.path"))
# this is because only the first package/module is considered
self.assertTrue(modutils.is_standard_module("sys.whatever"))
self.assertFalse(modutils.is_standard_module("xml.whatever", etree.__path__))
class IsRelativeTest(unittest.TestCase):
def test_knownValues_is_relative_1(self):
self.assertTrue(modutils.is_relative("utils", email.__path__[0]))
def test_knownValues_is_relative_3(self):
self.assertFalse(modutils.is_relative("astroid", astroid.__path__[0]))
class GetModuleFilesTest(unittest.TestCase):
def test_get_module_files_1(self):
package = resources.find("data/find_test")
modules = set(modutils.get_module_files(package, []))
expected = [
"__init__.py",
"module.py",
"module2.py",
"noendingnewline.py",
"nonregr.py",
]
self.assertEqual(modules, {os.path.join(package, x) for x in expected})
def test_get_all_files(self):
"""test that list_all returns all Python files from given location
"""
non_package = resources.find("data/notamodule")
modules = modutils.get_module_files(non_package, [], list_all=True)
self.assertEqual(modules, [os.path.join(non_package, "file.py")])
def test_load_module_set_attribute(self):
del xml.etree.ElementTree
del sys.modules["xml.etree.ElementTree"]
m = modutils.load_module_from_modpath(["xml", "etree", "ElementTree"])
self.assertTrue(hasattr(xml, "etree"))
self.assertTrue(hasattr(xml.etree, "ElementTree"))
self.assertTrue(m is xml.etree.ElementTree)
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,681 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2016-2020 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2016 Derek Gustafson <degustaf@gmail.com>
# Copyright (c) 2017 Łukasz Rogalski <rogalski.91@gmail.com>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import builtins
import unittest
import xml
import pytest
import astroid
from astroid import builder, util
from astroid import exceptions
from astroid import MANAGER
from astroid import test_utils
from astroid import objects
BUILTINS = MANAGER.astroid_cache[builtins.__name__]
class InstanceModelTest(unittest.TestCase):
def test_instance_special_model(self):
ast_nodes = builder.extract_node(
"""
class A:
"test"
def __init__(self):
self.a = 42
a = A()
a.__class__ #@
a.__module__ #@
a.__doc__ #@
a.__dict__ #@
""",
module_name="fake_module",
)
cls = next(ast_nodes[0].infer())
self.assertIsInstance(cls, astroid.ClassDef)
self.assertEqual(cls.name, "A")
module = next(ast_nodes[1].infer())
self.assertIsInstance(module, astroid.Const)
self.assertEqual(module.value, "fake_module")
doc = next(ast_nodes[2].infer())
self.assertIsInstance(doc, astroid.Const)
self.assertEqual(doc.value, "test")
dunder_dict = next(ast_nodes[3].infer())
self.assertIsInstance(dunder_dict, astroid.Dict)
attr = next(dunder_dict.getitem(astroid.Const("a")).infer())
self.assertIsInstance(attr, astroid.Const)
self.assertEqual(attr.value, 42)
@pytest.mark.xfail(reason="Instance lookup cannot override object model")
def test_instance_local_attributes_overrides_object_model(self):
# The instance lookup needs to be changed in order for this to work.
ast_node = builder.extract_node(
"""
class A:
@property
def __dict__(self):
return []
A().__dict__
"""
)
inferred = next(ast_node.infer())
self.assertIsInstance(inferred, astroid.List)
self.assertEqual(inferred.elts, [])
class BoundMethodModelTest(unittest.TestCase):
def test_bound_method_model(self):
ast_nodes = builder.extract_node(
"""
class A:
def test(self): pass
a = A()
a.test.__func__ #@
a.test.__self__ #@
"""
)
func = next(ast_nodes[0].infer())
self.assertIsInstance(func, astroid.FunctionDef)
self.assertEqual(func.name, "test")
self_ = next(ast_nodes[1].infer())
self.assertIsInstance(self_, astroid.Instance)
self.assertEqual(self_.name, "A")
class UnboundMethodModelTest(unittest.TestCase):
def test_unbound_method_model(self):
ast_nodes = builder.extract_node(
"""
class A:
def test(self): pass
t = A.test
t.__class__ #@
t.__func__ #@
t.__self__ #@
t.im_class #@
t.im_func #@
t.im_self #@
"""
)
cls = next(ast_nodes[0].infer())
self.assertIsInstance(cls, astroid.ClassDef)
unbound_name = "function"
self.assertEqual(cls.name, unbound_name)
func = next(ast_nodes[1].infer())
self.assertIsInstance(func, astroid.FunctionDef)
self.assertEqual(func.name, "test")
self_ = next(ast_nodes[2].infer())
self.assertIsInstance(self_, astroid.Const)
self.assertIsNone(self_.value)
self.assertEqual(cls.name, next(ast_nodes[3].infer()).name)
self.assertEqual(func, next(ast_nodes[4].infer()))
self.assertIsNone(next(ast_nodes[5].infer()).value)
class ClassModelTest(unittest.TestCase):
def test_priority_to_local_defined_values(self):
ast_node = builder.extract_node(
"""
class A:
__doc__ = "first"
A.__doc__ #@
"""
)
inferred = next(ast_node.infer())
self.assertIsInstance(inferred, astroid.Const)
self.assertEqual(inferred.value, "first")
def test_class_model_correct_mro_subclasses_proxied(self):
ast_nodes = builder.extract_node(
"""
class A(object):
pass
A.mro #@
A.__subclasses__ #@
"""
)
for node in ast_nodes:
inferred = next(node.infer())
self.assertIsInstance(inferred, astroid.BoundMethod)
self.assertIsInstance(inferred._proxied, astroid.FunctionDef)
self.assertIsInstance(inferred.bound, astroid.ClassDef)
self.assertEqual(inferred.bound.name, "type")
def test_class_model(self):
ast_nodes = builder.extract_node(
"""
class A(object):
"test"
class B(A): pass
class C(A): pass
A.__module__ #@
A.__name__ #@
A.__qualname__ #@
A.__doc__ #@
A.__mro__ #@
A.mro() #@
A.__bases__ #@
A.__class__ #@
A.__dict__ #@
A.__subclasses__() #@
""",
module_name="fake_module",
)
module = next(ast_nodes[0].infer())
self.assertIsInstance(module, astroid.Const)
self.assertEqual(module.value, "fake_module")
name = next(ast_nodes[1].infer())
self.assertIsInstance(name, astroid.Const)
self.assertEqual(name.value, "A")
qualname = next(ast_nodes[2].infer())
self.assertIsInstance(qualname, astroid.Const)
self.assertEqual(qualname.value, "fake_module.A")
doc = next(ast_nodes[3].infer())
self.assertIsInstance(doc, astroid.Const)
self.assertEqual(doc.value, "test")
mro = next(ast_nodes[4].infer())
self.assertIsInstance(mro, astroid.Tuple)
self.assertEqual([cls.name for cls in mro.elts], ["A", "object"])
called_mro = next(ast_nodes[5].infer())
self.assertEqual(called_mro.elts, mro.elts)
bases = next(ast_nodes[6].infer())
self.assertIsInstance(bases, astroid.Tuple)
self.assertEqual([cls.name for cls in bases.elts], ["object"])
cls = next(ast_nodes[7].infer())
self.assertIsInstance(cls, astroid.ClassDef)
self.assertEqual(cls.name, "type")
cls_dict = next(ast_nodes[8].infer())
self.assertIsInstance(cls_dict, astroid.Dict)
subclasses = next(ast_nodes[9].infer())
self.assertIsInstance(subclasses, astroid.List)
self.assertEqual([cls.name for cls in subclasses.elts], ["B", "C"])
class ModuleModelTest(unittest.TestCase):
def test_priority_to_local_defined_values(self):
ast_node = astroid.parse(
"""
__file__ = "mine"
"""
)
file_value = next(ast_node.igetattr("__file__"))
self.assertIsInstance(file_value, astroid.Const)
self.assertEqual(file_value.value, "mine")
def test__path__not_a_package(self):
ast_node = builder.extract_node(
"""
import sys
sys.__path__ #@
"""
)
with self.assertRaises(exceptions.InferenceError):
next(ast_node.infer())
def test_module_model(self):
ast_nodes = builder.extract_node(
"""
import xml
xml.__path__ #@
xml.__name__ #@
xml.__doc__ #@
xml.__file__ #@
xml.__spec__ #@
xml.__loader__ #@
xml.__cached__ #@
xml.__package__ #@
xml.__dict__ #@
"""
)
path = next(ast_nodes[0].infer())
self.assertIsInstance(path, astroid.List)
self.assertIsInstance(path.elts[0], astroid.Const)
self.assertEqual(path.elts[0].value, xml.__path__[0])
name = next(ast_nodes[1].infer())
self.assertIsInstance(name, astroid.Const)
self.assertEqual(name.value, "xml")
doc = next(ast_nodes[2].infer())
self.assertIsInstance(doc, astroid.Const)
self.assertEqual(doc.value, xml.__doc__)
file_ = next(ast_nodes[3].infer())
self.assertIsInstance(file_, astroid.Const)
self.assertEqual(file_.value, xml.__file__.replace(".pyc", ".py"))
for ast_node in ast_nodes[4:7]:
inferred = next(ast_node.infer())
self.assertIs(inferred, astroid.Uninferable)
package = next(ast_nodes[7].infer())
self.assertIsInstance(package, astroid.Const)
self.assertEqual(package.value, "xml")
dict_ = next(ast_nodes[8].infer())
self.assertIsInstance(dict_, astroid.Dict)
class FunctionModelTest(unittest.TestCase):
def test_partial_descriptor_support(self):
bound, result = builder.extract_node(
"""
class A(object): pass
def test(self): return 42
f = test.__get__(A(), A)
f #@
f() #@
"""
)
bound = next(bound.infer())
self.assertIsInstance(bound, astroid.BoundMethod)
self.assertEqual(bound._proxied._proxied.name, "test")
result = next(result.infer())
self.assertIsInstance(result, astroid.Const)
self.assertEqual(result.value, 42)
def test___get__has_extra_params_defined(self):
node = builder.extract_node(
"""
def test(self): return 42
test.__get__
"""
)
inferred = next(node.infer())
self.assertIsInstance(inferred, astroid.BoundMethod)
args = inferred.args.args
self.assertEqual(len(args), 2)
self.assertEqual([arg.name for arg in args], ["self", "type"])
@test_utils.require_version(minver="3.8")
def test__get__and_positional_only_args(self):
node = builder.extract_node(
"""
def test(self, a, b, /, c): return a + b + c
test.__get__(test)(1, 2, 3)
"""
)
inferred = next(node.infer())
assert inferred is util.Uninferable
@pytest.mark.xfail(reason="Descriptors cannot infer what self is")
def test_descriptor_not_inferrring_self(self):
# We can't infer __get__(X, Y)() when the bounded function
# uses self, because of the tree's parent not being propagating good enough.
result = builder.extract_node(
"""
class A(object):
x = 42
def test(self): return self.x
f = test.__get__(A(), A)
f() #@
"""
)
result = next(result.infer())
self.assertIsInstance(result, astroid.Const)
self.assertEqual(result.value, 42)
def test_descriptors_binding_invalid(self):
ast_nodes = builder.extract_node(
"""
class A: pass
def test(self): return 42
test.__get__()() #@
test.__get__(2, 3, 4) #@
"""
)
for node in ast_nodes:
with self.assertRaises(exceptions.InferenceError):
next(node.infer())
@pytest.mark.xfail(reason="Relying on path copy")
def test_descriptor_error_regression(self):
"""Make sure the following code does
node cause an exception"""
node = builder.extract_node(
"""
class MyClass:
text = "MyText"
def mymethod1(self):
return self.text
def mymethod2(self):
return self.mymethod1.__get__(self, MyClass)
cl = MyClass().mymethod2()()
cl #@
"""
)
[const] = node.inferred()
assert const.value == "MyText"
def test_function_model(self):
ast_nodes = builder.extract_node(
'''
def func(a=1, b=2):
"""test"""
func.__name__ #@
func.__doc__ #@
func.__qualname__ #@
func.__module__ #@
func.__defaults__ #@
func.__dict__ #@
func.__globals__ #@
func.__code__ #@
func.__closure__ #@
''',
module_name="fake_module",
)
name = next(ast_nodes[0].infer())
self.assertIsInstance(name, astroid.Const)
self.assertEqual(name.value, "func")
doc = next(ast_nodes[1].infer())
self.assertIsInstance(doc, astroid.Const)
self.assertEqual(doc.value, "test")
qualname = next(ast_nodes[2].infer())
self.assertIsInstance(qualname, astroid.Const)
self.assertEqual(qualname.value, "fake_module.func")
module = next(ast_nodes[3].infer())
self.assertIsInstance(module, astroid.Const)
self.assertEqual(module.value, "fake_module")
defaults = next(ast_nodes[4].infer())
self.assertIsInstance(defaults, astroid.Tuple)
self.assertEqual([default.value for default in defaults.elts], [1, 2])
dict_ = next(ast_nodes[5].infer())
self.assertIsInstance(dict_, astroid.Dict)
globals_ = next(ast_nodes[6].infer())
self.assertIsInstance(globals_, astroid.Dict)
for ast_node in ast_nodes[7:9]:
self.assertIs(next(ast_node.infer()), astroid.Uninferable)
@test_utils.require_version(minver="3.0")
def test_empty_return_annotation(self):
ast_node = builder.extract_node(
"""
def test(): pass
test.__annotations__
"""
)
annotations = next(ast_node.infer())
self.assertIsInstance(annotations, astroid.Dict)
self.assertEqual(len(annotations.items), 0)
@test_utils.require_version(minver="3.0")
def test_builtin_dunder_init_does_not_crash_when_accessing_annotations(self):
ast_node = builder.extract_node(
"""
class Class:
@classmethod
def class_method(cls):
cls.__init__.__annotations__ #@
"""
)
inferred = next(ast_node.infer())
self.assertIsInstance(inferred, astroid.Dict)
self.assertEqual(len(inferred.items), 0)
@test_utils.require_version(minver="3.0")
def test_annotations_kwdefaults(self):
ast_node = builder.extract_node(
"""
def test(a: 1, *args: 2, f:4='lala', **kwarg:3)->2: pass
test.__annotations__ #@
test.__kwdefaults__ #@
"""
)
annotations = next(ast_node[0].infer())
self.assertIsInstance(annotations, astroid.Dict)
self.assertIsInstance(
annotations.getitem(astroid.Const("return")), astroid.Const
)
self.assertEqual(annotations.getitem(astroid.Const("return")).value, 2)
self.assertIsInstance(annotations.getitem(astroid.Const("a")), astroid.Const)
self.assertEqual(annotations.getitem(astroid.Const("a")).value, 1)
self.assertEqual(annotations.getitem(astroid.Const("args")).value, 2)
self.assertEqual(annotations.getitem(astroid.Const("kwarg")).value, 3)
self.assertEqual(annotations.getitem(astroid.Const("f")).value, 4)
kwdefaults = next(ast_node[1].infer())
self.assertIsInstance(kwdefaults, astroid.Dict)
# self.assertEqual(kwdefaults.getitem('f').value, 'lala')
@test_utils.require_version(minver="3.8")
def test_annotation_positional_only(self):
ast_node = builder.extract_node(
"""
def test(a: 1, b: 2, /, c: 3): pass
test.__annotations__ #@
"""
)
annotations = next(ast_node.infer())
self.assertIsInstance(annotations, astroid.Dict)
self.assertIsInstance(annotations.getitem(astroid.Const("a")), astroid.Const)
self.assertEqual(annotations.getitem(astroid.Const("a")).value, 1)
self.assertEqual(annotations.getitem(astroid.Const("b")).value, 2)
self.assertEqual(annotations.getitem(astroid.Const("c")).value, 3)
class GeneratorModelTest(unittest.TestCase):
def test_model(self):
ast_nodes = builder.extract_node(
"""
def test():
"a"
yield
gen = test()
gen.__name__ #@
gen.__doc__ #@
gen.gi_code #@
gen.gi_frame #@
gen.send #@
"""
)
name = next(ast_nodes[0].infer())
self.assertEqual(name.value, "test")
doc = next(ast_nodes[1].infer())
self.assertEqual(doc.value, "a")
gi_code = next(ast_nodes[2].infer())
self.assertIsInstance(gi_code, astroid.ClassDef)
self.assertEqual(gi_code.name, "gi_code")
gi_frame = next(ast_nodes[3].infer())
self.assertIsInstance(gi_frame, astroid.ClassDef)
self.assertEqual(gi_frame.name, "gi_frame")
send = next(ast_nodes[4].infer())
self.assertIsInstance(send, astroid.BoundMethod)
class ExceptionModelTest(unittest.TestCase):
def test_valueerror_py3(self):
ast_nodes = builder.extract_node(
"""
try:
x[42]
except ValueError as err:
err.args #@
err.__traceback__ #@
err.message #@
"""
)
args = next(ast_nodes[0].infer())
self.assertIsInstance(args, astroid.Tuple)
tb = next(ast_nodes[1].infer())
self.assertIsInstance(tb, astroid.Instance)
self.assertEqual(tb.name, "traceback")
with self.assertRaises(exceptions.InferenceError):
next(ast_nodes[2].infer())
def test_syntax_error(self):
ast_node = builder.extract_node(
"""
try:
x[42]
except SyntaxError as err:
err.text #@
"""
)
inferred = next(ast_node.infer())
assert isinstance(inferred, astroid.Const)
def test_oserror(self):
ast_nodes = builder.extract_node(
"""
try:
raise OSError("a")
except OSError as err:
err.filename #@
err.filename2 #@
err.errno #@
"""
)
expected_values = ["", "", 0]
for node, value in zip(ast_nodes, expected_values):
inferred = next(node.infer())
assert isinstance(inferred, astroid.Const)
assert inferred.value == value
def test_import_error(self):
ast_nodes = builder.extract_node(
"""
try:
raise ImportError("a")
except ImportError as err:
err.name #@
err.path #@
"""
)
for node in ast_nodes:
inferred = next(node.infer())
assert isinstance(inferred, astroid.Const)
assert inferred.value == ""
def test_exception_instance_correctly_instantiated(self):
ast_node = builder.extract_node(
"""
try:
raise ImportError("a")
except ImportError as err:
err #@
"""
)
inferred = next(ast_node.infer())
assert isinstance(inferred, astroid.Instance)
cls = next(inferred.igetattr("__class__"))
assert isinstance(cls, astroid.ClassDef)
class DictObjectModelTest(unittest.TestCase):
def test__class__(self):
ast_node = builder.extract_node("{}.__class__")
inferred = next(ast_node.infer())
self.assertIsInstance(inferred, astroid.ClassDef)
self.assertEqual(inferred.name, "dict")
def test_attributes_inferred_as_methods(self):
ast_nodes = builder.extract_node(
"""
{}.values #@
{}.items #@
{}.keys #@
"""
)
for node in ast_nodes:
inferred = next(node.infer())
self.assertIsInstance(inferred, astroid.BoundMethod)
def test_wrapper_objects_for_dict_methods_python3(self):
ast_nodes = builder.extract_node(
"""
{1:1, 2:3}.values() #@
{1:1, 2:3}.keys() #@
{1:1, 2:3}.items() #@
"""
)
values = next(ast_nodes[0].infer())
self.assertIsInstance(values, objects.DictValues)
self.assertEqual([elt.value for elt in values.elts], [1, 3])
keys = next(ast_nodes[1].infer())
self.assertIsInstance(keys, objects.DictKeys)
self.assertEqual([elt.value for elt in keys.elts], [1, 2])
items = next(ast_nodes[2].infer())
self.assertIsInstance(items, objects.DictItems)
class LruCacheModelTest(unittest.TestCase):
def test_lru_cache(self):
ast_nodes = builder.extract_node(
"""
import functools
class Foo(object):
@functools.lru_cache()
def foo():
pass
f = Foo()
f.foo.cache_clear #@
f.foo.__wrapped__ #@
f.foo.cache_info() #@
"""
)
cache_clear = next(ast_nodes[0].infer())
self.assertIsInstance(cache_clear, astroid.BoundMethod)
wrapped = next(ast_nodes[1].infer())
self.assertIsInstance(wrapped, astroid.FunctionDef)
self.assertEqual(wrapped.name, "foo")
cache_info = next(ast_nodes[2].infer())
self.assertIsInstance(cache_info, astroid.Instance)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,535 @@
# Copyright (c) 2015-2016, 2018, 2020 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
from astroid import bases
from astroid import builder
from astroid import exceptions
from astroid import nodes
from astroid import objects
from astroid import test_utils
class ObjectsTest(unittest.TestCase):
def test_frozenset(self):
node = builder.extract_node(
"""
frozenset({1: 2, 2: 3}) #@
"""
)
inferred = next(node.infer())
self.assertIsInstance(inferred, objects.FrozenSet)
self.assertEqual(inferred.pytype(), "%s.frozenset" % bases.BUILTINS)
itered = inferred.itered()
self.assertEqual(len(itered), 2)
self.assertIsInstance(itered[0], nodes.Const)
self.assertEqual([const.value for const in itered], [1, 2])
proxied = inferred._proxied
self.assertEqual(inferred.qname(), "%s.frozenset" % bases.BUILTINS)
self.assertIsInstance(proxied, nodes.ClassDef)
class SuperTests(unittest.TestCase):
def test_inferring_super_outside_methods(self):
ast_nodes = builder.extract_node(
"""
class Module(object):
pass
class StaticMethod(object):
@staticmethod
def static():
# valid, but we don't bother with it.
return super(StaticMethod, StaticMethod) #@
# super outside methods aren't inferred
super(Module, Module) #@
# no argument super is not recognised outside methods as well.
super() #@
"""
)
in_static = next(ast_nodes[0].value.infer())
self.assertIsInstance(in_static, bases.Instance)
self.assertEqual(in_static.qname(), "%s.super" % bases.BUILTINS)
module_level = next(ast_nodes[1].infer())
self.assertIsInstance(module_level, bases.Instance)
self.assertEqual(in_static.qname(), "%s.super" % bases.BUILTINS)
no_arguments = next(ast_nodes[2].infer())
self.assertIsInstance(no_arguments, bases.Instance)
self.assertEqual(no_arguments.qname(), "%s.super" % bases.BUILTINS)
def test_inferring_unbound_super_doesnt_work(self):
node = builder.extract_node(
"""
class Test(object):
def __init__(self):
super(Test) #@
"""
)
unbounded = next(node.infer())
self.assertIsInstance(unbounded, bases.Instance)
self.assertEqual(unbounded.qname(), "%s.super" % bases.BUILTINS)
def test_use_default_inference_on_not_inferring_args(self):
ast_nodes = builder.extract_node(
"""
class Test(object):
def __init__(self):
super(Lala, self) #@
super(Test, lala) #@
"""
)
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, bases.Instance)
self.assertEqual(first.qname(), "%s.super" % bases.BUILTINS)
second = next(ast_nodes[1].infer())
self.assertIsInstance(second, bases.Instance)
self.assertEqual(second.qname(), "%s.super" % bases.BUILTINS)
@test_utils.require_version(minver="3.0")
def test_no_arguments_super(self):
ast_nodes = builder.extract_node(
"""
class First(object): pass
class Second(First):
def test(self):
super() #@
@classmethod
def test_classmethod(cls):
super() #@
"""
)
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, objects.Super)
self.assertIsInstance(first.type, bases.Instance)
self.assertEqual(first.type.name, "Second")
self.assertIsInstance(first.mro_pointer, nodes.ClassDef)
self.assertEqual(first.mro_pointer.name, "Second")
second = next(ast_nodes[1].infer())
self.assertIsInstance(second, objects.Super)
self.assertIsInstance(second.type, nodes.ClassDef)
self.assertEqual(second.type.name, "Second")
self.assertIsInstance(second.mro_pointer, nodes.ClassDef)
self.assertEqual(second.mro_pointer.name, "Second")
def test_super_simple_cases(self):
ast_nodes = builder.extract_node(
"""
class First(object): pass
class Second(First): pass
class Third(First):
def test(self):
super(Third, self) #@
super(Second, self) #@
# mro position and the type
super(Third, Third) #@
super(Third, Second) #@
super(Fourth, Fourth) #@
class Fourth(Third):
pass
"""
)
# .type is the object which provides the mro.
# .mro_pointer is the position in the mro from where
# the lookup should be done.
# super(Third, self)
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, objects.Super)
self.assertIsInstance(first.type, bases.Instance)
self.assertEqual(first.type.name, "Third")
self.assertIsInstance(first.mro_pointer, nodes.ClassDef)
self.assertEqual(first.mro_pointer.name, "Third")
# super(Second, self)
second = next(ast_nodes[1].infer())
self.assertIsInstance(second, objects.Super)
self.assertIsInstance(second.type, bases.Instance)
self.assertEqual(second.type.name, "Third")
self.assertIsInstance(first.mro_pointer, nodes.ClassDef)
self.assertEqual(second.mro_pointer.name, "Second")
# super(Third, Third)
third = next(ast_nodes[2].infer())
self.assertIsInstance(third, objects.Super)
self.assertIsInstance(third.type, nodes.ClassDef)
self.assertEqual(third.type.name, "Third")
self.assertIsInstance(third.mro_pointer, nodes.ClassDef)
self.assertEqual(third.mro_pointer.name, "Third")
# super(Third, second)
fourth = next(ast_nodes[3].infer())
self.assertIsInstance(fourth, objects.Super)
self.assertIsInstance(fourth.type, nodes.ClassDef)
self.assertEqual(fourth.type.name, "Second")
self.assertIsInstance(fourth.mro_pointer, nodes.ClassDef)
self.assertEqual(fourth.mro_pointer.name, "Third")
# Super(Fourth, Fourth)
fifth = next(ast_nodes[4].infer())
self.assertIsInstance(fifth, objects.Super)
self.assertIsInstance(fifth.type, nodes.ClassDef)
self.assertEqual(fifth.type.name, "Fourth")
self.assertIsInstance(fifth.mro_pointer, nodes.ClassDef)
self.assertEqual(fifth.mro_pointer.name, "Fourth")
def test_super_infer(self):
node = builder.extract_node(
"""
class Super(object):
def __init__(self):
super(Super, self) #@
"""
)
inferred = next(node.infer())
self.assertIsInstance(inferred, objects.Super)
reinferred = next(inferred.infer())
self.assertIsInstance(reinferred, objects.Super)
self.assertIs(inferred, reinferred)
def test_inferring_invalid_supers(self):
ast_nodes = builder.extract_node(
"""
class Super(object):
def __init__(self):
# MRO pointer is not a type
super(1, self) #@
# MRO type is not a subtype
super(Super, 1) #@
# self is not a subtype of Bupper
super(Bupper, self) #@
class Bupper(Super):
pass
"""
)
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, objects.Super)
with self.assertRaises(exceptions.SuperError) as cm:
first.super_mro()
self.assertIsInstance(cm.exception.super_.mro_pointer, nodes.Const)
self.assertEqual(cm.exception.super_.mro_pointer.value, 1)
for node, invalid_type in zip(ast_nodes[1:], (nodes.Const, bases.Instance)):
inferred = next(node.infer())
self.assertIsInstance(inferred, objects.Super, node)
with self.assertRaises(exceptions.SuperError) as cm:
inferred.super_mro()
self.assertIsInstance(cm.exception.super_.type, invalid_type)
def test_proxied(self):
node = builder.extract_node(
"""
class Super(object):
def __init__(self):
super(Super, self) #@
"""
)
inferred = next(node.infer())
proxied = inferred._proxied
self.assertEqual(proxied.qname(), "%s.super" % bases.BUILTINS)
self.assertIsInstance(proxied, nodes.ClassDef)
def test_super_bound_model(self):
ast_nodes = builder.extract_node(
"""
class First(object):
def method(self):
pass
@classmethod
def class_method(cls):
pass
class Super_Type_Type(First):
def method(self):
super(Super_Type_Type, Super_Type_Type).method #@
super(Super_Type_Type, Super_Type_Type).class_method #@
@classmethod
def class_method(cls):
super(Super_Type_Type, Super_Type_Type).method #@
super(Super_Type_Type, Super_Type_Type).class_method #@
class Super_Type_Object(First):
def method(self):
super(Super_Type_Object, self).method #@
super(Super_Type_Object, self).class_method #@
"""
)
# Super(type, type) is the same for both functions and classmethods.
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, nodes.FunctionDef)
self.assertEqual(first.name, "method")
second = next(ast_nodes[1].infer())
self.assertIsInstance(second, bases.BoundMethod)
self.assertEqual(second.bound.name, "First")
self.assertEqual(second.type, "classmethod")
third = next(ast_nodes[2].infer())
self.assertIsInstance(third, nodes.FunctionDef)
self.assertEqual(third.name, "method")
fourth = next(ast_nodes[3].infer())
self.assertIsInstance(fourth, bases.BoundMethod)
self.assertEqual(fourth.bound.name, "First")
self.assertEqual(fourth.type, "classmethod")
# Super(type, obj) can lead to different attribute bindings
# depending on the type of the place where super was called.
fifth = next(ast_nodes[4].infer())
self.assertIsInstance(fifth, bases.BoundMethod)
self.assertEqual(fifth.bound.name, "First")
self.assertEqual(fifth.type, "method")
sixth = next(ast_nodes[5].infer())
self.assertIsInstance(sixth, bases.BoundMethod)
self.assertEqual(sixth.bound.name, "First")
self.assertEqual(sixth.type, "classmethod")
def test_super_getattr_single_inheritance(self):
ast_nodes = builder.extract_node(
"""
class First(object):
def test(self): pass
class Second(First):
def test2(self): pass
class Third(Second):
test3 = 42
def __init__(self):
super(Third, self).test2 #@
super(Third, self).test #@
# test3 is local, no MRO lookup is done.
super(Third, self).test3 #@
super(Third, self) #@
# Unbounds.
super(Third, Third).test2 #@
super(Third, Third).test #@
"""
)
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, bases.BoundMethod)
self.assertEqual(first.bound.name, "Second")
second = next(ast_nodes[1].infer())
self.assertIsInstance(second, bases.BoundMethod)
self.assertEqual(second.bound.name, "First")
with self.assertRaises(exceptions.InferenceError):
next(ast_nodes[2].infer())
fourth = next(ast_nodes[3].infer())
with self.assertRaises(exceptions.AttributeInferenceError):
fourth.getattr("test3")
with self.assertRaises(exceptions.AttributeInferenceError):
next(fourth.igetattr("test3"))
first_unbound = next(ast_nodes[4].infer())
self.assertIsInstance(first_unbound, nodes.FunctionDef)
self.assertEqual(first_unbound.name, "test2")
self.assertEqual(first_unbound.parent.name, "Second")
second_unbound = next(ast_nodes[5].infer())
self.assertIsInstance(second_unbound, nodes.FunctionDef)
self.assertEqual(second_unbound.name, "test")
self.assertEqual(second_unbound.parent.name, "First")
def test_super_invalid_mro(self):
node = builder.extract_node(
"""
class A(object):
test = 42
class Super(A, A):
def __init__(self):
super(Super, self) #@
"""
)
inferred = next(node.infer())
with self.assertRaises(exceptions.AttributeInferenceError):
next(inferred.getattr("test"))
def test_super_complex_mro(self):
ast_nodes = builder.extract_node(
"""
class A(object):
def spam(self): return "A"
def foo(self): return "A"
@staticmethod
def static(self): pass
class B(A):
def boo(self): return "B"
def spam(self): return "B"
class C(A):
def boo(self): return "C"
class E(C, B):
def __init__(self):
super(E, self).boo #@
super(C, self).boo #@
super(E, self).spam #@
super(E, self).foo #@
super(E, self).static #@
"""
)
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, bases.BoundMethod)
self.assertEqual(first.bound.name, "C")
second = next(ast_nodes[1].infer())
self.assertIsInstance(second, bases.BoundMethod)
self.assertEqual(second.bound.name, "B")
third = next(ast_nodes[2].infer())
self.assertIsInstance(third, bases.BoundMethod)
self.assertEqual(third.bound.name, "B")
fourth = next(ast_nodes[3].infer())
self.assertEqual(fourth.bound.name, "A")
static = next(ast_nodes[4].infer())
self.assertIsInstance(static, nodes.FunctionDef)
self.assertEqual(static.parent.scope().name, "A")
def test_super_data_model(self):
ast_nodes = builder.extract_node(
"""
class X(object): pass
class A(X):
def __init__(self):
super(A, self) #@
super(A, A) #@
super(X, A) #@
"""
)
first = next(ast_nodes[0].infer())
thisclass = first.getattr("__thisclass__")[0]
self.assertIsInstance(thisclass, nodes.ClassDef)
self.assertEqual(thisclass.name, "A")
selfclass = first.getattr("__self_class__")[0]
self.assertIsInstance(selfclass, nodes.ClassDef)
self.assertEqual(selfclass.name, "A")
self_ = first.getattr("__self__")[0]
self.assertIsInstance(self_, bases.Instance)
self.assertEqual(self_.name, "A")
cls = first.getattr("__class__")[0]
self.assertEqual(cls, first._proxied)
second = next(ast_nodes[1].infer())
thisclass = second.getattr("__thisclass__")[0]
self.assertEqual(thisclass.name, "A")
self_ = second.getattr("__self__")[0]
self.assertIsInstance(self_, nodes.ClassDef)
self.assertEqual(self_.name, "A")
third = next(ast_nodes[2].infer())
thisclass = third.getattr("__thisclass__")[0]
self.assertEqual(thisclass.name, "X")
selfclass = third.getattr("__self_class__")[0]
self.assertEqual(selfclass.name, "A")
def assertEqualMro(self, klass, expected_mro):
self.assertEqual([member.name for member in klass.super_mro()], expected_mro)
def test_super_mro(self):
ast_nodes = builder.extract_node(
"""
class A(object): pass
class B(A): pass
class C(A): pass
class E(C, B):
def __init__(self):
super(E, self) #@
super(C, self) #@
super(B, self) #@
super(B, 1) #@
super(1, B) #@
"""
)
first = next(ast_nodes[0].infer())
self.assertEqualMro(first, ["C", "B", "A", "object"])
second = next(ast_nodes[1].infer())
self.assertEqualMro(second, ["B", "A", "object"])
third = next(ast_nodes[2].infer())
self.assertEqualMro(third, ["A", "object"])
fourth = next(ast_nodes[3].infer())
with self.assertRaises(exceptions.SuperError):
fourth.super_mro()
fifth = next(ast_nodes[4].infer())
with self.assertRaises(exceptions.SuperError):
fifth.super_mro()
def test_super_yes_objects(self):
ast_nodes = builder.extract_node(
"""
from collections import Missing
class A(object):
def __init__(self):
super(Missing, self) #@
super(A, Missing) #@
"""
)
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, bases.Instance)
second = next(ast_nodes[1].infer())
self.assertIsInstance(second, bases.Instance)
def test_super_invalid_types(self):
node = builder.extract_node(
"""
import collections
class A(object):
def __init__(self):
super(A, collections) #@
"""
)
inferred = next(node.infer())
with self.assertRaises(exceptions.SuperError):
inferred.super_mro()
with self.assertRaises(exceptions.SuperError):
inferred.super_mro()
def test_super_properties(self):
node = builder.extract_node(
"""
class Foo(object):
@property
def dict(self):
return 42
class Bar(Foo):
@property
def dict(self):
return super(Bar, self).dict
Bar().dict
"""
)
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.Const)
self.assertEqual(inferred.value, 42)
def test_super_qname(self):
"""Make sure a Super object generates a qname
equivalent to super.__qname__
"""
# See issue 533
code = """
class C:
def foo(self): return super()
C().foo() #@
"""
super_obj = next(builder.extract_node(code).infer())
self.assertEqual(super_obj.qname(), "super")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,278 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2015-2019 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2016 Jakub Wilk <jwilk@jwilk.net>
# Copyright (c) 2017 Łukasz Rogalski <rogalski.91@gmail.com>
# Copyright (c) 2018 Nick Drozd <nicholasdrozd@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import contextlib
import unittest
import pytest
import sys
import astroid
from astroid import extract_node
from astroid.test_utils import require_version
from astroid import InferenceError
from astroid import nodes
from astroid import util
from astroid.node_classes import AssignName, Const, Name, Starred
@contextlib.contextmanager
def _add_transform(manager, node, transform, predicate=None):
manager.register_transform(node, transform, predicate)
try:
yield
finally:
manager.unregister_transform(node, transform, predicate)
class ProtocolTests(unittest.TestCase):
def assertConstNodesEqual(self, nodes_list_expected, nodes_list_got):
self.assertEqual(len(nodes_list_expected), len(nodes_list_got))
for node in nodes_list_got:
self.assertIsInstance(node, Const)
for node, expected_value in zip(nodes_list_got, nodes_list_expected):
self.assertEqual(expected_value, node.value)
def assertNameNodesEqual(self, nodes_list_expected, nodes_list_got):
self.assertEqual(len(nodes_list_expected), len(nodes_list_got))
for node in nodes_list_got:
self.assertIsInstance(node, Name)
for node, expected_name in zip(nodes_list_got, nodes_list_expected):
self.assertEqual(expected_name, node.name)
def test_assigned_stmts_simple_for(self):
assign_stmts = extract_node(
"""
for a in (1, 2, 3): #@
pass
for b in range(3): #@
pass
"""
)
for1_assnode = next(assign_stmts[0].nodes_of_class(AssignName))
assigned = list(for1_assnode.assigned_stmts())
self.assertConstNodesEqual([1, 2, 3], assigned)
for2_assnode = next(assign_stmts[1].nodes_of_class(AssignName))
self.assertRaises(InferenceError, list, for2_assnode.assigned_stmts())
@require_version(minver="3.0")
def test_assigned_stmts_starred_for(self):
assign_stmts = extract_node(
"""
for *a, b in ((1, 2, 3), (4, 5, 6, 7)): #@
pass
"""
)
for1_starred = next(assign_stmts.nodes_of_class(Starred))
assigned = next(for1_starred.assigned_stmts())
assert isinstance(assigned, astroid.List)
assert assigned.as_string() == "[1, 2]"
def _get_starred_stmts(self, code):
assign_stmt = extract_node("{} #@".format(code))
starred = next(assign_stmt.nodes_of_class(Starred))
return next(starred.assigned_stmts())
def _helper_starred_expected_const(self, code, expected):
stmts = self._get_starred_stmts(code)
self.assertIsInstance(stmts, nodes.List)
stmts = stmts.elts
self.assertConstNodesEqual(expected, stmts)
def _helper_starred_expected(self, code, expected):
stmts = self._get_starred_stmts(code)
self.assertEqual(expected, stmts)
def _helper_starred_inference_error(self, code):
assign_stmt = extract_node("{} #@".format(code))
starred = next(assign_stmt.nodes_of_class(Starred))
self.assertRaises(InferenceError, list, starred.assigned_stmts())
@require_version(minver="3.0")
def test_assigned_stmts_starred_assnames(self):
self._helper_starred_expected_const("a, *b = (1, 2, 3, 4) #@", [2, 3, 4])
self._helper_starred_expected_const("*a, b = (1, 2, 3) #@", [1, 2])
self._helper_starred_expected_const("a, *b, c = (1, 2, 3, 4, 5) #@", [2, 3, 4])
self._helper_starred_expected_const("a, *b = (1, 2) #@", [2])
self._helper_starred_expected_const("*b, a = (1, 2) #@", [1])
self._helper_starred_expected_const("[*b] = (1, 2) #@", [1, 2])
@require_version(minver="3.0")
def test_assigned_stmts_starred_yes(self):
# Not something iterable and known
self._helper_starred_expected("a, *b = range(3) #@", util.Uninferable)
# Not something inferrable
self._helper_starred_expected("a, *b = balou() #@", util.Uninferable)
# In function, unknown.
self._helper_starred_expected(
"""
def test(arg):
head, *tail = arg #@""",
util.Uninferable,
)
# These cases aren't worth supporting.
self._helper_starred_expected(
"a, (*b, c), d = (1, (2, 3, 4), 5) #@", util.Uninferable
)
@require_version(minver="3.0")
def test_assign_stmts_starred_fails(self):
# Too many starred
self._helper_starred_inference_error("a, *b, *c = (1, 2, 3) #@")
# This could be solved properly, but it complicates needlessly the
# code for assigned_stmts, without offering real benefit.
self._helper_starred_inference_error(
"(*a, b), (c, *d) = (1, 2, 3), (4, 5, 6) #@"
)
def test_assigned_stmts_assignments(self):
assign_stmts = extract_node(
"""
c = a #@
d, e = b, c #@
"""
)
simple_assnode = next(assign_stmts[0].nodes_of_class(AssignName))
assigned = list(simple_assnode.assigned_stmts())
self.assertNameNodesEqual(["a"], assigned)
assnames = assign_stmts[1].nodes_of_class(AssignName)
simple_mul_assnode_1 = next(assnames)
assigned = list(simple_mul_assnode_1.assigned_stmts())
self.assertNameNodesEqual(["b"], assigned)
simple_mul_assnode_2 = next(assnames)
assigned = list(simple_mul_assnode_2.assigned_stmts())
self.assertNameNodesEqual(["c"], assigned)
@require_version(minver="3.6")
def test_assigned_stmts_annassignments(self):
annassign_stmts = extract_node(
"""
a: str = "abc" #@
b: str #@
"""
)
simple_annassign_node = next(annassign_stmts[0].nodes_of_class(AssignName))
assigned = list(simple_annassign_node.assigned_stmts())
self.assertEqual(1, len(assigned))
self.assertIsInstance(assigned[0], Const)
self.assertEqual(assigned[0].value, "abc")
empty_annassign_node = next(annassign_stmts[1].nodes_of_class(AssignName))
assigned = list(empty_annassign_node.assigned_stmts())
self.assertEqual(1, len(assigned))
self.assertIs(assigned[0], util.Uninferable)
def test_sequence_assigned_stmts_not_accepting_empty_node(self):
def transform(node):
node.root().locals["__all__"] = [node.value]
manager = astroid.MANAGER
with _add_transform(manager, astroid.Assign, transform):
module = astroid.parse(
"""
__all__ = ['a']
"""
)
module.wildcard_import_names()
def test_not_passing_uninferable_in_seq_inference(self):
class Visitor:
def visit(self, node):
for child in node.get_children():
child.accept(self)
visit_module = visit
visit_assign = visit
visit_binop = visit
visit_list = visit
visit_const = visit
visit_name = visit
def visit_assignname(self, node):
for _ in node.infer():
pass
parsed = extract_node(
"""
a = []
x = [a*2, a]*2*2
"""
)
parsed.accept(Visitor())
@pytest.mark.skipif(
sys.version_info[:2] < (3, 8), reason="needs assignment expressions"
)
def test_named_expr_inference():
code = """
if (a := 2) == 2:
a #@
# Test a function call
def test():
return 24
if (a := test()):
a #@
# Normal assignments in sequences
{ (a:= 4) } #@
[ (a:= 5) ] #@
# Something more complicated
def test(value=(p := 24)): return p
[ y:= test()] #@
# Priority assignment
(x := 1, 2)
x #@
"""
ast_nodes = extract_node(code)
node = next(ast_nodes[0].infer())
assert isinstance(node, nodes.Const)
assert node.value == 2
node = next(ast_nodes[1].infer())
assert isinstance(node, nodes.Const)
assert node.value == 24
node = next(ast_nodes[2].infer())
assert isinstance(node, nodes.Set)
assert isinstance(node.elts[0], nodes.Const)
assert node.elts[0].value == 4
node = next(ast_nodes[3].infer())
assert isinstance(node, nodes.List)
assert isinstance(node.elts[0], nodes.Const)
assert node.elts[0].value == 5
node = next(ast_nodes[4].infer())
assert isinstance(node, nodes.List)
assert isinstance(node.elts[0], nodes.Const)
assert node.elts[0].value == 24
node = next(ast_nodes[5].infer())
assert isinstance(node, nodes.Const)
assert node.value == 1
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,412 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2010, 2013-2014 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
# Copyright (c) 2012 FELD Boris <lothiraldan@gmail.com>
# Copyright (c) 2013-2018 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2014 Google, Inc.
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2016 Jared Garst <jgarst@users.noreply.github.com>
# Copyright (c) 2017, 2019 Łukasz Rogalski <rogalski.91@gmail.com>
# Copyright (c) 2017 Hugo <hugovk@users.noreply.github.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
from textwrap import dedent
import unittest
from astroid import nodes
from astroid.node_classes import Assign, Expr, YieldFrom, Name, Const
from astroid.builder import AstroidBuilder, extract_node
from astroid.scoped_nodes import ClassDef, FunctionDef
from astroid.test_utils import require_version
class Python3TC(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.builder = AstroidBuilder()
@require_version("3.4")
def test_starred_notation(self):
astroid = self.builder.string_build("*a, b = [1, 2, 3]", "test", "test")
# Get the star node
node = next(next(next(astroid.get_children()).get_children()).get_children())
self.assertTrue(isinstance(node.assign_type(), Assign))
@require_version("3.4")
def test_yield_from(self):
body = dedent(
"""
def func():
yield from iter([1, 2])
"""
)
astroid = self.builder.string_build(body)
func = astroid.body[0]
self.assertIsInstance(func, FunctionDef)
yieldfrom_stmt = func.body[0]
self.assertIsInstance(yieldfrom_stmt, Expr)
self.assertIsInstance(yieldfrom_stmt.value, YieldFrom)
self.assertEqual(yieldfrom_stmt.as_string(), "yield from iter([1, 2])")
@require_version("3.4")
def test_yield_from_is_generator(self):
body = dedent(
"""
def func():
yield from iter([1, 2])
"""
)
astroid = self.builder.string_build(body)
func = astroid.body[0]
self.assertIsInstance(func, FunctionDef)
self.assertTrue(func.is_generator())
@require_version("3.4")
def test_yield_from_as_string(self):
body = dedent(
"""
def func():
yield from iter([1, 2])
value = yield from other()
"""
)
astroid = self.builder.string_build(body)
func = astroid.body[0]
self.assertEqual(func.as_string().strip(), body.strip())
# metaclass tests
@require_version("3.4")
def test_simple_metaclass(self):
astroid = self.builder.string_build("class Test(metaclass=type): pass")
klass = astroid.body[0]
metaclass = klass.metaclass()
self.assertIsInstance(metaclass, ClassDef)
self.assertEqual(metaclass.name, "type")
@require_version("3.4")
def test_metaclass_error(self):
astroid = self.builder.string_build("class Test(metaclass=typ): pass")
klass = astroid.body[0]
self.assertFalse(klass.metaclass())
@require_version("3.4")
def test_metaclass_imported(self):
astroid = self.builder.string_build(
dedent(
"""
from abc import ABCMeta
class Test(metaclass=ABCMeta): pass"""
)
)
klass = astroid.body[1]
metaclass = klass.metaclass()
self.assertIsInstance(metaclass, ClassDef)
self.assertEqual(metaclass.name, "ABCMeta")
@require_version("3.4")
def test_metaclass_multiple_keywords(self):
astroid = self.builder.string_build(
"class Test(magic=None, metaclass=type): pass"
)
klass = astroid.body[0]
metaclass = klass.metaclass()
self.assertIsInstance(metaclass, ClassDef)
self.assertEqual(metaclass.name, "type")
@require_version("3.4")
def test_as_string(self):
body = dedent(
"""
from abc import ABCMeta
class Test(metaclass=ABCMeta): pass"""
)
astroid = self.builder.string_build(body)
klass = astroid.body[1]
self.assertEqual(
klass.as_string(), "\n\nclass Test(metaclass=ABCMeta):\n pass\n"
)
@require_version("3.4")
def test_old_syntax_works(self):
astroid = self.builder.string_build(
dedent(
"""
class Test:
__metaclass__ = type
class SubTest(Test): pass
"""
)
)
klass = astroid["SubTest"]
metaclass = klass.metaclass()
self.assertIsNone(metaclass)
@require_version("3.4")
def test_metaclass_yes_leak(self):
astroid = self.builder.string_build(
dedent(
"""
# notice `ab` instead of `abc`
from ab import ABCMeta
class Meta(metaclass=ABCMeta): pass
"""
)
)
klass = astroid["Meta"]
self.assertIsNone(klass.metaclass())
@require_version("3.4")
def test_parent_metaclass(self):
astroid = self.builder.string_build(
dedent(
"""
from abc import ABCMeta
class Test(metaclass=ABCMeta): pass
class SubTest(Test): pass
"""
)
)
klass = astroid["SubTest"]
self.assertTrue(klass.newstyle)
metaclass = klass.metaclass()
self.assertIsInstance(metaclass, ClassDef)
self.assertEqual(metaclass.name, "ABCMeta")
@require_version("3.4")
def test_metaclass_ancestors(self):
astroid = self.builder.string_build(
dedent(
"""
from abc import ABCMeta
class FirstMeta(metaclass=ABCMeta): pass
class SecondMeta(metaclass=type):
pass
class Simple:
pass
class FirstImpl(FirstMeta): pass
class SecondImpl(FirstImpl): pass
class ThirdImpl(Simple, SecondMeta):
pass
"""
)
)
classes = {"ABCMeta": ("FirstImpl", "SecondImpl"), "type": ("ThirdImpl",)}
for metaclass, names in classes.items():
for name in names:
impl = astroid[name]
meta = impl.metaclass()
self.assertIsInstance(meta, ClassDef)
self.assertEqual(meta.name, metaclass)
@require_version("3.4")
def test_annotation_support(self):
astroid = self.builder.string_build(
dedent(
"""
def test(a: int, b: str, c: None, d, e,
*args: float, **kwargs: int)->int:
pass
"""
)
)
func = astroid["test"]
self.assertIsInstance(func.args.varargannotation, Name)
self.assertEqual(func.args.varargannotation.name, "float")
self.assertIsInstance(func.args.kwargannotation, Name)
self.assertEqual(func.args.kwargannotation.name, "int")
self.assertIsInstance(func.returns, Name)
self.assertEqual(func.returns.name, "int")
arguments = func.args
self.assertIsInstance(arguments.annotations[0], Name)
self.assertEqual(arguments.annotations[0].name, "int")
self.assertIsInstance(arguments.annotations[1], Name)
self.assertEqual(arguments.annotations[1].name, "str")
self.assertIsInstance(arguments.annotations[2], Const)
self.assertIsNone(arguments.annotations[2].value)
self.assertIsNone(arguments.annotations[3])
self.assertIsNone(arguments.annotations[4])
astroid = self.builder.string_build(
dedent(
"""
def test(a: int=1, b: str=2):
pass
"""
)
)
func = astroid["test"]
self.assertIsInstance(func.args.annotations[0], Name)
self.assertEqual(func.args.annotations[0].name, "int")
self.assertIsInstance(func.args.annotations[1], Name)
self.assertEqual(func.args.annotations[1].name, "str")
self.assertIsNone(func.returns)
@require_version("3.4")
def test_kwonlyargs_annotations_supper(self):
node = self.builder.string_build(
dedent(
"""
def test(*, a: int, b: str, c: None, d, e):
pass
"""
)
)
func = node["test"]
arguments = func.args
self.assertIsInstance(arguments.kwonlyargs_annotations[0], Name)
self.assertEqual(arguments.kwonlyargs_annotations[0].name, "int")
self.assertIsInstance(arguments.kwonlyargs_annotations[1], Name)
self.assertEqual(arguments.kwonlyargs_annotations[1].name, "str")
self.assertIsInstance(arguments.kwonlyargs_annotations[2], Const)
self.assertIsNone(arguments.kwonlyargs_annotations[2].value)
self.assertIsNone(arguments.kwonlyargs_annotations[3])
self.assertIsNone(arguments.kwonlyargs_annotations[4])
@require_version("3.4")
def test_annotation_as_string(self):
code1 = dedent(
"""
def test(a, b: int = 4, c=2, f: 'lala' = 4) -> 2:
pass"""
)
code2 = dedent(
"""
def test(a: typing.Generic[T], c: typing.Any = 24) -> typing.Iterable:
pass"""
)
for code in (code1, code2):
func = extract_node(code)
self.assertEqual(func.as_string(), code)
@require_version("3.5")
def test_unpacking_in_dicts(self):
code = "{'x': 1, **{'y': 2}}"
node = extract_node(code)
self.assertEqual(node.as_string(), code)
keys = [key for (key, _) in node.items]
self.assertIsInstance(keys[0], nodes.Const)
self.assertIsInstance(keys[1], nodes.DictUnpack)
@require_version("3.5")
def test_nested_unpacking_in_dicts(self):
code = "{'x': 1, **{'y': 2, **{'z': 3}}}"
node = extract_node(code)
self.assertEqual(node.as_string(), code)
@require_version("3.5")
def test_unpacking_in_dict_getitem(self):
node = extract_node("{1:2, **{2:3, 3:4}, **{5: 6}}")
for key, expected in ((1, 2), (2, 3), (3, 4), (5, 6)):
value = node.getitem(nodes.Const(key))
self.assertIsInstance(value, nodes.Const)
self.assertEqual(value.value, expected)
@require_version("3.6")
def test_format_string(self):
code = "f'{greetings} {person}'"
node = extract_node(code)
self.assertEqual(node.as_string(), code)
@require_version("3.6")
def test_underscores_in_numeral_literal(self):
pairs = [("10_1000", 101000), ("10_000_000", 10000000), ("0x_FF_FF", 65535)]
for value, expected in pairs:
node = extract_node(value)
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.Const)
self.assertEqual(inferred.value, expected)
@require_version("3.6")
def test_async_comprehensions(self):
async_comprehensions = [
extract_node(
"async def f(): return __([i async for i in aiter() if i % 2])"
),
extract_node(
"async def f(): return __({i async for i in aiter() if i % 2})"
),
extract_node(
"async def f(): return __((i async for i in aiter() if i % 2))"
),
extract_node(
"async def f(): return __({i: i async for i in aiter() if i % 2})"
),
]
non_async_comprehensions = [
extract_node("async def f(): return __({i: i for i in iter() if i % 2})")
]
for comp in async_comprehensions:
self.assertTrue(comp.generators[0].is_async)
for comp in non_async_comprehensions:
self.assertFalse(comp.generators[0].is_async)
@require_version("3.7")
def test_async_comprehensions_outside_coroutine(self):
# When async and await will become keywords, async comprehensions
# will be allowed outside of coroutines body
comprehensions = [
"[i async for i in aiter() if condition(i)]",
"[await fun() async for fun in funcs]",
"{await fun() async for fun in funcs}",
"{fun: await fun() async for fun in funcs}",
"[await fun() async for fun in funcs if await smth]",
"{await fun() async for fun in funcs if await smth}",
"{fun: await fun() async for fun in funcs if await smth}",
"[await fun() async for fun in funcs]",
"{await fun() async for fun in funcs}",
"{fun: await fun() async for fun in funcs}",
"[await fun() async for fun in funcs if await smth]",
"{await fun() async for fun in funcs if await smth}",
"{fun: await fun() async for fun in funcs if await smth}",
]
for comp in comprehensions:
node = extract_node(comp)
self.assertTrue(node.generators[0].is_async)
@require_version("3.6")
def test_async_comprehensions_as_string(self):
func_bodies = [
"return [i async for i in aiter() if condition(i)]",
"return [await fun() for fun in funcs]",
"return {await fun() for fun in funcs}",
"return {fun: await fun() for fun in funcs}",
"return [await fun() for fun in funcs if await smth]",
"return {await fun() for fun in funcs if await smth}",
"return {fun: await fun() for fun in funcs if await smth}",
"return [await fun() async for fun in funcs]",
"return {await fun() async for fun in funcs}",
"return {fun: await fun() async for fun in funcs}",
"return [await fun() async for fun in funcs if await smth]",
"return {await fun() async for fun in funcs if await smth}",
"return {fun: await fun() async for fun in funcs if await smth}",
]
for func_body in func_bodies:
code = dedent(
"""
async def f():
{}""".format(
func_body
)
)
func = extract_node(code)
self.assertEqual(func.as_string().strip(), code.strip())
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,86 @@
# Copyright (c) 2013 AndroWiiid <androwiiid@gmail.com>
# Copyright (c) 2014-2016, 2018-2019 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2014 Google, Inc.
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2018 Anthony Sottile <asottile@umich.edu>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import platform
import unittest
import _io
from astroid.builder import AstroidBuilder
from astroid.raw_building import (
attach_dummy_node,
build_module,
build_class,
build_function,
build_from_import,
)
from astroid import test_utils
class RawBuildingTC(unittest.TestCase):
def test_attach_dummy_node(self):
node = build_module("MyModule")
attach_dummy_node(node, "DummyNode")
self.assertEqual(1, len(list(node.get_children())))
def test_build_module(self):
node = build_module("MyModule")
self.assertEqual(node.name, "MyModule")
self.assertEqual(node.pure_python, False)
self.assertEqual(node.package, False)
self.assertEqual(node.parent, None)
def test_build_class(self):
node = build_class("MyClass")
self.assertEqual(node.name, "MyClass")
self.assertEqual(node.doc, None)
def test_build_function(self):
node = build_function("MyFunction")
self.assertEqual(node.name, "MyFunction")
self.assertEqual(node.doc, None)
def test_build_function_args(self):
args = ["myArgs1", "myArgs2"]
node = build_function("MyFunction", args)
self.assertEqual("myArgs1", node.args.args[0].name)
self.assertEqual("myArgs2", node.args.args[1].name)
self.assertEqual(2, len(node.args.args))
def test_build_function_defaults(self):
defaults = ["defaults1", "defaults2"]
node = build_function(name="MyFunction", args=None, defaults=defaults)
self.assertEqual(2, len(node.args.defaults))
def test_build_function_posonlyargs(self):
node = build_function(name="MyFunction", posonlyargs=["a", "b"])
self.assertEqual(2, len(node.args.posonlyargs))
def test_build_from_import(self):
names = ["exceptions, inference, inspector"]
node = build_from_import("astroid", names)
self.assertEqual(len(names), len(node.names))
@unittest.skipIf(platform.python_implementation() == "PyPy", "Only affects CPython")
@test_utils.require_version(minver="3.0")
def test_io_is__io(self):
# _io module calls itself io. This leads
# to cyclic dependencies when astroid tries to resolve
# what io.BufferedReader is. The code that handles this
# is in astroid.raw_building.imported_member, which verifies
# the true name of the module.
builder = AstroidBuilder()
module = builder.inspect_build(_io)
buffered_reader = module.getattr("BufferedReader")[0]
self.assertEqual(buffered_reader.root().name, "io")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,328 @@
# Copyright (c) 2006-2008, 2010-2014 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
# Copyright (c) 2007 Marien Zwart <marienz@gentoo.org>
# Copyright (c) 2013-2014 Google, Inc.
# Copyright (c) 2014-2016, 2018-2020 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2014 Eevee (Alex Munroe) <amunroe@yelp.com>
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2016 Jakub Wilk <jwilk@jwilk.net>
# Copyright (c) 2018 Nick Drozd <nicholasdrozd@gmail.com>
# Copyright (c) 2018 Anthony Sottile <asottile@umich.edu>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import sys
import unittest
import textwrap
from astroid import MANAGER, Instance, nodes
from astroid.bases import BUILTINS
from astroid.builder import AstroidBuilder, extract_node
from astroid import exceptions
from astroid.raw_building import build_module
from astroid.manager import AstroidManager
from astroid.test_utils import require_version
from astroid import transforms
from . import resources
try:
import numpy # pylint: disable=unused-import
except ImportError:
HAS_NUMPY = False
else:
HAS_NUMPY = True
class NonRegressionTests(resources.AstroidCacheSetupMixin, unittest.TestCase):
def setUp(self):
sys.path.insert(0, resources.find("data"))
MANAGER.always_load_extensions = True
def tearDown(self):
MANAGER.always_load_extensions = False
sys.path.pop(0)
sys.path_importer_cache.pop(resources.find("data"), None)
def brainless_manager(self):
manager = AstroidManager()
# avoid caching into the AstroidManager borg since we get problems
# with other tests :
manager.__dict__ = {}
manager._failed_import_hooks = []
manager.astroid_cache = {}
manager._mod_file_cache = {}
manager._transform = transforms.TransformVisitor()
return manager
def test_module_path(self):
man = self.brainless_manager()
mod = man.ast_from_module_name("package.import_package_subpackage_module")
package = next(mod.igetattr("package"))
self.assertEqual(package.name, "package")
subpackage = next(package.igetattr("subpackage"))
self.assertIsInstance(subpackage, nodes.Module)
self.assertTrue(subpackage.package)
self.assertEqual(subpackage.name, "package.subpackage")
module = next(subpackage.igetattr("module"))
self.assertEqual(module.name, "package.subpackage.module")
def test_package_sidepackage(self):
manager = self.brainless_manager()
assert "package.sidepackage" not in MANAGER.astroid_cache
package = manager.ast_from_module_name("absimp")
self.assertIsInstance(package, nodes.Module)
self.assertTrue(package.package)
subpackage = next(package.getattr("sidepackage")[0].infer())
self.assertIsInstance(subpackage, nodes.Module)
self.assertTrue(subpackage.package)
self.assertEqual(subpackage.name, "absimp.sidepackage")
def test_living_property(self):
builder = AstroidBuilder()
builder._done = {}
builder._module = sys.modules[__name__]
builder.object_build(build_module("module_name", ""), Whatever)
@unittest.skipIf(not HAS_NUMPY, "Needs numpy")
def test_numpy_crash(self):
"""test don't crash on numpy"""
# a crash occurred somewhere in the past, and an
# InferenceError instead of a crash was better, but now we even infer!
builder = AstroidBuilder()
data = """
from numpy import multiply
multiply(1, 2, 3)
"""
astroid = builder.string_build(data, __name__, __file__)
callfunc = astroid.body[1].value.func
inferred = callfunc.inferred()
self.assertEqual(len(inferred), 1)
@require_version("3.0")
def test_nameconstant(self):
# used to fail for Python 3.4
builder = AstroidBuilder()
astroid = builder.string_build("def test(x=True): pass")
default = astroid.body[0].args.args[0]
self.assertEqual(default.name, "x")
self.assertEqual(next(default.infer()).value, True)
def test_recursion_regression_issue25(self):
builder = AstroidBuilder()
data = """
import recursion as base
_real_Base = base.Base
class Derived(_real_Base):
pass
def run():
base.Base = Derived
"""
astroid = builder.string_build(data, __name__, __file__)
# Used to crash in _is_metaclass, due to wrong
# ancestors chain
classes = astroid.nodes_of_class(nodes.ClassDef)
for klass in classes:
# triggers the _is_metaclass call
klass.type # pylint: disable=pointless-statement
def test_decorator_callchain_issue42(self):
builder = AstroidBuilder()
data = """
def test():
def factory(func):
def newfunc():
func()
return newfunc
return factory
@test()
def crash():
pass
"""
astroid = builder.string_build(data, __name__, __file__)
self.assertEqual(astroid["crash"].type, "function")
def test_filter_stmts_scoping(self):
builder = AstroidBuilder()
data = """
def test():
compiler = int()
class B(compiler.__class__):
pass
compiler = B()
return compiler
"""
astroid = builder.string_build(data, __name__, __file__)
test = astroid["test"]
result = next(test.infer_call_result(astroid))
self.assertIsInstance(result, Instance)
base = next(result._proxied.bases[0].infer())
self.assertEqual(base.name, "int")
def test_ancestors_patching_class_recursion(self):
node = AstroidBuilder().string_build(
textwrap.dedent(
"""
import string
Template = string.Template
class A(Template):
pass
class B(A):
pass
def test(x=False):
if x:
string.Template = A
else:
string.Template = B
"""
)
)
klass = node["A"]
ancestors = list(klass.ancestors())
self.assertEqual(ancestors[0].qname(), "string.Template")
def test_ancestors_yes_in_bases(self):
# Test for issue https://bitbucket.org/logilab/astroid/issue/84
# This used to crash astroid with a TypeError, because an Uninferable
# node was present in the bases
node = extract_node(
"""
def with_metaclass(meta, *bases):
class metaclass(meta):
def __new__(cls, name, this_bases, d):
return meta(name, bases, d)
return type.__new__(metaclass, 'temporary_class', (), {})
import lala
class A(with_metaclass(object, lala.lala)): #@
pass
"""
)
ancestors = list(node.ancestors())
self.assertEqual(len(ancestors), 1)
self.assertEqual(ancestors[0].qname(), "{}.object".format(BUILTINS))
def test_ancestors_missing_from_function(self):
# Test for https://www.logilab.org/ticket/122793
node = extract_node(
"""
def gen(): yield
GEN = gen()
next(GEN)
"""
)
self.assertRaises(exceptions.InferenceError, next, node.infer())
def test_unicode_in_docstring(self):
# Crashed for astroid==1.4.1
# Test for https://bitbucket.org/logilab/astroid/issues/273/
# In a regular file, "coding: utf-8" would have been used.
node = extract_node(
"""
from __future__ import unicode_literals
class MyClass(object):
def method(self):
"With unicode : %s "
instance = MyClass()
"""
% "\u2019"
)
next(node.value.infer()).as_string()
def test_binop_generates_nodes_with_parents(self):
node = extract_node(
"""
def no_op(*args):
pass
def foo(*args):
def inner(*more_args):
args + more_args #@
return inner
"""
)
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.Tuple)
self.assertIsNotNone(inferred.parent)
self.assertIsInstance(inferred.parent, nodes.BinOp)
def test_decorator_names_inference_error_leaking(self):
node = extract_node(
"""
class Parent(object):
@property
def foo(self):
pass
class Child(Parent):
@Parent.foo.getter
def foo(self): #@
return super(Child, self).foo + ['oink']
"""
)
inferred = next(node.infer())
self.assertEqual(inferred.decoratornames(), {".Parent.foo.getter"})
def test_ssl_protocol(self):
node = extract_node(
"""
import ssl
ssl.PROTOCOL_TLSv1
"""
)
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.Const)
def test_recursive_property_method(self):
node = extract_node(
"""
class APropert():
@property
def property(self):
return self
APropert().property
"""
)
next(node.infer())
def test_uninferable_string_argument_of_namedtuple(self):
node = extract_node(
"""
import collections
collections.namedtuple('{}'.format("a"), '')()
"""
)
next(node.infer())
def test_regression_inference_of_self_in_lambda(self):
code = """
class A:
@b(lambda self: __(self))
def d(self):
pass
"""
node = extract_node(code)
inferred = next(node.infer())
assert isinstance(inferred, Instance)
assert inferred.qname() == ".A"
class Whatever:
a = property(lambda x: x, lambda x: x)
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,243 @@
# Copyright (c) 2015-2018 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2016 Jakub Wilk <jwilk@jwilk.net>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
from __future__ import print_function
import contextlib
import time
import unittest
from astroid import builder
from astroid import nodes
from astroid import parse
from astroid import transforms
@contextlib.contextmanager
def add_transform(manager, node, transform, predicate=None):
manager.register_transform(node, transform, predicate)
try:
yield
finally:
manager.unregister_transform(node, transform, predicate)
class TestTransforms(unittest.TestCase):
def setUp(self):
self.transformer = transforms.TransformVisitor()
def parse_transform(self, code):
module = parse(code, apply_transforms=False)
return self.transformer.visit(module)
def test_function_inlining_transform(self):
def transform_call(node):
# Let's do some function inlining
inferred = next(node.infer())
return inferred
self.transformer.register_transform(nodes.Call, transform_call)
module = self.parse_transform(
"""
def test(): return 42
test() #@
"""
)
self.assertIsInstance(module.body[1], nodes.Expr)
self.assertIsInstance(module.body[1].value, nodes.Const)
self.assertEqual(module.body[1].value.value, 42)
def test_recursive_transforms_into_astroid_fields(self):
# Test that the transformer walks properly the tree
# by going recursively into the _astroid_fields per each node.
def transform_compare(node):
# Let's check the values of the ops
_, right = node.ops[0]
# Assume they are Consts and they were transformed before
# us.
return nodes.const_factory(node.left.value < right.value)
def transform_name(node):
# Should be Consts
return next(node.infer())
self.transformer.register_transform(nodes.Compare, transform_compare)
self.transformer.register_transform(nodes.Name, transform_name)
module = self.parse_transform(
"""
a = 42
b = 24
a < b
"""
)
self.assertIsInstance(module.body[2], nodes.Expr)
self.assertIsInstance(module.body[2].value, nodes.Const)
self.assertFalse(module.body[2].value.value)
def test_transform_patches_locals(self):
def transform_function(node):
assign = nodes.Assign()
name = nodes.AssignName()
name.name = "value"
assign.targets = [name]
assign.value = nodes.const_factory(42)
node.body.append(assign)
self.transformer.register_transform(nodes.FunctionDef, transform_function)
module = self.parse_transform(
"""
def test():
pass
"""
)
func = module.body[0]
self.assertEqual(len(func.body), 2)
self.assertIsInstance(func.body[1], nodes.Assign)
self.assertEqual(func.body[1].as_string(), "value = 42")
def test_predicates(self):
def transform_call(node):
inferred = next(node.infer())
return inferred
def should_inline(node):
return node.func.name.startswith("inlineme")
self.transformer.register_transform(nodes.Call, transform_call, should_inline)
module = self.parse_transform(
"""
def inlineme_1():
return 24
def dont_inline_me():
return 42
def inlineme_2():
return 2
inlineme_1()
dont_inline_me()
inlineme_2()
"""
)
values = module.body[-3:]
self.assertIsInstance(values[0], nodes.Expr)
self.assertIsInstance(values[0].value, nodes.Const)
self.assertEqual(values[0].value.value, 24)
self.assertIsInstance(values[1], nodes.Expr)
self.assertIsInstance(values[1].value, nodes.Call)
self.assertIsInstance(values[2], nodes.Expr)
self.assertIsInstance(values[2].value, nodes.Const)
self.assertEqual(values[2].value.value, 2)
def test_transforms_are_separated(self):
# Test that the transforming is done at a separate
# step, which means that we are not doing inference
# on a partially constructed tree anymore, which was the
# source of crashes in the past when certain inference rules
# were used in a transform.
def transform_function(node):
if node.decorators:
for decorator in node.decorators.nodes:
inferred = next(decorator.infer())
if inferred.qname() == "abc.abstractmethod":
return next(node.infer_call_result())
return None
manager = builder.MANAGER
with add_transform(manager, nodes.FunctionDef, transform_function):
module = builder.parse(
"""
import abc
from abc import abstractmethod
class A(object):
@abc.abstractmethod
def ala(self):
return 24
@abstractmethod
def bala(self):
return 42
"""
)
cls = module["A"]
ala = cls.body[0]
bala = cls.body[1]
self.assertIsInstance(ala, nodes.Const)
self.assertEqual(ala.value, 24)
self.assertIsInstance(bala, nodes.Const)
self.assertEqual(bala.value, 42)
def test_transforms_are_called_for_builtin_modules(self):
# Test that transforms are called for builtin modules.
def transform_function(node):
name = nodes.AssignName()
name.name = "value"
node.args.args = [name]
return node
manager = builder.MANAGER
predicate = lambda node: node.root().name == "time"
with add_transform(manager, nodes.FunctionDef, transform_function, predicate):
builder_instance = builder.AstroidBuilder()
module = builder_instance.module_build(time)
asctime = module["asctime"]
self.assertEqual(len(asctime.args.args), 1)
self.assertIsInstance(asctime.args.args[0], nodes.AssignName)
self.assertEqual(asctime.args.args[0].name, "value")
def test_builder_apply_transforms(self):
def transform_function(node):
return nodes.const_factory(42)
manager = builder.MANAGER
with add_transform(manager, nodes.FunctionDef, transform_function):
astroid_builder = builder.AstroidBuilder(apply_transforms=False)
module = astroid_builder.string_build("""def test(): pass""")
# The transform wasn't applied.
self.assertIsInstance(module.body[0], nodes.FunctionDef)
def test_transform_crashes_on_is_subtype_of(self):
# Test that we don't crash when having is_subtype_of
# in a transform, as per issue #188. This happened
# before, when the transforms weren't in their own step.
def transform_class(cls):
if cls.is_subtype_of("django.db.models.base.Model"):
return cls
return cls
self.transformer.register_transform(nodes.ClassDef, transform_class)
self.parse_transform(
"""
# Change environ to automatically call putenv() if it exists
import os
putenv = os.putenv
try:
# This will fail if there's no putenv
putenv
except NameError:
pass
else:
import UserDict
"""
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,126 @@
# Copyright (c) 2008-2010, 2013 LOGILAB S.A. (Paris, FRANCE) <contact@logilab.fr>
# Copyright (c) 2014 Google, Inc.
# Copyright (c) 2015-2016, 2018 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2016 Dave Baum <dbaum@google.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
import unittest
from astroid import builder
from astroid import InferenceError
from astroid import nodes
from astroid import node_classes
from astroid import util as astroid_util
class InferenceUtil(unittest.TestCase):
def test_not_exclusive(self):
module = builder.parse(
"""
x = 10
for x in range(5):
print (x)
if x > 0:
print ('#' * x)
""",
__name__,
__file__,
)
xass1 = module.locals["x"][0]
assert xass1.lineno == 2
xnames = [n for n in module.nodes_of_class(nodes.Name) if n.name == "x"]
assert len(xnames) == 3
assert xnames[1].lineno == 6
self.assertEqual(node_classes.are_exclusive(xass1, xnames[1]), False)
self.assertEqual(node_classes.are_exclusive(xass1, xnames[2]), False)
def test_if(self):
module = builder.parse(
"""
if 1:
a = 1
a = 2
elif 2:
a = 12
a = 13
else:
a = 3
a = 4
"""
)
a1 = module.locals["a"][0]
a2 = module.locals["a"][1]
a3 = module.locals["a"][2]
a4 = module.locals["a"][3]
a5 = module.locals["a"][4]
a6 = module.locals["a"][5]
self.assertEqual(node_classes.are_exclusive(a1, a2), False)
self.assertEqual(node_classes.are_exclusive(a1, a3), True)
self.assertEqual(node_classes.are_exclusive(a1, a5), True)
self.assertEqual(node_classes.are_exclusive(a3, a5), True)
self.assertEqual(node_classes.are_exclusive(a3, a4), False)
self.assertEqual(node_classes.are_exclusive(a5, a6), False)
def test_try_except(self):
module = builder.parse(
"""
try:
def exclusive_func2():
"docstring"
except TypeError:
def exclusive_func2():
"docstring"
except:
def exclusive_func2():
"docstring"
else:
def exclusive_func2():
"this one redefine the one defined line 42"
"""
)
f1 = module.locals["exclusive_func2"][0]
f2 = module.locals["exclusive_func2"][1]
f3 = module.locals["exclusive_func2"][2]
f4 = module.locals["exclusive_func2"][3]
self.assertEqual(node_classes.are_exclusive(f1, f2), True)
self.assertEqual(node_classes.are_exclusive(f1, f3), True)
self.assertEqual(node_classes.are_exclusive(f1, f4), False)
self.assertEqual(node_classes.are_exclusive(f2, f4), True)
self.assertEqual(node_classes.are_exclusive(f3, f4), True)
self.assertEqual(node_classes.are_exclusive(f3, f2), True)
self.assertEqual(node_classes.are_exclusive(f2, f1), True)
self.assertEqual(node_classes.are_exclusive(f4, f1), False)
self.assertEqual(node_classes.are_exclusive(f4, f2), True)
def test_unpack_infer_uninferable_nodes(self):
node = builder.extract_node(
"""
x = [A] * 1
f = [x, [A] * 2]
f
"""
)
inferred = next(node.infer())
unpacked = list(node_classes.unpack_infer(inferred))
self.assertEqual(len(unpacked), 3)
self.assertTrue(all(elt is astroid_util.Uninferable for elt in unpacked))
def test_unpack_infer_empty_tuple(self):
node = builder.extract_node(
"""
()
"""
)
inferred = next(node.infer())
with self.assertRaises(InferenceError):
list(node_classes.unpack_infer(inferred))
if __name__ == "__main__":
unittest.main()