diff --git a/gneiss/tests/test_util.py b/gneiss/tests/test_util.py index 9f6655f..425d750 100644 --- a/gneiss/tests/test_util.py +++ b/gneiss/tests/test_util.py @@ -14,7 +14,8 @@ from gneiss.util import match, match_tips, design_formula from gneiss.util import (rename_internal_nodes, _type_cast_to_float, block_diagonal, band_diagonal, - split_balance, check_internal_nodes) + split_balance, check_internal_nodes, + _xarray_match_tips) from biom import Table from patsy import dmatrix import numpy.testing as npt @@ -592,6 +593,44 @@ def test_band_diagonal(self): npt.assert_allclose(res, exp, rtol=1e-4, atol=1e-4) +class TestMatchXarray(unittest.TestCase): + def setUp(self): + pass + + def test_data_array_match(self): + import xarray as xr + data = np.array( + [[[0, 0, 1, 1], [2, 3, 4, 4], [5, 5, 3, 3], [0, 0, 0, 1]], + [[0, 0, 1, 2], [2, 3, 4, 4], [5, 5, 3, 3], [0, 0, 0, 1]], + [[0, 0, 1, 3], [2, 3, 4, 4], [5, 5, 3, 3], [0, 0, 0, 1]]] + ) + table = xr.DataArray( + data, + dims=['monte_carlo_samples', 'samples', 'features'], + coords=[[0, 1, 2], + ['s1', 's2', 's3', 's4'], + ['a', 'b', 'c', 'd']] + ) + tree = TreeNode.read([u"(((b,a)f, c),d)r;"]) + + exp_tree = tree + data = np.array( + [[[0, 0, 1, 1], [3, 2, 4, 4], [5, 5, 3, 3], [0, 0, 0, 1]], + [[0, 0, 1, 2], [3, 2, 4, 4], [5, 5, 3, 3], [0, 0, 0, 1]], + [[0, 0, 1, 3], [3, 2, 4, 4], [5, 5, 3, 3], [0, 0, 0, 1]]] + ) + exp_table = xr.DataArray( + data, + dims=['monte_carlo_samples', 'samples', 'features'], + coords=[[0, 1, 2], + ['s1', 's2', 's3', 's4'], + ['b', 'a', 'c', 'd']] + ) + res_table, res_tree = _xarray_match_tips(table, tree) + xr.testing.assert_equal(exp_table, res_table) + self.assertEqual(str(exp_tree), str(res_tree)) + + class TestSplitBalance(unittest.TestCase): def setUp(self): diff --git a/gneiss/util.py b/gneiss/util.py index 165c6e4..c410a85 100644 --- a/gneiss/util.py +++ b/gneiss/util.py @@ -232,6 +232,20 @@ def sort_f(x): return _table, _tree +def _xarray_match_tips(data_array, tree, dim='features'): + """ Match on xarray Dataset or DataArray object. """ + tips = [x.name for x in tree.tips()] + common_tips = list(set(tips) & set(list(data_array[dim].values))) + + _table = data_array.loc[{dim: common_tips}] + _tree = tree.shear(names=common_tips) + _tree.bifurcate() + _tree.prune() + sorted_features = [n.name for n in _tree.tips()] + data_array = data_array.reindex(indexers={dim: sorted_features}) + return data_array, _tree + + def _dense_match_tips(table, tree): """ Match on dense pandas dataframes. """ tips = [x.name for x in tree.tips()]