Esempio n. 1
0
def main():
    utils.ensure_exists(output_directory)

    date_of_last_cached_leaderboard = get_date_of_last_cached_leaderboard()
    log.info('date of the last cached leaderboard is %s', date_of_last_cached_leaderboard)

    date_of_last_goko_leaderboard = datetime.date.today()

    one_day_delta = datetime.timedelta(1)
    date = date_of_last_cached_leaderboard + one_day_delta

    while date <= datetime.date.today():
        log.info('Processing %s', date)

        if date == datetime.date.today():
            log.info('scraping from goko')
            status = run_scrape_function_with_retries(scrape_leaderboard_from_goko, date)
        else:
            log.info('scraping from councilroom')
            status = run_scrape_function_with_retries(scrape_leaderboard_from_councilroom, date)

            if status != 200 and date <= datetime.date(2013,01,01):
                log.info('scraping from bggdl')
                status = run_scrape_function_with_retries(scrape_leaderboard_from_bggdl, date)

        if status == 200:
            pass
        elif status == 404:
            log.warning('file not found, so we will assume that it does not exist, and go to the next day')
        else:
            log.warning('Unexpected status of %d, please try again later', status)
            break

        date += one_day_delta
Esempio n. 2
0
def run_directory(index):
    directory = '{sim_dir}/{index}'.format(
        sim_dir=simulation_store_directory(), index=index)

    utils.ensure_exists(directory)

    return directory
Esempio n. 3
0
def scrape_games():
    parser = utils.incremental_date_range_cmd_line_parser()
    utils.ensure_exists('static/scrape_data')
    os.chdir('static/scrape_data')

    args = parser.parse_args()
    last_month = ''

    for cur_date in utils.daterange(datetime.date(2010, 10, 15), 
                                    datetime.date.today()):
        str_date = time.strftime("%Y%m%d", cur_date.timetuple())
        if not utils.includes_day(args, str_date):
            if DEBUG:
                print 'skipping', str_date, 'because not in cmd line arg daterange'
            continue
        mon = time.strftime("%b%y", cur_date.timetuple())
        if mon != last_month:
            print
            print mon, cur_date.day*"  ",
            sys.stdout.flush()
            last_month = mon
        ret = scrape_date(str_date, cur_date, passive=args.passive)
        if ret==DOWNLOADED:
            print 'o',
        elif ret==REPACKAGED:
            print 'O',
        elif ret==ERROR:
            print '!',
        elif ret==MISSING:
            print '_',
        else:
            print '.',
        sys.stdout.flush()
    print
    os.chdir('../..')
Esempio n. 4
0
    def get_structure(self, pid, dst=None, parser=None):
        dirname = '%s/%s' % (self.pdb_dir, pid)

        if not dst:
            dst = '%s/%s' % (dirname, pid)
        if not os.path.exists(dst) or os.path.getsize(dst) < 100:
            ensure_exists(dirname)
            download_pdb(pid, dst)

        if not parser:
            parser = PDBParser(QUIET=True)

        def creator(parser=parser):
            try:
                ret = parser.get_structure(pid, file=dst)
            except ValueError as e:  # assume it's a .cif
                if PARSE_CIF:
                    parser = MMCIFParser(QUIET=True)
                    ret = parser.get_structure(pid, dst)
                else:
                    raise e
            finally:
                self.freemem()
            return ret

        if self.cache:  # warn: Leaky Code!
            return self.cache.get(key=pid, createfunc=creator)
        elif os.path.getsize(dst) > consts.PDB_SIZE_LIMIT:
            raise ValueError('file size exceeds %s' % consts.PDB_SIZE_LIMIT)
        else:
            return creator()
