@@ -5,49 +5,83 @@ from flint.types.fmpz cimport fmpz, any_as_fmpz
55from flint.types.fmpz_poly cimport any_as_fmpz_poly
66from flint.types.fmpz_poly cimport fmpz_poly
77from flint.types.nmod cimport any_as_nmod_ctx
8- from flint.types.nmod cimport nmod
8+ from flint.types.nmod cimport nmod, nmod_ctx
99
1010from flint.flintlib.nmod_vec cimport *
1111from flint.flintlib.nmod_poly cimport *
1212from flint.flintlib.nmod_poly_factor cimport *
1313from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly
14- from flint.flintlib.ulong_extras cimport n_gcdinv
14+ from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime
1515
1616from flint.utils.flint_exceptions import DomainError
1717
1818
19- cdef any_as_nmod_poly(obj, nmod_t mod):
20- cdef nmod_poly r
21- cdef mp_limb_t v
22- # XXX: should check that modulus is the same here, and not all over the place
23- if typecheck(obj, nmod_poly):
19+ _nmod_poly_ctx_cache = {}
20+
21+
22+ cdef nmod_ctx any_as_nmod_poly_ctx(obj):
23+ """ Convert an int to an nmod_ctx."""
24+ if typecheck(obj, nmod_poly_ctx):
2425 return obj
25- if any_as_nmod(& v, obj, mod):
26- r = nmod_poly.__new__ (nmod_poly)
27- nmod_poly_init(r.val, mod.n)
28- nmod_poly_set_coeff_ui(r.val, 0 , v)
29- return r
30- x = any_as_fmpz_poly(obj)
31- if x is not NotImplemented :
32- r = nmod_poly.__new__ (nmod_poly)
33- nmod_poly_init(r.val, mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
34- fmpz_poly_get_nmod_poly(r.val, (< fmpz_poly> x).val)
35- return r
26+ if typecheck(obj, int ):
27+ ctx = _nmod_poly_ctx_cache.get(obj)
28+ if ctx is None :
29+ ctx = nmod_poly_ctx(obj)
30+ _nmod_poly_ctx_cache[obj] = ctx
31+ return ctx
3632 return NotImplemented
3733
38- cdef nmod_poly_set_list(nmod_poly_t poly, list val):
39- cdef long i, n
40- cdef nmod_t mod
41- cdef mp_limb_t v
42- nmod_init(& mod, nmod_poly_modulus(poly)) # XXX
43- n = PyList_GET_SIZE(val)
44- nmod_poly_fit_length(poly, n)
45- for i from 0 <= i < n:
46- c = val[i]
47- if any_as_nmod(& v, val[i], mod):
48- nmod_poly_set_coeff_ui(poly, i, v)
49- else :
50- raise TypeError (" unsupported coefficient in list" )
34+
35+ cdef class nmod_poly_ctx:
36+ """
37+ Context object for creating :class:`~.nmod_poly` initalised
38+ with modulus :math:`N`.
39+
40+ >>> nmod_ctx(17)
41+ nmod_ctx(17)
42+
43+ """
44+ def __init__ (self , mod ):
45+ cdef mp_limb_t m
46+ m = mod
47+ nmod_init(& self .mod, m)
48+ self .ctx = nmod_ctx(mod)
49+ self ._is_prime = n_is_prime(m)
50+
51+ cdef int any_as_nmod(self , mp_limb_t * val, obj) except - 1 :
52+ return self .ctx.any_as_nmod(val, obj)
53+
54+ cdef any_as_nmod_poly(self , obj):
55+ cdef nmod_poly r
56+ cdef mp_limb_t v
57+ # XXX: should check that modulus is the same here, and not all over the place
58+ if typecheck(obj, nmod_poly):
59+ return obj
60+ if self .ctx.any_as_nmod(& v, obj):
61+ r = nmod_poly.__new__ (nmod_poly)
62+ nmod_poly_init(r.val, self .mod.n)
63+ nmod_poly_set_coeff_ui(r.val, 0 , v)
64+ return r
65+ x = any_as_fmpz_poly(obj)
66+ if x is not NotImplemented :
67+ r = nmod_poly.__new__ (nmod_poly)
68+ nmod_poly_init(r.val, self .mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
69+ fmpz_poly_get_nmod_poly(r.val, (< fmpz_poly> x).val)
70+ return r
71+ return NotImplemented
72+
73+ cdef nmod_poly_set_list(self , nmod_poly_t poly, list val):
74+ cdef long i, n
75+ cdef mp_limb_t v
76+ n = PyList_GET_SIZE(val)
77+ nmod_poly_fit_length(poly, n)
78+ for i from 0 <= i < n:
79+ c = val[i]
80+ if self .any_as_nmod(& v, val[i]):
81+ nmod_poly_set_coeff_ui(poly, i, v)
82+ else :
83+ raise TypeError (" unsupported coefficient in list" )
84+
5185
5286cdef class nmod_poly(flint_poly):
5387 """
@@ -79,24 +113,32 @@ cdef class nmod_poly(flint_poly):
79113 def __dealloc__ (self ):
80114 nmod_poly_clear(self .val)
81115
82- def __init__ (self , val = None , ulong mod = 0 ):
116+ def __init__ (self , val = None , mod = 0 ):
83117 cdef ulong m2
84118 cdef mp_limb_t v
119+ cdef nmod_poly_ctx ctx
120+
85121 if typecheck(val, nmod_poly):
86122 m2 = nmod_poly_modulus((< nmod_poly> val).val)
87123 if m2 != mod:
88124 raise ValueError (" different moduli!" )
89125 nmod_poly_init(self .val, m2)
90126 nmod_poly_set(self .val, (< nmod_poly> val).val)
127+ self .ctx = (< nmod_poly> val).ctx
91128 else :
92129 if mod == 0 :
93130 raise ValueError (" a nonzero modulus is required" )
94- nmod_poly_init(self .val, mod)
131+ ctx = any_as_nmod_poly_ctx(mod)
132+ if ctx is NotImplemented :
133+ raise TypeError (" cannot create nmod_poly_ctx from input of type %s " , type (mod))
134+
135+ self .ctx = ctx
136+ nmod_poly_init(self .val, ctx.mod.n)
95137 if typecheck(val, fmpz_poly):
96138 fmpz_poly_get_nmod_poly(self .val, (< fmpz_poly> val).val)
97139 elif typecheck(val, list ):
98- nmod_poly_set_list(self .val, val)
99- elif any_as_nmod(& v, val, self .val.mod ):
140+ ctx. nmod_poly_set_list(self .val, val)
141+ elif ctx. any_as_nmod(& v, val):
100142 nmod_poly_fit_length(self .val, 1 )
101143 nmod_poly_set_coeff_ui(self .val, 0 , v)
102144 else :
@@ -178,7 +220,7 @@ cdef class nmod_poly(flint_poly):
178220 cdef mp_limb_t v
179221 if i < 0 :
180222 raise ValueError (" cannot assign to index < 0 of polynomial" )
181- if any_as_nmod(& v, x, self .val.mod ):
223+ if self .ctx. any_as_nmod(& v, x):
182224 nmod_poly_set_coeff_ui(self .val, i, v)
183225 else :
184226 raise TypeError (" cannot set element of type %s " % type (x))
@@ -291,7 +333,7 @@ cdef class nmod_poly(flint_poly):
291333 9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
292334 """
293335 cdef nmod_poly res
294- other = any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
336+ other = self .ctx.any_as_nmod_poly(other )
295337 if other is NotImplemented :
296338 raise TypeError (" cannot convert input to nmod_poly" )
297339 res = nmod_poly.__new__ (nmod_poly)
@@ -316,11 +358,11 @@ cdef class nmod_poly(flint_poly):
316358 147*x^3 + 159*x^2 + 4*x + 7
317359 """
318360 cdef nmod_poly res
319- g = any_as_nmod_poly(other, self .val.mod )
361+ g = self .ctx.any_as_nmod_poly(other )
320362 if g is NotImplemented :
321363 raise TypeError (f" cannot convert {other = } to nmod_poly" )
322364
323- h = any_as_nmod_poly(modulus, self .val.mod )
365+ h = self . any_as_nmod_poly(modulus)
324366 if h is NotImplemented :
325367 raise TypeError (f" cannot convert {modulus = } to nmod_poly" )
326368
@@ -334,11 +376,11 @@ cdef class nmod_poly(flint_poly):
334376
335377 def __call__ (self , other ):
336378 cdef mp_limb_t c
337- if any_as_nmod(& c, other, self .val.mod ):
379+ if self .ctx. any_as_nmod(& c, other):
338380 v = nmod(0 , self .modulus())
339381 (< nmod> v).val = nmod_poly_evaluate_nmod(self .val, c)
340382 return v
341- t = any_as_nmod_poly(other, self .val.mod )
383+ t = self .ctx.any_as_nmod_poly(other )
342384 if t is not NotImplemented :
343385 r = nmod_poly.__new__ (nmod_poly)
344386 nmod_poly_init_preinv((< nmod_poly> r).val, self .val.mod.n, self .val.mod.ninv)
@@ -369,7 +411,7 @@ cdef class nmod_poly(flint_poly):
369411
370412 def _add_ (s , t ):
371413 cdef nmod_poly r
372- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
414+ t = s.ctx.any_as_nmod_poly(t )
373415 if t is NotImplemented :
374416 return t
375417 if (< nmod_poly> s).val.mod.n != (< nmod_poly> t).val.mod.n:
@@ -395,20 +437,20 @@ cdef class nmod_poly(flint_poly):
395437 return r
396438
397439 def __sub__ (s , t ):
398- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
440+ t = s.ctx.any_as_nmod_poly(t )
399441 if t is NotImplemented :
400442 return t
401443 return s._sub_(t)
402444
403445 def __rsub__ (s , t ):
404- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
446+ t = s. any_as_nmod_poly(t)
405447 if t is NotImplemented :
406448 return t
407449 return t._sub_(s)
408450
409451 def _mul_ (s , t ):
410452 cdef nmod_poly r
411- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
453+ t = s. any_as_nmod_poly(t)
412454 if t is NotImplemented :
413455 return t
414456 if (< nmod_poly> s).val.mod.n != (< nmod_poly> t).val.mod.n:
@@ -425,7 +467,7 @@ cdef class nmod_poly(flint_poly):
425467 return s._mul_(t)
426468
427469 def __truediv__ (s , t ):
428- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
470+ t = s. any_as_nmod_poly(t)
429471 if t is NotImplemented :
430472 return t
431473 res, r = s._divmod_(t)
@@ -434,7 +476,7 @@ cdef class nmod_poly(flint_poly):
434476 return res
435477
436478 def __rtruediv__ (s , t ):
437- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
479+ t = s. any_as_nmod_poly(t)
438480 if t is NotImplemented :
439481 return t
440482 res, r = t._divmod_(s)
@@ -454,13 +496,13 @@ cdef class nmod_poly(flint_poly):
454496 return r
455497
456498 def __floordiv__ (s , t ):
457- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
499+ t = s. any_as_nmod_poly(t)
458500 if t is NotImplemented :
459501 return t
460502 return s._floordiv_(t)
461503
462504 def __rfloordiv__ (s , t ):
463- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
505+ t = s. any_as_nmod_poly(t)
464506 if t is NotImplemented :
465507 return t
466508 return t._floordiv_(s)
@@ -479,13 +521,13 @@ cdef class nmod_poly(flint_poly):
479521 return P, Q
480522
481523 def __divmod__ (s , t ):
482- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
524+ t = s. any_as_nmod_poly(t)
483525 if t is NotImplemented :
484526 return t
485527 return s._divmod_(t)
486528
487529 def __rdivmod__ (s , t ):
488- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
530+ t = s. any_as_nmod_poly(t)
489531 if t is NotImplemented :
490532 return t
491533 return t._divmod_(s)
@@ -534,7 +576,7 @@ cdef class nmod_poly(flint_poly):
534576 if e < 0 :
535577 raise ValueError (" Exponent must be non-negative" )
536578
537- modulus = any_as_nmod_poly(modulus, ( < nmod_poly > self ).val.mod )
579+ modulus = self .ctx.any_as_nmod_poly(modulus )
538580 if modulus is NotImplemented :
539581 raise TypeError (" cannot convert input to nmod_poly" )
540582
@@ -556,7 +598,7 @@ cdef class nmod_poly(flint_poly):
556598
557599 # To optimise powering, we precompute the inverse of the reverse of the modulus
558600 if mod_rev_inv is not None :
559- mod_rev_inv = any_as_nmod_poly(mod_rev_inv, ( < nmod_poly > self ).val.mod )
601+ mod_rev_inv = self . any_as_nmod_poly(mod_rev_inv)
560602 if mod_rev_inv is NotImplemented :
561603 raise TypeError (f" Cannot interpret {mod_rev_inv} as a polynomial" )
562604 else :
@@ -585,7 +627,7 @@ cdef class nmod_poly(flint_poly):
585627
586628 """
587629 cdef nmod_poly res
588- other = any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
630+ other = self . any_as_nmod_poly(other)
589631 if other is NotImplemented :
590632 raise TypeError (" cannot convert input to nmod_poly" )
591633 if self .val.mod.n != (< nmod_poly> other).val.mod.n:
@@ -597,7 +639,7 @@ cdef class nmod_poly(flint_poly):
597639
598640 def xgcd (self , other ):
599641 cdef nmod_poly res1, res2, res3
600- other = any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
642+ other = self . any_as_nmod_poly(other)
601643 if other is NotImplemented :
602644 raise TypeError (" cannot convert input to fmpq_poly" )
603645 res1 = nmod_poly.__new__ (nmod_poly)
0 commit comments