#include using namespace std; void solve() { long long n, m, ans = 0; scanf("%lld %lld", &n, &m); if (m >= (n + 1) / 2) ans = n * (n - 1) / 2; else ans = m * (m - 1) + 2 * m * (n - 2 * m) + m * m; printf("%lld\n", ans); } int main() { int T; cin >> T; while (T--) { solve(); } return 0; }