Esempio n. 5
0
    def train(self, x, y, epochs, batch_size=128, sample_interval=50, sample_path="samples/unknown"):
        """
        Trains the GAN.
        :param x: The training data.
        :param y: The labels for the training data.
        :param epochs: The number of epochs to train.
        :param batch_size: The size of an epoch.
        :param sample_interval: How often to save sample images.
        :param sample_path: Where to save sample images.
        """
        ensure_exists(sample_path)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, x.shape[0], half_batch)
            imgs = x[idx]
            labels = y[idx]

            # Generate a half batch of new images
            noise = np.random.normal(0, 1, (half_batch, self.noise_size))
            gen_imgs = self.generator.predict({"noise": noise, "label": labels})

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch({"image": imgs, "label": labels}, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch({"image": gen_imgs, "label": labels},
                                                            np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator on random labels
            noise = np.random.normal(0, 1, (batch_size, self.noise_size))
            idx = np.random.randint(0, x.shape[0], batch_size)
            labels = y[idx]

            g_loss = self.combined.train_on_batch({"noise": noise, "label": labels}, valid_y)

            # Plot the progress
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if sample_interval and epoch % sample_interval == 0:
                self.save_sample(epoch, sample_path)
Esempio n. 6
0
    def save_summary(self, path):
        path = path.rstrip('/')
        ensure_exists(path)
        with open(f'{path}/summary.txt', 'w') as f:
            def write_to_summary_file(text):
                f.write(f"{text}\n")

            self.generator.summary(print_fn=write_to_summary_file)
            self.discriminator.summary(print_fn=write_to_summary_file)
        tf.keras.utils.plot_model(self.discriminator, to_file=f"{path}/d.png")
        tf.keras.utils.plot_model(self.generator, to_file=f"{path}/g.png")
Esempio n. 7
0
 def save(self, path):
     """Saves the GAN to a folder."""
     path = path.rstrip('/')
     ensure_exists(path)
     with open(f"{path}/config.json", 'w') as f:
         json.dump({"k": float(self.k.numpy())}, f)
     self.generator.save_weights(f"{path}/g.h5", save_format='h5')
     self.discriminator.save_weights(f"{path}/d.h5", save_format='h5')
     try:
         self.save_summary(path)
     except:
         pass
Esempio n. 8
0
def main(args, log):
    BEEN_PARSED_KEY = 'day_analyzed'

    if args.incremental:
        log.info("Performing incremental parsing from %s to %s",
                 args.startdate, args.enddate)
    else:
        log.info("Performing non-incremental (re)parsing from %s to %s",
                 args.startdate, args.enddate)

    connection = pymongo.MongoClient()
    db = connection.test  # RT: changed.
    raw_games = db.raw_games
    raw_games.ensure_index('game_date')

    utils.ensure_exists('parsed_out')

    day_status_col = db.day_status
    days = day_status_col.find({'raw_games_loaded': True})

    for day in days:
        year_month_day = day['_id']

        if not utils.includes_day(args, year_month_day):
            log.debug(
                "Raw games for %s available in the database but not in date range, skipping",
                year_month_day)
            continue

        if BEEN_PARSED_KEY not in day:
            day[BEEN_PARSED_KEY] = False
            day_status_col.save(day)

        if day[BEEN_PARSED_KEY] and args.incremental:
            log.debug(
                "Raw games for %s have been parsed, and we're running incrementally, skipping",
                year_month_day)
            continue

        try:
            log.info("Parsing %s", year_month_day)
            convert_to_json(log, raw_games, year_month_day)
            continue
            day[BEEN_PARSED_KEY] = True
            day_status_col.save(day)
        except ParseTurnHeaderError as e:
            log.error("ParseTurnHeaderError occurred while parsing %s: %s",
                      year_month_day, e)
            return
        except Exception as e:
            log.error("Exception occurred while parsing %s: %s",
                      year_month_day, e)
            return
Esempio n. 9
0
def read_usage():
    """
    Read the usage information from ~/cluster-load/info. Information is expected to
    be updated through another process.
    The function creates ~/cluster-load/info if it doesn't exists.
    """

    # Because the file transfer might change files as they're being read, this could fail
    # Retry until success or num_retries

    num_retries = 10

    info_path = os.path.expanduser('~/cluster-load/info/')
    utils.ensure_exists(info_path)

    for retries in range(num_retries):
        try:
            cmd = 'cat {path}* 2>/dev/null'.format(path=info_path)
            output_lines = subprocess.check_output(
                cmd, shell=True).decode('utf8').split('\n')

            num_tries = retries + 1

            if num_tries > 1:
                print('warning: cpu usage reading attempt success at try {x}'.
                      format(x=num_tries))
            break

        except subprocess.CalledProcessError as e:
            if (retries == num_retries - 1):
                raise (e)
            else:
                continue

    line_parts = [
        line.split() for line in output_lines if line[:7] == '10.0.0.'
    ]

    cpu_usage = {
        line[0]: string_to_rounded_int(line[1])
        for line in line_parts
    }

    temp = {
        line[0]: int(float(line[2]) * 100 / settings.pi_max_temp)
        for line in line_parts
    }

    return cpu_usage, temp
Esempio n. 10
0
def dump_chpt(eval_batcher,
              hps,
              model,
              sess,
              saver,
              eval_loss_best,
              early_stop=False):
    dump_model = False
    # Run evals on development set and print their perplexity.
    previous_losses = [eval_loss_best]
    eval_losses = []
    eval_accuracies = []
    stop_flag = False
    while True:
        batch = eval_batcher.next_batch()
        if not batch[0]:
            eval_batcher.reset()
            break
        eval_inputs, eval_conditions, eval_targets = \
            data.prepare_dis_pretraining_batch(batch)
        eval_inputs = np.split(eval_inputs, 2)[0]
        eval_conditions = np.split(eval_conditions, 2)[0]
        eval_targets = np.split(eval_targets, 2)[0]
        eval_results = model.run_one_batch(sess,
                                           eval_inputs,
                                           eval_conditions,
                                           eval_targets,
                                           update=False)
        eval_losses.append(eval_results["loss"])
        eval_accuracies.append(eval_results["accuracy"])

    eval_loss = sum(eval_losses) / len(eval_losses)
    eval_accuracy = sum(eval_accuracies) / len(eval_accuracies)
    previous_losses.append(eval_loss)
    sys.stdout.flush()
    threshold = 10
    if eval_loss > 0.99 * previous_losses[-2]:
        sess.run(
            model.learning_rate.assign(
                tf.maximum(
                    hps.learning_rate_decay_factor * model.learning_rate,
                    1e-4)))
    if len(previous_losses) > threshold and \
            eval_loss > max(previous_losses[-threshold-1:-1]) and \
            eval_loss_best < min(previous_losses[-threshold:]):
        if early_stop:
            stop_flag = True
        else:
            stop_flag = False
            print("Proper time to stop...")
    if eval_loss < eval_loss_best:
        dump_model = True
        eval_loss_best = eval_loss
    # Save checkpoint and zero timer and loss.
    if dump_model:
        checkpoint_path = ensure_exists(
            join_path(hps.model_dir, "discriminator")) + "/model.ckpt"
        saver.save(sess, checkpoint_path, global_step=model.global_step)
        print("Saving the checkpoint to %s" % checkpoint_path)
    return eval_accuracy, eval_loss, stop_flag, eval_loss_best
def main():
    utils.ensure_exists(output_directory)

    one_day_delta = datetime.timedelta(1)
    date = get_date_of_last_cached_leaderboard() + one_day_delta
    date_of_current_isotropic_leaderboard = get_date_of_current_isotropic_leaderboard()
    success = True

    while success and date <= date_of_current_isotropic_leaderboard:
        print date

        if date == date_of_current_isotropic_leaderboard:
            success = scrape_leaderboard_from_isotropic(date)
        else:
            success = scrape_leaderboard_from_online_cache(date)

        date += one_day_delta
Esempio n. 12
0
def main(args, log):
    BEEN_PARSED_KEY = 'day_analyzed'

    if args.incremental:
        log.info("Performing incremental parsing from %s to %s", args.startdate, args.enddate)
    else:
        log.info("Performing non-incremental (re)parsing from %s to %s", args.startdate, args.enddate)

    connection = pymongo.Connection()
    db = connection.test
    raw_games = db.raw_games
    raw_games.ensure_index('game_date')

    utils.ensure_exists('parsed_out')

    day_status_col = db.day_status
    days = day_status_col.find({'raw_games_loaded': True})

    for day in days:
        year_month_day = day['_id']

        if not utils.includes_day(args, year_month_day):
            log.debug("Raw games for %s available in the database but not in date range, skipping", year_month_day)
            continue

        if BEEN_PARSED_KEY not in day:
            day[BEEN_PARSED_KEY] = False
            day_status_col.save(day)

        if day[BEEN_PARSED_KEY] and args.incremental:
            log.debug("Raw games for %s have been parsed, and we're running incrementally, skipping", year_month_day)
            continue

        try:
            log.info("Parsing %s", year_month_day)
            convert_to_json(log, raw_games, year_month_day)
            continue
            day[BEEN_PARSED_KEY] = True
            day_status_col.save(day)
        except ParseTurnHeaderError, e:
            log.error("ParseTurnHeaderError occurred while parsing %s: %s", year_month_day, e)
            return
        except Exception, e:
            log.error("Exception occurred while parsing %s: %s", year_month_day, e)
            return
Esempio n. 13
0
 def save(self, path):
     """Saves the GAN to a folder."""
     path = path.rstrip('/')
     ensure_exists(path)
     config = {
         "rows": self.img_rows,
         "cols": self.img_cols,
         "chans": self.channels,
         "noise_size": self.noise_size
     }
     with open(f"{path}/config.json", 'w') as f:
         json.dump(config, f)
     self.generator.save(f"{path}/g.h5")
     self.discriminator.save(f"{path}/d.h5")
     try:
         self.save_summary(path)
     except:
         pass
Esempio n. 14
0
 def init_files(self, filename):
     data_dir = self.pdb_dir
     with open(filename) as f:
         ids = f.read().split(',')
     if not os.path.exists(data_dir):
         os.mkdir(data_dir)
     for pid in ids:
         pid_dir = "%s/%s" % (data_dir, pid)
         dst = '%s/%s' % (pid_dir, pid)
         old_dst = '%s/%s.pdb' % (pid_dir, pid)
         ensure_exists(pid_dir)
         if os.path.exists(old_dst):
             print("moving %s->%s" % (old_dst, dst))
             shutil.move(old_dst, dst)
         if not os.path.exists(dst) or os.path.getsize(dst) < 100:
             download_pdb(pid, dst)
         else:
             print("found %s\t%10s Bytes" % (dst, os.path.getsize(dst)))
Esempio n. 15
0
def main():
    utils.ensure_exists(output_directory)

    date_of_last_cached_leaderboard = get_date_of_last_cached_leaderboard()
    print 'date of the last cached leaderboard is', date_of_last_cached_leaderboard

    date_of_current_isotropic_leaderboard = get_date_of_current_isotropic_leaderboard()
    if date_of_current_isotropic_leaderboard is None:
        print 'could not determine the date of the current isotropic leaderboard, so please try again later'
        return
    print 'date of the current isotropic leaderboard is', date_of_current_isotropic_leaderboard

    one_day_delta = datetime.timedelta(1)
    date = date_of_last_cached_leaderboard + one_day_delta

    while date <= date_of_current_isotropic_leaderboard:
        print
        print date

        if date == date_of_current_isotropic_leaderboard:
            print 'scraping from isotropic'
            status = run_scrape_function_with_retries(scrape_leaderboard_from_isotropic, date)
        else:
            print 'scraping from councilroom'
            status = run_scrape_function_with_retries(scrape_leaderboard_from_councilroom, date)

            if status != 200:
                print 'scraping from bggdl'
                status = run_scrape_function_with_retries(scrape_leaderboard_from_bggdl, date)

        if status == 200:
            pass
        elif status == 404:
            print 'file not found, so we will assume that it does not exist, and go to the next day'
        else:
            print 'please try again later'
            break

        date += one_day_delta
def main():
    utils.ensure_exists(output_directory)

    date_of_last_cached_leaderboard = get_date_of_last_cached_leaderboard()
    log.info('date of the last cached leaderboard is %s', date_of_last_cached_leaderboard)

    date_of_current_isotropic_leaderboard = get_date_of_current_isotropic_leaderboard()
    if date_of_current_isotropic_leaderboard is None:
        log.warning('could not determine the date of the current isotropic leaderboard, so please try again later')
        return
    log.info('date of the current isotropic leaderboard is %s', date_of_current_isotropic_leaderboard)

    one_day_delta = datetime.timedelta(1)
    date = date_of_last_cached_leaderboard + one_day_delta

    while date <= date_of_current_isotropic_leaderboard:
        log.info('Processing %s', date)

        if date == date_of_current_isotropic_leaderboard:
            log.info('scraping from isotropic')
            status = run_scrape_function_with_retries(scrape_leaderboard_from_isotropic, date)
        else:
            log.info('scraping from councilroom')
            status = run_scrape_function_with_retries(scrape_leaderboard_from_councilroom, date)

            if status != 200:
                log.info('scraping from bggdl')
                status = run_scrape_function_with_retries(scrape_leaderboard_from_bggdl, date)

        if status == 200:
            pass
        elif status == 404:
            log.warning('file not found, so we will assume that it does not exist, and go to the next day')
        else:
            log.warning('Unexpected status of %d, please try again later', status)
            break

        date += one_day_delta
Esempio n. 17
0
def scrape_games():
    parser = utils.incremental_date_range_cmd_line_parser()
    utils.ensure_exists('static/scrape_data')
    os.chdir('static/scrape_data')

    args = parser.parse_args()
    last_month = ''
    
    yesterday = datetime.date.today() - datetime.timedelta(days=1)
    #Goko updates logs in real time; wait a day so the list is finalized.

    for cur_date in utils.daterange(default_startdate, yesterday, reverse=True):
        str_date = time.strftime("%Y%m%d", cur_date.timetuple())
        if not utils.includes_day(args, str_date):
            if DEBUG:
                print 'skipping', str_date, 'because not in cmd line arg daterange'
            continue
        mon = time.strftime("%b%y", cur_date.timetuple())
        if mon != last_month:
            print
            print mon, cur_date.day*"  ",
            sys.stdout.flush()
            last_month = mon
        ret = scrape_date(str_date, cur_date, passive=args.passive)
        if ret==DOWNLOADED:
            print 'o',
        elif ret==REPACKAGED:
            print 'O',
        elif ret==ERROR:
            print '!',
        elif ret==MISSING:
            print '_',
        else:
            print '.',
        sys.stdout.flush()
    print
    os.chdir('../..')
Esempio n. 18
0
# taken from
# http://stackoverflow.com/questions/1060279/iterating-through-a-range-of-dates-in-python

import datetime
import time
import os
import urllib

import utils

parser = utils.IncrementalDateRangeCmdLineParser()

# if the size of the game log is less than this assume we got an error page
SMALL_FILE_SIZE = 5000

utils.ensure_exists('static/scrape_data')
os.chdir('static/scrape_data')

# make I should just adopt the isotropic format for consistency?
ISOTROPIC_FORMAT = '%(year)d%(month)02d/%(day)02d/all.tar.bz2'
COUNCILROOM_FORMAT = '%(year)d%(month)02d%(day)02d/%(year)d%(month)02d%(day)02d.all.tar.bz2'


def FormatDate(fmt, date):
    return fmt % {
        'year': cur_date.year,
        'month': cur_date.month,
        'day': cur_date.day
    }

Esempio n. 19
0
    def train(self,
              x,
              epochs,
              batch_size=128,
              sample_interval=50,
              sample_path="samples/unknown",
              starting_epoch=0,
              save_interval=5000):
        ensure_exists(sample_path)

        half_batch = int(batch_size / 2)
        exp_replay = []

        epoch = starting_epoch
        while epoch < epochs:
            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, x.shape[0], half_batch)
            imgs = x[idx]

            # Generate a half batch of new images
            noise = np.random.normal(0, 1, (half_batch, self.noise_size))
            gen_imgs = self.generator.predict(noise)

            # save one random generated image for experience replay
            r_idx = np.random.randint(0, half_batch)
            exp_replay.append(gen_imgs[r_idx])

            # Train the discriminator
            # If we have enough points, do experience replay
            if len(exp_replay) == half_batch:
                generated_images = np.array(exp_replay)
                d_loss_replay = self.discriminator.train_on_batch(
                    generated_images, np.zeros((half_batch, 1)))
                exp_replay = []
            d_loss_real = self.discriminator.train_on_batch(
                imgs, np.array([0.9] * half_batch))
            d_loss_fake = self.discriminator.train_on_batch(
                gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator on random labels
            noise = np.random.normal(0, 1, (batch_size, self.noise_size))
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # Plot the progress
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %
                  (epoch, d_loss[0], 100 * d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.save_sample(epoch, sample_path)

            # save model
            if epoch % save_interval == 0:
                self.save(f"temp/{epoch}")
            epoch += 1
Esempio n. 20
0
def simulation_store_directory():
    directory = '{root_dir}/simulations'.format(root_dir=settings.root_dir)

    utils.ensure_exists(directory)

    return directory
Esempio n. 21
0
    days.sort()
    for year_month_day in days:
        if not utils.includes_day(args, year_month_day):
            continue

        if args.incremental and os.path.exists("parsed_out/%s-0.json" % year_month_day):
            print "skipping", year_month_day, "because already done"
            continue

        try:
            print "trying", year_month_day
            convert_to_json(year_month_day)
        except ParseTurnHeaderError, e:
            print e
            return


# def profilemain():
#     import hotshot, hotshot.stats
#     prof = hotshot.Profile("parse_game.prof")
#     prof.runcall(t)
#     prof.close()
#     stats = hotshot.stats.load("parse_game.prof")
#     stats.strip_dirs()
#     stats.sort_stats('time', 'calls')
#     stats.print_stats(20)

if __name__ == "__main__":
    utils.ensure_exists("parsed_out")
    main()
Esempio n. 22
0
    for year_month_day in days:
        if not utils.includes_day(args, year_month_day):
            continue
            
        if args.incremental and os.path.exists(
            'parsed_out/%s-0.json' % year_month_day):
            print 'skipping', year_month_day, 'because already done'
            continue        

        try:
            print 'trying', year_month_day
            convert_to_json(year_month_day)
        except ParseTurnHeaderError, e:
            print e
            return
    
# def profilemain():
#     import hotshot, hotshot.stats
#     prof = hotshot.Profile("parse_game.prof")
#     prof.runcall(t)
#     prof.close()
#     stats = hotshot.stats.load("parse_game.prof")
#     stats.strip_dirs()
#     stats.sort_stats('time', 'calls')
#     stats.print_stats(20)

if __name__ == '__main__':
    utils.ensure_exists('parsed_out')
    main()
    
Esempio n. 23
0
tf.app.flags.DEFINE_float('gan_lr', 0.0005,
                          'learning rate for the gen in GAN training')

FLAGS = tf.app.flags.FLAGS

assert FLAGS.mode in [
    "pretrain_gen", "pretrain_dis", "train_gan", "decode", "test"
]

if FLAGS.mode == "train_gan":
    FLAGS.single_pass = False

if FLAGS.min_dec_steps > FLAGS.max_dec_steps / 2:
    FLAGS.min_dec_steps = int(FLAGS.max_dec_steps / 2)

ensure_exists(FLAGS.model_dir)


def pretrain_generator(model, batcher, sess, val_batcher, model_saver,
                       model_dir, val_saver, val_dir):
    """Repeatedly runs training iterations, logging loss to screen and writing
    summaries"""
    print("starting run_training")
    best_loss = None  # will hold the best loss achieved so far
    best_loss = get_best_loss_from_chpt(val_dir)
    # get the val loss score
    coverage_loss = None
    hps = model.hps
    # this is where checkpoints of best models are saved
    running_avg_loss = 0
    # the eval job keeps a smoother, running average loss to tell it when to
Esempio n. 24
0
import os
import time
from datetime import date
import utils 

import sys

utils.ensure_exists('static/status')

cmds = [
    ('python scrape.py', False),             # downloads gamelogs from isotropic
    ('python parse_game.py', True),        # parses data into useable format
    ('python load_parsed_data.py', False),  # loads data into database
    ('python analyze.py', False),            # produces data for graphs
    ('python goals.py', False),
    ('python count_buys.py', False)
]

extra_args = sys.argv[1:]

# should think about how to parrallelize this for multiprocessor machines.
while True:
    for cmd, spittable in cmds:
        status_fn = (date.today().isoformat() + '-' +
                     time.strftime('%H:%M:%S') +
                     '-' + cmd.replace(' ', '_') + '.txt')
        cmd = cmd + ' ' + ' '.join(extra_args) + ' 2>&1 | tee -a ' + status_fn
        print cmd
        os.system(cmd)
        os.system('mv %s static/status' % status_fn)
    print 'sleeping'
Esempio n. 25
0
        os.stat(fn).st_size <= SMALL_FILE_SIZE):
        print 'removing small existing file', fn
        os.unlink(fn)


class MyURLOpener(urllib.FancyURLopener):

    def http_error_default(self, *args, **kwargs):
        urllib.URLopener.http_error_default(self, *args, **kwargs)


if __name__ == '__main__':
    parser = utils.incremental_date_range_cmd_line_parser()
    args = parser.parse_args()

    utils.ensure_exists('static/scrape_data')
    os.chdir('static/scrape_data')

    for cur_date in utils.daterange(datetime.date(2010, 10, 15),
                                    datetime.date.today()):
        str_date = time.strftime("%Y%m%d", cur_date.timetuple())
        if not utils.includes_day(args, str_date):
            print 'skipping', str_date, 'because not in cmd line arg daterange'
            continue
        directory = str_date
        print str_date
        games_short_name = str_date + '.all.tar.bz2'
        saved_games_bundle = directory + '/' + games_short_name
        if utils.at_least_as_big_as(saved_games_bundle, SMALL_FILE_SIZE):
            print 'skipping because exists', str_date, saved_games_bundle, \
                'and not small (size=', os.stat(saved_games_bundle).st_size, ')'
Esempio n. 26
0

if __name__ == "__main__":

    INPUT_FILE = sys.argv[1]

    log.debug(f"input file: {INPUT_FILE}")

    OUTPUT_FILE: Text = os.path.join(
        "/tmp/vim",
        basename(dirname(INPUT_FILE)),
        re.sub(
            r"^(.*)\.(?:r?md|m(?:ark)?down)$",
            r"\1.html",
            basename(INPUT_FILE),
            re.IGNORECASE | re.MULTILINE,
        ),
    )

    log.debug(f"output file: {OUTPUT_FILE}")

    ensure_exists(OUTPUT_FILE)

    with open(OUTPUT_FILE, "w", encoding="utf-8") as output:
        cmd = PandocCmd(INPUT_FILE)
        output.write(cmd.execute())
        print(f"Cmd: {' '.join(cmd.args)}")
        print(f'Output: {output.name}')

# vim:foldmethod=manual:
Esempio n. 27
0
import os
import time
from datetime import date, timedelta
import utils

import sys

utils.ensure_exists('static/status')

cmds = [
    'python scrape.py',
    'python parse_game.py',
    'python load_parsed_data.py ',
    'python analyze.py',
    'python goals.py',
    'python count_buys.py',
    'python run_trueskill.py',
    'python optimal_card_ratios.py',
    'python goal_stats.py',
    'python scrape_leaderboard.py',
    'python load_leaderboard.py',
]

extra_args = sys.argv[1:]

# should think about how to parrallelize this for multiprocessor machines.
while True:
    for cmd in cmds:
        month_ago = (date.today() - timedelta(days=30)).strftime('%Y%m%d')
        fmt_dict = {'month_ago': month_ago}
        cmd = (cmd % fmt_dict)
Esempio n. 28
0
def main(argv):
    tf.set_random_seed(111)  # a seed value for randomness

    # Create a batcher object that will create minibatches of data
    # TODO change to pass number

    # --------------- building graph ---------------
    hparam_gen = [
        'mode',
        'model_dir',
        'adagrad_init_acc',
        'steps_per_checkpoint',
        'batch_size',
        'beam_size',
        'cov_loss_wt',
        'coverage',
        'emb_dim',
        'rand_unif_init_mag',
        'gen_vocab_file',
        'gen_vocab_size',
        'hidden_dim',
        'gen_lr',
        'gen_max_gradient',
        'max_dec_steps',
        'max_enc_steps',
        'min_dec_steps',
        'trunc_norm_init_std',
        'single_pass',
        'log_root',
        'data_path',
    ]

    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_gen:  # if it's in the list
            hps_dict[key] = val  # add it to the dict

    hps_gen = namedtuple("HParams4Gen", hps_dict.keys())(**hps_dict)

    print("Building vocabulary for generator ...")
    gen_vocab = Vocab(join_path(hps_gen.data_path, hps_gen.gen_vocab_file),
                      hps_gen.gen_vocab_size)

    hparam_dis = [
        'mode',
        'vocab_type',
        'model_dir',
        'dis_vocab_size',
        'steps_per_checkpoint',
        'learning_rate_decay_factor',
        'dis_vocab_file',
        'num_class',
        'layer_size',
        'conv_layers',
        'max_steps',
        'kernel_size',
        'early_stop',
        'pool_size',
        'pool_layers',
        'dis_max_gradient',
        'batch_size',
        'dis_lr',
        'lr_decay_factor',
        'cell_type',
        'max_enc_steps',
        'max_dec_steps',
        'single_pass',
        'data_path',
        'num_models',
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_dis:  # if it's in the list
            hps_dict[key] = val  # add it to the dict

    hps_dis = namedtuple("HParams4Dis", hps_dict.keys())(**hps_dict)
    if hps_gen.gen_vocab_file == hps_dis.dis_vocab_file:
        hps_dis = hps_dis._replace(vocab_type="word")
        hps_dis = hps_dis._replace(layer_size=hps_gen.emb_dim)
        hps_dis = hps_dis._replace(dis_vocab_size=hps_gen.gen_vocab_size)
    else:
        hps_dis = hps_dis._replace(max_enc_steps=hps_dis.max_enc_steps * 2)
        hps_dis = hps_dis._replace(max_dec_steps=hps_dis.max_dec_steps * 2)
    if FLAGS.mode == "train_gan":
        hps_gen = hps_gen._replace(batch_size=hps_gen.batch_size *
                                   hps_dis.num_models)

    if FLAGS.mode != "pretrain_dis":
        with tf.variable_scope("generator"):
            generator = PointerGenerator(hps_gen, gen_vocab)
            print("Building generator graph ...")
            gen_decoder_scope = generator.build_graph()

    if FLAGS.mode != "pretrain_gen":
        print("Building vocabulary for discriminator ...")
        dis_vocab = Vocab(join_path(hps_dis.data_path, hps_dis.dis_vocab_file),
                          hps_dis.dis_vocab_size)
    if FLAGS.mode in ['train_gan', 'pretrain_dis']:
        with tf.variable_scope("discriminator"), tf.device("/gpu:0"):
            discriminator = Seq2ClassModel(hps_dis)
            print("Building discriminator graph ...")
            discriminator.build_graph()

    hparam_gan = [
        'mode',
        'model_dir',
        'gan_iter',
        'gan_gen_iter',
        'gan_dis_iter',
        'gan_lr',
        'rollout_num',
        'sample_num',
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_gan:  # if it's in the list
            hps_dict[key] = val  # add it to the dict

    hps_gan = namedtuple("HParams4GAN", hps_dict.keys())(**hps_dict)
    hps_gan = hps_gan._replace(mode="train_gan")
    if FLAGS.mode == 'train_gan':
        with tf.device("/gpu:0"):
            print("Creating rollout...")
            rollout = Rollout(generator, 0.8, gen_decoder_scope)

    # --------------- initializing variables ---------------
    all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) + \
        tf.get_collection_ref(tf.GraphKeys.WEIGHTS) + \
        tf.get_collection_ref(tf.GraphKeys.BIASES)
    sess = tf.Session(config=utils.get_config())
    sess.run(tf.variables_initializer(all_variables))
    if FLAGS.mode == "pretrain_gen":
        val_dir = ensure_exists(
            join_path(FLAGS.model_dir, 'generator', FLAGS.val_dir))
        model_dir = ensure_exists(join_path(FLAGS.model_dir, 'generator'))
        print("Restoring the generator model from the latest checkpoint...")
        gen_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[
                v for v in all_variables
                if "generator" in v.name and "GAN" not in v.name
            ])
        gen_dir = ensure_exists(join_path(FLAGS.model_dir, "generator"))
        # gen_dir = ensure_exists(FLAGS.model_dir)
        # temp_saver = tf.train.Saver(
        #     var_list=[v for v in all_variables if "generator" in v.name and "Adagrad" not in v.name])
        ckpt_path = utils.load_ckpt(gen_saver, sess, gen_dir)
        print('going to restore embeddings from checkpoint')
        if not ckpt_path:
            emb_path = join_path(FLAGS.model_dir, "generator", "init_embed")
            if emb_path:
                generator.saver.restore(
                    sess,
                    tf.train.get_checkpoint_state(
                        emb_path).model_checkpoint_path)
                print(
                    colored(
                        "successfully restored embeddings form %s" % emb_path,
                        'green'))
            else:
                print(
                    colored("failed to restore embeddings form %s" % emb_path,
                            'red'))

    elif FLAGS.mode in ["decode", "train_gan"]:
        print("Restoring the generator model from the best checkpoint...")
        dec_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "generator" in v.name])
        gan_dir = ensure_exists(
            join_path(FLAGS.model_dir, 'generator', FLAGS.gan_dir))
        gan_val_dir = ensure_exists(
            join_path(FLAGS.model_dir, 'generator', FLAGS.gan_dir,
                      FLAGS.val_dir))
        gan_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "generator" in v.name])
        gan_val_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "generator" in v.name])
        utils.load_ckpt(dec_saver, sess, val_dir,
                        (FLAGS.mode in ["train_gan", "decode"]))

    if FLAGS.mode in ["pretrain_dis", "train_gan"]:
        dis_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "discriminator" in v.name])
        dis_dir = ensure_exists(join_path(FLAGS.model_dir, 'discriminator'))
        ckpt = utils.load_ckpt(dis_saver, sess, dis_dir)
        if not ckpt:
            if hps_dis.vocab_type == "word":
                discriminator.init_emb(
                    sess, join_path(FLAGS.model_dir, "generator",
                                    "init_embed"))
            else:
                discriminator.init_emb(
                    sess,
                    join_path(FLAGS.model_dir, "discriminator", "init_embed"))

    # --------------- train models ---------------
    if FLAGS.mode != "pretrain_dis":
        gen_batcher_train = GenBatcher("train",
                                       gen_vocab,
                                       hps_gen,
                                       single_pass=hps_gen.single_pass)
        decoder = Decoder(sess, generator, gen_vocab)
        gen_batcher_val = GenBatcher("val",
                                     gen_vocab,
                                     hps_gen,
                                     single_pass=True)
        val_saver = tf.train.Saver(
            max_to_keep=10,
            var_list=[
                v for v in all_variables
                if "generator" in v.name and "GAN" not in v.name
            ])

    if FLAGS.mode != "pretrain_gen":
        dis_val_batch_size = hps_dis.batch_size * hps_dis.num_models \
            if hps_dis.mode == "train_gan" else hps_dis.batch_size * hps_dis.num_models * 2
        dis_batcher_val = DisBatcher(
            hps_dis.data_path,
            "eval",
            gen_vocab,
            dis_vocab,
            dis_val_batch_size,
            single_pass=True,
            max_art_steps=hps_dis.max_enc_steps,
            max_abs_steps=hps_dis.max_dec_steps,
        )

    if FLAGS.mode == "pretrain_gen":
        # get reload the
        print('Going to pretrain the generator')
        try:
            pretrain_generator(generator, gen_batcher_train, sess,
                               gen_batcher_val, gen_saver, model_dir,
                               val_saver, val_dir)
        except KeyboardInterrupt:
            tf.logging.info("Caught keyboard interrupt on worker....")

    elif FLAGS.mode == "pretrain_dis":
        print('Going to pretrain the discriminator')
        dis_batcher = DisBatcher(
            hps_dis.data_path,
            "decode",
            gen_vocab,
            dis_vocab,
            hps_dis.batch_size * hps_dis.num_models,
            single_pass=hps_dis.single_pass,
            max_art_steps=hps_dis.max_enc_steps,
            max_abs_steps=hps_dis.max_dec_steps,
        )
        try:
            pretrain_discriminator(sess, discriminator, dis_batcher_val,
                                   dis_vocab, dis_batcher, dis_saver)
        except KeyboardInterrupt:
            tf.logging.info("Caught keyboard interrupt on worker....")

    elif FLAGS.mode == "train_gan":
        gen_best_loss = get_best_loss_from_chpt(val_dir)
        gen_global_step = 0
        print('Going to tune the two using Gan')
        for i_gan in range(hps_gan.gan_iter):
            # Train the generator for one step
            g_losses = []
            current_speed = []
            for it in range(hps_gan.gan_gen_iter):
                start_time = time.time()
                batch = gen_batcher_train.next_batch()

                # generate samples
                enc_states, dec_in_state, n_samples, n_targets_padding_mask = decoder.mc_generate(
                    batch, include_start_token=True, s_num=hps_gan.sample_num)
                # get rewards for the samples
                n_rewards = rollout.get_reward(sess, gen_vocab, dis_vocab,
                                               batch, enc_states, dec_in_state,
                                               n_samples, hps_gan.rollout_num,
                                               discriminator)

                # fine tune the generator
                n_sample_targets = [samples[:, 1:] for samples in n_samples]
                n_targets_padding_mask = [
                    padding_mask[:, 1:]
                    for padding_mask in n_targets_padding_mask
                ]
                n_samples = [samples[:, :-1] for samples in n_samples]
                # sample_target_padding_mask = pad_sample(sample_target, gen_vocab, hps_gen)
                n_samples = [
                    np.where(
                        np.less(samples, hps_gen.gen_vocab_size), samples,
                        np.array([[gen_vocab.word2id(data.UNKNOWN_TOKEN)] *
                                  hps_gen.max_dec_steps] * hps_gen.batch_size))
                    for samples in n_samples
                ]
                results = generator.run_gan_batch(sess, batch, n_samples,
                                                  n_sample_targets,
                                                  n_targets_padding_mask,
                                                  n_rewards)

                gen_global_step = results["global_step"]

                # for visualization
                g_loss = results["loss"]
                if not math.isnan(g_loss):
                    g_losses.append(g_loss)
                else:
                    print(colored('a nan in gan loss', 'red'))
                current_speed.append(time.time() - start_time)

            # Test
            # if FLAGS.gan_gen_iter and (i_gan % 100 == 0 or i_gan == hps_gan.gan_iter - 1):
            if i_gan % 100 == 0 or i_gan == hps_gan.gan_iter - 1:
                print('Going to test the generator.')
                current_speed = sum(current_speed) / (len(current_speed) *
                                                      hps_gen.batch_size)
                everage_g_loss = sum(g_losses) / len(g_losses)
                # one more process hould be opened for the evaluation
                eval_loss, gen_best_loss = save_ckpt(
                    sess, generator, gen_best_loss, gan_dir, gan_saver,
                    gen_batcher_val, gan_val_dir, gan_val_saver,
                    gen_global_step)

                if eval_loss:
                    print("\nDashboard for " +
                          colored("GAN Generator", 'green') + " updated %s, "
                          "finished steps:\t%s\n"
                          "\tBatch size:\t%s\n"
                          "\tVocabulary size:\t%s\n"
                          "\tCurrent speed:\t%.4f seconds/article\n"
                          "\tAverage training loss:\t%.4f; "
                          "eval loss:\t%.4f" % (
                              datetime.datetime.now().strftime(
                                  "on %m-%d at %H:%M"),
                              gen_global_step,
                              FLAGS.batch_size,
                              hps_gen.gen_vocab_size,
                              current_speed,
                              everage_g_loss.item(),
                              eval_loss.item(),
                          ))

            # Train the discriminator
            print('Going to train the discriminator.')
            dis_best_loss = 1000
            dis_losses = []
            dis_accuracies = []
            for d_gan in range(hps_gan.gan_dis_iter):
                batch = gen_batcher_train.next_batch()
                enc_states, dec_in_state, k_samples_words, _ = decoder.mc_generate(
                    batch, s_num=hps_gan.sample_num)
                # shuould first tanslate to words to avoid unk
                articles_oovs = batch.art_oovs
                for samples_words in k_samples_words:
                    dec_batch_words = batch.target_batch
                    conditions_words = batch.enc_batch_extend_vocab
                    if hps_dis.vocab_type == "char":
                        samples = gen_vocab2dis_vocab(samples_words, gen_vocab,
                                                      articles_oovs, dis_vocab,
                                                      hps_dis.max_dec_steps,
                                                      STOP_DECODING)
                        dec_batch = gen_vocab2dis_vocab(
                            dec_batch_words, gen_vocab, articles_oovs,
                            dis_vocab, hps_dis.max_dec_steps, STOP_DECODING)
                        conditions = gen_vocab2dis_vocab(
                            conditions_words, gen_vocab, articles_oovs,
                            dis_vocab, hps_dis.max_enc_steps, PAD_TOKEN)
                    else:
                        samples = samples_words
                        dec_batch = dec_batch_words
                        conditions = conditions_words
                        # the unknown in target

                    inputs = np.concatenate([samples, dec_batch], 0)
                    conditions = np.concatenate([conditions, conditions], 0)

                    targets = [[1, 0] for _ in samples] + [[0, 1]
                                                           for _ in dec_batch]
                    targets = np.array(targets)
                    # randomize the samples
                    assert len(inputs) == len(conditions) == len(
                        targets
                    ), "lengthes of the inputs, conditions and targests should be the same."
                    indices = np.random.permutation(len(inputs))
                    inputs = np.split(inputs[indices], 2)
                    conditions = np.split(conditions[indices], 2)
                    targets = np.split(targets[indices], 2)
                    assert len(inputs) % 2 == 0, "the length should be mean"

                    results = discriminator.run_one_batch(
                        sess, inputs[0], conditions[0], targets[0])
                    dis_accuracies.append(results["accuracy"].item())
                    dis_losses.append(results["loss"].item())

                    results = discriminator.run_one_batch(
                        sess, inputs[1], conditions[1], targets[1])
                    dis_accuracies.append(results["accuracy"].item())

                ave_dis_acc = sum(dis_accuracies) / len(dis_accuracies)
                if d_gan == hps_gan.gan_dis_iter - 1:
                    if (sum(dis_losses) / len(dis_losses)) < dis_best_loss:
                        dis_best_loss = sum(dis_losses) / len(dis_losses)
                        checkpoint_path = ensure_exists(
                            join_path(hps_dis.model_dir,
                                      "discriminator")) + "/model.ckpt"
                        dis_saver.save(sess,
                                       checkpoint_path,
                                       global_step=results["global_step"])
                    print_dashboard("GAN Discriminator",
                                    results["global_step"].item(),
                                    hps_dis.batch_size, hps_dis.dis_vocab_size,
                                    results["loss"].item(), 0.00, 0.00, 0.00)
                    print("Average training accuracy: \t%.4f" % ave_dis_acc)

                if ave_dis_acc > 0.9:
                    break

    # --------------- decoding samples ---------------
    elif FLAGS.mode == "decode":
        print('Going to decode from the generator.')
        decoder.bs_decode(gen_batcher_train)
        print("Finished decoding..")
        # decode for generating corpus for discriminator

    sess.close()
