API

Here you can find description of functions and methods avaliable in ttax.

Module contents

class ttax.base_class.BatchIndexing(tt)

Bases: object

class ttax.base_class.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Tensor object as a list of TT-cores.

TT-Tensor cores take form (r_l, n, r_r), where

  • r_l, r_r are TT-ranks

  • n makes TT-Tensor shape

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Tensors.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.

Returns

TT-Tensor shape

Return type

tuple

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.base_class.TTBase

Bases: object

Represents the base for both TT-Tensor and TT-Matrix (TT-object). Includes some basic routines and properties.

property axis_dim

Get the position of mode axis in TT-core. It could differ according to the batch shape.

Returns

index

Return type

int

property batch_loc

Represents the batch indexing for TT-object. Wraps TT-object by special BatchIndexing class with overloaded __getitem__ method.

Example

tt.batch_loc[1, :, :]

property batch_shape

Get the list representing the shape of the batch of TT-object.

Returns

batch shape

Return type

list

property dtype

Represents the dtype of elements in TT-object.

Returns

dtype of elements

Return type

dtype

property ndim

Get the number of dimensions of the TT-object.

Returns

dimensions number

Return type

int

property tt_ranks

Get TT-ranks of the TT-object in amount of ndim + 1. The first TT-rank and the last one equals to 1.

Returns

TT-ranks

Return type

list

class ttax.base_class.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Matrix object as a list of TT-cores.

TT-Matrix cores take form (r_l, n_l, n_r, r_r), where

  • r_l, r_r are TT-ranks just as for TT-Tensor

  • n_l, n_r make left and right shapes of TT-Matrix as rows and cols

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Matrices.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).

Returns

TT-Matrix shapes

Return type

list, list

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).

Returns

TT-Matrix shape in dense form with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.ops.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Tensor object as a list of TT-cores.

TT-Tensor cores take form (r_l, n, r_r), where

  • r_l, r_r are TT-ranks

  • n makes TT-Tensor shape

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Tensors.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.

Returns

TT-Tensor shape

Return type

tuple

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.ops.TTBase

Bases: object

Represents the base for both TT-Tensor and TT-Matrix (TT-object). Includes some basic routines and properties.

property axis_dim

Get the position of mode axis in TT-core. It could differ according to the batch shape.

Returns

index

Return type

int

property batch_loc

Represents the batch indexing for TT-object. Wraps TT-object by special BatchIndexing class with overloaded __getitem__ method.

Example

tt.batch_loc[1, :, :]

property batch_shape

Get the list representing the shape of the batch of TT-object.

Returns

batch shape

Return type

list

property dtype

Represents the dtype of elements in TT-object.

Returns

dtype of elements

Return type

dtype

property ndim

Get the number of dimensions of the TT-object.

Returns

dimensions number

Return type

int

property tt_ranks

Get TT-ranks of the TT-object in amount of ndim + 1. The first TT-rank and the last one equals to 1.

Returns

TT-ranks

Return type

list

class ttax.ops.TTEinsum(inputs, output, how_to_apply, order='left-to-right')

Bases: object

A class which contains einsum rule, which is needed for fusion to work.

apply_mapping(mapping: Dict[str, str])

Rename letters according to the given mapping.

Returns

new TTEinsum with renamed latters

Return type

TTEinsum

change_input(input_idx: int, new_inputs: List)

Change argument input_idx into new_inputs.

E.g.

tt_einsum = TTEinsum(inputs=[[‘a’, ‘i’, ‘b’], [‘c’, ‘i’, ‘d’]], output=[‘ac’, ‘i’, ‘bd’], how_to_apply=’independent’)

tt_einsum.change_input(0, [[‘e’, ‘i’, ‘f’], [‘g’, ‘i’, ‘h’]])

print(tt_einsum.to_vanilla_einsum())

will return ‘eif,gih,cid->acibd’

resolve_i_or_ij(is_tt_matrix)

Return a version of TTEinsum with I_OR_IJ changed to either ‘i’ or ‘ij’.

to_distinct_letters(distinct_from)

