子序列中的 k 种字母25-01-19

题目 #

题目描述 #

一个序列的子序列是指从原序列中选取部分元素,并保持它们在原序列中的相对顺序所形成的新序列。这意味着在子序列中,元素的相对顺序与它们在原序列中的相对顺序保持一致,但不一定要求是连续的。注意,原序列也可以视为自己的子序列。

现有一个长度为 n 的仅由小写字母组成的字符串 s,求 s 有多少个子序列恰好包含 k 种字母。

输入描述 #

输入仅包含一组测试数据。测试数据包含两行:

第一行包含两个整数 n 和 k(1 ≤ n ≤ 1000,1 ≤ k ≤ 26),用空格隔开,表示字符串的长度和符合要求的子序列的所需长度。

第二行是一个由小写字母组成的字符串,长度为 n。

输出描述 #

对于输入的测试数据,输出一个整数,表示计算得到的答案。

输入示例 #
6 5
eecbad
输出示例 #
3
提示信息 #

s有两个子序列"ecbad"满足要求(重复也算),同时s自身也满足要求,所以答案是3; 注意:由于答案会比较大,所以需要将答案对(10^9 + 7)取模以后再输出。

解答 #

1 最基础的做法 #

  1. 求出字符串中每一个出现的字母的频率,求这个是首先判断有多少种类的字符串,每一种字符串分别有多少个。(由于不同位置的、相同类别的字符串算是不同的子串,因此aabc的ab子串,就有两种(第一个a和第二个位置的a))
  2. 将不同种类的字母转为一个数组,准备遍历这个数组ARRAY,来实现要求:不同的K个字母的子序列。
  3. 使用dfs进行遍历ARRAY,使用一个临时变量数组记录当前的长度,当遍历长度为K时,当前分支停止,开始计算,由于每一个字符有一定的数量,例如a有3个,则其分布在 ai,aj,aka_i,a_j,a_k 三个位置,也就是三个位置都可以取,或者也可以不取,但是至少要取一个,不然就没有a这个字母的出现了,因此对于aaa来说,其全部的出现或者不出现为 232^3 ,但是要去除掉全部都不出现的情况,因此最终的结果是 2312^3-1 ,对于每一种类型的字母都是一样的。
  4. 由于指数的运算增长很快,因此需要边进行计算,边进行取模的运行。由于最后的结果是相乘然后求和,因此在过程中进行取模不会影响最终的结果。在这个过程中,使用快速幂的方式加速求解。快速幂是一种高效计算 abmodca^b \mod c 的算法,通过二分法减少乘法运算的次数,时间复杂度为 O(logb)O(\log b)

快速幂的计算

public static long pow(long base,int exp,int MOD){
  long result=1;
  while(exp>0){
    if(exp%2==1){
      result=(result*base)%MOD;
		}
    base=(base*base)%MOD
    exp=exp/2;
  }
  return result;
}

完整的代码实现如下:

import java.io.*;
import java.util.*;
import java.math.BigInteger;
 
public class Main{
    static int MOD=1000000007;
    static long res=0;
    static int oriK;
    public static void dfs(int index,int[] cnt,List<Character> list,List<Character> path){
        if(path.size()==oriK){
            long temp=1;
            for(char ch:path){
                temp = (temp * (lowmod(2, cnt[ch-'a'], MOD) - 1)) % MOD;
            }
            res=(int)((res+temp)%MOD);
            return;
        }
        for(int i=index;i<list.size();i++){
            path.add(list.get(i));
            dfs(i+1,cnt,list,path);
            path.remove(path.size()-1);
        }
    }
    public static long lowmod(long base,int exp,int mod){
        long tmp=1;
        while(exp>0){
            if(exp%2==1){
                tmp=(tmp*base)%mod;
            }
            base=(base*base)%mod;
            exp/=2;
        }
        return tmp;
    }
    public static void main(String[] args) throws IOException{
        BufferedReader in=new BufferedReader(new InputStreamReader(System.in));
        PrintWriter out=new PrintWriter(System.out);
        String line=in.readLine();
        String[] tempArr=line.split(" ");
        int n=Integer.parseInt(tempArr[0]);
        int k=Integer.parseInt(tempArr[1]);
        oriK=k;
        line=in.readLine();
        HashMap<Character,Integer> map=new HashMap<>();
        int[] cnt=new int[26];
        for(char ch:line.toCharArray()){
            map.put(ch,map.getOrDefault(ch,0)+1);
            cnt[ch-'a']++;
        }
        List<Character> list=new ArrayList<>(map.keySet());
        dfs(0,cnt,list,new ArrayList<Character>());
        out.println(res);
        out.flush();
    }
}
/**************************************************************
    Problem: 1028
    User: odCYZ6plxwtdtFD_XTmXoIzEjmnQ [kamaCoder97850]
    Language: Java
    Result: 时间超限
****************************************************************/

