from numarrayall import *

numarray_nonzero = nonzero
numarray_put = put
numarray_take = take
numarray_ones = ones
numarray_zeros = zeros

def nonzero(a):
    """nonzero() returns either an index array or a tuple of index
    arrays corresponding to the indices where the values of an array
    are not equal to zero.

    This is the Numeric compatible version of nonzero() supplied by
    numarray.numeric.  It differs from numarray.nonzero() in the
    handling of 1D arrays.  numarray.nonzero() always returns a tuple;
    Numeric.nonzero() returns a tuple when the number of dimensions of
    the array is not one.
    """
    nz = numarray_nonzero(a)
    if len(nz) > 1:
        return nz
    elif len(nz) == 1:
        return nz[0]
    else:
        return nz

def take(a, indices, axis=0):
    if not isinstance(a, NumArray) or not a.iscontiguous():
        a = array(a)
    return numarray_take(a, indices, axis=axis, clipmode=WRAP)

def put(a, ind, v):
    """put(a, ind, v) results in a[n] = v[n] for all n in ind
    If v is shorter than mask it will be repeated as necessary.
    In particular v can be a scalar or length 1 array.
    The routine put is the equivalent of the following (although the loop 
    is in C for speed):     
    ind = array(indices, copy=0) 
    v = array(values, copy=0).astype(a, typecode()) 
    for i in ind: a.flat[i] = v[i] 
    a must be a contiguous Numeric array.
    """
    numarray_put(a.flat, ind, v)

def ones(shape, typecode='l', savespace=0):
    return numarray_ones(shape, typecode)

def zeros(shape, typecode='l', savespace=0):
    return numarray_zeros(shape, typecode)

arraytype = NumArray