Rename letters to make them distinct from letters used in distinct_from.

to_vanilla_einsum()

Build regular einsum.

class ttax.ops.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Matrix object as a list of TT-cores.

TT-Matrix cores take form (r_l, n_l, n_r, r_r), where

  • r_l, r_r are TT-ranks just as for TT-Tensor

  • n_l, n_r make left and right shapes of TT-Matrix as rows and cols

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Matrices.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).

Returns

TT-Matrix shapes

Return type

list, list

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).

Returns

TT-Matrix shape in dense form with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.ops.WrappedTT(tt: ttax.base_class.TT, tt_inputs=None, tt_einsum=None)

Bases: object

A class which wraps TT-object, which is needed for fusion to work.

Base TT-object class can only have jnp.array objects so that you can pass it into jitted function. But, for fusing two functions together we need to track which operation created a TT-object, so while fusing ops we wrap TT-objects with this class, to track that.

property axis_dim

Get the position of mode axis in underlying TT-core. It could differ according to the batch shape.

Returns

index

Return type

int

property batch_loc

Represents the batch indexing for underlying TT-object. Wraps TT-object by special BatchIndexing class with overloaded __getitem__ method.

Example

tt.batch_loc[1, :, :]

property batch_shape

Get the list representing the shape of the batch of underlying TT-object.

Returns

batch shape

Return type

list

property dtype

Represents the dtype of elements in underlying TT-object.

Returns

dtype of elements

Return type

dtype

property is_tt_matrix

Determine whether the underlying TT-object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property ndim

Get the number of dimensions of the underlying TT-object.

Returns

dimensions number

Return type

int

property num_batch_dims

Get the number of batch dimensions for batch of underlying TT-object.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of underlying TT-object. Depends on raw_tensor_shape.

Returns

shape

Return type

list

property shape

Get the tuple representing the shape of underlying TT-object. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

property tt_cores

Get the list of TT-cores of underlying TT-object.

Returns

TT-cores

Return type

list

property tt_ranks

Get TT-ranks of the underlying TT-object in amount of ndim + 1. The first TT-rank and the last one equals to 1.

Returns

TT-ranks

Return type

list

ttax.ops.add(tt_a, tt_b)

Returns a TT-object corresponding to elementwise sum tt_a + tt_b. The shapes of tt_a and tt_b should coincide. Supports broadcasting, e.g. you can add a tensor train with batch size 7 and a tensor train with batch size 1:

tt_batch.add(tt_single.batch_loc[np.newaxis])

where tt_single.batch_loc[np.newaxis] creates a singleton batch dimension.

Parameters
  • tt_a (TT-Tensor or TT-Matrix) – first argument

  • tt_b (TT-Tensor or TT-Matrix) – second argument

Return type

TT-Tensor or TT-Matrix

Returns

tt_a + tt_b

Raises

[ValueError] – if the arguments shapes do not coincide

ttax.ops.are_batches_broadcastable(tt_a, tt_b)

Returns the result of compatibility check of 2 tensors’ batches: True if batches are compatible and False otherwise. The batch sizes should be equal otherwise at least one of them should equal to 1 for broadcasting to be available.

Parameters
  • tt_a (TT-Tensor or TT-Matrix) – first argument to check

  • tt_b (TT-Tensor or TT-Matrix) – second argument to check

Returns

the result of broadcasting check

Return type

bool

ttax.ops.are_shapes_equal(tt_a, tt_b)

Returns the result of equality check of 2 tensors’ shapes: True if shapes are equal and False otherwise. The arguments should be both TT-Tensors or both TT-Matrices. The arguments should have the same tensor shape but their TT-ranks differ.

Parameters
  • tt_a (TT-Tensor or TT-Matrix) – first argument to check

  • tt_b (TT-Tensor or TT-Matrix) – second argument to check

Returns

tensor_check - the result of shape check

Return type

bool

ttax.ops.flat_inner(a, b)

Calculate inner product of given TT-Tensors or TT-Matrices wrapped with WrappedTT.

Parameters
  • a (WrappedTT) – first argument

  • b (WrappedTT) – second argument

