明月不谙离恨苦
斜光到晓穿朱户
以前刷PAT的时候经常能做到二叉搜索树的题了,现在刷leetcode蛮难遇到的。
理解这道题能用BST做的关键在于理解插入顺序应该是倒序。
首先肯定的是,可以利用BST的特性,在树中每次插入结点时,都更新各个结点 “索引大于该结点且值小于该结点的个数”,即遍历每个位于该结点右侧的结点,修改右侧每个结点的“索引大于该结点且值小于该结点的个数”。
显然这样子的复杂度偏高,反正我超时了,也就是这个版本。
struct BSTNode {
BSTNode *left, *right;
int val;
int index;
explicit BSTNode(int x, int i) : val(x), index(i), left(nullptr), right(nullptr) {}
};
class Solution {
private:
void dfs(BSTNode *root, vector<int> &ans) {
if (root == nullptr)
return ;
ans[root->index]++;
dfs(root->left, ans);
dfs(root->right, ans);
}
BSTNode *insertx(BSTNode *root, int x, const int index, vector<int> &ans) {
if (root == nullptr)
return new BSTNode(x, index);
if (x >= root->val)
root->right = insertx(root->right, x, index, ans);
else {
ans[root->index]++;
dfs(root->right, ans);
root->left = insertx(root->left, x, index, ans);
}
return root;
}
public:
vector<int> countSmaller(vector<int>& nums) {
vector<int> ans(nums.size(), 0);
BSTNode *root = nullptr;
for (int i = 0; i < nums.size(); i++) {
root = insertx(root, nums[i], i, ans);
}
return ans;
}
};
简单来说“又臭又长又慢”。。。但是还算好理解的。
如何解决超时问题呢?
答:倒序遍历 nums
数组,这样在插入结点时,保证了树中已经安置的结点都是该结点右侧的结点,这样就不再需要比较其索引。递归插入每个新结点时,如果新结点大于当前根结点,且当前根结点肯定在新结点右侧,就可以直接加上根结点的 cnt
并加1,就能得到新结点“索引大于新结点且值小于新结点的个数”;反之,则将当前根结点的 cnt
加1,因为新结点要插入根结点的左儿子中。
然后每次插入都会计算出新结点“索引大于新结点且值小于新结点的个数”,也就是答案。
总结就是两点:
- 维护BST,BST每个结点会统计左儿子结点的个数
- 每次插入新结点时,就能得到新结点“索引大于新结点且值小于新结点的个数”,因为倒数插入此时树中存在的结点全都在新结点的右边
代码如下:
struct BSTNode {
BSTNode *left, *right;
int val;
int cnt;
explicit BSTNode(int x) : val(x), cnt(0), left(nullptr), right(nullptr) {}
};
class Solution {
private:
BSTNode *insertx(BSTNode *root, int x, const int index, vector<int> &ans) {
if (root == nullptr)
return new BSTNode(x);
if (x > root->val) {
ans[index] += root->cnt + 1;
root->right = insertx(root->right, x, index, ans);
}
else {
root->cnt++;
root->left = insertx(root->left, x, index, ans);
}
return root;
}
public:
vector<int> countSmaller(vector<int>& nums) {
vector<int> ans(nums.size(), 0);
BSTNode *root = nullptr;
for (int i = nums.size() - 1; i >= 0; i--) {
root = insertx(root, nums[i], i, ans);
}
return ans;
}
};