题目 #
题目描述 #
一个序列的子序列是指从原序列中选取部分元素,并保持它们在原序列中的相对顺序所形成的新序列。这意味着在子序列中,元素的相对顺序与它们在原序列中的相对顺序保持一致,但不一定要求是连续的。注意,原序列也可以视为自己的子序列。
现有一个长度为 n 的仅由小写字母组成的字符串 s,求 s 有多少个子序列恰好包含 k 种字母。
输入描述 #
输入仅包含一组测试数据。测试数据包含两行:
第一行包含两个整数 n 和 k(1 ≤ n ≤ 1000,1 ≤ k ≤ 26),用空格隔开,表示字符串的长度和符合要求的子序列的所需长度。
第二行是一个由小写字母组成的字符串,长度为 n。
输出描述 #
对于输入的测试数据,输出一个整数,表示计算得到的答案。
输入示例 #
6 5
eecbad
输出示例 #
3
提示信息 #
s有两个子序列"ecbad"满足要求(重复也算),同时s自身也满足要求,所以答案是3; 注意:由于答案会比较大,所以需要将答案对(10^9 + 7)取模以后再输出。
解答 #
1 最基础的做法 #
- 求出字符串中每一个出现的字母的频率,求这个是首先判断有多少种类的字符串,每一种字符串分别有多少个。(由于不同位置的、相同类别的字符串算是不同的子串,因此
aabc的ab子串,就有两种(第一个a和第二个位置的a)) - 将不同种类的字母转为一个数组,准备遍历这个数组ARRAY,来实现要求:不同的K个字母的子序列。
- 使用dfs进行遍历ARRAY,使用一个临时变量数组记录当前的长度,当遍历长度为K时,当前分支停止,开始计算,由于每一个字符有一定的数量,例如a有3个,则其分布在 三个位置,也就是三个位置都可以取,或者也可以不取,但是至少要取一个,不然就没有a这个字母的出现了,因此对于aaa来说,其全部的出现或者不出现为 ,但是要去除掉全部都不出现的情况,因此最终的结果是 ,对于每一种类型的字母都是一样的。
- 由于指数的运算增长很快,因此需要边进行计算,边进行取模的运行。由于最后的结果是相乘然后求和,因此在过程中进行取模不会影响最终的结果。在这个过程中,使用快速幂的方式加速求解。快速幂是一种高效计算 的算法,通过二分法减少乘法运算的次数,时间复杂度为 。
快速幂的计算
public static long pow(long base,int exp,int MOD){
long result=1;
while(exp>0){
if(exp%2==1){
result=(result*base)%MOD;
}
base=(base*base)%MOD
exp=exp/2;
}
return result;
}
完整的代码实现如下:
import java.io.*;
import java.util.*;
import java.math.BigInteger;
public class Main{
static int MOD=1000000007;
static long res=0;
static int oriK;
public static void dfs(int index,int[] cnt,List<Character> list,List<Character> path){
if(path.size()==oriK){
long temp=1;
for(char ch:path){
temp = (temp * (lowmod(2, cnt[ch-'a'], MOD) - 1)) % MOD;
}
res=(int)((res+temp)%MOD);
return;
}
for(int i=index;i<list.size();i++){
path.add(list.get(i));
dfs(i+1,cnt,list,path);
path.remove(path.size()-1);
}
}
public static long lowmod(long base,int exp,int mod){
long tmp=1;
while(exp>0){
if(exp%2==1){
tmp=(tmp*base)%mod;
}
base=(base*base)%mod;
exp/=2;
}
return tmp;
}
public static void main(String[] args) throws IOException{
BufferedReader in=new BufferedReader(new InputStreamReader(System.in));
PrintWriter out=new PrintWriter(System.out);
String line=in.readLine();
String[] tempArr=line.split(" ");
int n=Integer.parseInt(tempArr[0]);
int k=Integer.parseInt(tempArr[1]);
oriK=k;
line=in.readLine();
HashMap<Character,Integer> map=new HashMap<>();
int[] cnt=new int[26];
for(char ch:line.toCharArray()){
map.put(ch,map.getOrDefault(ch,0)+1);
cnt[ch-'a']++;
}
List<Character> list=new ArrayList<>(map.keySet());
dfs(0,cnt,list,new ArrayList<Character>());
out.println(res);
out.flush();
}
}
/**************************************************************
Problem: 1028
User: odCYZ6plxwtdtFD_XTmXoIzEjmnQ [kamaCoder97850]
Language: Java
Result: 时间超限
****************************************************************/
该方法超过时间限制的原因是还没对dfs的递归过程进行优化,只是考虑到了MOD的影响因素。
2 优化dfs #
在这段dfs的代码中
public static void dfs(int index,int[] cnt,List<Character> list,List<Character> path){
if(path.size()==oriK){
long temp=1;
for(char ch:path){
temp = (temp * (lowmod(2, cnt[ch-'a'], MOD) - 1)) % MOD;
}
res=(int)((res+temp)%MOD);
return;
}
for(int i=index;i<list.size();i++){
path.add(list.get(i));
dfs(i+1,cnt,list,path);
path.remove(path.size()-1);
}
}
有几个点可以优化:
- 现在是等到最后才开始计算全部的内容,这个会导致开销比较大,可以在dfs的过程中直接进行计算,这样最后就没有计算开销了
- 在for循环中,当前只考虑了一个限制,就是
i<list.size(),但其实在计算到数组靠后位置时,如果前面不选的字母过多的话,后面其实字母不够用,直接删除剪枝。
public static void dfs(int index,int[] cnt,long tmp,int deep){
if(deep>=oriK){
res=(int)((res+tmp)%MOD);
return;
}
deep++;
for(int i=index;i<list.size()&&list.size()-i>oriK-deep;i++){
long t=tmp;
tmp=(tmp*(lowmod(2L,cnt[list.get(i)-'a'],MOD)-1)%MOD)%MOD;
dfs(i+1,cnt,tmp,deep);
tmp=t;
// dfs(i+1,cnt,tmp,chooseNum);
}
}
在这段代码中,dfs(i+1,cnt,tmp,deep);只需要在for循环中进行一遍即可,只需要遍历选择当前字母list[i]的情况,如果不选择当前的字母,在进行下一次循环的过程中,相当于是进行了一个没有选择该字母的情况,相当于是原始的tmp进行循环。
最终的代码实现:
import java.io.*;
import java.util.*;
import java.math.BigInteger;
public class Main{
static int MOD=1000000007;
static long res=0;
static int oriK;
static List<Character> list;
public static void dfs(int index,int[] cnt,long tmp,int deep){
if(deep>=oriK){
res=(int)((res+tmp)%MOD);
return;
}
deep++;
for(int i=index;i<list.size()&&list.size()-i>oriK-deep;i++){
long t=tmp;
tmp=(tmp*(lowmod(2L,cnt[list.get(i)-'a'],MOD)-1)%MOD)%MOD;
dfs(i+1,cnt,tmp,deep);
tmp=t;
// dfs(i+1,cnt,tmp,chooseNum);
}
}
public static long lowmod(long base,int exp,int mod){
long tmp=1;
while(exp>0){
if(exp%2==1){
tmp=(tmp*base)%mod;
}
base=(base*base)%mod;
exp/=2;
}
return tmp;
}
public static void main(String[] args) throws IOException{
BufferedReader in=new BufferedReader(new InputStreamReader(System.in));
PrintWriter out=new PrintWriter(System.out);
String line=in.readLine();
String[] tempArr=line.split(" ");
int n=Integer.parseInt(tempArr[0]);
int k=Integer.parseInt(tempArr[1]);
oriK=k;
line=in.readLine();
HashMap<Character,Integer> map=new HashMap<>();
int[] cnt=new int[26];
for(char ch:line.toCharArray()){
map.put(ch,map.getOrDefault(ch,0)+1);
cnt[ch-'a']++;
}
list=new ArrayList<>(map.keySet());
dfs(0,cnt,1L,0);
out.println(res);
out.flush();
}
}