该方法超过时间限制的原因是还没对dfs的递归过程进行优化,只是考虑到了MOD的影响因素。

2 优化dfs #

在这段dfs的代码中

public static void dfs(int index,int[] cnt,List<Character> list,List<Character> path){
        if(path.size()==oriK){
            long temp=1;
            for(char ch:path){
                temp = (temp * (lowmod(2, cnt[ch-'a'], MOD) - 1)) % MOD;
            }
            res=(int)((res+temp)%MOD);
            return;
        }
        for(int i=index;i<list.size();i++){
            path.add(list.get(i));
            dfs(i+1,cnt,list,path);
            path.remove(path.size()-1);
        }
    }

有几个点可以优化:

  1. 现在是等到最后才开始计算全部的内容,这个会导致开销比较大,可以在dfs的过程中直接进行计算,这样最后就没有计算开销了
  2. 在for循环中,当前只考虑了一个限制,就是i<list.size(),但其实在计算到数组靠后位置时,如果前面不选的字母过多的话,后面其实字母不够用,直接删除剪枝。
    public static void dfs(int index,int[] cnt,long tmp,int deep){
        if(deep>=oriK){
            res=(int)((res+tmp)%MOD);
            return;
        }
        deep++;
        for(int i=index;i<list.size()&&list.size()-i>oriK-deep;i++){
            long t=tmp;
            tmp=(tmp*(lowmod(2L,cnt[list.get(i)-'a'],MOD)-1)%MOD)%MOD;
            dfs(i+1,cnt,tmp,deep);
            tmp=t;
            // dfs(i+1,cnt,tmp,chooseNum);
        }
    }

在这段代码中,dfs(i+1,cnt,tmp,deep);只需要在for循环中进行一遍即可,只需要遍历选择当前字母list[i]的情况,如果不选择当前的字母,在进行下一次循环的过程中,相当于是进行了一个没有选择该字母的情况,相当于是原始的tmp进行循环。

最终的代码实现:

import java.io.*;
import java.util.*;
import java.math.BigInteger;
 
public class Main{
    static int MOD=1000000007;
    static long res=0;
    static int oriK;
    static List<Character> list;
    public static void dfs(int index,int[] cnt,long tmp,int deep){
        if(deep>=oriK){
            res=(int)((res+tmp)%MOD);
            return;
        }
        deep++;
        for(int i=index;i<list.size()&&list.size()-i>oriK-deep;i++){
            long t=tmp;
            tmp=(tmp*(lowmod(2L,cnt[list.get(i)-'a'],MOD)-1)%MOD)%MOD;
            dfs(i+1,cnt,tmp,deep);
            tmp=t;
            // dfs(i+1,cnt,tmp,chooseNum);
        }
    }
    public static long lowmod(long base,int exp,int mod){
        long tmp=1;
        while(exp>0){
            if(exp%2==1){
                tmp=(tmp*base)%mod;
            }
            base=(base*base)%mod;
            exp/=2;
        }
        return tmp;
    }
    public static void main(String[] args) throws IOException{
        BufferedReader in=new BufferedReader(new InputStreamReader(System.in));
        PrintWriter out=new PrintWriter(System.out);
        String line=in.readLine();
        String[] tempArr=line.split(" ");
        int n=Integer.parseInt(tempArr[0]);
        int k=Integer.parseInt(tempArr[1]);
        oriK=k;
        line=in.readLine();
        HashMap<Character,Integer> map=new HashMap<>();
        int[] cnt=new int[26];
        for(char ch:line.toCharArray()){
            map.put(ch,map.getOrDefault(ch,0)+1);
            cnt[ch-'a']++;
        }
        list=new ArrayList<>(map.keySet());
        dfs(0,cnt,1L,0);
        out.println(res);
        out.flush();
    }
}