数位dp踩坑

2022-10-17,,

前言


数位dp是什么?以前总觉得这个概念很高大上,最近闲的没事,学了一下发现确实挺神奇的。

从一道简单题说起


一个数字,如果包含'4'或者'62',它是不吉利的。给定m,n,0<m≤n<10^6,统计[m,n]范围内吉利数的个数。

这题的数据范围比较小,只有1e6,理论上暴力也是可以解的。但数位dp的题目数据范围通常很大,往往达到1e18甚至更大,暴力法o(n)显然会tle。这个时候需要一种时间复杂度近似o(logn)的算法。仔细思考一下,其实可以使用排除法。从高位到低位依次排除0~1e6中不符合条件的数。

举个例子,1~999999中不包含4的数,步骤如下:
1.先排除6位数中最高位是4的数,即400000~499999,只需要判断最高位,就排除了10万个数。
2.接着排除次高位是4的数(此时默认最高位不是4),比如最高位是1时可以排除140000~149999,共1万个。注意,首位可以是0,此时相当于考虑000000~099999(初学者可能有点不理解,看完后面一些不考虑前导0的题就懂了)。
3.同理,继续排除4位数,3位数,直到结束。

数位dp 是对数字的“位”进行的和计数有关的dp,数的每一个位称为数位,一个数有个位、十位、百位、千位……数位dp用来解决和数字操作有关的问题,比如某区间内数字和,特定数字问题等。这些问题的数字范围通常很大,无法暴力解决,必须用接近o(logn)的算法。通过dp对“数位”操作,记录算过的区间的状态,用于后续计算快速筛选数字。

那么这道题用数位dp怎么实现呢?数位dp一般有两种方法,一种是dp预处理乱搞,另一种是记忆化dfs。前者不适合模板化,而后者效率高,模板易记,可以快速上手。

数位dp模板


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll dp[20][20];
int a[20];
ll n,m;
// pos:当前到第几位 pre:上一位数字是什么 lead:是否有前导0 limit:是否有上界限制 (参数通常有pos和limit,是否有lead看题目要求,此外还有可能增加其他参数)
ll dfs(int pos,int pre,bool lead,bool limit){
    if(pos == 0) return 1; //搜完了,返回1.这里通常是返回1,但有些题有限制条件,返回时要判断是否符合限制条件.
    if(!lead && !limit && dp[pos][pre] != -1) return dp[pos][pre]; //记忆化
    int up = limit? a[pos] : 9;  //是否有上界,如果有上界那么不能超过上界.
    ll ans = 0;
    for(int i = 0;i <= up;i++){
        if(lead) ans += dfs(pos - 1,i,lead && i == 0,limit && i == a[pos]);
        else{
            if() continue;
            ans += dfs(pos - 1,i,lead && i == 0,limit && i == a[pos]);
        }
    }
    if(!lead && !limit) dp[pos][pre] = ans;
    return ans;
}
ll solve(ll x){
    int len = 0;
    while(x){   //拆位,a[i]表示这个数的第i位
        a[++len] = x % 10;
        x /= 10;
    }
    memset(dp,-1,sizeof(dp)); //初始化dp数组
    return dfs(len,0,true,true); //从高位向低位枚举
}
int main()
{
    cin >> n >> m;
    cout << solve(m) - solve(n - 1) << endl;
    return 0;
}

这里解答几个初学者常见的疑问

  • 为什么dfs的时候要控制上界(limit)?
    举一个简单的例子,我们要求这个区间[25,628]内满足某种条件的数的个数,这样我们枚举的数肯定不能超过628,因为628是上界。我们从高位往低位枚举,百位的所有可能是0,1,2,3,4,5,6。当百位枚举了1,十位可以枚举0-9,此时相当于十位的枚举没有限制(最大可以到9)。但是,如果百位枚举了6,那么十位只能枚举0-2,因为不能超过628。同理,如果百位枚举了6,十位枚举了2,那么个位只能枚举0-8了。
  • 前导零的影响(lead)?
    为什么百位可以枚举0呢?道理很简单,百位等于0时,相当于此时我们枚举的数是一个两位数,同理,百位和十位都等于0,相当于我们在枚举一个一位数。不过这样会带来前导0的影响,在某些题目,比如:windy数,这种考虑相邻两个数的某种关系的题目的时候,会让记忆化搜索出错,这个时候参数lead就非常重要了,我们通过增加参数来消除影响。

下面我们通过模板来秒掉这道题目吧

