# 
# keyderiv.rb
# 
# Copyright 2005-2006 Tero Hasu
#
# Licensed under the same terms as Ruby itself.
# 
# This is an implementation of a key derivation function, for deriving
# fixed-size symmetric keys from passwords.
# 
# This code is based on PKCS #5 v2.0: Password-Based Cryptography
# Standard. We chose to implement the PBKDF2 key derivation function,
# which was recommended for new applications.
# 
# When encrypting data, one has to create a random salt, and store the
# salt so that the same key can be derived again for decryption. The
# salt may be generated entirely at random, expect when the use case
# requires something different (one might be concerned about
# interactions between multiple uses of the same key, as stated in the
# specification).
# 
# Note that this library depends on the OpenSSL Ruby bindings; they
# are included in the standard distribution, I believe. At least
# Debian has them.
# 

require 'openssl'

module KeyDeriv
  VERSION = [1, 1]

  module_function

  # A random function, implemented by making an OpenSSL API call.
  # Can be used e.g. to generate salt.
  # size:: Required size of output (in octets).
  def openssl_random(size)
    OpenSSL::Random.random_bytes(size)
  end

  # A fairly random function, implemented by reading data from
  # /dev/random. Can be used e.g. to generate salt, but of course
  # only on platforms that have the /dev/random device.
  # size:: Required size of output (in octets).
  def dev_random(size)
    # Note that Kernel::rand is apparantly implemented simply by
    # calling the rand system call (see rand(3)). On Linux, the rand
    # system call is apparently implemented in the same way as the
    # random system call (see random(3)), but on other systems rand
    # may return less random results. Whats more, Linux has the
    # /dev/random device, which apparently yields result more random
    # than those returned by the random system call, so we should
    # naturally favor that. Note that /dev/random will block when
    # the pool of random data is empty, which could be a problem if
    # we must derive lots of keys quickly. Using /dev/urandom
    # would be a non-blocking alternative.
    result = ''
    File.open('/dev/random') do |input|
      while result.size < size
        data = input.read(size - result.size)
        raise unless data
        result << data
      end
    end
    return result
  end

  # A pseudorandom function. Always yields the same result when
  # hashing the same data with the same key. Returns an hlen-octet
  # string.
  def hmac_sha_1(key, data)
    hmac = OpenSSL::HMAC::new(key, OpenSSL::Digest::SHA1::new)
    hmac.update(data)
    # HMAC supports "hexdigest" also.
    hmac.digest()
  end

  # Returns the length of the output of "hmac_sha_1".
  def hmac_sha_1_len
    20
  end

  # A pseudorandom function. Always yields the same result when
  # given the same arguments. Returns an hlen-octet string.
  # key:: The HMAC "key".
  # data:: The HMAC "data".
  def prf(key, data)
    # Ruby does have a facility for generating pseudorandom numbers,
    # but the problem with Kernel#srand is that it sets the seed in
    # a runtime-wide manner, which is we do not want. Also, we
    # ideally want an algorithm that takes the key and the data as
    # separate parameters, instead of us say concatenating them.
    # This is why we are using HMAC-SHA-1. The OpenSSL
    # "pkcs5_keyivgen" method uses an MD5 HMAC, I believe.
    hmac_sha_1(key, data)
  end

  # Returns the length (in octets) of "prf" output.
  def hlen
    hmac_sha_1_len
  end

  # The first c iterates of the preudorandom function prf applied to
  # the password p and the concatenation of the salt s and the block
  # index i.
  def f(p, s, c, i)
    data = s + encode_int32(i)
    uxor = "\000" * hlen
    for j in 0...c
      data = prf(p, data)
      uxor = string_xor(uxor, data)
    end
    return uxor
  end

  # Does a pairwise exclusive OR of the individual bytes of the
  # passed strings (each pair of bytes at a particular index),
  # returning the result as a string. The passed strings must be of
  # the same length.
  def string_xor(s1, s2)
    raise unless (s1.size == s2.size)
    len = s1.size
    fixnums = []
    for i in 0...len
      fixnums.push(s1[i] ^ s2[i])
    end
    fixnums.pack("c*")
  end

  # Returns the four-octet encoding of the specified (assumed to be
  # 32-bit) integer, most significant octet first.
  def encode_int32(i)
    [i].pack('N')
  end

  # A PBKDF2 key derivation function. Returns the derived key as an
  # octet string.
  # p:: Password.
  # s:: Salt. Should be at least 8 octets long, according to the
  #     standard.
  # c:: Iteration count. A minimum of 1000 iterations is recommended
  #     by the standard.
  # dklen:: Intended length (in octets) of the key to be derived.
  def derive_key(p, s, c, dklen)
    ## Step 1.
    if dklen > ((2**32 - 1) * hlen)
      raise "derived key too long"
    end

    ## Step 2.
    # Number of hlen-octet blocks in the derived key.
    l = (dklen.to_f / hlen).ceil
    # Number of octets in the last block.
    r = dklen - (l - 1) * hlen

    ## Step 3.
    # Compute the key blocks. We need no more than l blocks.
    t = []
    for i in 1..l
      t.push(f(p,s,c,i))
    end

    ## Step 4.
    # Concatenate the key blocks. We take all of the first
    # (l - 1) blocks, and r bytes of the last block
    t[-1] = t[-1][0,r]
    dk = t.join('')

    ## Step 5.
    # Output the derived key, which is a dklen-octet string.
    return dk
  end
