https://www.acmicpc.net/problem/9556

 

9556번: 수학 숙제

각 테스트 케이스마다 조건을 만족하는 N자리 정수의 개수를 1,000,000,007로 나눈 값을 출력한다.

www.acmicpc.net

매겨진 난이도가 높지만 아이디어가 어려운 문제는 아니다. 다음의 필요한 사전 지식을 안다면 쉽게 해결할 수 있다.

1. 분할 정복을 이용한 거듭제곱 - 티어 올린 주범인 듯 하다. 알면 쉬움.

2. 페르마 소정리 - 페르마 소정리 자체를 알지 못해도, 거듭제곱이 mod n에서 주기성을 가진다는 사실을 이해하면 충분하다.

 

우선은 1, 2, 3, 4, 5, 6의 배수인지 여부가 문제 정답에 큰 기여를 한다는 사실은 자명하다. 따라서 이들의 최소공배수를 구하자면 60이다. 이 60을 주기로 하여 반복된 패턴이 나타남을 예상할 수 있다. 필자는 최소공배수를 120으로 착각하여 풀이하였으나, 충분히 작은 공배수라면 어떤 것을 사용하여도 무관하다. 후술할 풀이에서는 주기를 120으로 사용한다.

 

120을 주기로 답에 포함이 되는 수인지 여부가 반복된다는 사실을 관찰하였다. 따라서 문제에서 제시된 10n 미만의 자연수 또는 0에 과연 몇 개의 주기가 들어가는지 확인하기 위하여 10n을 120으로 나누어보자. n의 값이 1, 2, 3,...이 됨에 따라 10n120×0+1,120×0+10,120×0+100,120×8+40,120×83+40,...으로 표현됨을 확인할 수 있다.

 

즉, n이 3 이상이면 10nmod120=40이 성립한다는 뜻이다. 이를 수학적 귀납법으로 증명할 수 있다. 10n=120x+40을 만족하는 정수 x가 존재한다고 가정하자. 이때 10n+1=10×(120x+40)=1200x+400=120×(10x+3)+40이므로, 10n+1 역시 120으로 나눈 나머지가 40이다. n=3일 때부터 해당 가정이 성립하므로 수학적 귀납법에 의하여 n이 3 이상일 때 10nmod120은 40이다.

 

또 한 가지 위의 과정에서 관찰할 수 있는 사실은, 10n을 120으로 나눈 몫이 an이라 할 때, an+1=10an+3이 성립한다는 사실이다. 만약 an을 구할 수 있다면 0 이상 120 미만의 정수 중 입력 조건을 만족하는 정수의 수 s, 0 이상 40 미만의 정수 중 입력 조건을 만족하는 정수의 수 t에 대하여 정답은 s×an+t임을 쉽게 알 수 있다. 이러한 논리 흐름을 풀이에 적용하기 위하여 an을 빠르게 구해야 하며 이는 행렬 곱셈을 활용하여 해결 가능하다.

 

an+1=10an+3 꼴의 점화식을 계산하는 것은 다음 식을 통해서도 가능하다. 성립 여부는 우변을 직접 계산하면 점화식에 의해 좌변과 같이 정리할 수 있음을 통해 확인할 수 있다.

[an+11]=[10301]×[an1]

 

따라서 다음과 같이 행렬 거듭제곱에 관한 식으로 표현하면 분할 정복을 활용하여 쉽게 4 이상의 n에 대해 an을 계산할 수 있다.

[an1]=[10301]n3×[a31]

 

이후에는 위에서 서술한 바와 같이 s×an+t를 출력하면 된다. 더불어 해당 풀이는 n이 4 이상일 때 구현하기 쉬우므로, n이 3 이하일 때는 brute force를 활용하여 답을 구하도록 하면 깔끔하게 문제를 해결할 수 있다.

 

#include<bits/stdc++.h>
using namespace std;
using ll=long long;

const ll MOD = 1000000007;

ll ipow(ll a,ll n,const int mod){
    if(n==1)return a;
    if(n==2)return 1ll*a*a%mod;
    ll res=ipow(a*a%mod,n/2,mod);
    if(n%2)res=res*a%mod;
    return res;
}

struct matrix{
    ll a,b;
    ll c,d;
    matrix operator*(const matrix t){
        return {(a*t.a%MOD+b*t.c%MOD)%MOD,(a*t.b%MOD+b*t.d%MOD)%MOD,(c*t.a%MOD+d*t.c%MOD)%MOD,(c*t.b%MOD+d*t.d%MOD)%MOD};
    }
};

matrix mpow(matrix a,ll n){
    if(n==0)return {1,0,0,1};
    if(n==1)return a;
    if(n==2)return a*a;
    matrix r=mpow(a*a,n/2);
    if(n%2)r=r*a;
    return r;
}

void tc(){
    ll n;
    scanf("%lld",&n);
    if(n<=3){
        int x[7];
        for(int i=1;i<=6;i++){
            scanf("%1d",&x[i]);
        }
        int cnt=0;
        for(int i=0;i<(n==1?10:(n==2?100:1000));i++){
            int c=0;
            for(int j=1;j<=6;j++){
                c+=((x[j]==0&&i%j!=0)||(x[j]==1&&i%j==0)||(x[j]==2));
            }
            cnt+=(c==6);
        }
        printf("%d\n",cnt);
        return;
    }
    int arr[120]={1,};
    for(int i=0;i<120;i++)arr[i]=1;
    for(int i=1;i<=6;i++){
        int x;
        scanf("%1d",&x);
        if(x==0){
            for(int j=0;j<120;j++)if(j%i==0)arr[j]=0;
        }
        if(x==1){
            for(int j=0;j<120;j++)if(j%i!=0)arr[j]=0;
        }
    }
    matrix m = mpow({10,3,0,1},n-3);
    ll r = (m.a*8+m.b)%MOD;
    ll x1=0, x2=0;
    for(int i=0;i<120;i++){
        x1+=arr[i];
    }
    for(int i=0;i<40;i++){
        x2+=arr[i];
    }
    printf("%lld\n",(x1*r%MOD+x2)%MOD);
}

int32_t main(void){
    int TC;
    scanf("%d",&TC);
    while(TC--){
        tc();
    }
}

+ Recent posts