2518. Number of Great Partitions
Description
You are given an array nums
consisting of positive integers and an integer k
.
Partition the array into two ordered groups such that each element is in exactly one group. A partition is called great if the sum of elements of each group is greater than or equal to k
.
Return the number of distinct great partitions. Since the answer may be too large, return it modulo 109 + 7
.
Two partitions are considered distinct if some element nums[i]
is in different groups in the two partitions.
Example 1:
Input: nums = [1,2,3,4], k = 4 Output: 6 Explanation: The great partitions are: ([1,2,3], [4]), ([1,3], [2,4]), ([1,4], [2,3]), ([2,3], [1,4]), ([2,4], [1,3]) and ([4], [1,2,3]).
Example 2:
Input: nums = [3,3,3], k = 4 Output: 0 Explanation: There are no great partitions for this array.
Example 3:
Input: nums = [6,6], k = 2 Output: 2 Explanation: We can either put nums[0] in the first partition or in the second partition. The great partitions will be ([6], [6]) and ([6], [6]).
Constraints:
1 <= nums.length, k <= 1000
1 <= nums[i] <= 109
Solution
number-of-great-partitions.py
class Solution:
def countPartitions(self, nums: List[int], k: int) -> int:
N = len(nums)
M = 10 ** 9 + 7
if sum(nums) < 2 * k: return 0
@cache
def dp(index, s):
if index >= N: return 1
skip = dp(index + 1, s)
take = 0 if s + nums[index] >= k else dp(index + 1, s + nums[index])
return (skip + take) % M
return (pow(2, N, M) - 2 * dp(0, 0) + M) % M
number-of-great-partitions.cpp
class Solution {
public:
int N, K;
int MOD = 1e9 + 7;
int cache[1001][1001];
int powmod (int a, int b, int k) {
int result = 1 ;
while (b--) {
result *= a ;
result %= k ;
}
return result ;
}
int dp(vector<int>& nums, int index, int s) {
if (cache[index][s] != -1) {
return cache[index][s];
}
if (s >= K) return 0;
if (index >= N) return 1;
int skip = dp(nums, index + 1, s);
int take = s + nums[index] <= 1000 ? dp(nums, index + 1, s + nums[index]) : 0;
return cache[index][s] = (skip + take) % MOD;
}
int countPartitions(vector<int>& nums, int k) {
long long total = accumulate(nums.begin(), nums.end(), 0LL);
if (total < 2 * k) return 0;
memset(cache, -1, sizeof(cache));
N = nums.size();
K = k;
int all = powmod(2, N, MOD);
int invalid = 2 * dp(nums, 0, 0);
return (all - invalid + MOD) % MOD;
}
};