Zadanie

Martin má kocúra Ferka. Ferko je ale mierne obézny kocúr, takže musí začať držať diétu. Martin počul o jednej veľmi efektívnej diéte pre kocúrov - pórová diéta. V tejto diéte môžu mačky jesť iba pór.

Bohužiaľ prestup na túto diétu Ferkovi nepomohol. To preto, že mu Martin dával príliš veľa póru. Na internetovom fóre Kŕmenie Super Parťákov sa dozvedel, že na každej škatuli od póru sú čísla, ktoré určujú, koľko póru má dať svojmu miláčikovi. No a keďže je to pór, tak sa to zisťuje pOR-om. Pomôžete Martinovi zistiť, koľko póru má dať Ferkovi?

Úloha

Máme pole čísiel \(P\). Zaujíma nás, koľko je v ňom takých súvislých úsekov, ktoré majú bitový OR rovný \(k\).

Formát vstupu

V prvom riadku vstupu dostanete dve čísla - \(n\) - počet prvkov v poli \(P\) a \(k\) - bitový OR ktorý sa snažíme docieliť.

Ďalej bude nasledovať jeden riadok, na ktorom bude \(n\) nezáporných celých čísiel.

V jednotlivých sadách platia nasledujúce obmedzenia:

Sada 1 2 3 4
\(1 \leq n \leq\) \(100\) \(10\,000\) \(10^5\) \(5 \cdot 10^5\)
\(1 \leq k \leq\) \(1\,000\) \(10^6\) \(10^9\) \(10^{9}\)
\(1 \leq P_i \leq\) \(1\,000\) \(10^6\) \(10^9\) \(10^{9}\)

Formát výstupu

Vypíšte počet úsekov poľa \(P\), ktoré majú bitový OR rovný \(k\)

Príklady

Input:

2 1
1 1

Output:

3

Úseky, ktoré majú bitový OR \(1\) sú: [0, 0], [0, 1], [1, 1]

Input:

10 1
1 1 1 1 1 1 1 1 1 1

Output:

55

Každý súvislý podúsek má bitový OR 1 (podúsekov je 55)

Input:

5 7
1 2 2 1 1

Output:

0

V žiadnom podúseku nie je bitový OR 7

Bruteforce

Najjednoduchšie riešenie je, že spravíme OR každého súvislého intervalu. Toto riešenie má časovú zložitosť \(O(n^3)\), pretože máme \(n^2\) intervalov a na väčšine z nich musíme spraviť OR rádovo \(n\) čísiel. Toto riešenie má pamäťovú zložitosť \(O(n)\), pretože si potrebujeme zapamätať celý vstup dĺžky \(n\) a zopár pomocných premenných.

O niečo lepšie riešenie

Povedzme, že máme vypočítaný OR intervalu \(P[i, j]\), ktorý sa rovná \(o_{i,j}\). Potom OR intervalu \(P[i, j + 1]\) vypočítame: \(o_{i, j+1} = o_{i, j} \space\) OR \(\space P[j + 1]\). Časová zložitosť tohto riešenia je \(O(n^2)\), keďže musíme prejsť každý interval, ktorých je rádovo \(n^2\). Toto riešenie má pamäťovú zložitosť \(O(n)\), pretože si potrebujeme zapamätať celý vstup dĺžky \(n\) a zopár pomocných premenných.

Zaujímavé myšlienky

Povedzme, že máme vypočítaný OR intervalu \(P[i, j]\), ktorý je \(o_{i, j}\). Vieme z tejto informácie zistiť \(o_{i + 1, j}\)?

Áno, vieme! Len si potrebujeme pamätať zopár extra vecí intervale \(P[i, j]\). Potrebujeme si pamätať, koľko krát boli jednotlivé bity “zapnuté” v intervale \(P[i, j]\). Pamätajme si tútu informáciu v poli \(BITS\). Keď potom chceme vedieť \(o_{i + 1, j}\), tak od každej hodnoty v \(BITS[i]\) odpočítame 1, ak ten bit bol zapnutý v čísle \(P[i]\). \(o_{i + 1, j}\) zistíme tak, že za každú hodnotu \(BITS\), ktorá je väčšia ako 1 pripočítame príslušnú mocninu 2. Pole \(BITS\) má veľkosť rádovo \(\log(\max(P[i]))\), ale keďže \(P[i] \leq 10^9\), čiže \(\log (\max(P[i])) \leq 31\) tak môžeme povedať, že pole \(BITS\) má konštantnú veľkosť.

Poďme sa teraz pozrieť na jednotlivé čísla v poli. Konkrétne o každom čísle chceme vedieť, či môže byť v niektorom intervale, ktorý má OR presne \(k\). Keď má \(P[i]\) zapnutý bit \(b\), ktorý je vypnutý v \(k\), tak nemôže byť v žiadnom intervale, ktorý má OR presne \(k\). \(P[i]\) nám teda rozdelí pole \(P\) na 2 časti, pričom žiadny interval, ktorý začína v prvej časti a končí v druhej časti nemá bitový OR \(k\).

Niečo málo o 2 bežcoch

Ak už viete, ako funguje metóda 2 bežcov, môžete túto časť preskočiť

Optimálne riešenie využíva metódu dvoch bežcov. Tá spočíva v tom, že si spravíme 2 premenné, ktoré budú ukazovať na rôzne prvky v poli. Títo dvaja bežci môžu bežať oproti sebe (na začiatku ich nastavíme, aby jeden ukazoval na začiatok a druhý na koniec poľa) alebo rovnakým smerom (na začiatku obaja ukazujú na začiatok poľa) - toto je prípad tejto úlohy.

