bihariforces's blog

By bihariforces, history, 15 months ago, In English

How would we implement a multiset data structure that allows insertion, deletion and querying GCD of values currently present in multiset?

insert(5)
GCD() -> 5
insert(10)
GCD() -> 5
remove(5)
GCD() -> 10

This question is not from some existing problem and was just wondering if it was possible to have such data structure where deletion is possible on operations like GCD. I believe this would be helpful https://cp-algorithms.com/data_structures/deleting_in_log_n.html but can't understand how we would apply this or maybe some other method which would work for large numbers upto 1e18.

  • Vote: I like it
  • +3
  • Vote: I do not like it

»
15 months ago, # |
  Vote: I like it +9 Vote: I do not like it

Treaps. Just modify the "upd_cnt()" function to merge GCD values of the subtrees.

»
15 months ago, # |
  Vote: I like it 0 Vote: I do not like it

you can use dynamic segment tree, which support these operation in O(log n) (Although GCD is O(log n), but realistically, it is much faster than that, so we assume it is O(1))

»
15 months ago, # |
Rev. 2   Vote: I like it -8 Vote: I do not like it

Splay also works here.

#include <bits/stdc++.h>

template<typename Data, typename Less=std::less<Data>> 
struct spt{
    #define DATACOUT 1

    bool equals(const Data& d1, const Data& d2) const{
        return !Less()(d1, d2) && !Less()(d2, d1);
    }

    bool lesswrapper(const Data& d1, const Data& d2) const{
        return Less()(d1, d2) > 0;
    }

    struct node{
        Data d;
        std::map<int, Data> res;
        int cnt, sz;
        struct node *fa, *c[2];
        node():d(Data()), cnt(0), sz(0), fa(nullptr){
            for(int i = 0; i <= 1; ++i) { c[0] = nullptr; c[1] = nullptr; }
        }
        ~node(){
            cnt = 0;
            sz = 0;
            res.clear();
            for(int i = 0; i <= 1; ++i){
                if(!c[i]){
                    delete c[i];
                    c[i] = nullptr;
                }
            }
        }
    };

    node* root;
    std::map<int, Data(*)(Data, Data)> aggs;
    Data power(const Data& x, long long p, Data(*agg)(Data, Data)){
        if(p <= 0ll) return Data(0);
        if(p == 1ll) return x;
        Data res = power(x, p/2, agg);
        Data res2 = agg(res, res);
        if(p % 2 == 1){
            return agg(res2, x);
        }
        return res2;
    }

    spt():root(nullptr){
        root = nullptr;
    }

    void deleteonce(node* cur){
        for(int i = 0; i <= 1; ++i) cur->c[i] = nullptr;
        cur->cnt = 0;
        cur->sz = 0;
        delete cur;
    }

    bool get(node* x) { 
        return x == (x->fa->c[1]);
    }

    void update(node* x){
        if(!x) return;
        x->sz = x->cnt;
        for(auto [k, agg]: aggs){
            x->res[k] = power(x->d, x->cnt, agg);
        }
        for(int i = 0; i <= 1; ++i) {
            if(x->c[i]) {
                x->sz+=x->c[i]->sz;
                for(auto [k, agg]: aggs){
                    if(agg) x->res[k] = (*agg)(x->res[k], x->c[i]->res[k]);
                }
            }
        }
    }

    void rotate(node* x){
        node *y = x->fa;
        node *z = y->fa;
        int f = get(x);
        y->c[f] = x->c[!f];
        if(x->c[!f]){
            x->c[!f]->fa = y;
        } 
        x->c[!f] = y;
        y->fa = x;
        x->fa = z;
        if(z){
            z->c[z->c[1]==y] = x;
        }
        update(y);
        update(x);
    }

    void splay(node *x){
        if(!x) return;
        for(node* f = x->fa; f = x->fa, f; rotate(x)){
            if(f->fa) rotate(get(x) == get(f)?f:x);
        }
        root = x;
    }