dp[pos][sta]表示第pos位前一位是6(sta = 1),或者不是6(sta = 0)时满足条件的数的数目,这里由于数的前一位是否是6会对答案构成影响,所以只需要在统计时把答案分成两类。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m;
ll dp[20][2];
int a[20];
int dfs(int pos,int pre,int sta,bool limit){
    if(pos == 0) return 1;
    if(!limit && dp[pos][sta] != -1) return dp[pos][sta];
    int up = limit?a[pos] : 9;
    int sum = 0;
    for(int i = 0;i <= up;i++){
        if(i == 4) continue;
        if(pre == 6 && i == 2) continue;
        sum += dfs(pos - 1,i,i == 6,limit && i == a[pos]);
    }
    if(!limit) dp[pos][sta] = sum;
    return sum;
}
int solve(int x){
    int len = 0;
    while(x){
        a[++len] = x % 10;
        x /= 10;
    }
    return dfs(len,-1,0,true);
}
int main(){
    memset(dp,-1,sizeof(dp));
    while(~scanf("%d %d",&m,&n) && n && m){
        cout << solve(n) - solve(m - 1) << endl;
    }
    return 0;
}

蓝桥杯 k好数

这题定义了一个叫"k好数"的概念,即该数在k进制下任意相邻两位的数不能相邻(即相差不能等于1),然后要统计l位k进制中k好数的数目。这道题的坑点在于要考虑前导0的影响,比如4进制下的两位数,11,20是k好数,但是32不是k好数。但如果我们求的是4进制下的一位数,问题就来了,如果没有lead这个参数,1这个数就不会被统计,因为传参时默认前一位是0,而0和1相邻,这个时候只需要加一个判断即可。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll mod = 1e9 + 7;
ll dp[105][105];
int k,l;
ll dfs(int pos,int pre,bool lead,bool limit){
    if(pos == 0) return 1;
    if(!lead && !limit && dp[pos][pre] != -1) return dp[pos][pre];
    int up = k - 1;
    ll ans = 0;
    for(int i = 0;i <= up;i++){
        if(lead) ans += dfs(pos - 1,i,lead && i == 0,limit && i == k - 1);
        else{
            if(abs(i - pre) == 1) continue;
            ans = (ans + dfs(pos - 1,i,lead && i == 0,limit && i == k - 1)) % mod;
        }
    }
    if(!lead && !limit) dp[pos][pre] = ans;
    return ans;
}
ll solve(int len){
    memset(dp,-1,sizeof(dp));
    return dfs(len,0,true,true);
}
int main()
{
    cin >> k >> l;
    cout << (solve(l) - solve(l - 1) + mod ) % mod << endl;
    return 0;
}

b-number hdu3652

这题定义了一个叫b数的东西,这个数含"13"这个串并且可以被13整除。难点在于怎么保存被13整除这个状态,其实很简单,再加一个参数tot记录余数,每次枚举%一下13,当pos=0时判断即可。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll dp[15][10][13][2];
int a[15];
ll n;
ll dfs(int pos,int pre,int tot,int ok,bool limit){
    if(pos == 0) return ok == 1 && tot == 0;
    if(!limit && dp[pos][pre][tot][ok] != -1) return dp[pos][pre][tot][ok];
    int up = limit? a[pos] : 9;
    ll ans = 0;
    for(int i = 0;i <= up;i++){
        if(pre == 1 && i == 3) ans += dfs(pos - 1,i,(tot * 10 + i) % 13,1,limit && i == a[pos]);
        else ans += dfs(pos - 1,i,(tot * 10 + i) % 13,ok,limit && i == a[pos]);
    }
    if(!limit) dp[pos][pre][tot][ok] = ans;
    return ans;
}
ll solve(ll x){
    int len = 0;
    while(x){
        a[++len] = x % 10;
        x /= 10;
    }
    return dfs(len,0,0,false,true);
}
int main()
{
    memset(dp,-1,sizeof(dp));
    while(~scanf("%lld",&n)) cout << solve(n) << endl;
    return 0;
}

花神的数论题

这道题挺有意思的,要统计的是每个数二进制下1的个数sum(i),然后求sum(1)到sum(n)的累乘。
    我们要转换一下思维,n上限是1e18,意味着这个数的二进制最多有50位,即最多就50个1,我们分别统计二进制下有一个1,两个1,三个1……五十个1的数的个数,然后用快速幂进行优化,问题就迎刃而解了。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 1e7 + 7;