Rerurn

the result of inner product

Return type

WrappedTT

ttax.ops.full(tt: ttax.base_class.TTBase)jax._src.numpy.lax_numpy.array

Converts TT-Tensor or TT-Matrix into a dense format.

ttax.ops.full_tt_matrix(tt: ttax.base_class.TTMatrix)jax._src.numpy.lax_numpy.array

Converts TT-matrix into a regular matrix.

ttax.ops.full_tt_tensor(tt: ttax.base_class.TT)jax._src.numpy.lax_numpy.array

Converts TT into a regular tensor.

ttax.ops.is_tt_matrix(arg)bool

Determine whether the object is a TT-Matrix or WrappedTT with underlying TT-Matrix.

Returns

True if TT-Matrix or WrappedTT(TT-Matrix), False otherwise

Return type

bool

ttax.ops.is_tt_object(arg)bool

Determine whether the object is a TT-Tensor, TT-Matrix or WrappedTT with one of them.

Returns

True if TT-object, False otherwise

Return type

bool

ttax.ops.is_tt_tensor(arg)bool

Determine whether the object is a TT-Tensor or WrappedTT with underlying TT-Tensor.

Returns

True if TT-Tensor or WrappedTT(TT-Tensor), False otherwise

Return type

bool

ttax.ops.matmul(a, b)

Calculate matrix multiplication of given TT-Matrices wrapped with WrappedTT.

Parameters
  • a (WrappedTT) – first argument

  • b (WrappedTT) – second argument

Rerurn

the result of inner product

Return type

WrappedTT

ttax.ops.multiply(a, b)

Calculate elementwise product of 2 TT-Tensors TT-Matrices or their product by scalar. Arguments could be wrapped by WrappedTT or not.

Parameters
  • a (Union[float, TT-object]) – first argument

  • b (Union[float, TT-object]) – second argument

Returns

the result of elementwise product

Return type

TT-object

ttax.ops.multiply_by_scalar(a, b)

Returns the result of multiplication so called TT-object (TTTensOrMat or WrappedTT) by scalar. Takes 2 arguments as input, one of which is TT-object and other is a scalar. Does not depends on arguments order.

Returns

the result of multiplication by scalar

Return type

TTTensOrMat

ttax.ops.to_function(tt_einsum: ttax.compile.TTEinsum)Callable

Compile TT-einsum into a function.

Example:

def multiply(a, b):

tt_einsum = TTEinsum( inputs=[[‘a’, ‘i’, ‘b’], [‘c’, ‘i’, ‘d’]], output=[‘ac’, ‘i’, ‘bd’], how_to_apply=’independent’ )

func = tt_einsum.to_function()

return func(a, b)

ttax.ops.tt_tt_multiply(a, b)
ttax.ops.tt_vmap(num_batch_args=None)

Decorator which makes a function support batch TT-inputs.

Parameters

num_batch_args (int or None) –

The amount of arguments that are batches of TT-objects.

  • If None, than function will be mapped over all arguments.

  • If integer, specifies the count of first arguments to map over, e.g. num_batch_args=n means that function will be mapped over first n arguments.

Returns

Decorator

Comments:

The function is vmapped num_batch_dims times, as it supports multidimensional batches. The number of batch dimension to be mapped over is shown by num_batch_dims property and should be the same for all args of the function, by which it will be mapped over. Otherwise such axis should be specified by num_batch_args.

ttax.ops.unwrap_tt(arg)

Unwraps argument if it is of WrappedTT class, otherwise just returns the argument.

Parameters

arg (WrappedTT or TTTensOrMat) – argument to unwrap

Return type

TTTensOrMat

Returns

unwrapped argument

class ttax.decompositions.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Tensor object as a list of TT-cores.

TT-Tensor cores take form (r_l, n, r_r), where

  • r_l, r_r are TT-ranks

  • n makes TT-Tensor shape

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Tensors.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.

Returns

TT-Tensor shape

Return type

tuple

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.decompositions.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Matrix object as a list of TT-cores.