    void insert(const Data& val, int times=1){
        if(!root){
            root = new node;
            root->cnt = times;
            root->d = val;
            update(root);
            return;
        }
        node *cur = root, *f=nullptr;
        while(1){
            if(equals(cur->d, val)){
                cur->cnt+=times;
                update(cur);
                update(f);
                splay(cur);
                break;
            }
            f = cur;
            cur = cur->c[lesswrapper(cur->d, val)];
            if(!cur){
                cur = new node;
                cur->cnt+=times;
                cur->fa=f;
                cur->d=val;
                f->c[lesswrapper(f->d, val)]=cur;
                update(cur);
                update(f);
                splay(cur);
                break;
            }
        }
    }

    int rk(const Data& val){
        if(!root) return -1;
        int res = 0;
        node *cur = root;
        while(cur){
            if(lesswrapper(val, cur->d)){
                cur = cur->c[0];
            }else{
                res += cur->c[0]?cur->c[0]->sz:0;
                if(equals(val, cur->d)){
                    splay(cur);
                    return res;
                }
                res += cur->cnt;
                cur = cur->c[1];
            }
        }
        return -1;
    }

    int sz(){
        return root?root->sz:0;
    }

    std::pair<node*, std::map<int, Data>> kth(int k){
        k++;
        node* cur = root;
        std::map<int, Data> ans;
        if(!root || k <= 0 || k > sz()){
            return std::make_pair(nullptr, ans);
        }
        while(1){
            if(cur && cur->c[0] && k <= cur->c[0]->sz){
                cur = cur->c[0];
            }else{
                if(cur && cur->c[0]){
                    for(auto [k, agg]: aggs){
                        if(agg) ans[k] = agg(ans[k], cur->c[0]->res[k]);
                    }
                }
                int tmp = k;
                int add = ((cur?cur->cnt:0) + ((cur&&cur->c[0])?cur->c[0]->sz:0));
                k -= add;
                if(k <= 0){
                    int remain = tmp - (cur&&cur->c[0]?cur->c[0]->sz:0);
                    for(auto [k, agg]: aggs){
                        if(agg) ans[k] = agg(ans[k], power(cur->d, remain, agg));
                    }
                    splay(cur);
                    return std::make_pair(cur, ans);
                }
                if(cur) {
                    for(auto [k, agg]: aggs){
                        if(agg) ans[k] = agg(ans[k], power(cur->d, cur->cnt, agg));
                    }
                    cur = cur->c[1];
                }
            }
        }
        return std::make_pair(nullptr, ans);
    }

    std::map<int, Data> aggall(){
        int splay_sz = sz();
        if(!splay_sz) return std::map<int, Data>{};
        const auto& [n, m] = kth(splay_sz-1);
        return m;
    }

    std::pair<int, bool> aggall(int id){
        auto&& m = aggall();
        return std::make_pair(m[id], m.count(id) > 0);
    }

    node* pre(){
        node* cur = root->c[0];
        if(!cur) return cur;
        while(cur->c[1]) cur = cur->c[1];
        splay(cur);
        return cur;
    }

    node* nxt(){
        node* cur = root->c[1];
        if(!cur) return cur;
        while(cur->c[0]) cur = cur->c[0];
        splay(cur);
        return cur;
    }

    node* lb(const Data& val){//lower_bound
        node* ans = nullptr;
        node* cur = root;
        while(cur){
            if(!lesswrapper(cur->d, val)){
                ans = cur;
                cur = cur->c[0];
            }else{
                cur = cur->c[1];
            }
        }
        splay(ans);
        return ans;
    }

    node* ub(const Data& val){//upper_bound
        node* ans = nullptr;
        node* cur = root;
        while(cur){
            if(lesswrapper(val, cur->d)){
                ans = cur;
                cur = cur->c[0];
            }else{
                cur = cur->c[1];
            }
        }
        splay(ans);
        return ans;
    }

