API¶
Here you can find description of functions and methods avaliable in ttax.
Module contents¶
- class ttax.base_class.BatchIndexing(tt)¶
Bases:
object
- class ttax.base_class.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Tensor object as a list of TT-cores.
TT-Tensor cores take form (r_l, n, r_r), where
r_l, r_r are TT-ranks
n makes TT-Tensor shape
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Tensors.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.
- Returns
TT-Tensor shape
- Return type
tuple
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.base_class.TTBase¶
Bases:
objectRepresents the base for both TT-Tensor and TT-Matrix (TT-object). Includes some basic routines and properties.
- property axis_dim¶
Get the position of mode axis in TT-core. It could differ according to the batch shape.
- Returns
index
- Return type
int
- property batch_loc¶
Represents the batch indexing for TT-object. Wraps TT-object by special BatchIndexing class with overloaded
__getitem__method.Example
tt.batch_loc[1, :, :]
- property batch_shape¶
Get the list representing the shape of the batch of TT-object.
- Returns
batch shape
- Return type
list
- property dtype¶
Represents the dtype of elements in TT-object.
- Returns
dtype of elements
- Return type
dtype
- property ndim¶
Get the number of dimensions of the TT-object.
- Returns
dimensions number
- Return type
int
- property tt_ranks¶
Get TT-ranks of the TT-object in amount of
ndim + 1. The first TT-rank and the last one equals to 1.- Returns
TT-ranks
- Return type
list
- class ttax.base_class.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Matrix object as a list of TT-cores.
TT-Matrix cores take form (r_l, n_l, n_r, r_r), where
r_l, r_r are TT-ranks just as for TT-Tensor
n_l, n_r make left and right shapes of TT-Matrix as rows and cols
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Matrices.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).
- Returns
TT-Matrix shapes
- Return type
list, list
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).
- Returns
TT-Matrix shape in dense form with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.ops.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Tensor object as a list of TT-cores.
TT-Tensor cores take form (r_l, n, r_r), where
r_l, r_r are TT-ranks
n makes TT-Tensor shape
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Tensors.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.
- Returns
TT-Tensor shape
- Return type
tuple
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.ops.TTBase¶
Bases:
objectRepresents the base for both TT-Tensor and TT-Matrix (TT-object). Includes some basic routines and properties.
- property axis_dim¶
Get the position of mode axis in TT-core. It could differ according to the batch shape.
- Returns
index
- Return type
int
- property batch_loc¶
Represents the batch indexing for TT-object. Wraps TT-object by special BatchIndexing class with overloaded
__getitem__method.Example
tt.batch_loc[1, :, :]
- property batch_shape¶
Get the list representing the shape of the batch of TT-object.
- Returns
batch shape
- Return type
list
- property dtype¶
Represents the dtype of elements in TT-object.
- Returns
dtype of elements
- Return type
dtype
- property ndim¶
Get the number of dimensions of the TT-object.
- Returns
dimensions number
- Return type
int
- property tt_ranks¶
Get TT-ranks of the TT-object in amount of
ndim + 1. The first TT-rank and the last one equals to 1.- Returns
TT-ranks
- Return type
list
- class ttax.ops.TTEinsum(inputs, output, how_to_apply, order='left-to-right')¶
Bases:
objectA class which contains einsum rule, which is needed for fusion to work.
- apply_mapping(mapping: Dict[str, str])¶
Rename letters according to the given mapping.
- Returns
new TTEinsum with renamed latters
- Return type
TTEinsum
- change_input(input_idx: int, new_inputs: List)¶
Change argument
input_idxintonew_inputs.- E.g.
tt_einsum = TTEinsum(inputs=[[‘a’, ‘i’, ‘b’], [‘c’, ‘i’, ‘d’]], output=[‘ac’, ‘i’, ‘bd’], how_to_apply=’independent’)
tt_einsum.change_input(0, [[‘e’, ‘i’, ‘f’], [‘g’, ‘i’, ‘h’]])
print(tt_einsum.to_vanilla_einsum())
will return ‘eif,gih,cid->acibd’
- resolve_i_or_ij(is_tt_matrix)¶
Return a version of TTEinsum with
I_OR_IJchanged to either ‘i’ or ‘ij’.
- to_distinct_letters(distinct_from)¶
Rename letters to make them distinct from letters used in
distinct_from.
- to_vanilla_einsum()¶
Build regular einsum.
- class ttax.ops.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Matrix object as a list of TT-cores.
TT-Matrix cores take form (r_l, n_l, n_r, r_r), where
r_l, r_r are TT-ranks just as for TT-Tensor
n_l, n_r make left and right shapes of TT-Matrix as rows and cols
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Matrices.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).
- Returns
TT-Matrix shapes
- Return type
list, list
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).
- Returns
TT-Matrix shape in dense form with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.ops.WrappedTT(tt: ttax.base_class.TT, tt_inputs=None, tt_einsum=None)¶
Bases:
objectA class which wraps TT-object, which is needed for fusion to work.
Base TT-object class can only have jnp.array objects so that you can pass it into jitted function. But, for fusing two functions together we need to track which operation created a TT-object, so while fusing ops we wrap TT-objects with this class, to track that.
- property axis_dim¶
Get the position of mode axis in underlying TT-core. It could differ according to the batch shape.
- Returns
index
- Return type
int
- property batch_loc¶
Represents the batch indexing for underlying TT-object. Wraps TT-object by special BatchIndexing class with overloaded
__getitem__method.Example
tt.batch_loc[1, :, :]
- property batch_shape¶
Get the list representing the shape of the batch of underlying TT-object.
- Returns
batch shape
- Return type
list
- property dtype¶
Represents the dtype of elements in underlying TT-object.
- Returns
dtype of elements
- Return type
dtype
- property is_tt_matrix¶
Determine whether the underlying TT-object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property ndim¶
Get the number of dimensions of the underlying TT-object.
- Returns
dimensions number
- Return type
int
- property num_batch_dims¶
Get the number of batch dimensions for batch of underlying TT-object.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of underlying TT-object. Depends on
raw_tensor_shape.- Returns
shape
- Return type
list
- property shape¶
Get the tuple representing the shape of underlying TT-object. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- property tt_cores¶
Get the list of TT-cores of underlying TT-object.
- Returns
TT-cores
- Return type
list
- property tt_ranks¶
Get TT-ranks of the underlying TT-object in amount of
ndim + 1. The first TT-rank and the last one equals to 1.- Returns
TT-ranks
- Return type
list
- ttax.ops.add(tt_a, tt_b)¶
Returns a TT-object corresponding to elementwise sum tt_a + tt_b. The shapes of tt_a and tt_b should coincide. Supports broadcasting, e.g. you can add a tensor train with batch size 7 and a tensor train with batch size 1:
tt_batch.add(tt_single.batch_loc[np.newaxis])where
tt_single.batch_loc[np.newaxis]creates a singleton batch dimension.- Parameters
tt_a (TT-Tensor or TT-Matrix) – first argument
tt_b (TT-Tensor or TT-Matrix) – second argument
- Return type
TT-Tensor or TT-Matrix
- Returns
tt_a + tt_b
- Raises
[ValueError] – if the arguments shapes do not coincide
- ttax.ops.are_batches_broadcastable(tt_a, tt_b)¶
Returns the result of compatibility check of 2 tensors’ batches: True if batches are compatible and False otherwise. The batch sizes should be equal otherwise at least one of them should equal to 1 for broadcasting to be available.
- Parameters
tt_a (TT-Tensor or TT-Matrix) – first argument to check
tt_b (TT-Tensor or TT-Matrix) – second argument to check
- Returns
the result of broadcasting check
- Return type
bool
- ttax.ops.are_shapes_equal(tt_a, tt_b)¶
Returns the result of equality check of 2 tensors’ shapes: True if shapes are equal and False otherwise. The arguments should be both TT-Tensors or both TT-Matrices. The arguments should have the same tensor shape but their TT-ranks differ.
- Parameters
tt_a (TT-Tensor or TT-Matrix) – first argument to check
tt_b (TT-Tensor or TT-Matrix) – second argument to check
- Returns
tensor_check - the result of shape check
- Return type
bool
- ttax.ops.flat_inner(a, b)¶
Calculate inner product of given TT-Tensors or TT-Matrices wrapped with WrappedTT.
- Parameters
a (WrappedTT) – first argument
b (WrappedTT) – second argument
- Rerurn
the result of inner product
- Return type
WrappedTT
- ttax.ops.full(tt: ttax.base_class.TTBase) → jax._src.numpy.lax_numpy.array¶
Converts TT-Tensor or TT-Matrix into a dense format.
- ttax.ops.full_tt_matrix(tt: ttax.base_class.TTMatrix) → jax._src.numpy.lax_numpy.array¶
Converts TT-matrix into a regular matrix.
- ttax.ops.full_tt_tensor(tt: ttax.base_class.TT) → jax._src.numpy.lax_numpy.array¶
Converts TT into a regular tensor.
- ttax.ops.is_tt_matrix(arg) → bool¶
Determine whether the object is a TT-Matrix or WrappedTT with underlying TT-Matrix.
- Returns
True if TT-Matrix or WrappedTT(TT-Matrix), False otherwise
- Return type
bool
- ttax.ops.is_tt_object(arg) → bool¶
Determine whether the object is a TT-Tensor, TT-Matrix or WrappedTT with one of them.
- Returns
True if TT-object, False otherwise
- Return type
bool
- ttax.ops.is_tt_tensor(arg) → bool¶
Determine whether the object is a TT-Tensor or WrappedTT with underlying TT-Tensor.
- Returns
True if TT-Tensor or WrappedTT(TT-Tensor), False otherwise
- Return type
bool
- ttax.ops.matmul(a, b)¶
Calculate matrix multiplication of given TT-Matrices wrapped with WrappedTT.
- Parameters
a (WrappedTT) – first argument
b (WrappedTT) – second argument
- Rerurn
the result of inner product
- Return type
WrappedTT
- ttax.ops.multiply(a, b)¶
Calculate elementwise product of 2 TT-Tensors TT-Matrices or their product by scalar. Arguments could be wrapped by WrappedTT or not.
- Parameters
a (Union[float, TT-object]) – first argument
b (Union[float, TT-object]) – second argument
- Returns
the result of elementwise product
- Return type
TT-object
- ttax.ops.multiply_by_scalar(a, b)¶
Returns the result of multiplication so called TT-object (TTTensOrMat or WrappedTT) by scalar. Takes 2 arguments as input, one of which is TT-object and other is a scalar. Does not depends on arguments order.
- Returns
the result of multiplication by scalar
- Return type
TTTensOrMat
- ttax.ops.to_function(tt_einsum: ttax.compile.TTEinsum) → Callable¶
Compile TT-einsum into a function.
Example:
def multiply(a, b):
tt_einsum = TTEinsum( inputs=[[‘a’, ‘i’, ‘b’], [‘c’, ‘i’, ‘d’]], output=[‘ac’, ‘i’, ‘bd’], how_to_apply=’independent’ )
func = tt_einsum.to_function()
return func(a, b)
- ttax.ops.tt_tt_multiply(a, b)¶
- ttax.ops.tt_vmap(num_batch_args=None)¶
Decorator which makes a function support batch TT-inputs.
- Parameters
num_batch_args (int or None) –
The amount of arguments that are batches of TT-objects.
If None, than function will be mapped over all arguments.
If integer, specifies the count of first arguments to map over, e.g. num_batch_args=n means that function will be mapped over first n arguments.
- Returns
Decorator
- Comments:
The function is vmapped num_batch_dims times, as it supports multidimensional batches. The number of batch dimension to be mapped over is shown by num_batch_dims property and should be the same for all args of the function, by which it will be mapped over. Otherwise such axis should be specified by num_batch_args.
- ttax.ops.unwrap_tt(arg)¶
Unwraps argument if it is of WrappedTT class, otherwise just returns the argument.
- Parameters
arg (WrappedTT or TTTensOrMat) – argument to unwrap
- Return type
TTTensOrMat
- Returns
unwrapped argument
- class ttax.decompositions.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Tensor object as a list of TT-cores.
TT-Tensor cores take form (r_l, n, r_r), where
r_l, r_r are TT-ranks
n makes TT-Tensor shape
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Tensors.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.
- Returns
TT-Tensor shape
- Return type
tuple
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.decompositions.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Matrix object as a list of TT-cores.
TT-Matrix cores take form (r_l, n_l, n_r, r_r), where
r_l, r_r are TT-ranks just as for TT-Tensor
n_l, n_r make left and right shapes of TT-Matrix as rows and cols
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Matrices.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).
- Returns
TT-Matrix shapes
- Return type
list, list
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).
- Returns
TT-Matrix shape in dense form with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- ttax.decompositions.orthogonalize(tt, left_to_right=True)¶
Orthogonalize TT-cores of a TT-object.
- Parameters
tt (TT-Tensor or TT-Matrix) – TT-object which TT-cores would be orthogonalized
left_to_right (bool) – the direction of orthogonalization, True for left to right and False for right to left
- Returns
TT-object with orthogonalized TT-cores
- Return type
TT-Tensor or TT-Matrix
- ttax.decompositions.round(tt, max_tt_rank=None, epsilon=None)¶
Tensor Train rounding procedure, returns a TT-object with smaller TT-ranks.
- Parameters
tt (TT-Tensor or TT-Matrix) – argument which ranks would be reduced
max_tt_rank (int or list of ints) –
If a number, than defines the maximal TT-rank of the result.
If a list of numbers, than max_tt_rank length should be d+1 (where d is the number of dimensions) and max_tt_rank[i] defines the maximal (i+1)-th TT-rank of the result.
The following two versions are equivalent
max_tt_rank = rmax_tt_rank = [1] + [r] * (d-1) + [1]
epsilon (float or None) –
If the TT-ranks are not restricted (max_tt_rank=None), then the result would be guarantied to be epsilon-close to tt in terms of relative Frobenius error:
||res - tt||_F / ||tt||_F <= epsilon
If the TT-ranks are restricted, providing a loose epsilon may reduce the TT-ranks of the result. E.g.
round(tt, max_tt_rank=100, epsilon=0.9)will probably return you a TT-Tensor with TT-ranks close to 1, not 100. Note that providing a nontrivial (= not equal to None) epsilon will make the TT-ranks of the result change depending on the data, which will prevent you from using
jax.jitfor speeding up the computations.
- Returns
TT-object with reduced TT-ranks
- Return type
TT-Tensor or TT-Matrix
- Raises
ValueError if max_tt_rank is less than 0, if max_tt_rank is not a number and not a vector of length d + 1 where d is the number of dimensions of the input tensor, if epsilon is less than 0.
- ttax.decompositions.tt_vmap(num_batch_args=None)¶
Decorator which makes a function support batch TT-inputs.
- Parameters
num_batch_args (int or None) –
The amount of arguments that are batches of TT-objects.
If None, than function will be mapped over all arguments.
If integer, specifies the count of first arguments to map over, e.g. num_batch_args=n means that function will be mapped over first n arguments.
- Returns
Decorator
- Comments:
The function is vmapped num_batch_dims times, as it supports multidimensional batches. The number of batch dimension to be mapped over is shown by num_batch_dims property and should be the same for all args of the function, by which it will be mapped over. Otherwise such axis should be specified by num_batch_args.
Utils for compiling functions defined as einsum strings.
Here we use the notion of tt_einsum, which is similar to einsum strings, but with more structure.
Basic element of tt_einsum is a tt_einsum_core: list with three elements, which defines an einsum string for a single TT-core. First element is indices for the left TT-rank, second element is the indices for the main dimensions of the resulting TT-core, and the last element is the indices for the right TT-rank.
TT_einsum consists of list of input and output cores defined like with tt_einsum_cores.
- class ttax.compile.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Tensor object as a list of TT-cores.
TT-Tensor cores take form (r_l, n, r_r), where
r_l, r_r are TT-ranks
n makes TT-Tensor shape
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Tensors.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.
- Returns
TT-Tensor shape
- Return type
tuple
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.compile.TTEinsum(inputs, output, how_to_apply, order='left-to-right')¶
Bases:
objectA class which contains einsum rule, which is needed for fusion to work.
- apply_mapping(mapping: Dict[str, str])¶
Rename letters according to the given mapping.
- Returns
new TTEinsum with renamed latters
- Return type
TTEinsum
- change_input(input_idx: int, new_inputs: List)¶
Change argument
input_idxintonew_inputs.- E.g.
tt_einsum = TTEinsum(inputs=[[‘a’, ‘i’, ‘b’], [‘c’, ‘i’, ‘d’]], output=[‘ac’, ‘i’, ‘bd’], how_to_apply=’independent’)
tt_einsum.change_input(0, [[‘e’, ‘i’, ‘f’], [‘g’, ‘i’, ‘h’]])
print(tt_einsum.to_vanilla_einsum())
will return ‘eif,gih,cid->acibd’
- resolve_i_or_ij(is_tt_matrix)¶
Return a version of TTEinsum with
I_OR_IJchanged to either ‘i’ or ‘ij’.
- to_distinct_letters(distinct_from)¶
Rename letters to make them distinct from letters used in
distinct_from.
- to_vanilla_einsum()¶
Build regular einsum.
- class ttax.compile.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Matrix object as a list of TT-cores.
TT-Matrix cores take form (r_l, n_l, n_r, r_r), where
r_l, r_r are TT-ranks just as for TT-Tensor
n_l, n_r make left and right shapes of TT-Matrix as rows and cols
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Matrices.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).
- Returns
TT-Matrix shapes
- Return type
list, list
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).
- Returns
TT-Matrix shape in dense form with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.compile.WrappedTT(tt: ttax.base_class.TT, tt_inputs=None, tt_einsum=None)¶
Bases:
objectA class which wraps TT-object, which is needed for fusion to work.
Base TT-object class can only have jnp.array objects so that you can pass it into jitted function. But, for fusing two functions together we need to track which operation created a TT-object, so while fusing ops we wrap TT-objects with this class, to track that.
- property axis_dim¶
Get the position of mode axis in underlying TT-core. It could differ according to the batch shape.
- Returns
index
- Return type
int
- property batch_loc¶
Represents the batch indexing for underlying TT-object. Wraps TT-object by special BatchIndexing class with overloaded
__getitem__method.Example
tt.batch_loc[1, :, :]
- property batch_shape¶
Get the list representing the shape of the batch of underlying TT-object.
- Returns
batch shape
- Return type
list
- property dtype¶
Represents the dtype of elements in underlying TT-object.
- Returns
dtype of elements
- Return type
dtype
- property is_tt_matrix¶
Determine whether the underlying TT-object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property ndim¶
Get the number of dimensions of the underlying TT-object.
- Returns
dimensions number
- Return type
int
- property num_batch_dims¶
Get the number of batch dimensions for batch of underlying TT-object.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of underlying TT-object. Depends on
raw_tensor_shape.- Returns
shape
- Return type
list
- property shape¶
Get the tuple representing the shape of underlying TT-object. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- property tt_cores¶
Get the list of TT-cores of underlying TT-object.
- Returns
TT-cores
- Return type
list
- property tt_ranks¶
Get TT-ranks of the underlying TT-object in amount of
ndim + 1. The first TT-rank and the last one equals to 1.- Returns
TT-ranks
- Return type
list
- ttax.compile.apply_single_mapping(strings, mapping)¶
Apply letter mapping to a list of strings.
- ttax.compile.compile_cumulative(tt_einsum: ttax.compile.TTEinsum) → Callable¶
- ttax.compile.compile_independent(tt_einsum: ttax.compile.TTEinsum) → Callable¶
- ttax.compile.fuse(func)¶
Fuse a composite function to make it faster.
Example:
Let’s look at f(a, b, c) = <a * b, c> = sum_{i_1, …, i_d} a[i_1, …, i_d] b[i_1, …, i_d] c[i_1, …, i_d], as ttax.flat_inner(a * b, c) do.
Function f can be suboptimal for some inputs. For example, if a and b are of large TT-rank, and c is of low TT-rank, implementing the same operation as
ttax.flat_inner(a * c, b)
would be much more efficient.
fuse automates such optimizations. You can build an optimal implementation of this function for any inputs by doing
faster_f = ttax.fuse(f)
Finally, don’t forget that in JAX to get good speed you need to wrap you highest level function in jit, e.g.
faster_f = jax.jit(faster_f)
Now, by using faster_f(a, b, c) instead of f(a, b, c) you can achieve a much faster cumulative time for any inputs.
- ttax.compile.to_function(tt_einsum: ttax.compile.TTEinsum) → Callable¶
Compile TT-einsum into a function.
Example:
def multiply(a, b):
tt_einsum = TTEinsum( inputs=[[‘a’, ‘i’, ‘b’], [‘c’, ‘i’, ‘d’]], output=[‘ac’, ‘i’, ‘bd’], how_to_apply=’independent’ )
func = tt_einsum.to_function()
return func(a, b)
- ttax.compile.unwrap_tt(arg)¶
Unwraps argument if it is of WrappedTT class, otherwise just returns the argument.
- Parameters
arg (WrappedTT or TTTensOrMat) – argument to unwrap
- Return type
TTTensOrMat
- Returns
unwrapped argument
- class ttax.random_.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Tensor object as a list of TT-cores.
TT-Tensor cores take form (r_l, n, r_r), where
r_l, r_r are TT-ranks
n makes TT-Tensor shape
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Tensors.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.
- Returns
TT-Tensor shape
- Return type
tuple
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.random_.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Matrix object as a list of TT-cores.
TT-Matrix cores take form (r_l, n_l, n_r, r_r), where
r_l, r_r are TT-ranks just as for TT-Tensor
n_l, n_r make left and right shapes of TT-Matrix as rows and cols
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Matrices.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).
- Returns
TT-Matrix shapes
- Return type
list, list
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).
- Returns
TT-Matrix shape in dense form with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- ttax.random_.matrix(rng, shape, tt_rank=2, batch_shape=None, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)¶
Generate a random TT-Matrix of the given shape and TT-rank.
- Parameters
rng (random state is described by two unsigned 32-bit integers) – JAX PRNG key
shape (array) – desired tensor shape
tt_rank (single number for equal TT-ranks or array specifying all TT-ranks) – desired TT-ranks of TT-Matrix
batch_shape (array) – desired batch shape of TT-Matrix
dtype (dtype) – type of elements in TT-Matrix
- Returns
generated TT-Matrix
- Return type
- ttax.random_.tensor(rng, shape, tt_rank=2, batch_shape=None, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)¶
Generate a random TT-Tensor of the given shape and TT-rank.
- Parameters
rng (random state is described by two unsigned 32-bit integers) – JAX PRNG key
shape (array) – desired tensor shape
tt_rank (single number for equal TT-ranks or array specifying all TT-ranks) – desired TT-ranks of TT-Tensor
batch_shape (array) – desired batch shape of TT-Tensor
dtype (dtype) – type of elements in TT-Tensor
- Returns
generated TT-Tensor
- Return type
- class ttax.riemannian.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Tensor object as a list of TT-cores.
TT-Tensor cores take form (r_l, n, r_r), where
r_l, r_r are TT-ranks
n makes TT-Tensor shape
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Tensors.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.
- Returns
TT-Tensor shape
- Return type
tuple
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.riemannian.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Matrix object as a list of TT-cores.
TT-Matrix cores take form (r_l, n_l, n_r, r_r), where
r_l, r_r are TT-ranks just as for TT-Tensor
n_l, n_r make left and right shapes of TT-Matrix as rows and cols
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Matrices.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).
- Returns
TT-Matrix shapes
- Return type
list, list
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).
- Returns
TT-Matrix shape in dense form with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- ttax.riemannian.deltas_to_tangent(deltas: List[jax._src.numpy.lax_numpy.ndarray], tt: Union[ttax.base_class.TT, ttax.base_class.TTMatrix]) → Union[ttax.base_class.TT, ttax.base_class.TTMatrix]¶
Converts deltas representation of tangent space vector to TT-object. Takes as input a list of [dP1, …, dPd] and returns dP1 V2 … Vd + U1 dP2 V3 … Vd + … + U1 … Ud-1 dPd.
This function is hard to use correctly because deltas should obey the so called gauge conditions. If they don’t, the function will silently return incorrect result. That is why this function is not imported in __init__.
- Parameters
deltas – a list of deltas (essentially TT-cores) obeying the gauge conditions.
tt (TT-Tensor or TT-Matrix) – object on which the tangent space tensor represented by delta is projected.
- Returns
object constructed from deltas, that is from the tangent space at point tt.
- Return type
TT-Tensor or TT-Matrix
- ttax.riemannian.orthogonalize(tt, left_to_right=True)¶
Orthogonalize TT-cores of a TT-object.
- Parameters
tt (TT-Tensor or TT-Matrix) – TT-object which TT-cores would be orthogonalized
left_to_right (bool) – the direction of orthogonalization, True for left to right and False for right to left
- Returns
TT-object with orthogonalized TT-cores
- Return type
TT-Tensor or TT-Matrix
- ttax.riemannian.project(what, where)¶
Project what TTs on the tangent space of where TT.
project(what, x) = P_x(what)project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))- Complexity:
O(d r_where^3 m)for orthogonalizing the TT-cores of where+O(batch_size d r_what r_where n (r_what + r_where))dis the number of TT-cores:what.ndims()r_whatis the largest TT-rank of what:max(what.tt_rank())r_whereis the largest TT-rank of wherenis the size of the axis dimension of what and where e.g. for a tensor of size 4 x 4 x 4,nis 4; for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4)nis 12
- Parameters
what (TT-Tensor or TT-Matrix) – in the case of batch returns batch with projection of each individual tensor
where (TT-Tensor or TT-Matrix) – on which tangent space to project
- Returns
TT-object with the TT-ranks equal
2 * tangent_space_tens.tt_ranks()- Return type
TT-Tensor or TT-Matrix
- ttax.riemannian.tangent_to_deltas(tangent_element: Union[ttax.base_class.TT, ttax.base_class.TTMatrix]) → List[jax._src.numpy.lax_numpy.ndarray]¶
Convert an element of the tangent space to deltas representation. Tangent space elements (outputs of
ttax.project) look like:dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd. This function takes as input an element of the tangent space and converts it to the list of deltas:[dP1, ..., dPd].- Parameters
tangent_element (TT-Tensor or TT-Matrix) – a result of
ttax.project- Returns
list of delta-cores
- Return type
list
- ttax.riemannian.tt_vmap(num_batch_args=None)¶
Decorator which makes a function support batch TT-inputs.
- Parameters
num_batch_args (int or None) –
The amount of arguments that are batches of TT-objects.
If None, than function will be mapped over all arguments.
If integer, specifies the count of first arguments to map over, e.g. num_batch_args=n means that function will be mapped over first n arguments.
- Returns
Decorator
- Comments:
The function is vmapped num_batch_dims times, as it supports multidimensional batches. The number of batch dimension to be mapped over is shown by num_batch_dims property and should be the same for all args of the function, by which it will be mapped over. Otherwise such axis should be specified by num_batch_args.
- class ttax.utils.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Tensor object as a list of TT-cores.
TT-Tensor cores take form (r_l, n, r_r), where
r_l, r_r are TT-ranks
n makes TT-Tensor shape
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Tensors.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.
- Returns
TT-Tensor shape
- Return type
tuple
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.utils.TTBase¶
Bases:
objectRepresents the base for both TT-Tensor and TT-Matrix (TT-object). Includes some basic routines and properties.
- property axis_dim¶
Get the position of mode axis in TT-core. It could differ according to the batch shape.
- Returns
index
- Return type
int
- property batch_loc¶
Represents the batch indexing for TT-object. Wraps TT-object by special BatchIndexing class with overloaded
__getitem__method.Example
tt.batch_loc[1, :, :]
- property batch_shape¶
Get the list representing the shape of the batch of TT-object.
- Returns
batch shape
- Return type
list
- property dtype¶
Represents the dtype of elements in TT-object.
- Returns
dtype of elements
- Return type
dtype
- property ndim¶
Get the number of dimensions of the TT-object.
- Returns
dimensions number
- Return type
int
- property tt_ranks¶
Get TT-ranks of the TT-object in amount of
ndim + 1. The first TT-rank and the last one equals to 1.- Returns
TT-ranks
- Return type
list
- class ttax.utils.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])¶
Bases:
ttax.base_class.TTBaseRepresents a TT-Matrix object as a list of TT-cores.
TT-Matrix cores take form (r_l, n_l, n_r, r_r), where
r_l, r_r are TT-ranks just as for TT-Tensor
n_l, n_r make left and right shapes of TT-Matrix as rows and cols
- property is_tt_matrix¶
Determine whether the object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property num_batch_dims¶
Get the number of batch dimensions for batch of TT-Matrices.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).
- Returns
TT-Matrix shapes
- Return type
list, list
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- property shape¶
Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.
For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).
- Returns
TT-Matrix shape in dense form with batch shape
- Return type
tuple
- tt_cores: List[jax._src.numpy.lax_numpy.array]¶
- class ttax.utils.WrappedTT(tt: ttax.base_class.TT, tt_inputs=None, tt_einsum=None)¶
Bases:
objectA class which wraps TT-object, which is needed for fusion to work.
Base TT-object class can only have jnp.array objects so that you can pass it into jitted function. But, for fusing two functions together we need to track which operation created a TT-object, so while fusing ops we wrap TT-objects with this class, to track that.
- property axis_dim¶
Get the position of mode axis in underlying TT-core. It could differ according to the batch shape.
- Returns
index
- Return type
int
- property batch_loc¶
Represents the batch indexing for underlying TT-object. Wraps TT-object by special BatchIndexing class with overloaded
__getitem__method.Example
tt.batch_loc[1, :, :]
- property batch_shape¶
Get the list representing the shape of the batch of underlying TT-object.
- Returns
batch shape
- Return type
list
- property dtype¶
Represents the dtype of elements in underlying TT-object.
- Returns
dtype of elements
- Return type
dtype
- property is_tt_matrix¶
Determine whether the underlying TT-object is a TT-Matrix.
- Returns
True if TT-Matrix, False if TT-Tensor
- Return type
bool
- property ndim¶
Get the number of dimensions of the underlying TT-object.
- Returns
dimensions number
- Return type
int
- property num_batch_dims¶
Get the number of batch dimensions for batch of underlying TT-object.
- Returns
number of batch dimensions
- Return type
int
- property raw_tensor_shape¶
Get the tuple representing the shape of underlying TT-object. Depends on
raw_tensor_shape.- Returns
shape
- Return type
list
- property shape¶
Get the tuple representing the shape of underlying TT-object. In batch case includes the shape of the batch.
- Returns
TT-Tensor shape with batch shape
- Return type
tuple
- property tt_cores¶
Get the list of TT-cores of underlying TT-object.
- Returns
TT-cores
- Return type
list
- property tt_ranks¶
Get TT-ranks of the underlying TT-object in amount of
ndim + 1. The first TT-rank and the last one equals to 1.- Returns
TT-ranks
- Return type
list
- ttax.utils.is_tt_matrix(arg) → bool¶
Determine whether the object is a TT-Matrix or WrappedTT with underlying TT-Matrix.
- Returns
True if TT-Matrix or WrappedTT(TT-Matrix), False otherwise
- Return type
bool
- ttax.utils.is_tt_object(arg) → bool¶
Determine whether the object is a TT-Tensor, TT-Matrix or WrappedTT with one of them.
- Returns
True if TT-object, False otherwise
- Return type
bool
- ttax.utils.is_tt_tensor(arg) → bool¶
Determine whether the object is a TT-Tensor or WrappedTT with underlying TT-Tensor.
- Returns
True if TT-Tensor or WrappedTT(TT-Tensor), False otherwise
- Return type
bool