yukicoder No.1300 Sum of Inversions

問題

解法
  1. 空の数列に、A_iの値の昇順(値が同じ場合はiの昇順)に元の位置に挿入していくことを考える。
  2. 挿入した値を(A_i, A_j, A_k)の組のA_iとして扱う場合、挿入したA_iについて選択可能な(A_j, A_k)の全ての組の数c_2A_j+A_kの総和v_2がわかれば、このA_iに対するA_i+A_j+A_kの総和は、A_i \times c_2 + v_2と計算できる。
  3. 挿入した値をA_jとして扱う場合、挿入したA_jについて選択可能なA_kの数c_1A_kの総和v_1がわかれば、このA_jに対するA_j+A_kの総和は、A_j \times c_1 + v_1と計算できる。
  4. 上記3.のc_1v_1は、BITにA_kの情報を設定しておくことで求めることができる。
  5. 上記2.のc_2v_2は、BITに上記3.の計算結果を設定しておくことで求めることができる。
具体例

サンプル2の入力値:「4, 2, 3, 1」で4を挿入時。

  • BIT1(値)の内容・・・・-, 2, 3, 1 →挿入位置より後ろの合計により、v_1=6
  • BIT1(件数)の内容・・・-, 1, 1, 1 →挿入位置より後ろの合計により、c_1=3
  • BIT2(値)の内容・・・・-, 3, 4, 0 →挿入位置より後ろの合計により、v_2=73(2,1)4(3,1)由来)
  • BIT2(件数)の内容・・・-, 1, 1, 0 →挿入位置より後ろの合計により、c_2=2
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.Arrays;

public class Main {
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine());
        String[] sa = br.readLine().split(" ");
        Obj[] arr = new Obj[n];
        for (int i = 0; i < n; i++) {
            Obj o = new Obj();
            o.i = i + 1;
            o.a = Integer.parseInt(sa[i]);
            arr[i] = o;
        }
        br.close();

        // aの昇順、aが同じならiの昇順
        Arrays.sort(arr, (o1, o2) -> {
            if (o1.a != o2.a) {
                return o1.a - o2.a;
            } else {
                return o1.i - o2.i;
            }
        });

        int mod = 998244353;
        BIT bit1 = new BIT(n + 1);
        BIT bit1c = new BIT(n + 1);
        BIT bit2 = new BIT(n + 1);
        BIT bit2c = new BIT(n + 1);
        long ans = 0;
        for (int i = 0; i < n; i++) {
            Obj o = arr[i];
            // i
            long v2 = bit2.sum(n) - bit2.sum(o.i);
            long c2 = bit2c.sum(n) - bit2c.sum(o.i);
            c2 %= mod;
            long val = o.a * c2 + v2;
            ans += val;
            ans %= mod;

            // j
            long v1 = bit1.sum(n) - bit1.sum(o.i);
            long c1 = bit1c.sum(n) - bit1c.sum(o.i);
            c1 %= mod;
            long p = o.a * c1 + v1;
            bit2.add(o.i, p % mod);
            bit2c.add(o.i, c1);

            // k
            bit1.add(o.i, o.a);
            bit1c.add(o.i, 1);
        }
        System.out.println(ans);
    }

    static class Obj {
        int i, a;
    }

    // 以下ライブラリ

    static class BIT {
        int n;
        long[] arr;

        public BIT(int n) {
            this.n = n;
            arr = new long[n + 1];
        }

        void add(int idx, long val) {
            for (int i = idx; i <= n; i += i & -i) {
                arr[i] += val;
            }
        }

        long sum(int idx) {
            long sum = 0;
            for (int i = idx; i > 0; i -= i & -i) {
                sum += arr[i];
            }
            return sum;
        }
    }
}