TT-Matrix cores take form (r_l, n_l, n_r, r_r), where

  • r_l, r_r are TT-ranks just as for TT-Tensor

  • n_l, n_r make left and right shapes of TT-Matrix as rows and cols

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Matrices.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).

Returns

TT-Matrix shapes

Return type

list, list

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).

Returns

TT-Matrix shape in dense form with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
ttax.decompositions.orthogonalize(tt, left_to_right=True)

Orthogonalize TT-cores of a TT-object.

Parameters
  • tt (TT-Tensor or TT-Matrix) – TT-object which TT-cores would be orthogonalized

  • left_to_right (bool) – the direction of orthogonalization, True for left to right and False for right to left

Returns

TT-object with orthogonalized TT-cores

Return type

TT-Tensor or TT-Matrix

ttax.decompositions.round(tt, max_tt_rank=None, epsilon=None)

Tensor Train rounding procedure, returns a TT-object with smaller TT-ranks.

Parameters
  • tt (TT-Tensor or TT-Matrix) – argument which ranks would be reduced

  • max_tt_rank (int or list of ints) –

    • If a number, than defines the maximal TT-rank of the result.

    • If a list of numbers, than max_tt_rank length should be d+1 (where d is the number of dimensions) and max_tt_rank[i] defines the maximal (i+1)-th TT-rank of the result.

      The following two versions are equivalent

      • max_tt_rank = r

      • max_tt_rank = [1] + [r] * (d-1) + [1]

  • epsilon (float or None) –

    • If the TT-ranks are not restricted (max_tt_rank=None), then the result would be guarantied to be epsilon-close to tt in terms of relative Frobenius error:

      ||res - tt||_F / ||tt||_F <= epsilon

    • If the TT-ranks are restricted, providing a loose epsilon may reduce the TT-ranks of the result. E.g.

      round(tt, max_tt_rank=100, epsilon=0.9)

      will probably return you a TT-Tensor with TT-ranks close to 1, not 100. Note that providing a nontrivial (= not equal to None) epsilon will make the TT-ranks of the result change depending on the data, which will prevent you from using jax.jit for speeding up the computations.

Returns

TT-object with reduced TT-ranks

Return type

TT-Tensor or TT-Matrix

Raises

ValueError if max_tt_rank is less than 0, if max_tt_rank is not a number and not a vector of length d + 1 where d is the number of dimensions of the input tensor, if epsilon is less than 0.

ttax.decompositions.tt_vmap(num_batch_args=None)

Decorator which makes a function support batch TT-inputs.

Parameters

num_batch_args (int or None) –

The amount of arguments that are batches of TT-objects.

  • If None, than function will be mapped over all arguments.

  • If integer, specifies the count of first arguments to map over, e.g. num_batch_args=n means that function will be mapped over first n arguments.

Returns

Decorator

Comments:

The function is vmapped num_batch_dims times, as it supports multidimensional batches. The number of batch dimension to be mapped over is shown by num_batch_dims property and should be the same for all args of the function, by which it will be mapped over. Otherwise such axis should be specified by num_batch_args.

Utils for compiling functions defined as einsum strings.

Here we use the notion of tt_einsum, which is similar to einsum strings, but with more structure.

Basic element of tt_einsum is a tt_einsum_core: list with three elements, which defines an einsum string for a single TT-core. First element is indices for the left TT-rank, second element is the indices for the main dimensions of the resulting TT-core, and the last element is the indices for the right TT-rank.

TT_einsum consists of list of input and output cores defined like with tt_einsum_cores.

class ttax.compile.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Tensor object as a list of TT-cores.

TT-Tensor cores take form (r_l, n, r_r), where

  • r_l, r_r are TT-ranks

  • n makes TT-Tensor shape

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Tensors.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.

Returns

TT-Tensor shape

Return type

tuple

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.compile.TTEinsum(inputs, output, how_to_apply, order='left-to-right')

Bases: object

A class which contains einsum rule, which is needed for fusion to work.

apply_mapping(mapping: Dict[str, str])

Rename letters according to the given mapping.

Returns

new TTEinsum with renamed latters

Return type

TTEinsum

change_input(input_idx: int, new_inputs: List)