ll dp[51][51][51];
int a[51];
ll ans[51];
ll n;
ll poww(ll x,ll y,ll p){
    ll ans = 1;
    while(y){
        if(y & 1) ans = ans * x % p;
        x = x * x % p;
        y >>= 1;
    }
    return ans;
}
//位置 当前统计到的1的个数 目标要统计的1的个数 上界限制
ll dfs(int pos,int tmp,int tot,bool limit){
    if(pos == 0) return tmp == tot;
    if(!limit && dp[pos][tmp][tot] != -1) return dp[pos][tmp][tot];
    int up = limit? a[pos] : 1;
    ll sum = 0;
    for(int i = 0;i <= up;i++){
        sum += dfs(pos - 1,tmp + (i == 1),tot,limit && a[pos] == i);
    }
    if(!limit) dp[pos][tmp][tot] = sum;
    return sum;
}
ll solve(ll x){
    int len = 0;
    ll sum = 1;
    while(x){ //拆出二进制的每一位
        a[++len] = x & 1;
        x >>= 1;
    }
    memset(dp,-1,sizeof(dp));
    //这里是一个很巧妙的优化 分别统计二进制中1的个数为i的数的个数,然后快速幂优化
    for(int i = 1;i <= len;i++){
        ans[i] = dfs(len,0,i,true);
        sum = sum * poww(i,ans[i],mod) % mod;
    }
    return sum;
}
int main()
{
    cin >> n;
    cout << solve(n) << endl;
    return 0;
}

洛谷2602 数字计数

这道题大意是求某区间内数字0-9出现的次数。我大概的想法是,分别统计区间内存在1个1,2个1,3个1……的数的个数,存在1个2,2个2,……,1个9,2个9,……的数的个数,然后求累加,当然题解有更好的方法,以下仅供参考(注意前导0的处理,否则统计0的时候会出错)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll dp[15][15][15];
int a[15];
ll ans[10][15];
ll total[10];
ll n,m;
int num;
//前导0很关键!!
ll dfs(int pos,int now,int tot,bool lead,bool limit){
    if(pos == 0) return now == tot;
    if(!limit && !lead && dp[pos][now][tot] != -1) return dp[pos][now][tot];
    int up = limit? a[pos] : 9;
    ll sum = 0;
    for(int i = 0;i <= up;i++){
        if(lead && i == 0) sum += dfs(pos - 1,now,tot,true,limit && i == a[pos]);
        else sum += dfs(pos - 1,now + (i == num),tot,lead && i == 0,limit && i == a[pos]);
    }
    if(!limit && !lead) dp[pos][now][tot] = sum;
    return sum;
}
ll solve(ll x,bool ok){
    int len = 0;
    while(x){
        a[++len] = x % 10;
        x /= 10;
    }
    memset(dp,-1,sizeof(dp));
    memset(ans,0,sizeof(ans));
    for(int i = 0;i < 10;i++){
        num = i;
        for(int j = 1;j <= len;j++){
            ans[i][j] = dfs(len,0,j,true,true);
            if(ok) total[i] += ans[i][j] * j;
            else total[i] -= ans[i][j] * j;
        }
    }
}
int main()
{
    cin >> n >> m;
    solve(m,1);
    solve(n - 1,0);
    for(int i = 0;i < 10;i++) cout << total[i] << " ";
    return 0;
}

洛谷4999 烦人的数学作业

题目的大意是求l-r区间内每个数的各位数字和,假设这个数共len位,那么它的各位数字之和不会超过9*len(因为每位数字最大是9),我们只需要从小到大枚举统计即可,本质思想和上面的题目是一样的。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 1e9 + 7;
int t;
ll n,m;
ll dp[20][180][180],ans[180];
int a[20];
//位置 当前的各位和 目标和 上界限制
ll dfs(int pos,int now,int tot,bool limit){
    if(pos == 0) return now == tot;
    if(!limit && dp[pos][now][tot] != -1) return dp[pos][now][tot];
    int up = limit? a[pos] : 9;
    ll sum = 0;
    for(int i = 0;i <= up;i++){
        sum += dfs(pos - 1,now + i,tot,limit && a[pos] == i);
    }
    if(!limit) dp[pos][now][tot] = sum;
    return sum;
}
ll solve(ll x){
    int len = 0;
    ll sum = 0;
    while(x){
        a[++len] = x % 10;
        x /= 10;
    }
    //统计各位和为1-9*len的所有情况
    for(int i = 1;i <= 9 * len;i++){
        ans[i] = dfs(len,0,i,true);
        sum = (sum + i * (ans[i] % mod)) % mod;
    }
    return sum;
}
int main()
{
    scanf("%d",&t);
    memset(dp,-1,sizeof(dp));
    while(t--){
        scanf("%lld %lld",&n,&m);
        //这里减法要处理!!!
        cout << (solve(m) - solve(n - 1) + mod) % mod << endl;
    }
    return 0;
}

蒟蒻的第一篇blog,有错还请各位神犇指出~~

《数位dp踩坑.doc》

下载本文的Word格式文档,以方便收藏与打印。