Esempio n. 1
0
def train(args):

    configure(args['log_dir'])

    dial_data = get_dataloader(
        os.path.join(args['data_dir'], 'encoded_train_dialogue_pair.json'),
        os.path.join(args['data_dir'], 'vocabulary.json'), args['batch_size'])
    vocab = Vocabulary()
    vocab.load_vocab(os.path.join(args['data_dir'], 'vocabulary.json'))
    args['voca_size'] = len(vocab.word2idx)

    model = Seq2Seq(args).cuda() if torch.cuda.is_available() else Seq2Seq(
        args)

    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])

    criterion = nn.NLLLoss(ignore_index=vocab.get_idx('PADED'))

    min_valid_loss = float('inf')

    for epoch in range(args['epoches']):
        for batch_idx, (sour, sour_len, targ,
                        targ_len) in enumerate(dial_data):
            if torch.cuda.is_available():
                sour = sour.cuda()
                targ = targ.cuda()
            loss = train_batch(model, optimizer, criterion,
                               (sour, sour_len, targ, targ_len))

            logger.info('training loss:{}'.format(loss))
            log_value('CrossEntropy loss', loss,
                      epoch * len(dial_data) + batch_idx)

            if (batch_idx + epoch * len(dial_data)) % args['valid_step'] == 0:
                valid_loader = get_dataloader(
                    os.path.join(args['data_dir'],
                                 'encoded_valid_dialogue_pair.json'),
                    os.path.join(args['data_dir'], 'vocabulary.json'),
                    args['batch_size'])
                valid_loss = validate(model, valid_loader, criterion)

                log_value(
                    'valid loss', valid_loss,
                    int((batch_idx + epoch * len(dial_data)) /
                        args['valid_step']))
                logger.info('valid_step:{} valid_loss:{}'.format(
                    int((batch_idx + epoch * len(dial_data)) /
                        args['valid_step']), valid_loss))

                checkpoint = Checkpoint(model, optimizer, epoch, batch_idx)
                checkpoint.save(args['exp_dir'])
Esempio n. 2
0
def main(_):
    checkpoint = Checkpoint(FLAGS.checkpoint_dir)
    utils.exists_or_mkdir(FLAGS.sample_dir)
    utils.exists_or_mkdir(FLAGS.log_dir)
    summaryWriter = tensorboardX.SummaryWriter(log_dir = FLAGS.log_dir)#torch.utils.tensorboard.SummaryWriter(log_dir = FLAGS.log_dir)

    logger.info('[Params] lr:%f, size:%d, dataset:%s, av_gen:%d, n_disc:%d'%
                (FLAGS.learning_rate, FLAGS.output_size, FLAGS.dataset, int(FLAGS.use_averaged_gen), FLAGS.n_discriminator))

    #dataset
    z_shape = (FLAGS.z_dim,)
    image_size = (FLAGS.output_size, FLAGS.output_size)
    image_shape = (3,) + image_size

    ds = dataset.datasets.from_name(name=FLAGS.dataset, data_folder=FLAGS.data_folder,
                                    output_size=image_size)

    batch = batch_gen.BatchWithNoise(ds, batch_size=FLAGS.batch_size, z_shape=z_shape,num_workers=10)

    #initialize device
    device = utils.get_torch_device()

    #model
    nn_model = models.model_factory.create_model(FLAGS.model_name,
                                                 device=device,
                                                 image_shape=image_shape,
                                                 z_shape=z_shape,
                                                 use_av_gen=FLAGS.use_averaged_gen,
                                                 g_tanh=False)
    nn_model.register_checkpoint(checkpoint)

    loss = gan_loss.js_loss()
    #lambd = lambda_scheduler.Constant(0.1)
    lambd = lambda_scheduler.ThresholdAnnealing(1000., threshold=loss.lambda_switch_level, min_switch_step=FLAGS.lambda_switch_steps, verbose=True)
    checkpoint.register('lambda', lambd, True)

    trainer = Trainer(model=nn_model, batch=batch, loss=loss, lr=FLAGS.learning_rate,
                      reg='gp', lambd=lambd)
    trainer.sub_batches = FLAGS.batch_per_update
    trainer.register_checkpoint(checkpoint)

    it_start = checkpoint.load(FLAGS.checkpoint_it_to_load)

    trainer.update_lr()

    ##========================= LOAD CONTEXT ================================##
    context_path = os.path.join(FLAGS.checkpoint_dir, 'context.npz')
    
    sample_seed = None
    if os.path.exists(context_path):
        sample_seed = np.load(context_path)['z']
        if sample_seed.shape[0] != FLAGS.sample_size or sample_seed.shape[1] != FLAGS.z_dim:
            sample_seed = None
            logger.info('Invalid sample seed')
        else:
            logger.info('Sample seed loaded')
    
    if sample_seed is None:
        sample_seed = batch.sample_z(FLAGS.sample_size).data.numpy()
        np.savez(context_path, z = sample_seed)

    ##========================= TRAIN MODELS ================================##
    batches_per_epoch = 10000
    total_time = 0

    bLambdaSwitched = (it_start == 0)
    n_too_good_d = []

    number_of_iterations = FLAGS.epoch*batches_per_epoch
    for it in range(number_of_iterations):
        start_time = time.time()
        iter_counter = it + it_start

        # updates the discriminator
        #if iter_counter < 25 or iter_counter % 500 == 0:
        #    d_iter = 20
        #else:
        #    d_iter = 5
        if bLambdaSwitched:
            #if lambda was switched we want to keep discriminator optimal
            logger.info('[!] Warming up discriminator')
            d_iter = 25
        else:
            d_iter = FLAGS.n_discriminator