Change argument input_idx into new_inputs.

E.g.

tt_einsum = TTEinsum(inputs=[[‘a’, ‘i’, ‘b’], [‘c’, ‘i’, ‘d’]], output=[‘ac’, ‘i’, ‘bd’], how_to_apply=’independent’)

tt_einsum.change_input(0, [[‘e’, ‘i’, ‘f’], [‘g’, ‘i’, ‘h’]])

print(tt_einsum.to_vanilla_einsum())

will return ‘eif,gih,cid->acibd’

resolve_i_or_ij(is_tt_matrix)

Return a version of TTEinsum with I_OR_IJ changed to either ‘i’ or ‘ij’.

to_distinct_letters(distinct_from)

Rename letters to make them distinct from letters used in distinct_from.

to_vanilla_einsum()

Build regular einsum.

class ttax.compile.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Matrix object as a list of TT-cores.

TT-Matrix cores take form (r_l, n_l, n_r, r_r), where

  • r_l, r_r are TT-ranks just as for TT-Tensor

  • n_l, n_r make left and right shapes of TT-Matrix as rows and cols

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Matrices.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).

Returns

TT-Matrix shapes

Return type

list, list

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).

Returns

TT-Matrix shape in dense form with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.compile.WrappedTT(tt: ttax.base_class.TT, tt_inputs=None, tt_einsum=None)

Bases: object

A class which wraps TT-object, which is needed for fusion to work.

Base TT-object class can only have jnp.array objects so that you can pass it into jitted function. But, for fusing two functions together we need to track which operation created a TT-object, so while fusing ops we wrap TT-objects with this class, to track that.

property axis_dim

Get the position of mode axis in underlying TT-core. It could differ according to the batch shape.

Returns

index

Return type

int

property batch_loc

Represents the batch indexing for underlying TT-object. Wraps TT-object by special BatchIndexing class with overloaded __getitem__ method.

Example

tt.batch_loc[1, :, :]

property batch_shape

Get the list representing the shape of the batch of underlying TT-object.

Returns

batch shape

Return type

list

property dtype

Represents the dtype of elements in underlying TT-object.

Returns

dtype of elements

Return type

dtype

property is_tt_matrix

Determine whether the underlying TT-object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property ndim

Get the number of dimensions of the underlying TT-object.

Returns

dimensions number

Return type

int

property num_batch_dims

Get the number of batch dimensions for batch of underlying TT-object.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of underlying TT-object. Depends on raw_tensor_shape.

Returns

shape

Return type

list

property shape

Get the tuple representing the shape of underlying TT-object. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

property tt_cores

Get the list of TT-cores of underlying TT-object.

Returns

TT-cores

Return type

list

property tt_ranks

Get TT-ranks of the underlying TT-object in amount of ndim + 1. The first TT-rank and the last one equals to 1.

Returns

TT-ranks

Return type

list

ttax.compile.apply_single_mapping(strings, mapping)

Apply letter mapping to a list of strings.

ttax.compile.compile_cumulative(tt_einsum: ttax.compile.TTEinsum)Callable
ttax.compile.compile_independent(tt_einsum: ttax.compile.TTEinsum)Callable
ttax.compile.fuse(func)

Fuse a composite function to make it faster.

Example:

Let’s look at f(a, b, c) = <a * b, c> = sum_{i_1, …, i_d} a[i_1, …, i_d] b[i_1, …, i_d] c[i_1, …, i_d], as ttax.flat_inner(a * b, c) do.

Function f can be suboptimal for some inputs. For example, if a and b are of large TT-rank, and c is of low TT-rank, implementing the same operation as

ttax.flat_inner(a * c, b)

would be much more efficient.

fuse automates such optimizations. You can build an optimal implementation of this function for any inputs by doing

faster_f = ttax.fuse(f)

Finally, don’t forget that in JAX to get good speed you need to wrap you highest level function in jit, e.g.

faster_f = jax.jit(faster_f)

Now, by using faster_f(a, b, c) instead of f(a, b, c) you can achieve a much faster cumulative time for any inputs.

