This week, I’ve been digging into the details of how to best carry out co-iteration when there are levels that are not ordered. The thinking before was that co-iteration when there is a conjunctive merge was we can make the following changes:
Advancing iterators: Advance only the ordered level iterators, since the other iterators can do `locate(PKM1, min_ik)`.
Dereferencing iterators to get PKs: Since we have locate, we can directly get the PK for a unordered level with the `locate()` function.
However, in 1., this is not as straightforward. Say we have A \intersect B and A has 1000 non-zero elements and B has 1 non-zero element. If A is ordered and B is not, we have to iterate over the entirety of A, when in reality, we should be able to exit early. The dereferencing part though can definitely be implemented and should be changed.
Other improvements I have made in the code are:
Add `BaseTraits::I i` as an unused parameter in `hashed::iter_helper`
iteration_helper iter_helper([[maybe_unused]] typename BaseTraits::I i,typename BaseTraits::PKM1 pkm1)
In `Coiterate::coiteration_helper`, I changed the initialization of the `iterators` member to
std::tuple<typename levels::iteration_helper::iterator...=""> it) noexcept
std::tuple<typename levels::levelcapabilities::iteration_helper::iterator...=""> it) noexcept
Where we remove LevelCapabilities from the namespace. This was required since the `iteration_helper` for the hashed level is not implemented as part of the `LevelCapabilities` namespace, so this change allows `Coiterate` to be defined with some levels that include a hashed level.</typename></typename>
Rename `Coiterate::coiteration_helper::iterator.locate` to `deref_PKs`, since that is what it is actually doing.
template <class iter=""></class>
inline auto deref_PKs(iter i) const noexcept
return (std::get<0>(*i) == min_ik)
? std::optional<std::tuple_element_t<1, decltype="">>(</std::tuple_element_t<1,>
Also in the actual `get_PKs` function, we now use locate if the iterator has the locate function, otherwise we apply dereferencing.
inline auto get_PKs() const noexcept
* @brief Return tuple of PKs from each level.
* @details If the level is ordered, return the PK from the iterator using
* dereferencing `*iter`. If the level is unordered, return the PK from
* the iterator using `iter.locate()`.
? args.locate(m_coiterHelper.m_pkm1, min_ik)
The only issue is that this produces a compiler error.
Besides these improvements, I started reviewing the MergeLattice implementation inside the existing taco compiler. The code here is implemented using run-time code. There the implementation uses a builder class to construct a MergeLattice.
The MergeLattice at a high level should take in a tuple of levels that are “merge points” on the merge lattice. In addition, it should take in a index expression that dictates how the levels are merged.The index expression in taco uses a set of strings like `expr = C(i, j) = A(i,j) + B(i,j);`, whereas we would want to define an arbitrarily complex index expression… I still have to do some more reading to get an idea of how this part is implemented.
Internally, the MergeLattice given the index expression will be able to determine which levels are co-iterated. Moreover, it must construct the `F` function that is passed to `Coiterate`.
Overall, we would like the following higher-level function as well:
Construct a union (disjunction) over lattice points
Construct an intersection (conjunction) over lattice points
A merge lattice is constructed per index iterator in a tensor operation. For example, say we have:
Aij = (Bij + Cij) @ Dij
To set index i for A, we have to iterate over B_i, C_i and D_i.
To set index j for A, we have to iterate over B_*j, C_*j, and D_*j.
Each index constitutes a merge lattice that we need to construct to then call Coiteration. We want to extract the operators “+” and “*” to determine addition and multiplication, where addition is converted to a disjunction and multiplication is converted to a disjunction.
In the following example:
A_i = b_i + c_i d_i
Has a conjunction and disjunction. We proceed by:
Create leaf of the merge lattice from the tensor access rule(?
Create merge lattice for c_i d_i by computing conjunctive merge (ciΛ di) a_i = this if this lattice point is reached
Create merge lattice for b_i since there is no other conjunctive merge with it.
Create upper-most merge latticepoint for disjunctive merge (bi) v (ciΛ di)
So the merge lattice points starts with a whole expression that is a disjunction among conjunctions
Then it traverses through each lattice point, which trims down parts that are not necessary for co-iteration
What is the general input for a merge lattice? What is calling it? Will we have to implement, for example, an iteration graph?
Do we implement the “LatticePoint”?
How do we expect the “tensor operations” to be represented?