Esempio n. 29
0
            # and since these are shown, they are 1 indexed
            show_turn_id = '%s-show-turn-%d' % (player, turn_no + 1)
            ret += '<div id="%s"></div>' % turn_id
            ret += '<a name="%s"></a><a href="#%s">%s</a>' % (
                show_turn_id, show_turn_id, cur_match.group())

            contents = contents[cur_match.end():]
        else:
            break
    before_end = contents.find('</html')
    ret = ret + contents[:before_end]
    ret += '<div id="end-game">\n'
    ret += '</div>&nbsp<br>\n' * 10
    contents = contents[before_end:]
    return ret + contents


# def profilemain():
#     import hotshot, hotshot.stats
#     prof = hotshot.Profile("parse_game.prof")
#     prof.runcall(t)
#     prof.close()
#     stats = hotshot.stats.load("parse_game.prof")
#     stats.strip_dirs()
#     stats.sort_stats('time', 'calls')
#     stats.print_stats(20)

if __name__ == '__main__':
    utils.ensure_exists('parsed_out')
    main()
Esempio n. 30
0
    def train(self, x, epochs, k_lambda, gamma, batch_size=16, sample_interval=50, sample_path="samples/unknown",
              starting_epoch=0, save_interval=5000):
        ensure_exists(sample_path)

        epoch = starting_epoch
        steps_per_epoch = len(x) // batch_size
        decay_every = 16000
        initial_lr = 0.0001

        step = tf.Variable(0)

        lr = initial_lr * pow(0.5, epoch // decay_every)
        print(f"Initialized initial learning rate at {lr}")
        self.optimizer = tf.train.AdamOptimizer(lr)

        # set up train operations
        # dataset
        def datagen():
            while True:
                idx = np.random.randint(0, x.shape[0], batch_size)
                imgs = x[idx]
                yield imgs

        dataset = tf.data.Dataset.from_generator(datagen, tf.float32, tf.TensorShape(
            (batch_size, self.img_rows, self.img_cols, self.channels))).repeat()

        # generator
        noise = tf.random_uniform((x.shape[0], batch_size), -1.0, 1.0)
        fake = self.generator(noise)
        g_loss = self.l1loss(fake, self.discriminator(fake))

        # discriminator
        real = dataset.get_next()
        d_loss_real = self.l1loss(real, self.discriminator(real))
        d_loss_gen = g_loss
        d_loss = d_loss_real - self.k * d_loss_gen

        # training
        g_opt = self.optimizer.minimize(g_loss, var_list=self.generator.trainable_variables)
        d_opt = self.optimizer.minimize(d_loss, global_step=step, var_list=self.discriminator.trainable_variables)

        # updating
        balance = gamma * d_loss_real - g_loss
        measure = d_loss_real + tf.abs(balance)
        with tf.control_dependencies([g_opt, d_opt]):
            k_update = tf.assign(self.k, tf.clip_by_value(self.k + k_lambda * balance, 0, 1))

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            while epoch < epochs:
                result = sess.run({
                    "k_update": k_update,
                    "m": measure,
                    "g_loss": g_loss,
                    "d_loss_real": d_loss_real,
                    "d_loss_fake": d_loss_gen,
                    "d_loss": d_loss,
                    "k": self.k
                })

                # ---------------------
                #  LR Decay
                # ---------------------
                if epoch % decay_every == 0:
                    lr = initial_lr * pow(0.5, epoch // decay_every)
                    print(f"Decaying learning rate to {lr}")
                    self.optimizer = tf.train.AdamOptimizer(lr)

                # ---------------------
                #  Status report
                # ---------------------
                # calculate convergence factor
                # Plot the progress
                print("%d [M: %f] [D loss: %f = %f + %f] [G loss: %f] [K: %f]" % (
                    epoch, result['m'], result['d_loss'], result['d_loss_real'], result['d_loss_fake'],
                    result['g_loss'], result['k']))

                # If at save interval => save generated image samples
                if epoch % sample_interval == 0:
                    self.save_sample(epoch, sample_path, x, sess)

                # save model
                if epoch % save_interval == 0:
                    self.save(f"temp/{epoch}")
                epoch += 1
Esempio n. 31
0
    def train(self, x, y, epochs, batch_size=128, sample_interval=50, sample_path="samples/unknown", starting_epoch=0,
              save_interval=5000):
        """
        Trains the GAN.
        :param x: The training data.
        :param y: The labels for the training data.
        :param epochs: The number of epochs to train.
        :param batch_size: The size of an epoch.
        :param sample_interval: How often to save sample images.
        :param sample_path: Where to save sample images.
        """
        ensure_exists(sample_path)

        half_batch = int(batch_size / 2)
        exp_replay = []

        epoch = starting_epoch
        while epoch < epochs:
            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, x.shape[0], half_batch)
            imgs = x[idx]
            labels = y[idx]

            # Generate a half batch of new images
            noise = np.random.normal(0, 1, (half_batch, self.noise_size))
            gen_imgs = self.generator.predict({"noise": noise, "label": labels})

            # save one random generated image for experience replay
            r_idx = np.random.randint(0, half_batch)
            exp_replay.append((gen_imgs[r_idx], labels[r_idx]))

            # Train the discriminator
            # If we have enough points, do experience replay
            if len(exp_replay) == half_batch:
                generated_images = np.array([p[0] for p in exp_replay])
                labels = np.array([p[1] for p in exp_replay])
                d_loss_replay = self.discriminator.train_on_batch([generated_images, labels],
                                                                  np.zeros((half_batch, 1)))
                exp_replay = []
            d_loss_real = self.discriminator.train_on_batch({"image": imgs, "label": labels}, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch({"image": gen_imgs, "label": labels},
                                                            np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator on random labels
            noise = np.random.normal(0, 1, (batch_size, self.noise_size))
            idx = np.random.randint(0, x.shape[0], batch_size)
            labels = y[idx]

            g_loss = self.combined.train_on_batch({"noise": noise, "label": labels}, valid_y)

            # Plot the progress
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.save_sample(epoch, sample_path)

            # save model
            if epoch % save_interval == 0:
                self.save(f"temp/{epoch}")
            epoch += 1
Esempio n. 32
0
def trainNet(net,
             opt,
             cri,
             sch,
             cp_dir,
             epoch_range,
             training_set,
             val_set,
             batch_size_big,
             rep=1):
    tb_dir = os.path.join(cp_dir, 'tb')
    utils.ensure_exists(tb_dir)
    fout = open(os.path.join(cp_dir, 'train.log'), 'a')

    for epoch in epoch_range:
        net.train()
        running_loss = 0
        start_time = time.time()
        tl = len(training_set)

        for e in range(rep):
            for i, data in enumerate(training_set, 0):

                com = data['com'].float().cuda()
                org = data['org'].float().cuda()

                opt.zero_grad()
                if com.shape[0] > batch_size_big:
                    com_big = com[:batch_size_big, :, :, :]
                    org_big = org[:batch_size_big, :, :, :]
                    ret_big = net(com_big)
                    loss_big = cri(ret_big, org_big)

                    _, _, h, w = com.shape
                    new_h = 32
                    new_w = 32
                    top = 8 * np.random.randint(0, (h - new_h) // 8)
                    left = 8 * np.random.randint(0, (w - new_w) // 8)
                    com_small = com[batch_size_big:, :, top:top + new_h,
                                    left:left + new_w]
                    org_small = org[batch_size_big:, :, top:top + new_h,
                                    left:left + new_w]

                    ret_small = net(com_small)
                    loss_small = cri(ret_small, org_small)
                    loss = loss_big + loss_small
                    loss.backward()

                    nn.utils.clip_grad_norm_(net.parameters(), 5)
                    opt.step()
                else:
                    ret = net(com)
                    loss = cri(ret, org)
                    loss.backward()
                    nn.utils.clip_grad_norm_(net.parameters(), 5)
                    opt.step()

                running_loss += loss.item()
                if i % 100 == 0:
                    print('[Running epoch %2d, batch %4d] loss: %.3f' %
                          (epoch + 1, i + 1, \
                           10000 * running_loss / (e * tl + i + 1),
                           ), end='\n')
                else:
                    print('[Running epoch %2d, batch %4d] loss: %.3f' %
                          (epoch + 1, i + 1, \
                           10000 * running_loss / (e * tl + i + 1),
                           ), end='\r')

        if not (epoch + 1) % 1:
            timestamp = time.time()
            print('[timestamp %d, epoch %2d] loss: %.3f, time: %6ds        ' %
                  (timestamp, epoch + 1, 10000 * running_loss /
                   ((i + 1) * rep), timestamp - start_time),
                  end='\n')
            with torch.no_grad():
                p_psnr = utils.evalPsnr(net, val_set, fout=fout)

            save_model(net, opt,
                       os.path.join(cp_dir,
                                    str(epoch + 1) + '_withopt'))
            torch.save(net.state_dict(), os.path.join(cp_dir, str(epoch + 1)))
            sch.step()
            print('cur_lr: %.5f' % sch.get_lr()[0])
Esempio n. 33
0
def trainNet(net,
             opt,
             cri,
             sch,
             cp_dir,
             epoch_range,
             training_set,
             val_set,
             batch_size_big,
             rep=1):
    tb_dir = os.path.join(cp_dir, 'tb')
    utils.ensure_exists(tb_dir)
    fout = open(os.path.join(cp_dir, 'train.log'), 'a')

    for epoch in epoch_range:
        net.train()
        running_loss = 0
        start_time = time.time()
        tl = len(training_set)

        for e in range(rep):
            for i, data in enumerate(training_set, 0):

                com = data['com'].float().cuda()
                c_2 = data['com_2'].float().cuda()
                c_4 = data['com_4'].float().cuda()
                org = data['org'].float().cuda()
                o_2 = data['org_2'].float().cuda()
                o_4 = data['org_4'].float().cuda()

                com_pair = (c_4, c_2, com)
                org_pair = (o_4, o_2, org)

                opt.zero_grad()
                ret = net(com_pair)
                loss, MSE4, MSE2, MSEp, MSEd = cri(ret, org_pair)

                loss.backward()
                nn.utils.clip_grad_norm_(net.parameters(), 10)
                opt.step()

                running_loss += loss.item()
                if i % 100 == 0:
                    print('[Running epoch %2d, batch %4d] loss: %.3f' %
                          (epoch + 1, i + 1, \
                           10000 * running_loss / (e * tl + i + 1),
                           ), end='\n')
                else:
                    print('[Running epoch %2d, batch %4d] loss: %.3f' %
                          (epoch + 1, i + 1, \
                           10000 * running_loss / (e * tl + i + 1),
                           ), end='\r')

        if not (epoch + 1) % 1:
            timestamp = time.time()
            print('[timestamp %d, epoch %2d] loss: %.3f, time: %6ds        ' %
                  (timestamp, epoch + 1, 10000 * running_loss /
                   ((i + 1) * rep), timestamp - start_time),
                  end='\n')
            with torch.no_grad():
                p_psnr = utils.evalPsnr(net, val_set, fout=fout)

            save_model(net, opt,
                       os.path.join(cp_dir,
                                    str(epoch + 1) + '_withopt'))
            torch.save(net.state_dict(), os.path.join(cp_dir, str(epoch + 1)))
            sch.step()
            print('cur_lr: %.5f' % sch.get_lr()[0])