ttax.compile.to_function(tt_einsum: ttax.compile.TTEinsum)Callable

Compile TT-einsum into a function.

Example:

def multiply(a, b):

tt_einsum = TTEinsum( inputs=[[‘a’, ‘i’, ‘b’], [‘c’, ‘i’, ‘d’]], output=[‘ac’, ‘i’, ‘bd’], how_to_apply=’independent’ )

func = tt_einsum.to_function()

return func(a, b)

ttax.compile.unwrap_tt(arg)

Unwraps argument if it is of WrappedTT class, otherwise just returns the argument.

Parameters

arg (WrappedTT or TTTensOrMat) – argument to unwrap

Return type

TTTensOrMat

Returns

unwrapped argument

class ttax.random_.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Tensor object as a list of TT-cores.

TT-Tensor cores take form (r_l, n, r_r), where

  • r_l, r_r are TT-ranks

  • n makes TT-Tensor shape

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Tensors.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.

Returns

TT-Tensor shape

Return type

tuple

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.random_.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Matrix object as a list of TT-cores.

TT-Matrix cores take form (r_l, n_l, n_r, r_r), where

  • r_l, r_r are TT-ranks just as for TT-Tensor

  • n_l, n_r make left and right shapes of TT-Matrix as rows and cols

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Matrices.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).

Returns

TT-Matrix shapes

Return type

list, list

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).

Returns

TT-Matrix shape in dense form with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
ttax.random_.matrix(rng, shape, tt_rank=2, batch_shape=None, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)

Generate a random TT-Matrix of the given shape and TT-rank.

Parameters
  • rng (random state is described by two unsigned 32-bit integers) – JAX PRNG key

  • shape (array) – desired tensor shape

  • tt_rank (single number for equal TT-ranks or array specifying all TT-ranks) – desired TT-ranks of TT-Matrix

  • batch_shape (array) – desired batch shape of TT-Matrix

  • dtype (dtype) – type of elements in TT-Matrix

Returns

generated TT-Matrix

Return type

TTMatrix

ttax.random_.tensor(rng, shape, tt_rank=2, batch_shape=None, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)

Generate a random TT-Tensor of the given shape and TT-rank.

Parameters
  • rng (random state is described by two unsigned 32-bit integers) – JAX PRNG key

  • shape (array) – desired tensor shape

  • tt_rank (single number for equal TT-ranks or array specifying all TT-ranks) – desired TT-ranks of TT-Tensor

  • batch_shape (array) – desired batch shape of TT-Tensor

  • dtype (dtype) – type of elements in TT-Tensor

Returns

generated TT-Tensor

Return type

TT

class ttax.riemannian.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Tensor object as a list of TT-cores.

TT-Tensor cores take form (r_l, n, r_r), where

  • r_l, r_r are TT-ranks

  • n makes TT-Tensor shape

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Tensors.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.

Returns

TT-Tensor shape

Return type

tuple

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.riemannian.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Matrix object as a list of TT-cores.

TT-Matrix cores take form (r_l, n_l, n_r, r_r), where

  • r_l, r_r are TT-ranks just as for TT-Tensor

  • n_l, n_r make left and right shapes of TT-Matrix as rows and cols

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Matrices.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).

Returns

TT-Matrix shapes

Return type

list, list

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).

Returns

TT-Matrix shape in dense form with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
ttax.riemannian.deltas_to_tangent(deltas: List[jax._src.numpy.lax_numpy.ndarray], tt: Union[ttax.base_class.TT, ttax.base_class.TTMatrix])Union[ttax.base_class.TT, ttax.base_class.TTMatrix]

Converts deltas representation of tangent space vector to TT-object. Takes as input a list of [dP1, …, dPd] and returns dP1 V2 … Vd + U1 dP2 V3 … Vd + … + U1 … Ud-1 dPd.

This function is hard to use correctly because deltas should obey the so called gauge conditions. If they don’t, the function will silently return incorrect result. That is why this function is not imported in __init__.

Parameters
  • deltas – a list of deltas (essentially TT-cores) obeying the gauge conditions.

  • tt (TT-Tensor or TT-Matrix) – object on which the tangent space tensor represented by delta is projected.