    node* rightmost(){
        node *cur = root, *f = cur;
        while(cur){
            f = cur;
            cur = cur->c[1];
        }
        splay(f);
        return f;
    }

    node* pre(const Data& x){
        node* p = lb(x);
        if(p) return pre();
        return rightmost();
    }

    node* nxt(const Data& x){
        return ub(x);
    }

    void del(const Data& val, int times=1){
        if(rk(val) == -1) return;
        if(root && root->cnt > times && times != -1){
            root->cnt -= times;
            update(root);
            return;
        }
        if(!root->c[0] && !root->c[1]){
            delete root;
            root = nullptr;
            return;
        }
        node* cur = root;
        if(!root->c[0]){
            root = root->c[1];
            root->fa = nullptr;
            deleteonce(cur);
            return;
        }
        if(!root->c[1]){
            root = root->c[0];
            root->fa = nullptr;
            deleteonce(cur);
            return;
        }
        cur = root;
        node* x = pre();
        cur->c[1]->fa = x;
        if(x) x->c[1] = cur->c[1];
        deleteonce(cur);
        update(root);
    }

    void erase(const Data& val, int times=1){
        del(val, times);
    }

    struct iter{
        node* t;
        spt* s;
        iter(node* t, spt* s): t(t), s(s) {}
        bool operator!=(iter rhs) {return t != rhs.t;}
        node* operator->() {return t;}
        node& operator*() {return *t;}
        void operator++() {t = s->nxt();}
        void operator--() {t = s->pre();}
    };

    iter begin(){
        kth(0);
        return iter(root, this);
    }

    iter end(){
        return iter(nullptr, this);
    }

    iter rbegin(){
        kth(sz()-1);
        return iter(root, this);
    }

    iter rend(){
        return iter(nullptr, this);
    }

    std::vector<Data> r2l(){
        std::vector<Data> res(sz());
        int cnt = 0;
        for(auto x = rbegin(); x != rend(); --x){
            res[cnt] = (*x).d;
            cnt++;
        }
        return res;
    }

    std::vector<Data> l2r(){
        std::vector<Data> res(sz());
        int cnt = 0;
        for(auto x: *this){
            res[cnt] = x.d;
            cnt++;
        }
        return res;
    }

    #if DATACOUT
    void debug(node* cur){
        if(cur){
            std::cout << "cur->d==" << cur->d << ", cur->cnt/sz==" << cur->cnt << "/" << cur->sz << ", cur==" << cur << ", cur->fa==" << cur->fa << ", cur->c=={" << cur->c[0] << ", " << cur->c[1] << "}\n";
            debug(cur->c[0]);
            debug(cur->c[1]);
        }
    }

    void debug(){
        debug(root);
        std::cout << "\n";
    }

    void debugres(const std::map<int, Data>& f){
        for(auto [k, d]: f){
            std::cout << "k==" << k << ", d==" << d << "\n";
        }
    }

    void debugv(const std::vector<Data>& v){
        for(int i = 0; i < v.size(); ++i){
            std::cout << "v[" << i << "]==" << v[i] << "\n";
        }
    }
    #endif
};

#define DEBUG 1

void debug(const char* p){
    #if DEBUG
    freopen(p, "r", stdin); 
    #else
    fastio;
    #endif      
}

spt<int> splay;
int main(void){
    debug("test1.txt");
    splay.aggs[0] = std::gcd;
    int q;
    std::cin >> q;
    for(int i = 0; i < q; ++i){
        char ch;
        int x;
        std::cin >> ch >> x;
        if(ch == '+') splay.insert(x);
        else if(ch == '-') splay.erase(x);
        std::cout << std::boolalpha << splay.aggall(0).first << "\n";
    }
}

/*
7
+ 7
- 5
+ 14
+ 21
+ 3
- 14
- 3
*/
»
15 months ago, # |
  Vote: I like it +9 Vote: I do not like it

No need to use a dynamic tree, just use a large enough fixed size segment tree with GCD as the operation and a map to track where in the segtree your elements are.