#
        errD, s, errG, b_too_good_D = trainer.update(d_iter, 1)

        summaryWriter.add_scalar('d_loss', errD, iter_counter)
        summaryWriter.add_scalar('slope', s, iter_counter)
        summaryWriter.add_scalar('g_loss', errG, iter_counter)
        summaryWriter.add_scalar('loss', errD + float(lambd) * s**2, iter_counter)
        summaryWriter.add_scalar('lambda', float(lambd), iter_counter)

        #updating lambda
        n_too_good_d.append(b_too_good_D)
        if len(n_too_good_d) > 20:
            del n_too_good_d[0]               
                
        bLambdaSwitched = lambd.update(errD)
        if not bLambdaSwitched and sum(n_too_good_d) > 10:
            bLambdaSwitched = lambd.switch()

        end_time = time.time()

        iter_time = end_time - start_time
        total_time += iter_time

        logger.info("[%2d/%2d] time: %4.4f, d_loss: %.8f, s: %.4f, g_loss: %.8f" % (iter_counter, it_start + number_of_iterations, iter_time, errD, s, errG))

        if np.mod(iter_counter, FLAGS.sample_step) == 0 and it > 0:
            n = int(np.sqrt(FLAGS.sample_size))

            img = trainer.sample(sample_seed)
            img = img.data.cpu()

            img_tb = utils.image_to_tensorboard(torchvision.utils.make_grid(img, n))
            summaryWriter.add_image('samples',img_tb, iter_counter)

            utils.save_images(img.data.cpu().numpy(), [n, n], './{}/train_{:02d}.png'.format(FLAGS.sample_dir, iter_counter))

        if np.mod(iter_counter, FLAGS.save_step) == 0 and it > 0:
            checkpoint.save(iter_counter)

    checkpoint.save(iter_counter)
