import java.util.*; public class Main{ public static void main(String[] args){ Scanner in = new Scanner(System.in); int T = in.nextInt(); for(int cas = 0; cas < T; cas++){ long n = in.nextLong(), m = in.nextLong(); System.out.println(solve(n,m)); } } static long solve(long n, long m){ if(m >= (n / 2)) return (n * (n-1)) / 2; long sigma = (m * (m-1)) * 2; return m * (n-1+n-2) - sigma; } }