Returns

object constructed from deltas, that is from the tangent space at point tt.

Return type

TT-Tensor or TT-Matrix

ttax.riemannian.orthogonalize(tt, left_to_right=True)

Orthogonalize TT-cores of a TT-object.

Parameters
  • tt (TT-Tensor or TT-Matrix) – TT-object which TT-cores would be orthogonalized

  • left_to_right (bool) – the direction of orthogonalization, True for left to right and False for right to left

Returns

TT-object with orthogonalized TT-cores

Return type

TT-Tensor or TT-Matrix

ttax.riemannian.project(what, where)

Project what TTs on the tangent space of where TT.

project(what, x) = P_x(what) project(batch_what, x) = batch(P_x(batch_what[0]), ..., P_x(batch_what[N]))

Complexity:

O(d r_where^3 m) for orthogonalizing the TT-cores of where +O(batch_size d r_what r_where n (r_what + r_where))

  • d is the number of TT-cores: what.ndims()

  • r_what is the largest TT-rank of what: max(what.tt_rank())

  • r_where is the largest TT-rank of where

  • n is the size of the axis dimension of what and where e.g. for a tensor of size 4 x 4 x 4, n is 4; for a 9 x 64 matrix of raw shape (3, 3, 3) x (4, 4, 4) n is 12

Parameters
  • what (TT-Tensor or TT-Matrix) – in the case of batch returns batch with projection of each individual tensor

  • where (TT-Tensor or TT-Matrix) – on which tangent space to project

Returns

TT-object with the TT-ranks equal 2 * tangent_space_tens.tt_ranks()

Return type

TT-Tensor or TT-Matrix

ttax.riemannian.tangent_to_deltas(tangent_element: Union[ttax.base_class.TT, ttax.base_class.TTMatrix])List[jax._src.numpy.lax_numpy.ndarray]

Convert an element of the tangent space to deltas representation. Tangent space elements (outputs of ttax.project) look like: dP1 V2 ... Vd + U1 dP2 V3 ... Vd + ... + U1 ... Ud-1 dPd. This function takes as input an element of the tangent space and converts it to the list of deltas: [dP1, ..., dPd].

Parameters

tangent_element (TT-Tensor or TT-Matrix) – a result of ttax.project

Returns

list of delta-cores

Return type

list

ttax.riemannian.tt_vmap(num_batch_args=None)

Decorator which makes a function support batch TT-inputs.

Parameters

num_batch_args (int or None) –

The amount of arguments that are batches of TT-objects.

  • If None, than function will be mapped over all arguments.

  • If integer, specifies the count of first arguments to map over, e.g. num_batch_args=n means that function will be mapped over first n arguments.

Returns

Decorator

Comments:

The function is vmapped num_batch_dims times, as it supports multidimensional batches. The number of batch dimension to be mapped over is shown by num_batch_dims property and should be the same for all args of the function, by which it will be mapped over. Otherwise such axis should be specified by num_batch_args.

class ttax.utils.TT(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Tensor object as a list of TT-cores.

TT-Tensor cores take form (r_l, n, r_r), where

  • r_l, r_r are TT-ranks

  • n makes TT-Tensor shape

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Tensors.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of TT-Tensor. In batch case does not include the shape of the batch.

Returns

TT-Tensor shape

Return type

tuple

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of TT-Tensor. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.utils.TTBase

Bases: object

Represents the base for both TT-Tensor and TT-Matrix (TT-object). Includes some basic routines and properties.

property axis_dim

Get the position of mode axis in TT-core. It could differ according to the batch shape.

Returns

index

Return type

int

property batch_loc

Represents the batch indexing for TT-object. Wraps TT-object by special BatchIndexing class with overloaded __getitem__ method.

Example

tt.batch_loc[1, :, :]

property batch_shape

Get the list representing the shape of the batch of TT-object.

Returns

batch shape

Return type

list

property dtype

Represents the dtype of elements in TT-object.

Returns

dtype of elements

Return type

dtype

property ndim

Get the number of dimensions of the TT-object.

Returns

dimensions number

Return type

int

property tt_ranks

Get TT-ranks of the TT-object in amount of ndim + 1. The first TT-rank and the last one equals to 1.

Returns

TT-ranks

Return type

list

class ttax.utils.TTMatrix(tt_cores: List[jax._src.numpy.lax_numpy.array])

Bases: ttax.base_class.TTBase

Represents a TT-Matrix object as a list of TT-cores.

TT-Matrix cores take form (r_l, n_l, n_r, r_r), where

  • r_l, r_r are TT-ranks just as for TT-Tensor

  • n_l, n_r make left and right shapes of TT-Matrix as rows and cols

property is_tt_matrix

Determine whether the object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property num_batch_dims

Get the number of batch dimensions for batch of TT-Matrices.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the lists representing left and right shapes of TT-Matrix. In batch case does not include the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) returns (2, 6), (3, 7).