Esempio n. 3
0
class Processor(object):
    """Class for processing dump files from postgresql."""
    MILLION = 1024 * 1024

    def __init__(self):
        self.bytes_count = 0
        self.start_time = 0.0
        self.out_files = {}
        self.checkpoint = Checkpoint(config.VALUE_SET)
        self.init_time()

    def init_time(self):
        """Init time."""

        self.start_time = time.time()

    def add_bytes_count(self, count: int):
        """Add up bytes count."""

        self.bytes_count += count

    def split_if_necessary(self) -> None:
        """Check size of each storage file, called each batch
        close and open a new one to store if size exceeds max_split_size
        """

        # Convert MB to Byte
        for v in config.VALUE_SET:
            file_size = self.out_files[v].tell()
            if file_size >= config.FILE_SPLIT_SIZE:
                self.checkpoint.update_file_index(v)
                new_file = open(
                    self.checkpoint.get_file_name(v, config.OUT_DIR), 'a')
                self.add_table_head(new_file)
                self.out_files[v].close()
                self.out_files[v] = new_file
                logging.info('File size grows over {:.2f} MB, '
                             'store in new file `{}`...'.format(
                                 config.FILE_SPLIT_SIZE / self.MILLION,
                                 new_file.name))

    def process_line(self, line: str) -> None:
        """Process each line, does NOT verify the validness of
         lines (print them and ignores invalid ones without terminating)
         check if this line is recorded, and record the line.

        :param line: str, line to process ('\n' not included)
        """

        attributes = line.split('\t')
        try:
            # Check value in values to group by
            value = attributes[config.GROUP_BY_ATTR_INDEX]
            if value not in config.VALUE_SET:
                return
            row_count = int(attributes[config.INDEX_ROW_COUNT])
            # Check if line is already parsed and recorded
            if row_count <= self.checkpoint.row_count[value]:
                return
            # Keep attributes we're interested in
            data = [attributes[i] for i in config.RECORD_ATTR_INDEX_LIST]
            # Write to related file
            self.out_files[value].write('\t'.join(data))
            self.out_files[value].write('\n')
            # Update index
            self.checkpoint.row_count[value] = row_count
        except Exception as e:
            logging.warning(e)
            logging.warning("Invalid row: {}".format(attributes))

    @staticmethod
    def verify_file_schema(fp: TextIO) -> bool:
        """Verify the schema of data contained in a file.
        The dump files of postgresql should contain exactly one table each.
        """

        line = fp.readline()
        # Remember to return head of file
        fp.seek(0)
        if isinstance(line, bytes):
            line = str(line, encoding='utf-8')
        # Remove empty cells
        attributes = list(filter(None, line.split('\t')))
        # Check attribute count
        if len(attributes) != config.ATTR_COUNT:
            return False
        # Check validness of index attribute
        try:
            _ = int(attributes[config.INDEX_ROW_COUNT])
        except ValueError:
            return False
        return True

    @staticmethod
    def add_table_head(f: TextIO) -> None:
        """Add headings of table."""

        f.write('\t'.join(config.RECORD_ATTR_LIST))
        f.write('\n')

    def process_file(self, filename: str, is_old_file: bool = False) -> None:
        """Process a text file (ends with '.dat') or gzip file (ends with .gz).

        :param filename: str, name of file to process
        :param is_old_file: bool, whether this file has been processed before
                if it has been, we should skip batches already read.
        :return: int, 0 if this file is ignored or 1 if processed
        """

        # Check file type
        file_type = filename[filename.rfind('.'):]
        if file_type not in config.OPEN_FUNCS:
            logging.info('Fail to process `{}`: unsupported file type.'.format(
                filename))
            return
        # Open file according to its type
        fp = config.OPEN_FUNCS[file_type](filename)

        # Old file: needs to recover to the starting point
        if is_old_file and self.checkpoint.offset > 0:
            fp.seek(self.checkpoint.offset)
            logging.info('Time for seeking file offset: {:.2f} s'.format(
                time.time() - self.start_time))
            # This should be the start of processing
            self.init_time()
        else:
            # New files:
            # needs to verify whether this file contains the table we want
            if not self.verify_file_schema(fp):
                logging.info(
                    'Schema of `{}` doesn\'t fit; skip.'.format(filename))
                fp.close()
                return
            # Record current file
            self.checkpoint.current_file = filename

        logging.info('Start processing `{}`...'.format(filename))
        while True:
            self.checkpoint.offset = fp.tell()
            batch = fp.read(config.BATCH_SIZE)
            # EOF
            line = fp.readline()
            if line:
                batch += line
            if not batch:
                break
            # Convert from bytes to str if needed
            if isinstance(batch, bytes):
                batch = str(batch, 'utf-8')
            # Parse batch
            for line in batch.splitlines():
                self.process_line(line)
            self.add_bytes_count(len(batch))
            # Split large files and change storage to new files
            if config.SPLIT:
                self.split_if_necessary()
        fp.close()

    def process_dir(self, dirname: str) -> None:
        """Recursively process files in given directory.

        :param dirname: str, directory of files to precess
        :return: number of files processed under this directory
        """

        file_list = sorted(os.listdir(dirname))
        for name in file_list:
            # Full name of file
            name = os.path.join(dirname, name)
            # Check if this file is already processed
            if name in self.checkpoint.processed_files:
                continue
            if os.path.isfile(name):
                self.process_file(name)
                self.checkpoint.processed_files.add(name)
            elif os.path.isdir(name) and config.RECURSIVE:
                self.process_dir(name)

    def before_process(self) -> None:
        """Create directory if needed, and load records."""
        if not os.path.isdir(config.OUT_DIR):
            os.mkdir(config.OUT_DIR)
        # Load checkpoints from file
        if os.path.exists(config.RECORD_FILE):
            self.checkpoint.load(config.RECORD_FILE)
            logging.info('Checkpoint loaded from `{}`.'.format(
                config.RECORD_FILE))
        # Open files to write
        for v in config.VALUE_SET:
            f = open(self.checkpoint.get_file_name(v, config.OUT_DIR), 'a')
            # If it's a new file, add headings
            if f.tell() == 0:
                self.add_table_head(f)
            self.out_files[v] = f

    def process(self, dir_list: list) -> None:
        """Process list of directories / files"""
        try:
            # Prepare for processing
            self.before_process()
            # Recover from file processed last time
            if os.path.exists(self.checkpoint.current_file):
                logging.info('Reloading `{}` from last checkpoints...'.format(
                    self.checkpoint.current_file))
                self.process_file(self.checkpoint.current_file,
                                  is_old_file=True)
            if len(dir_list) == 0:
                logging.error(
                    'Please specify at least one directory or file to process.'
                )
            # Process each directory / file
            for dir_name in dir_list:
                if os.path.isdir(dir_name):
                    self.process_dir(dir_name)
                elif os.path.isfile(dir_name):
                    self.process_file(dir_name)
                else:
                    logging.warning(
                        '`{}` is not a directory / file; skip.'.format(
                            dir_name))
        # Ctrl + C manually stopped
        except KeyboardInterrupt:
            self.after_process(is_interrupted=True)
        # Other unknown exceptions...
        except Exception as e:
            logging.warning(e)
            self.after_process(is_interrupted=True)
        else:
            self.after_process(is_interrupted=False)

    def after_process(self, is_interrupted: bool) -> None:
        """Deal with opened files, useless files and save records."""
        # Close files, and remove files with zero contents
        head_len = len('\t'.join(config.RECORD_ATTR_LIST)) + 1
        for file in self.out_files.values():
            file.close()
            # Not strictly compare size
            if os.path.getsize(file.name) <= head_len + 100:
                os.remove(file.name)
        # Handle interrupts
        if is_interrupted:
            self.checkpoint.save(config.RECORD_FILE)
            logging.info('Checkpoint saved in `{}`.'.format(
                config.RECORD_FILE))
        # Normal ending, remove record file
        elif os.path.exists(config.RECORD_FILE):
            os.remove(config.RECORD_FILE)
        # Analyse speed
        total_mb = self.bytes_count / self.MILLION
        total_time = time.time() - self.start_time
        avg_speed = total_mb / total_time
        logging.info(
            'Processed {:.2f} MB in {:.2f} s, {:.2f} MB/s on average.'.format(
                total_mb, total_time, avg_speed))
        exit(int(is_interrupted))
