2179. Count Good Triplets in an Array
A good triplet is a set of 3 distinct values which are present in increasing order by position both in nums1 and nums2. In other words, if we consider pos1v as the index of the value v in nums1 and pos2v as the index of the value v in nums2, then a good triplet will be a set (x, y, z) where 0 <= x, y, z <= n - 1, such that pos1x < pos1y < pos1z and pos2x < pos2y < pos2z.
Return the total number of good triplets.
Example 1:
Input: nums1 = [2,0,1,3], nums2 = [0,1,2,3]
Output: 1
Explanation:
There are 4 triplets (x,y,z) such that pos1x < pos1y < pos1z. They are (2,0,1), (2,0,3), (2,1,3), and (0,1,3).
Out of those triplets, only the triplet (0,1,3) satisfies pos2x < pos2y < pos2z. Hence, there is only 1 good triplet.
Example 2:
Input: nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]
Output: 4
Explanation: The 4 good triplets are (4,0,3), (4,0,2), (4,1,3), and (4,1,2).
- There are multiple approaches that one could take to solve this problem. Let us explore this by order of difficulty
1. Brute Force:
- We can generate all the triplets in both the lists and then see if there are common triplets formed from both the lists that satisfy the condition pos1x < pos1y < pos1z
- The TC of this solution is O(n^3) and SC of this approach is also O(n^3) to store all the triplets we generate
2. Hash Map + Brute Force
- We need to follow these steps -
- Build a position map for nums2 so we can easily look up where each value from nums1 appears in nums2.
- Loop through each element nums1[i], and for each:
- Find its position in nums2
- Count how many elements we’ve already seen (before i) that appear before it in nums2.
- This is done using bisect.bisect() on a sorted list of seen positions (seen_positions)
- Count how many elements will come after i in nums1 that will also come after it in nums2.
- This ensures the triplet will be increasing in both arrays
- Multiply count_before * count_after to get the number of triplets where nums1[i] is in the middle.
- Insert the current position (from nums2) into seen_positions for the next round.
from typing import List
import bisect
class Solution:
def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int:
n = len(nums1)
res = 0
# Map each value in nums2 to its index
index_in_nums2 = [0] * n
for i in range(n):
index_in_nums2[nums2[i]] = i
seen_positions = [] # Holds the positions (in nums2) of elements seen so far in nums1
for i in range(n):
# Find the corresponding index of nums1[i] in nums2
pos_in_nums2 = index_in_nums2[nums1[i]]
# Find how many seen elements are less than this position (binary search)
count_before = bisect.bisect(seen_positions, pos_in_nums2)
# Insert current position in sorted order
bisect.insort(seen_positions, pos_in_nums2)
# Elements after this position in both arrays
count_after = (n - 1 - i) - (len(seen_positions) - count_before - 1)
# Multiply the number of valid 'before' and 'after' elements
res += count_before * count_after
return res
TC is O(n log n)
SC is O(n)
3. Best approach is using Trie / Fenwick tree
1. Goal
Count triplets (i, j, k) such that: i < j < k
nums1[i], nums1[j], nums1[k] and nums2[i'], nums2[j'], nums2[k'] have the same relative order
2. Key Idea
Map the problem to:
Count increasing triplets in a transformed array
(i.e., after mapping nums1's values to their positions in nums2)
arr[i] = position of nums1[i] in nums2
Then, count triplets where:
arr[i] < arr[j] < arr[k] and i < j < k
- We use a Fenwick Tree (Binary Indexed Tree) to efficiently:
- Count how many numbers less than arr[i] have appeared so far (to the left)
- Deduce how many numbers greater than arr[i] will appear later (to the right)
- For each element at index i in arr:
- left = count of numbers < arr[i] before i → use tree.query(arr[i] - 1)
- right = count of numbers > arr[i] after i
- → total remaining = n - 1 - arr[i]
- → already counted = i - left
- → so right = (n - 1 - arr[i]) - (i - left)
- Total triplets contributed by arr[i] as the middle element: res += left * right
class FenwickTree:
def __init__(self, size):
self.tree = [0] * (size + 1)
def update(self, index, delta):
index += 1 # 1-based indexing
while index < len(self.tree):
self.tree[index] += delta
index += index & -index
def query(self, index):
index += 1 # 1-based indexing
res = 0
while index > 0:
res += self.tree[index]
index -= index & -index
return res
class Solution:
def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int:
n = len(nums1)
# Map each number to its index in nums2
pos_in_nums2 = [0] * n
for i, val in enumerate(nums2):
pos_in_nums2[val] = i
# Create array where nums1[i] is mapped to its position in nums2
arr = [0] * n
for i, val in enumerate(nums1):
arr[i] = pos_in_nums2[val]
tree = FenwickTree(n)
res = 0
for i, val in enumerate(arr):
left = tree.query(val - 1) # Count of elements < val so far
right = (n - 1 - val) - (i - left) # Remaining elements > val
res += left * right
tree.update(val, 1) # Mark val as seen
return res