Dobrý rule of thumb je, že keď potrebujeme niečo vedieť o podúseku poľa, tak hýbeme bežcami rovnakým smerom. Keď niečo potrebujeme vedieť o dvojici prvkov poľa, tak hýbeme bežcami oproti sebe.

Podľa toho, aký je momentálny podúsek hýbeme 2 bežcami. Keď je podúsek príliš “malý”, tak pohneme bežcom, ktorý ukazuje na koniec podúseku. Naopak, keď je podúsek príliš “veľký”, tak pohneme bežcom, ktorý ukazuje na začiatok podúseku.

Optimálne riešenie

Teraz to už iba celé musíme dať dokopy. Najskôr si pole \(P\) rozdelíme na časti podľa čísiel, ktoré majú zapnutý niektorý bit \(b\), ktorý je vypnutý v \(k\). Keď už ho takto máme rozdelené, tak ideme pre každú časť zistiť, koľko je v nej intervalov, ktoré majú OR \(k\).

To spravíme pomocou 2 bežcov. 1. bežec bude ukazovať na začiatok intervalov, ktoré majú OR \(k\), 2. bežec bude ukazovať na koniec najkratšieho intervalu, ktorý začína na mieste 1. bežca a má OR \(k\). Všetky dlhšie intervaly až po koniec daného úseku majú OR \(k\) a teda ich započítame - je ich toľko, koľko je rozdiel 2. bežca a konca intervalu. Keď už sme našli úsek, ktorý má OR \(k\), tak posunieme 1. bežca. Pri posúvaní využijem výpočet OR-u popísaný vyššie. 2. bežca posúvame, až dokým nenájdeme začiatok najkratšieho úseku, ktorý má OR \(k\) (niekedy ho vôbec nemusíme posunúť). Intervaly budeme zarátavať vtedy, keď pôjdeme pohnúť 1. bežcom. Keďže interval je jednoznačne určený svojím začiatkom a 1. bežcom prejdeme cez všetky možné začiatky intervalov, tak každý interval, ktorý vyhovuje zarátame raz.

Toto riešenie má časovú zložitosť \(O(n\log(\max(P[i])))\), keďže musíme prejsť celým poľom a na každý prvok môžu bežcovia byť najviac 2-krát. \(\log(\max(P[i]))\) je tam kvôli spracovávaniu OR čísiel poľa. Keďže \(\log (\max(P[i])) \leq 31\), tak môžeme povedať, že časová zložitosť je \(O(n)\). Pamäťovú zložitosť má toto riešenie \(O(n)\), pretože si potrebujeme zapamätať celý vstup dĺžky \(n\).

#include <iostream>
#include <vector>

using namespace std;

typedef long long ll;


// funkcia na spocitanie intervalov v useku P[begin, end] (end uz nepatri useku).
// Pouziva 2 pointers pristup.
ll count(ll *arr, ll begin, ll end, ll k){
    ll ans = 0;
    ll OR = 0;
    vector<ll> bits(31, 0);
    ll p2 = begin;
    for (ll p1 = begin; p1 < end; p1++){
        while (OR != k){
            if (p2 >= end){
                break;
            }
            for (int i = 0; i < 31; i++){
                bits[i] += ((arr[p2] & (1 << i)) > 0);
                OR |= (bits[i] > 0) << i;
            }
            p2++;
        }
        if (OR == k){
            ans += end - p2 + 1;
        }
        
        OR = 0;
        for (int i = 0; i < 31; i++){
            bits[i] -= ((arr[p1] & (1 << i)) > 0);
            OR |= (bits[i] > 0) << i;
        }
    }
    return ans;
}

int main(){
    ll n, k;
    cin >> n >> k;
    ll arr[n];
    for (ll i = 0; i < n; i++){
        cin >> arr[i];
    }

    ll ans = 0;
    ll begin = 0;
    ll end = 0;
    for (ll i = 0; i < n; i++){
        ll num = arr[i] | k;
        if (num != k){
            end = i;
            ans += count(arr, begin, end, k);
            begin = i + 1;
        }
    }
    // nesmieme zabudnut na posledny usek 
    ans += count(arr, begin, n, k);

    cout << ans << "\n";
}
arr = []

# funkcia na spocitanie intervalov v useku P[begin, end] (end uz nepatri useku).
# Pouziva 2 pointers pristup.
def count(begin, end, k):
    ans = 0
    result_OR = 0
    bits = []
    for _ in range(31):
        bits.append(0)
    
    p2 = begin
    for p1 in range(begin, end):
        while(result_OR != k):
            if (p2 >= end):
                break
            for i in range(31):
                bits[i] += ((arr[p2] & (1 << i)) > 0) 
                result_OR |= (bits[i] > 0) << i
            p2 += 1
        
        if result_OR == k:
            ans += end - p2 + 1
        
        result_OR = 0
        for i in range(31):
            bits[i] -= ((arr[p1] & (1 << i)) > 0)
            result_OR |= (bits[i] > 0) << i
    
    return ans

n, k = map(int, input().split())
arr = list(map(int, input().split()))

ans = 0
begin = 0
end = 0
for i in range(n):
    num = arr[i] | k
    if (num != k):
        end = i
        ans += count(begin, end, k)
        begin = i + 1

# nesmieme zabudnut na posledny usek    
ans += count(begin, n, k)
print(ans)

Diskusia

Tu môžte voľne diskutovať o riešení, deliť sa o svoje kusy kódu a podobne.

Pre pridávanie komentárov sa musíš prihlásiť.