Esempio n. 4
0
def main(env_name, num_episodes, gamma, lam, kl_targ, batch_size, nprocs,
         policy_hid_list, valfunc_hid_list, gpu_pct, restore_path, animate,
         submit):
    """ Main training loop

    Args:
        env_name: OpenAI Gym environment name, e.g. 'Hopper-v1'
        num_episodes: maximum number of episodes to run
        gamma: reward discount factor (float)
        lam: lambda from Generalized Advantage Estimate
        kl_targ: D_KL target for policy update [D_KL(pi_old || pi_new)
        batch_size: number of episodes per policy training batch
    """
    # killer = GracefulKiller()

    env, obs_dim, act_dim = init_osim(animate)
    env.seed(111 + mpi_util.rank)
    mpi_util.set_global_seeds(111 + mpi_util.rank)

    obs_dim += 1  # add 1 to obs dimension for time step feature (see run_episode())
    now = datetime.utcnow().strftime(
        "%b-%d_%H:%M:%S")  # create unique directories
    if mpi_util.rank == 0:
        #aigym_path = os.path.join('/tmp', env_name, now)
        #env = wrappers.Monitor(env, aigym_path, force=True)
        logger = Logger(logname=env_name, now=now)

    episode = 0

    checkpoint = Checkpoint("saves", now)
    # restore from checkpoint?
    if restore_path:
        (policy, val_func, scaler, episode, obs_dim, act_dim,
         kl_targ) = checkpoint.restore(restore_path)
    else:
        policy = Policy(obs_dim, act_dim, kl_targ)
        val_func = NNValueFunction(obs_dim)
        scaler = Scaler(obs_dim)

        if mpi_util.rank == 0:
            # run a few episodes (on node 0) of untrained policy to initialize scaler:
            trajectories = run_policy(env, policy, scaler, episodes=5)

            unscaled = np.concatenate(
                [t['unscaled_obs'] for t in trajectories])
            scaler.update(
                unscaled)  # update running statistics for scaling observations

        # broadcast policy weights, scaler, val_func
        (policy, scaler, val_func) = mpi_util.broadcast_policy_scaler_val(
            policy, scaler, val_func)

        if mpi_util.rank == 0:
            checkpoint.save(policy, val_func, scaler, episode)

    if animate:
        observes, actions, rewards, unscaled_obs = run_episode(env,
                                                               policy,
                                                               scaler,
                                                               animate=animate)
        exit(0)

    if submit:
        # Settings
        #remote_base = 'http://grader.crowdai.org:1729'
        remote_base = 'http://grader.crowdai.org:1730'
        token = 'a83412a94593cae3a491f3ee28ff44e1'

        client = Client(remote_base)

        # Create environment
        observation = client.env_create(token)
        step = 0.0
        observes, actions, rewards, unscaled_obs = [], [], [], []
        scale, offset = scaler.get()
        scale[-1] = 1.0  # don't scale time step feature
        offset[-1] = 0.0  # don't offset time step feature

        # Run a single step
        #
        # The grader runs 3 simulations of at most 1000 steps each. We stop after the last one
        while True:
            obs = np.array(observation).astype(np.float32).reshape((1, -1))
            print("OBSERVATION TYPE:", type(obs), obs.shape)
            print(obs)
            obs = np.append(obs, [[step]], axis=1)  # add time step feature
            unscaled_obs.append(obs)
            obs = (obs - offset) * scale  # center and scale observations
            observes.append(obs)

            action = policy.sample(obs).astype(np.float32).reshape((-1, 1))
            print("ACTION TYPE:", type(action), action.shape)
            print(action)
            actions.append(action)

            [observation, reward, done,
             info] = client.env_step(action.tolist())
            print("step:", step, "reward:", reward)

            if not isinstance(reward, float):
                reward = np.asscalar(reward)
            rewards.append(reward)
            step += 1e-3  # increment time step feature

            if done:
                print(
                    "================================== RESTARTING ================================="
                )
                observation = client.env_reset()
                step = 0.0
                observes, actions, rewards, unscaled_obs = [], [], [], []
                scale, offset = scaler.get()
                scale[-1] = 1.0  # don't scale time step feature
                offset[-1] = 0.0  # don't offset time step feature
                if not observation:
                    break

        client.submit()
        exit(0)

    ######

    worker_batch_size = int(batch_size / mpi_util.nworkers)  # HACK
    if (worker_batch_size * mpi_util.nworkers != batch_size):
        print("batch_size:", batch_size, " is not divisible by nworkers:",
              mpi_util.nworkers)
        exit(1)

    batch = 0
    while episode < num_episodes:
        if mpi_util.rank == 0 and batch > 0 and batch % 10 == 0:
            checkpoint.save(policy, val_func, scaler, episode)
        batch = batch + 1

        trajectories = run_policy(env,
                                  policy,
                                  scaler,
                                  episodes=worker_batch_size)
        trajectories = mpi_util.gather_trajectories(trajectories)

        if mpi_util.rank == 0:
            # concatentate trajectories into one list
            trajectories = list(itertools.chain.from_iterable(trajectories))
            print("did a batch of ", len(trajectories), " trajectories")
            print([t['rewards'].sum() for t in trajectories])

            episode += len(trajectories)
            add_value(trajectories,
                      val_func)  # add estimated values to episodes
            add_disc_sum_rew(trajectories,
                             gamma)  # calculated discounted sum of Rs
            add_gae(trajectories, gamma, lam)  # calculate advantage

            # concatenate all episodes into single NumPy arrays
            observes, actions, advantages, disc_sum_rew = build_train_set(
                trajectories)

            # add various stats to training log:
            logger.log({
                '_MeanReward':
                np.mean([t['rewards'].sum() for t in trajectories]),
                'Steps':
                np.sum([t['observes'].shape[0] for t in trajectories])
            })
            log_batch_stats(observes, actions, advantages, disc_sum_rew,
                            logger, episode)

            policy.update(observes, actions, advantages,
                          logger)  # update policy
            val_func.fit(observes, disc_sum_rew,
                         logger)  # update value function

            unscaled = np.concatenate(
                [t['unscaled_obs'] for t in trajectories])
            scaler.update(
                unscaled)  # update running statistics for scaling observations

            logger.write(
                display=True)  # write logger results to file and stdout

        # if mpi_util.rank == 0 and killer.kill_now:
        #     if input('Terminate training (y/[n])? ') == 'y':
        #         break
        #     killer.kill_now = False

        # broadcast policy weights, scaler, val_func
        (policy, scaler, val_func) = mpi_util.broadcast_policy_scaler_val(
            policy, scaler, val_func)

    if mpi_util.rank == 0: logger.close()
    policy.close_sess()
    if mpi_util.rank == 0: val_func.close_sess()