/*
* Based on code parts from pegwit program written by George Barwood.
* This code is in the public domain; do with it what you wish.
*
**/
#if defined(__DEBUG__)
#include <cstdio>
#endif
#include "string.h"
#include "stdlib.h"
#include "CP_RSA.h"
static vlong modexp( const vlong & x, const vlong & e, const vlong & m ); // m must be odd
static vlong gcd( const vlong &X, const vlong &Y ); // greatest common denominator
static vlong modinv( const vlong &a, const vlong &m ); // modular inverse
// VLONG.CPP -----------------------------------
class flex_unit // Provides storage allocation and index checking
{
public:
unsigned * a; // array of units
unsigned z; // units allocated
unsigned n; // used units (read-only)
flex_unit();
~flex_unit();
void clear(); // set n to zero
unsigned get( unsigned i ) const; // get ith unsigned
void set( unsigned i, unsigned x ); // set ith unsigned
void reserve( unsigned x ); // storage hint
// Time critical routine
void fast_mul( flex_unit &x, flex_unit &y, unsigned n );
};
class vlong_value : public flex_unit
{
public:
unsigned share; // share count, used by vlong to delay physical copying
int is_zero() const;
int test( unsigned i ) const;
unsigned bits() const;
int cf( vlong_value& x ) const;
void shl();
void shr();
void shr( unsigned n );
void add( vlong_value& x );
void subtract( vlong_value& x );
void init( unsigned x );
void copy( vlong_value& x );
operator unsigned(); // Unsafe conversion to unsigned
vlong_value();
void mul( vlong_value& x, vlong_value& y );
void divide( vlong_value& x, vlong_value& y, vlong_value& rem );
};
unsigned flex_unit::get( unsigned i ) const
{
if ( i >= n ) return 0;
return a[i];
}
void flex_unit::clear()
{
n = 0;
}
flex_unit::flex_unit()
{
z = 0;
a = 0;
n = 0;
}
flex_unit::~flex_unit()
{
unsigned i=z;
while (i) { i-=1; a[i] = 0; } // burn
delete [] a;
}
void flex_unit::reserve( unsigned x )
{
if (x > z)
{
unsigned * na = new unsigned[x];
for (unsigned i=0;i<n;i+=1) na[i] = a[i];
delete [] a;
a = na;
z = x;
}
}
void flex_unit::set( unsigned i, unsigned x )
{
if ( i < n )
{
a[i] = x;
if (x==0) while (n && a[n-1]==0) n-=1; // normalise
}
else if ( x )
{
reserve(i+1);
for (unsigned j=n;j<i;j+=1) a[j] = 0;
a[i] = x;
n = i+1;
}
}
// Macros for doing double precision multiply
#define BPU ( 8*sizeof(unsigned) ) // Number of bits in an unsigned
#define lo(x) ( (x) & ((1<<(BPU/2))-1) ) // lower half of unsigned
#define hi(x) ( (x) >> (BPU/2) ) // upper half
#define lh(x) ( (x) << (BPU/2) ) // make upper half
void flex_unit::fast_mul( flex_unit &x, flex_unit &y, unsigned keep )
{
// *this = (x*y) % (2**keep)
unsigned i,j,limit = (keep+BPU-1)/BPU; // size of result in words
reserve(limit); for (i=0; i<limit; i+=1) a[i] = 0;
unsigned min = x.n; if (min>limit) min = limit;
for (i=0; i<min; i+=1)
{
unsigned m = x.a[i];
unsigned c = 0; // carry
unsigned min = i+y.n; if (min>limit) min = limit;
for ( j=i; j<min; j+=1 )
{
// This is the critical loop
// Machine dependent code could help here
// c:a[j] = a[j] + c + m*y.a[j-i];
unsigned w, v = a[j], p = y.a[j-i];
v += c; c = ( v < c );
w = lo(p)*lo(m); v += w; c += ( v < w );
w = lo(p)*hi(m); c += hi(w); w = lh(w); v += w; c += ( v < w );
w = hi(p)*lo(m); c += hi(w); w = lh(w); v += w; c += ( v < w );
c += hi(p) * hi(m);
a[j] = v;
}
while ( c && j<limit )
{
a[j] += c;
c = a[j] < c;
j += 1;
}
}
// eliminate unwanted bits
keep %= BPU; if (keep) a[limit-1] &= (1<<keep)-1;
// calculate n
while (limit && a[limit-1]==0) limit-=1;
n = limit;
};
vlong_value::operator unsigned()
{
return get(0);
}
int vlong_value::is_zero() const
{
return n==0;
}
int vlong_value::test( unsigned i ) const
{ return ( get(i/BPU) & (1<<(i%BPU)) ) != 0; }
unsigned vlong_value::bits() const
{
unsigned x = n*BPU;
while (x && test(x-1)==0) x -= 1;
return x;
}
int vlong_value::cf( vlong_value& x ) const
{
if ( n > x.n ) return +1;
if ( n < x.n ) return -1;
unsigned i = n;
while (i)
{
i -= 1;
if ( get(i) > x.get(i) ) return +1;
if ( get(i) < x.get(i) ) return -1;
}
return 0;
}
void vlong_value::shl()
{
unsigned carry = 0;
unsigned N = n; // necessary, since n can change
for (unsigned i=0;i<=N;i+=1)
{
unsigned u = get(i);
set(i,(u<<1)+carry);
carry = u>>(BPU-1);
}
}
void vlong_value::shr()
{
unsigned carry = 0;
unsigned i=n;
while (i)
{
i -= 1;
unsigned u = get(i);
set(i,(u>>1)+carry);
carry = u<<(BPU-1);
}
}
void vlong_value::shr( unsigned x )
{
unsigned delta = x/BPU; x %= BPU;
for (unsigned i=0;i<n;i+=1)
{
unsigned u = get(i+delta);
if (x)
{
u >>= x;
u += get(i+delta+1) << (BPU-x);
}
set(i,u);
}
}
void vlong_value::add( vlong_value & x )
{
unsigned carry = 0;
unsigned max = n; if (max<x.n) max = x.n;
reserve(max);
for (unsigned i=0;i<max+1;i+=1)
{
unsigned u = get(i);
u = u + carry; carry = ( u < carry );
unsigned ux = x.get(i);
u = u + ux; carry += ( u < ux );
set(i,u);
}
}
void vlong_value::subtract( vlong_value & x )
{
unsigned carry = 0;
unsigned N = n;
for (unsigned i=0;i<N;i+=1)
{
unsigned ux = x.get(i);
ux += carry;
if ( ux >= carry )
{
unsigned u = get(i);
unsigned nu = u - ux;
carry = nu > u;
set(i,nu);
}
}
}
void vlong_value::init( unsigned x )
{
clear();
set(0,x);
}
void vlong_value::copy( vlong_value& x )
{
clear();
unsigned i=x.n;
while (i) { i -= 1; set( i, x.get(i) ); }
}
vlong_value::vlong_value()
{
share = 0;
}
void vlong_value::mul( vlong_value& x, vlong_value& y )
{
fast_mul( x, y, x.bits()+y.bits() );
}
void vlong_value::divide( vlong_value& x, vlong_value& y, vlong_value& rem )
{
init(0);
rem.copy(x);
vlong_value m,s;
m.copy(y);
s.init(1);
while ( rem.cf(m) > 0 )
{
m.shl();
s.shl();
}
while ( rem.cf(y) >= 0 )
{
while ( rem.cf(m) < 0 )
{
m.shr();
s.shr();
}
rem.subtract( m );
add( s );
}
}
// Implementation of vlong
void vlong::load( unsigned * a, unsigned n )
{
docopy();
value->clear();
for (unsigned i=0;i<n;i+=1)
value->set(i,a[i]);
}
void vlong::store( unsigned * a, unsigned n ) const
{
for (unsigned i=0;i<n;i+=1)
a[i] = value->get(i);
}
unsigned vlong::get_nunits() const
{
return value->n;
}
unsigned vlong::bits() const
{
return value->bits();
}
void vlong::docopy()
{
if ( value->share )
{
value->share -= 1;
vlong_value * nv = new vlong_value;
nv->copy(*value);
value = nv;
}
}
int vlong::cf( const vlong x ) const
{
int neg = negative && !value->is_zero();
//int neg2 = x.negative && !x.value->is_zero();
if ( neg == (x.negative && !x.value->is_zero()) )
//if ( neg == neg2)
return value->cf( *x.value );
else if ( neg ) return -1;
else return +1;
}
vlong::vlong (unsigned x)
{
value = new vlong_value;
negative = 0;
value->init(x);
}
vlong::vlong ( const vlong& x ) // copy constructor
{
negative = x.negative;
value = x.value;
value->share += 1;
}
vlong& vlong::operator =(const vlong& x)
{
if ( value->share ) value->share -=1; else delete value;
value = x.value;
value->share += 1;
negative = x.negative;
return *this;
}
vlong::~vlong()
{
if ( value->share ) value->share -=1; else delete value;
}
vlong::operator unsigned () // conversion to unsigned
{
return *value;
}
vlong& vlong::operator +=(const vlong& x)
{
if ( negative == x.negative )
{
docopy();
value->add( *x.value );
}
else if ( value->cf( *x.value ) >= 0 )
{
docopy();
value->subtract( *x.value );
}
else
{
vlong tmp = *this;
*this = x;
*this += tmp;
}
return *this;
}
vlong& vlong::operator -=(const vlong& x)
{
if ( negative != x.negative )
{
docopy();
value->add( *x.value );
}
else if ( value->cf( *x.value ) >= 0 )
{
docopy();
value->subtract( *x.value );
}
else
{
vlong tmp = *this;
*this = x;
*this -= tmp;
negative = 1 - negative;
}
return *this;
}
vlong operator +( const vlong& x, const vlong& y )
{
vlong result = x;
result += y;
return result;
}
vlong operator -( const vlong& x, const vlong& y )
{
vlong result = x;
result -= y;
return result;
}
vlong operator *( const vlong& x, const vlong& y )
{
vlong result;
result.value->mul( *x.value, *y.value );
result.negative = x.negative ^ y.negative;
return result;
}
vlong operator /( const vlong& x, const vlong& y )
{
vlong result;
vlong_value rem;
result.value->divide( *x.value, *y.value, rem );
result.negative = x.negative ^ y.negative;
return result;
}
#if defined(__DEBUG__)
void print_vlong( const vlong_value & v, const char *name )
{
printf("%s value(%d): ", name, v.n * sizeof(unsigned int));
for(int i = 0; i < v.n; ++i)
{
printf("%08X", v.a[i]);
}
printf("\n");
}
#endif
vlong operator %( const vlong& x, const vlong& y )
{
vlong result;
vlong_value divide;
divide.divide( *x.value, *y.value, *result.value );
result.negative = x.negative; // not sure about this?
return result;
}
static vlong gcd( const vlong &X, const vlong &Y )
{
vlong x=X, y=Y;
while (1)
{
if ( y == (vlong)0 ) return x;
x = x % y;
if ( x == (vlong)0 ) return y;
y = y % x;
}
}
static vlong modinv( const vlong &a, const vlong &m ) // modular inverse
// returns i in range 1..m-1 such that i*a = 1 mod m
// a must be in range 1..m-1
{
vlong j=1,i=0,b=m,c=a,x,y;
while ( c != (vlong)0 )
{
x = b / c;
y = b - x*c;
b = c;
c = y;
y = j;
j = i - j*x;
i = y;
}
if ( i < (vlong)0 )
i += m;
return i;
}
class monty // class for montgomery modular exponentiation
{
vlong R,R1,m,n1;
vlong T,k; // work registers
unsigned N; // bits for R
void mul( vlong &x, const vlong &y );
public:
vlong exp( const vlong &x, const vlong &e );
monty( const vlong &M );
};
monty::monty( const vlong &M )
{
m = M;
N = 0; R = 1; while ( R < M ) { R += R; N += 1; }
R1 = modinv( R-m, m );
n1 = R - modinv( m, R );
}
void monty::mul( vlong &x, const vlong &y )
{
// T = x*y;
T.value->fast_mul( *x.value, *y.value, N*2 );
// k = ( T * n1 ) % R;
k.value->fast_mul( *T.value, *n1.value, N );
// x = ( T + k*m ) / R;
x.value->fast_mul( *k.value, *m.value, N*2 );
x += T;
x.value->shr( N );
if (x>=m) x -= m;
}
vlong monty::exp( const vlong &x, const vlong &e )
{
vlong result = R-m, t = ( x * R ) % m;
unsigned bits = e.value->bits();
unsigned i = 0;
while (1)
{
if ( e.value->test(i) )
{
mul( result, t);
}
i += 1;
if ( i == bits ) break;
mul( t, t );
}
return ( result * R1 ) % m;
}
static vlong modexp( const vlong & x, const vlong & e, const vlong & m )
{
monty me(m);
return me.exp( x,e );
}
// RSA.CPP -----------------------------------
vlong public_key::encrypt( const vlong& plain )
{
#if defined(__DEBUG__)
if ( plain >= m ) {
printf("ERROR: plain too big for this key\n");
}
#endif
return modexp( plain, e, m );
}
vlong private_key::decrypt( const vlong& cipher )
{
// Calculate values for performing decryption
// These could be cached, but the calculation is quite fast
vlong d = modinv( e, (p-(vlong)1)*(q-(vlong)1) );
vlong u = modinv( p, q );
vlong dp = d % (p-(vlong)1);
vlong dq = d % (q-(vlong)1);
// Apply chinese remainder theorem
vlong a = modexp( cipher % p, dp, p );
vlong b = modexp( cipher % q, dq, q );
if ( b < a ) b += q;
return a + p * ( ((b-a)*u) % q );
}
void vlong_pair_2_str (char *me_str,vlong &m,vlong &e)
{
const char *hex_str = "0123456789ABCDEF";
char tmp_str[MAX_CRYPT_BITS/2+1];
unsigned int x;
unsigned int me_len = 0;
unsigned int i;
unsigned int j;
vlong m1 = m;
vlong e1 = e;
vlong zero = 0;
i = 0;
while (m1 != zero)
{
x = m1 % (vlong) 16;
m1 = m1 / (vlong) 16;
tmp_str[i++] = hex_str[x];
}
for (j=0; j < i; j++)
me_str[me_len++] = tmp_str[i-1-j];
me_str[me_len++] = '#';
i = 0;
while (e1 != zero)
{
x = e1 %(vlong)16;
e1 = e1 / (vlong)16;
tmp_str[i++] = hex_str[x];
}
for (j=0; j < i; j++)
me_str[me_len++] = tmp_str[i-1-j];
me_str[me_len] = 0;
}
void str_2_vlong_pair (const char *me_str,vlong &m,vlong &e)
{
int i;
int dash_pos = -1;
m = 0;
e = 0;
int me_len = (int)strlen (me_str);
for (i = me_len-1; i>=0; i--)
if (me_str[i] == '#')
{
dash_pos = i;
break;
}
if (dash_pos == 0)
throw "Bad key: dash (`#') found at bad position";
if (dash_pos == 1)
throw "Bad key: no dash (`#') found ";
for (i = 0; i<dash_pos; i++)
{
m = m * (vlong)16;
if (me_str[i] > '9')
m = m + (vlong) (me_str[i]-'A'+10);
else
m = m + (vlong) (me_str[i]-'0');
}
if (m == vlong(0))
throw "Bad key: bad value before `#'";
for (i = dash_pos+1; i<me_len; i++)
{
e = e * (vlong)16;
if (me_str[i] > '9')
e = e + (vlong) (me_str[i]-'A'+10);
else
e = e + (vlong) (me_str[i]-'0');
}
if (e == vlong(0))
throw "Bad key: no value after `#'";
}
void private_key::MakeMeStr(char * me_str)
{
vlong_pair_2_str (me_str,m,e);
}
void private_key::MakePqStr(char * me_str)
{
vlong_pair_2_str (me_str,p,q);
}
void private_key::MakePq (const char *me_str)
{
str_2_vlong_pair (me_str,p,q);
{
m = p*q;
e = 50001; // must be odd since p-1 and q-1 are even
while ( gcd(p-(vlong)1,e) != (vlong)1 || gcd(q-(vlong)1,e) != (vlong)1 ) e += 2;
}
}
void public_key::MakeMe(const char *me_str)
{
str_2_vlong_pair (me_str,m,e);
}
CCryptoProviderRSA::CCryptoProviderRSA()
{
}
CCryptoProviderRSA::~CCryptoProviderRSA()
{
}
void inline _rmemcpy (char *dst,const char *src,size_t size)
{
src += size;
while (size--)
*dst++ = *(--src);
}
void CCryptoProviderRSA::GetBlockSize(int &enbs, int &debs)
{
enbs=0;
debs=0;
}
void CCryptoProviderRSA::EncryptPortion(const char *pt, size_t pt_size, char *ct, size_t &ct_size)
{
vlong plain, cipher;
const size_t bytes_per_unit = BPU / 8;
size_t padding = (pt_size & 3) ? (4 - (pt_size & 3)) : 0;
char tmp[MAX_CRYPT_BITS/4];
// ensure big-endianness
_rmemcpy(tmp, pt, pt_size);
memset(tmp + pt_size, 0, padding);
plain.load((unsigned int*)tmp, (int)(pt_size+padding) / bytes_per_unit);
cipher = prkface.encrypt(plain);
ct_size = cipher.get_nunits() * bytes_per_unit;
// ensure big-endianness
cipher.store((unsigned int*)tmp, (int)ct_size / bytes_per_unit);
_rmemcpy(ct, tmp, ct_size);
}
void CCryptoProviderRSA::DecryptPortion(const char *ct, size_t ct_size, char *pt, size_t &pt_size)
{
vlong plain, cipher;
const size_t bytes_per_unit = BPU / 8;
char tmp[MAX_CRYPT_BITS/4];
// ensure big-endianness
_rmemcpy(tmp, ct, ct_size);
cipher.load((unsigned int*)tmp, (int)ct_size / bytes_per_unit);
plain = prkface.decrypt(cipher);
// ensure big-endianness
plain.store((unsigned int*)tmp, plain.get_nunits());
_rmemcpy(pt, tmp, pt_size);
}
void CCryptoProviderRSA::ImportPublicKey(const char *pk)
{
prkface.MakeMe(pk);
}
void CCryptoProviderRSA::ImportPrivateKey(const char *pk)
{
prkface.MakePq(pk);
}
void CCryptoProviderRSA::ExportPublicKey(char *pk)
{
prkface.MakeMeStr(pk);
}
void CCryptoProviderRSA::ExportPrivateKey(char *pk)
{
prkface.MakePqStr(pk);
}
#if defined(__DEBUG__)
void printbuf(const char * buf, int size)
{
for(const char * p = buf; p < buf + size; ++p)
{
printf("%02X", *p & 0x000000ff);
}
printf("\n");
}
#endif
void CCryptoProviderRSA::Encrypt(const char *inbuf, size_t in_size,char *outbuf, size_t &out_size)
{
size_t i,cp_size;
char portbuf[MAX_CRYPT_BITS/8];
char cpbuf[MAX_CRYPT_BITS/4];
const char *inp = inbuf;
unsigned short lm;
// must ensure that any data block would be < key's modulus
// hence -1
unsigned int portion_len = (prkface.m.bits() - 1) / 8;
char prev_crypted[portion_len];
memset(&prev_crypted, 0, portion_len);
out_size = 0;
while(in_size)
{
size_t cur_size = in_size > portion_len ? portion_len : in_size;
for (i=0; i<cur_size; i++)
portbuf[i] = inp[i] ^ prev_crypted[i];
EncryptPortion(portbuf, cur_size, cpbuf, cp_size);
for (i=0; i<portion_len; i++)
prev_crypted[i] = i < cp_size ? cpbuf[i] : 0;
lm=cur_size;
memcpy (outbuf+out_size,&lm, sizeof(unsigned short)); out_size+=sizeof (unsigned short);
lm=(unsigned short)cp_size;
memcpy (outbuf+out_size,&lm, sizeof(unsigned short)); out_size+=sizeof (unsigned short);
memcpy (outbuf+out_size,cpbuf, cp_size); out_size+=cp_size;
inp+=cur_size;
in_size-=cur_size;
}
return;
}
void CCryptoProviderRSA::Decrypt(const char *inbuf, size_t in_size,char *outbuf, size_t &out_size)
{
size_t i, cp_size,pt_size;
char portbuf[MAX_CRYPT_BITS/8];
char cpbuf[MAX_CRYPT_BITS/4];
unsigned short lmi, lmo;
// must ensure that any data block would be < key's modulus
// hence -1
int portion_len = (prkface.m.bits() - 1) / 8;
char prev_crypted[portion_len];
memset(&prev_crypted, 0, portion_len);
const char *inp=inbuf;
out_size = 0;
while(in_size)
{
memcpy (&lmi,inp,sizeof (unsigned short)); inp += sizeof(unsigned short); in_size -= sizeof(unsigned short);
memcpy (&lmo,inp,sizeof (unsigned short)); inp += sizeof(unsigned short); in_size -= sizeof(unsigned short);
if (lmo>in_size)
break;
memcpy (cpbuf,inp,lmo);
cp_size = lmo;
pt_size = lmi;
DecryptPortion(cpbuf, cp_size, portbuf, pt_size);
if (lmi>pt_size)
lmi=(unsigned short)pt_size;
for (i=0; i<lmi; i++)
portbuf[i] ^= prev_crypted[i];
for (i=0; i<portion_len; i++)
prev_crypted[i] = i < cp_size ? cpbuf[i] : 0;
memcpy (outbuf+out_size,portbuf,lmi);
out_size += lmi;
inp+=lmo;
in_size-=lmo;
}
return;
}