Returns

TT-Matrix shapes

Return type

list, list

replace(**updates)

“Returns a new object replacing the specified fields with new values.

property shape

Get the tuple representing the shape of underlying dense tensor as matrix. In batch case includes the shape of the batch.

For example if TT-Matrix cores are (1, 2, 3, 5) (5, 6, 7, 1) it’s shape is (12, 21).

Returns

TT-Matrix shape in dense form with batch shape

Return type

tuple

tt_cores: List[jax._src.numpy.lax_numpy.array]
class ttax.utils.WrappedTT(tt: ttax.base_class.TT, tt_inputs=None, tt_einsum=None)

Bases: object

A class which wraps TT-object, which is needed for fusion to work.

Base TT-object class can only have jnp.array objects so that you can pass it into jitted function. But, for fusing two functions together we need to track which operation created a TT-object, so while fusing ops we wrap TT-objects with this class, to track that.

property axis_dim

Get the position of mode axis in underlying TT-core. It could differ according to the batch shape.

Returns

index

Return type

int

property batch_loc

Represents the batch indexing for underlying TT-object. Wraps TT-object by special BatchIndexing class with overloaded __getitem__ method.

Example

tt.batch_loc[1, :, :]

property batch_shape

Get the list representing the shape of the batch of underlying TT-object.

Returns

batch shape

Return type

list

property dtype

Represents the dtype of elements in underlying TT-object.

Returns

dtype of elements

Return type

dtype

property is_tt_matrix

Determine whether the underlying TT-object is a TT-Matrix.

Returns

True if TT-Matrix, False if TT-Tensor

Return type

bool

property ndim

Get the number of dimensions of the underlying TT-object.

Returns

dimensions number

Return type

int

property num_batch_dims

Get the number of batch dimensions for batch of underlying TT-object.

Returns

number of batch dimensions

Return type

int

property raw_tensor_shape

Get the tuple representing the shape of underlying TT-object. Depends on raw_tensor_shape.

Returns

shape

Return type

list

property shape

Get the tuple representing the shape of underlying TT-object. In batch case includes the shape of the batch.

Returns

TT-Tensor shape with batch shape

Return type

tuple

property tt_cores

Get the list of TT-cores of underlying TT-object.

Returns

TT-cores

Return type

list

property tt_ranks

Get TT-ranks of the underlying TT-object in amount of ndim + 1. The first TT-rank and the last one equals to 1.

Returns

TT-ranks

Return type

list

ttax.utils.is_tt_matrix(arg)bool

Determine whether the object is a TT-Matrix or WrappedTT with underlying TT-Matrix.

Returns

True if TT-Matrix or WrappedTT(TT-Matrix), False otherwise

Return type

bool

ttax.utils.is_tt_object(arg)bool

Determine whether the object is a TT-Tensor, TT-Matrix or WrappedTT with one of them.

Returns

True if TT-object, False otherwise

Return type

bool

ttax.utils.is_tt_tensor(arg)bool

Determine whether the object is a TT-Tensor or WrappedTT with underlying TT-Tensor.

Returns

True if TT-Tensor or WrappedTT(TT-Tensor), False otherwise

Return type

bool