#! /usr/bin/python
#
# (c) 2005 Samuel Tardieu <sam@rfc1149.net>
# Released under the GNU General Public License version 2
#
# This program must be launched shortly after midnight every day (or
# at least before 10PM).
#
# Version: 1.2
#
# Arguments:
#   - name of your access file
#   - regular expression describing your feed name (according to your
#     RewriteBase)
#   - maximum number of times you allow a client to fetch the RSS file
#     every day
#   - the .htaccess file to modify
#
# Example: rsscache.py /home/log/apache/access.log '^/blog/feed'
#                      96 /home/sam/blog/.htaccess
# will limit clients to an average access every 15 minutes. To be on the
# safe side, you may want to use 100 instead of 96.
#
# The .htaccess file must have the form:
#
# ...
#   # rssabuse section
#   [automatically generated part]
#   RewriteRule ^(.*)$ http://my.rsscache.com/mywebsite/mypath/$1 [R,L]
#
# For example:
#
#  RewriteEngine on
#  RewriteBase /blog
#  # rssabuse section
#  RewriteCond %{REMOTE_ADDR} 0.0.0.0  [replaced later by this script]
#  RewriteRule ^(feed.*)$ http://my.rsscache.com/www.rfc1149.net/blog/$1 [R,L]
#
# All the RewriteCond directives following the '#rssabuse section' will
# be replaced by appropriate REMOTE_ADDR or REMOTE_HOST matches and
# OR-ed together.
#
# Note: the locale must match the one used when generating the logs (or
# the abbreviated month names will be incorrect).

import os, sre, sys, tempfile, time

def date_string_yesterday ():
    """Return yesterday's date in Apache format"""
    yesterday = time.localtime (time.time() - 79200)
    return time.strftime ('%d/%b/%Y', yesterday)

def count_feed_access (fd, regexp, date):
    """Return a dictionary with hosts/count per access matching regexp"""
    r = sre.compile (regexp)
    hosts = {}
    for l in fd:
        (host, _, _, when, _, _, uri, _) = l.split (' ', 7)
        if when[1:].split(':', 1)[0] == date and r.match (uri):
            if hosts.has_key (host): hosts[host] += 1
            else: hosts[host] = 1
    return hosts

def look_for_abusers (hosts, max):
    """Given a dictionary returned by count_feed_access, return a list
    of hosts accessing the protected resource more than max times a day."""
    return [h for h, n in hosts.items() if n > max]

_ipv4 = sre.compile ('^[\d\.]+$')
_ipv6 = sre.compile ('^[\da-fA-F:]+$')

def is_ip_address (host):
    """Check whether a given host is an IP address"""
    return _ipv4.search (host) or _ipv6.search (host)

def rewritecond_line (host):
    """Generate the appropriate RewriteCond line for host"""
    if is_ip_address (host): return 'RewriteCond %%{REMOTE_ADDR} =%s' % host
    else: return 'RewriteCond %%{REMOTE_HOST} =%s' % host
    
def rewritecond_string (abusers):
    """Generate the string to insert in the file with all the RewriteCond
    lines corresponding to the abusers. If there are no abusers, return
    a line that will never match anything to prevent systematic
    redirection. The string is not terminated by the final \n"""
    if not abusers: return 'RewriteCond %{REMOTE_ADDR} 0.0.0.0'
    return ' [OR]\n'.join (['%s' % rewritecond_line (h) for h in abusers])

_marker = '# rssabuse section'
_markerlen = len (_marker)

_rewritecond = 'RewriteCond '
_rewritecondlen = len (_rewritecond)

def generate_file (fdin, fdout, abusers):
    """Generate a new .htaccess file by replacing the old section by the
    new one."""
    for l in fdin:       # Copy everything up-to (and including) the marker
        fdout.write (l)
        if l[:_markerlen] == _marker: break
    else: return         # No marker found
    for l in fdin:       # Skip existing RewriteCond directives
        if l[:_rewritecondlen] != _rewritecond: break
    else: return         # No RewriteRule? -- do not break existing file
    fdout.write (rewritecond_string (abusers) + '\n')
    fdout.write (l)      # Write already read RewriteRule line
    for l in fdin:       # Copy rest of file
        fdout.write (l)

def rename_safely (oldfd, new):
    """Try to atomically rename old into new"""
    try:
        os.rename (oldfd.name, new)
        return
    except:              # Possible cross-device link
        pass
    try:                 # Try to copy the file in the same dir as the target
        localname = '%s.temp%d' % (new, os.getpid())
        oldfd.seek (0)
        open(localname, 'wb').write(oldfd.read())
        os.rename (localname, new)
        return
    except:              # No right to create file, will overwrite...
        pass
    oldfd.seek (0)
    open(new, 'wb').write(oldfd.read())

def stats (abusers, hosts, date):
    """Output the list of abusers"""
    if not abusers:
        print "rsscache.py: no abusers on %s" % date
        return
    print "rsscache.py: current abusers for %s (now blocked):" % date
    for a in abusers: print "   %s (%d access)" % (a, hosts[a])
    
if __name__ == '__main__':
    access_file, regexp, max, htaccess = sys.argv[1:5]
    max = int (max)
    yesterday = date_string_yesterday ()
    hosts = count_feed_access (open (access_file), regexp, yesterday)
    abusers = look_for_abusers (hosts, max)
    abusers.sort (lambda x, y: cmp (hosts[x], hosts[y]))
    fdout = tempfile.NamedTemporaryFile ()
    generate_file (open (htaccess), fdout, abusers)
    fdout.flush ()
    rename_safely (fdout, htaccess)
    stats (abusers, hosts, yesterday)