end

if $0 == __FILE__
  include KeyDeriv

  def testrand(desc, func, size)
    puts "%d bytes of random data from %s: %s" %
      [size, desc, func.call(size).inspect]
  end
  testrand("OpenSSL-based random function",
           proc {|x| openssl_random(x)}, 8)
  testrand("/dev/random-based random function",
           proc {|x| dev_random(x)}, 8)

  def testenc(i)
    puts "encoded 32-bit integer #{i} is #{encode_int32(i).inspect}"
  end
  testenc(0)
  testenc(1)
  testenc(2**8)
  testenc(2**16)
  testenc(2**24)
  testenc(2**24 + 3)
  testenc(2**32 - 1)

  def hmac_test(desc, func, key, data)
    digest = func.call(key, data)
    puts "%d-byte HMAC of data %s with key %s using %s is %s" %
      [digest.size, data.inspect, key.inspect, desc, digest.inspect]
  end
  hmac_test("HMAC-SHA-1", proc {|x,y| hmac_sha_1(x,y)},
            'secret key', 'secret data')

  def to_hex s
    (0...(s.size)).map do |i|
      ("%02x" % s[i])
    end.join(" ")
  end

  ## PBKDF2 HMAC-SHA1 test vectors.
  ## Grabbed these from somewhere on the net, forget where.
  def dertest(p, s, c, l, expected)
    key = derive_key(p, s, c, l)
    hex = to_hex(key)
    unless (hex == expected)
      raise "unexpected output: was #{hex}, expected #{expected}"
    end
    puts "PBKDF2 HMAC-SHA1 derived %s-octet key from %s with %d rounds and salt %s is %s" %
      [l, p.inspect, c, s.inspect, key.inspect]
  end
  dertest("password", "ATHENA.MIT.EDUraeburn", 1, 16,
          "cd ed b5 28 1b b2 f8 01 56 5a 11 22 b2 56 35 15")
  dertest("password", "ATHENA.MIT.EDUraeburn", 1, 32,
          %w{cd ed b5 28 1b b2 f8 01 56 5a 11 22 b2 56 35 15
             0a d1 f7 a0 4b b9 f3 a3 33 ec c0 e2 e1 f7 08 37}.join(" "))
  dertest("password", "ATHENA.MIT.EDUraeburn", 2, 16,
          '01 db ee 7f 4a 9e 24 3e 98 8b 62 c7 3c da 93 5d')
  dertest("password", "ATHENA.MIT.EDUraeburn", 2, 32,
          %w{01 db ee 7f 4a 9e 24 3e 98 8b 62 c7 3c da 93 5d
             a0 53 78 b9 32 44 ec 8f 48 a9 9e 61 ad 79 9d 86}.join(" "))
  dertest("password", "ATHENA.MIT.EDUraeburn", 1200, 16,
          '5c 08 eb 61 fd f7 1e 4e 4e c3 cf 6b a1 f5 51 2b')
  dertest("password", "ATHENA.MIT.EDUraeburn", 1200, 32,
          '5c 08 eb 61 fd f7 1e 4e 4e c3 cf 6b a1 f5 51 2b a7 e5 2d db c5 e5 14 2f 70 8a 31 e2 e6 2b 1e 13')
  dertest("password",
          [0x12, 0x34, 0x56, 0x78, 0x78, 0x56, 0x34, 0x12].pack("c*"),
          5, 16,
          'd1 da a7 86 15 f2 87 e6 a1 c8 b1 20 d7 06 2a 49')
  dertest("password",
          [0x12, 0x34, 0x56, 0x78, 0x78, 0x56, 0x34, 0x12].pack("c*"),
          5, 32,
          'd1 da a7 86 15 f2 87 e6 a1 c8 b1 20 d7 06 2a 49 3f 98 d2 03 e6 be 49 a6 ad f4 fa 57 4b 6e 64 ee')
  dertest("X" * 64, "pass phrase equals block size", 1200, 16,
          '13 9c 30 c0 96 6b c3 2b a5 5f db f2 12 53 0a c9')
  dertest("X" * 64, "pass phrase equals block size", 1200, 32,
          '13 9c 30 c0 96 6b c3 2b a5 5f db f2 12 53 0a c9 c5 ec 59 f1 a4 52 f5 cc 9a d9 40 fe a0 59 8e d1')
  dertest("X" * 65, "pass phrase exceeds block size", 1200, 16,
          '9c ca d6 d4 68 77 0c d5 1b 10 e6 a6 87 21 be 61')
  dertest("X" * 65, "pass phrase exceeds block size", 1200, 32,
          '9c ca d6 d4 68 77 0c d5 1b 10 e6 a6 87 21 be 61 1a 8b 4d 28 26 01 db 3b 36 be 92 46 91 5e c8 2a')
end
