Esempio n. 1
0
    def train(self,
              model,
              data,
              num_epochs=5,
              resume=False,
              dev_data=None,
              optimizer=None,
              teacher_forcing_ratio=0):
        """ Run training for a given model.

        Args:
            model (seq2seq.models): model to run training on, if `resume=True`, it would be
               overwritten by the model loaded from the latest checkpoint.
            data (seq2seq.dataset.dataset.Dataset): dataset object to train on
            num_epochs (int, optional): number of epochs to run (default 5)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
            dev_data (seq2seq.dataset.dataset.Dataset, optional): dev Dataset (default None)
            optimizer (seq2seq.optim.Optimizer, optional): optimizer for training
               (default: Optimizer(pytorch.optim.Adam, max_grad_norm=5))
            teacher_forcing_ratio (float, optional): teaching forcing ratio (default 0)
        Returns:
            model (seq2seq.models): trained model.
        """
        # If training is set to resume
        if resume:
            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                self.expt_dir)
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer

            # A walk around to set optimizing parameters properly
            resume_optim = self.optimizer.optimizer
            defaults = resume_optim.param_groups[0]
            defaults.pop('params', None)
            defaults.pop('initial_lr', None)
            self.optimizer.optimizer = resume_optim.__class__(
                model.parameters(), **defaults)

            start_epoch = resume_checkpoint.epoch
            step = resume_checkpoint.step
        else:
            start_epoch = 1
            step = 0
            if optimizer is None:
                optimizer = Optimizer(optim.Adam(model.parameters()),
                                      max_grad_norm=5)
            self.optimizer = optimizer

        self.logger.info("Optimizer: %s, Scheduler: %s" %
                         (self.optimizer.optimizer, self.optimizer.scheduler))

        self._train_epochs(data,
                           model,
                           num_epochs,
                           start_epoch,
                           step,
                           dev_data=dev_data,
                           teacher_forcing_ratio=teacher_forcing_ratio)
        return model
Esempio n. 2
0
def main():
    train_loader, test_loader = get_mnist_data('../%s' % opt.dataset,
                                               opt.batch_size)
    model = CapsuleNetwork(opt)
    if opt.cuda == True:
        model = model.cuda()

    if opt.is_train == True:
        if opt.resume == True:
            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                opt.save_folder)
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            optimizer = resume_checkpoint.optimizer
            start_epoch = resume_checkpoint.epoch + 1
        else:
            start_epoch = 0
            optimizer = Adam(model.parameters())

        for epoch in range(start_epoch, opt.n_epochs):
            train(epoch, model, train_loader, test_loader, optimizer)
            Checkpoint(model=model, optimizer=optimizer,
                       epoch=epoch).save(opt.save_folder)
    else:
        run_test(model, test_loader)
Esempio n. 3
0
def load_model():
    checkpoint_path = ""
    if not FLAGS.load_checkpoint is None:
        checkpoint_path = os.path.join(FLAGS.expt_dir,
                                       Checkpoint.CHECKPOINT_DIR_NAME,
                                       FLAGS.load_checkpoint)
    else:
        checkpoint_path = Checkpoint.get_latest_checkpoint(FLAGS.expt_dir)
    logging.info("loading checkpoint from {}".format(checkpoint_path))
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    # these are vocab classes with members stoi and itos
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab
    classifier = (seq2seq, input_vocab, output_vocab)

    return classifier
Esempio n. 4
0
    def train(self,encoder, decoder, n_epochs, train_data, dev_data,
                resume, optimizer, log_file):
        """
        ------------------------------------------------------------------------
        Args:
            encoder:                  Self explanatory.
            decoder:                  Self explanatory.
            n_epoch (int):            Number of epochs to train the model.
            train_data (Composition): Self explanatory.
            dev_data (Composition):   Self explanatory.
            resume (bool):            If true, load last checkpoint.
        ------------------------------------------------------------------------
        """
        if resume:
            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(self.exp_dir)
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            encoder        = resume_checkpoint.encoder
            decoder        = resume_checkpoint.decoder
            start_epoch    = resume_checkpoint.epoch
            step           = resume_checkpoint.step
            self.scheduler = resume_checkpoint.scheduler
            self.optimizer = resume_checkpoint.optimizer
            self.samp_rate = resume_checkpoint.samp_rate
            self.KL_rate   = resume_checkpoint.KL_rate
            self.free_bits = resume_checkpoint.free_bits
            self.vocab_size = decoder.vocab_size
        else:
            self.optimizer = optimizer
            if optimizer is None:
                params = list(encoder.parameters()) + list(decoder.parameters())
                self.optimizer = Adam(params, lr=1e-3)
            self.scheduler = LambdaLR(self.optimizer,decay)
            self.vocab_size = decoder.vocab_size

            start_epoch = 1
            step = 0

        self.train_epochs(encoder, decoder, start_epoch, step, train_data, dev_data,
                        start_epoch + n_epochs, log_file)
        return encoder,decoder
Esempio n. 5
0
def run_test(model, test_loader):
    latest_checkpoint_path = Checkpoint.get_latest_checkpoint(opt.save_folder)
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
    model = resume_checkpoint.model
    optimizer = resume_checkpoint.optimizer

    model.eval()
    test_loss = 0
    num_error = 0
    num_data = 0
    for batch_id, (data, target) in enumerate(test_loader):
        data = Variable(data)
        if opt.cuda == True:
            data = data.cuda()

        output, mask, recon = model(data)
        out_mag = torch.sqrt((output**2).sum(2))
        out_mag = F.softmax(out_mag, dim=1)
        max_val, max_idx = out_mag.max(dim=1)

        for idx in range(data.size(0)):
            print "(batch_index, sample_index, estimated, target) : ", batch_id, idx, max_idx[
                idx].data.cpu().numpy(), target[idx]
            if max_idx[idx].data.cpu().numpy() != target[idx]:
                num_error = num_error + 1
            num_data = num_data + 1
        if opt.vis == True:
            idx = random.randint(0, data.size(0) - 1)
            show_recon = recon[idx].data.cpu().numpy().reshape(28, 28)
            show_data = data[idx].data.cpu().numpy().reshape(28, 28)

            cv2.namedWindow("recon", cv2.WINDOW_NORMAL)
            cv2.imshow("recon", np.concatenate((show_data, show_recon),
                                               axis=1))
            cv2.waitKey(1)
    print 'test error : ', float(num_error) / float(num_data)
Esempio n. 6
0
def main(_):
    checkpoint = Checkpoint(FLAGS.checkpoint_dir)
    utils.exists_or_mkdir(FLAGS.sample_dir)
    utils.exists_or_mkdir(FLAGS.log_dir)
    summaryWriter = tensorboardX.SummaryWriter(log_dir = FLAGS.log_dir)#torch.utils.tensorboard.SummaryWriter(log_dir = FLAGS.log_dir)

    logger.info('[Params] lr:%f, size:%d, dataset:%s, av_gen:%d, n_disc:%d'%
                (FLAGS.learning_rate, FLAGS.output_size, FLAGS.dataset, int(FLAGS.use_averaged_gen), FLAGS.n_discriminator))

    #dataset
    z_shape = (FLAGS.z_dim,)
    image_size = (FLAGS.output_size, FLAGS.output_size)
    image_shape = (3,) + image_size

    ds = dataset.datasets.from_name(name=FLAGS.dataset, data_folder=FLAGS.data_folder,
                                    output_size=image_size)

    batch = batch_gen.BatchWithNoise(ds, batch_size=FLAGS.batch_size, z_shape=z_shape,num_workers=10)

    #initialize device
    device = utils.get_torch_device()

    #model
    nn_model = models.model_factory.create_model(FLAGS.model_name,
                                                 device=device,
                                                 image_shape=image_shape,
                                                 z_shape=z_shape,
                                                 use_av_gen=FLAGS.use_averaged_gen,
                                                 g_tanh=False)
    nn_model.register_checkpoint(checkpoint)

    loss = gan_loss.js_loss()
    #lambd = lambda_scheduler.Constant(0.1)
    lambd = lambda_scheduler.ThresholdAnnealing(1000., threshold=loss.lambda_switch_level, min_switch_step=FLAGS.lambda_switch_steps, verbose=True)
    checkpoint.register('lambda', lambd, True)

    trainer = Trainer(model=nn_model, batch=batch, loss=loss, lr=FLAGS.learning_rate,
                      reg='gp', lambd=lambd)
    trainer.sub_batches = FLAGS.batch_per_update
    trainer.register_checkpoint(checkpoint)

    it_start = checkpoint.load(FLAGS.checkpoint_it_to_load)

    trainer.update_lr()

    ##========================= LOAD CONTEXT ================================##
    context_path = os.path.join(FLAGS.checkpoint_dir, 'context.npz')
    
    sample_seed = None
    if os.path.exists(context_path):
        sample_seed = np.load(context_path)['z']
        if sample_seed.shape[0] != FLAGS.sample_size or sample_seed.shape[1] != FLAGS.z_dim:
            sample_seed = None
            logger.info('Invalid sample seed')
        else:
            logger.info('Sample seed loaded')
    
    if sample_seed is None:
        sample_seed = batch.sample_z(FLAGS.sample_size).data.numpy()
        np.savez(context_path, z = sample_seed)

    ##========================= TRAIN MODELS ================================##
    batches_per_epoch = 10000
    total_time = 0

    bLambdaSwitched = (it_start == 0)
    n_too_good_d = []

    number_of_iterations = FLAGS.epoch*batches_per_epoch
    for it in range(number_of_iterations):
        start_time = time.time()
        iter_counter = it + it_start

        # updates the discriminator
        #if iter_counter < 25 or iter_counter % 500 == 0:
        #    d_iter = 20
        #else:
        #    d_iter = 5
        if bLambdaSwitched:
            #if lambda was switched we want to keep discriminator optimal
            logger.info('[!] Warming up discriminator')
            d_iter = 25
        else:
            d_iter = FLAGS.n_discriminator
#
        errD, s, errG, b_too_good_D = trainer.update(d_iter, 1)

        summaryWriter.add_scalar('d_loss', errD, iter_counter)
        summaryWriter.add_scalar('slope', s, iter_counter)
        summaryWriter.add_scalar('g_loss', errG, iter_counter)
        summaryWriter.add_scalar('loss', errD + float(lambd) * s**2, iter_counter)
        summaryWriter.add_scalar('lambda', float(lambd), iter_counter)

        #updating lambda
        n_too_good_d.append(b_too_good_D)
        if len(n_too_good_d) > 20:
            del n_too_good_d[0]               
                
        bLambdaSwitched = lambd.update(errD)
        if not bLambdaSwitched and sum(n_too_good_d) > 10:
            bLambdaSwitched = lambd.switch()

        end_time = time.time()

        iter_time = end_time - start_time
        total_time += iter_time

        logger.info("[%2d/%2d] time: %4.4f, d_loss: %.8f, s: %.4f, g_loss: %.8f" % (iter_counter, it_start + number_of_iterations, iter_time, errD, s, errG))

        if np.mod(iter_counter, FLAGS.sample_step) == 0 and it > 0:
            n = int(np.sqrt(FLAGS.sample_size))

            img = trainer.sample(sample_seed)
            img = img.data.cpu()

            img_tb = utils.image_to_tensorboard(torchvision.utils.make_grid(img, n))
            summaryWriter.add_image('samples',img_tb, iter_counter)

            utils.save_images(img.data.cpu().numpy(), [n, n], './{}/train_{:02d}.png'.format(FLAGS.sample_dir, iter_counter))

        if np.mod(iter_counter, FLAGS.save_step) == 0 and it > 0:
            checkpoint.save(iter_counter)

    checkpoint.save(iter_counter)
Esempio n. 7
0
def main():
    colorama.init()
    print("Thank you for using canvas_grab!")
    print(
        f"You are using version {VERSION}. If you have any questions, please file an issue at {Fore.BLUE}https://github.com/skyzh/canvas_grab/issues{Style.RESET_ALL}"
    )
    print(
        f"You may review {Fore.GREEN}README(_zh-hans).md{Style.RESET_ALL} and {Fore.GREEN}LICENSE{Style.RESET_ALL} shipped with this release"
    )
    config.load_config()
    if config.ENABLE_VIDEO:
        print(
            f"Note: You've enabled video download. You should install the required tools yourself."
        )
        print(
            f"      This is an experimental functionality and takes up large amount of bandwidth. {Fore.RED}Use at your own risk.{Style.RESET_ALL}"
        )
    canvas = Canvas(config.API_URL, config.API_KEY)

    try:
        print(f'{Fore.BLUE}Logging in...{Style.RESET_ALL}')
        print(
            f"{Fore.GREEN}Logged in to {config.API_URL} as {canvas.get_current_user()}{Style.RESET_ALL}"
        )
    except canvasapi.exceptions.InvalidAccessToken:
        print(
            f"{Fore.RED}Invalid access token, please check your config.API_KEY in config file"
        )
        if is_windows():
            # for windows double-click user
            input()
        exit()

    try:
        global checkpoint
        checkpoint = Checkpoint(config.CHECKPOINT_FILE)
        checkpoint.load()
    except FileNotFoundError:
        print(f"{Fore.RED}No checkpoint found{Style.RESET_ALL}")

    courses = [
        course for course in canvas.get_courses() if hasattr(course, "name")
    ]
    if config.WHITELIST_CANVAS_ID:
        print(f"{Fore.BLUE}Whilelist mode enabled{Style.RESET_ALL}")
        courses = [
            course for course in courses
            if course.id in config.WHITELIST_CANVAS_ID
        ]
    try:
        for course in courses:
            if course.start_at:
                delta = -(datetime.strptime(
                    course.start_at, r'%Y-%m-%dT%H:%M:%S%z').replace(
                        tzinfo=None) - datetime.now()).days
            else:
                delta = 0
            if course.id in config.IGNORED_CANVAS_ID:
                print(
                    f"{Fore.CYAN}Explicitly Ignored Course: {course.course_code}{Style.RESET_ALL}"
                )
            elif config.RETAIN_COURSE_DAYS != 0 and delta > config.RETAIN_COURSE_DAYS:
                print(
                    f"{Fore.CYAN}Outdated Course: {course.course_code}{Style.RESET_ALL}"
                )
            else:
                try:
                    process_course(course)
                except KeyboardInterrupt:
                    raise
                except canvasapi.exceptions.Unauthorized as e:
                    print(
                        f"{Fore.RED}An error occoured when processing this course (unauthorized): {e}{Style.RESET_ALL}"
                    )
                except canvasapi.exceptions.ResourceDoesNotExist as e:
                    print(
                        f"{Fore.RED}An error occoured when processing this course (resourse not exist): {e}{Style.RESET_ALL}"
                    )
        if config.SCAN_STALE_FILE:
            scan_stale_files(courses)
    except KeyboardInterrupt:
        print(
            f"{Fore.RED}Terminated due to keyboard interrupt.{Style.RESET_ALL}"
        )

    checkpoint.dump()

    if new_files_list:
        print(
            f"{Fore.GREEN}{len(new_files_list)} new or updated files:{Style.RESET_ALL}"
        )
        for f in new_files_list:
            print(f"    {f}")

    if updated_files_list:
        print(
            f"{Fore.GREEN}{len(updated_files_list)} files have a more recent version on Canvas:{Style.RESET_ALL}"
        )
        for f in updated_files_list:
            print(f"    {f}")

    if failure_file_list:
        print(
            f"{Fore.YELLOW}{len(failure_file_list)} files are not downloaded:{Style.RESET_ALL}"
        )
        for f in failure_file_list:
            print(f"    {f}")

    if not new_files_list and not updated_files_list:
        print("All files up to date")

    if config.ENABLE_VIDEO:
        print(
            f"{Fore.GREEN}{len(ffmpeg_commands)} videos resolved{Style.RESET_ALL}"
        )
        print(
            f"Please run the automatically-generated script {Fore.BLUE}download_video.(sh/ps1){Style.RESET_ALL} to download all videos."
        )
        with open("download_video.sh", 'w') as file:
            file.write("\n".join(ffmpeg_commands))
        with open("download_video.ps1", 'w') as file:
            file.write("\n".join(ffmpeg_commands))

    if config.ALLOW_VERSION_CHECK:
        check_latest_version()

    print(f"{Fore.GREEN}Done.{Style.RESET_ALL}")

    if is_windows():
        # for windows double-click user
        input()
Esempio n. 8
0
class Processor(object):
    """Class for processing dump files from postgresql."""
    MILLION = 1024 * 1024

    def __init__(self):
        self.bytes_count = 0
        self.start_time = 0.0
        self.out_files = {}
        self.checkpoint = Checkpoint(config.VALUE_SET)
        self.init_time()

    def init_time(self):
        """Init time."""

        self.start_time = time.time()

    def add_bytes_count(self, count: int):
        """Add up bytes count."""

        self.bytes_count += count

    def split_if_necessary(self) -> None:
        """Check size of each storage file, called each batch
        close and open a new one to store if size exceeds max_split_size
        """

        # Convert MB to Byte
        for v in config.VALUE_SET:
            file_size = self.out_files[v].tell()
            if file_size >= config.FILE_SPLIT_SIZE:
                self.checkpoint.update_file_index(v)
                new_file = open(
                    self.checkpoint.get_file_name(v, config.OUT_DIR), 'a')
                self.add_table_head(new_file)
                self.out_files[v].close()
                self.out_files[v] = new_file
                logging.info('File size grows over {:.2f} MB, '
                             'store in new file `{}`...'.format(
                                 config.FILE_SPLIT_SIZE / self.MILLION,
                                 new_file.name))

    def process_line(self, line: str) -> None:
        """Process each line, does NOT verify the validness of
         lines (print them and ignores invalid ones without terminating)
         check if this line is recorded, and record the line.

        :param line: str, line to process ('\n' not included)
        """

        attributes = line.split('\t')
        try:
            # Check value in values to group by
            value = attributes[config.GROUP_BY_ATTR_INDEX]
            if value not in config.VALUE_SET:
                return
            row_count = int(attributes[config.INDEX_ROW_COUNT])
            # Check if line is already parsed and recorded
            if row_count <= self.checkpoint.row_count[value]:
                return
            # Keep attributes we're interested in
            data = [attributes[i] for i in config.RECORD_ATTR_INDEX_LIST]
            # Write to related file
            self.out_files[value].write('\t'.join(data))
            self.out_files[value].write('\n')
            # Update index
            self.checkpoint.row_count[value] = row_count
        except Exception as e:
            logging.warning(e)
            logging.warning("Invalid row: {}".format(attributes))

    @staticmethod
    def verify_file_schema(fp: TextIO) -> bool:
        """Verify the schema of data contained in a file.
        The dump files of postgresql should contain exactly one table each.
        """

        line = fp.readline()
        # Remember to return head of file
        fp.seek(0)
        if isinstance(line, bytes):
            line = str(line, encoding='utf-8')
        # Remove empty cells
        attributes = list(filter(None, line.split('\t')))
        # Check attribute count
        if len(attributes) != config.ATTR_COUNT:
            return False
        # Check validness of index attribute
        try:
            _ = int(attributes[config.INDEX_ROW_COUNT])
        except ValueError:
            return False
        return True

    @staticmethod
    def add_table_head(f: TextIO) -> None:
        """Add headings of table."""

        f.write('\t'.join(config.RECORD_ATTR_LIST))
        f.write('\n')

    def process_file(self, filename: str, is_old_file: bool = False) -> None:
        """Process a text file (ends with '.dat') or gzip file (ends with .gz).

        :param filename: str, name of file to process
        :param is_old_file: bool, whether this file has been processed before
                if it has been, we should skip batches already read.
        :return: int, 0 if this file is ignored or 1 if processed
        """

        # Check file type
        file_type = filename[filename.rfind('.'):]
        if file_type not in config.OPEN_FUNCS:
            logging.info('Fail to process `{}`: unsupported file type.'.format(
                filename))
            return
        # Open file according to its type
        fp = config.OPEN_FUNCS[file_type](filename)

        # Old file: needs to recover to the starting point
        if is_old_file and self.checkpoint.offset > 0:
            fp.seek(self.checkpoint.offset)
            logging.info('Time for seeking file offset: {:.2f} s'.format(
                time.time() - self.start_time))
            # This should be the start of processing
            self.init_time()
        else:
            # New files:
            # needs to verify whether this file contains the table we want
            if not self.verify_file_schema(fp):
                logging.info(
                    'Schema of `{}` doesn\'t fit; skip.'.format(filename))
                fp.close()
                return
            # Record current file
            self.checkpoint.current_file = filename

        logging.info('Start processing `{}`...'.format(filename))
        while True:
            self.checkpoint.offset = fp.tell()
            batch = fp.read(config.BATCH_SIZE)
            # EOF
            line = fp.readline()
            if line:
                batch += line
            if not batch:
                break
            # Convert from bytes to str if needed
            if isinstance(batch, bytes):
                batch = str(batch, 'utf-8')
            # Parse batch
            for line in batch.splitlines():
                self.process_line(line)
            self.add_bytes_count(len(batch))
            # Split large files and change storage to new files
            if config.SPLIT:
                self.split_if_necessary()
        fp.close()

    def process_dir(self, dirname: str) -> None:
        """Recursively process files in given directory.

        :param dirname: str, directory of files to precess
        :return: number of files processed under this directory
        """

        file_list = sorted(os.listdir(dirname))
        for name in file_list:
            # Full name of file
            name = os.path.join(dirname, name)
            # Check if this file is already processed
            if name in self.checkpoint.processed_files:
                continue
            if os.path.isfile(name):
                self.process_file(name)
                self.checkpoint.processed_files.add(name)
            elif os.path.isdir(name) and config.RECURSIVE:
                self.process_dir(name)

    def before_process(self) -> None:
        """Create directory if needed, and load records."""
        if not os.path.isdir(config.OUT_DIR):
            os.mkdir(config.OUT_DIR)
        # Load checkpoints from file
        if os.path.exists(config.RECORD_FILE):
            self.checkpoint.load(config.RECORD_FILE)
            logging.info('Checkpoint loaded from `{}`.'.format(
                config.RECORD_FILE))
        # Open files to write
        for v in config.VALUE_SET:
            f = open(self.checkpoint.get_file_name(v, config.OUT_DIR), 'a')
            # If it's a new file, add headings
            if f.tell() == 0:
                self.add_table_head(f)
            self.out_files[v] = f

    def process(self, dir_list: list) -> None:
        """Process list of directories / files"""
        try:
            # Prepare for processing
            self.before_process()
            # Recover from file processed last time
            if os.path.exists(self.checkpoint.current_file):
                logging.info('Reloading `{}` from last checkpoints...'.format(
                    self.checkpoint.current_file))
                self.process_file(self.checkpoint.current_file,
                                  is_old_file=True)
            if len(dir_list) == 0:
                logging.error(
                    'Please specify at least one directory or file to process.'
                )
            # Process each directory / file
            for dir_name in dir_list:
                if os.path.isdir(dir_name):
                    self.process_dir(dir_name)
                elif os.path.isfile(dir_name):
                    self.process_file(dir_name)
                else:
                    logging.warning(
                        '`{}` is not a directory / file; skip.'.format(
                            dir_name))
        # Ctrl + C manually stopped
        except KeyboardInterrupt:
            self.after_process(is_interrupted=True)
        # Other unknown exceptions...
        except Exception as e:
            logging.warning(e)
            self.after_process(is_interrupted=True)
        else:
            self.after_process(is_interrupted=False)

    def after_process(self, is_interrupted: bool) -> None:
        """Deal with opened files, useless files and save records."""
        # Close files, and remove files with zero contents
        head_len = len('\t'.join(config.RECORD_ATTR_LIST)) + 1
        for file in self.out_files.values():
            file.close()
            # Not strictly compare size
            if os.path.getsize(file.name) <= head_len + 100:
                os.remove(file.name)
        # Handle interrupts
        if is_interrupted:
            self.checkpoint.save(config.RECORD_FILE)
            logging.info('Checkpoint saved in `{}`.'.format(
                config.RECORD_FILE))
        # Normal ending, remove record file
        elif os.path.exists(config.RECORD_FILE):
            os.remove(config.RECORD_FILE)
        # Analyse speed
        total_mb = self.bytes_count / self.MILLION
        total_time = time.time() - self.start_time
        avg_speed = total_mb / total_time
        logging.info(
            'Processed {:.2f} MB in {:.2f} s, {:.2f} MB/s on average.'.format(
                total_mb, total_time, avg_speed))
        exit(int(is_interrupted))
Esempio n. 9
0
    def _train_gan(self):
        """
        TODO: Add in autoencoder to perform dimensionality reduction on data
        TODO: Not working yet - trying to work out good autoencoder model first
        :return:
        """

        criterion = nn.BCELoss()

        discriminator_optimiser = optim.Adam(self.discriminator.parameters(),
                                             lr=0.003,
                                             betas=(0.5, 0.999))
        discriminator_scheduler = optim.lr_scheduler.LambdaLR(
            discriminator_optimiser, lambda epoch: 0.97**epoch)
        discriminator_checkpoint = Checkpoint("discriminator")
        discriminator_epoch = 0
        if discriminator_checkpoint.load():
            discriminator_epoch = self.load_state(discriminator_checkpoint,
                                                  self.discriminator,
                                                  discriminator_optimiser)
        else:
            LOG.info('Discriminator checkpoint not found')

        generator_optimiser = optim.Adam(self.generator.parameters(),
                                         lr=0.003,
                                         betas=(0.5, 0.999))
        generator_scheduler = optim.lr_scheduler.LambdaLR(
            generator_optimiser, lambda epoch: 0.97**epoch)
        generator_checkpoint = Checkpoint("generator")
        generator_epoch = 0
        if generator_checkpoint.load():
            generator_epoch = self.load_state(generator_checkpoint,
                                              self.generator,
                                              generator_optimiser)
        else:
            LOG.info('Generator checkpoint not found')

        if discriminator_epoch is None or generator_epoch is None:
            epoch = 0
            LOG.info(
                "Discriminator or generator failed to load, training from start"
            )
        else:
            epoch = min(generator_epoch, discriminator_epoch)
            LOG.info("Generator loaded at epoch {0}".format(generator_epoch))
            LOG.info("Discriminator loaded at epoch {0}".format(
                discriminator_epoch))
            LOG.info("Training from lowest epoch {0}".format(epoch))

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "gan",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            real_labels = None  # all 1s
            fake_labels = None  # all 0s
            epochs_complete = 0
            while epoch < self.config.MAX_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return  # Requeue needed and training not complete

                for step, (data, noise1,
                           noise2) in enumerate(self.data_loader):
                    batch_size = data.size(0)
                    if real_labels is None or real_labels.size(
                            0) != batch_size:
                        real_labels = self.generate_labels(batch_size, [1.0])
                    if fake_labels is None or fake_labels.size(
                            0) != batch_size:
                        fake_labels = self.generate_labels(batch_size, [0.0])

                    if self.config.USE_CUDA:
                        data = data.cuda()
                        noise1 = noise1.cuda()
                        noise2 = noise2.cuda()

                    # ============= Train the discriminator =============
                    # Pass real noise through first - ideally the discriminator will return 1 #[1, 0]
                    d_output_real = self.discriminator(data)
                    # Pass generated noise through - ideally the discriminator will return 0 #[0, 1]
                    d_output_fake1 = self.discriminator(self.generator(noise1))

                    # Determine the loss of the discriminator by adding up the real and fake loss and backpropagate
                    d_loss_real = criterion(
                        d_output_real, real_labels
                    )  # How good the discriminator is on real input
                    d_loss_fake = criterion(
                        d_output_fake1, fake_labels
                    )  # How good the discriminator is on fake input
                    d_loss = d_loss_real + d_loss_fake
                    self.discriminator.zero_grad()
                    d_loss.backward()
                    discriminator_optimiser.step()

                    # =============== Train the generator ===============
                    # Pass in fake noise to the generator and get it to generate "real" noise
                    # Judge how good this noise is with the discriminator
                    d_output_fake2 = self.discriminator(self.generator(noise2))

                    # Determine the loss of the generator using the discriminator and backpropagate
                    g_loss = criterion(d_output_fake2, real_labels)
                    self.discriminator.zero_grad()
                    self.generator.zero_grad()
                    g_loss.backward()
                    generator_optimiser.step()

                    vis.step(d_loss_real.item(), d_loss_fake.item(),
                             g_loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], d_loss_real: {4:.4f}, d_loss_fake: {5:.4f}, g_loss: {6:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1, self.config.MAX_EPOCHS, step + 1,
                                   len(self.data_loader), d_loss_real,
                                   d_loss_fake, g_loss))

                epoch += 1
                epochs_complete += 1

                discriminator_checkpoint.set(
                    self.discriminator.state_dict(),
                    discriminator_optimiser.state_dict(), epoch).save()
                generator_checkpoint.set(self.generator.state_dict(),
                                         generator_optimiser.state_dict(),
                                         epoch).save()
                vis.plot_training(epoch)

                data, noise1, _ = iter(self.data_loader).__next__()
                if self.config.USE_CUDA:
                    data = data.cuda()
                    noise1 = noise1.cuda()
                vis.test(epoch, self.data_loader.get_input_size_first(),
                         self.discriminator, self.generator, noise1, data)

                generator_scheduler.step(epoch)
                discriminator_scheduler.step(epoch)

                LOG.info("Learning rates: d {0} g {1}".format(
                    discriminator_optimiser.param_groups[0]["lr"],
                    generator_optimiser.param_groups[0]["lr"]))

        LOG.info("GAN Training complete")
Esempio n. 10
0
    def _train_autoencoder(self):
        """
        Main training loop for the autencoder.
        This function will return False if:
        - Loading the autoencoder succeeded, but the NN model did not load the state dicts correctly.
        - The script needs to be re-queued because the NN has been trained for REQUEUE_EPOCHS
        :return: True if training was completed, False if training needs to continue.
        :rtype bool
        """

        criterion = nn.SmoothL1Loss()

        optimiser = optim.Adam(self.generator.parameters(),
                               lr=0.00003,
                               betas=(0.5, 0.999))
        checkpoint = Checkpoint("autoencoder")
        epoch = 0
        if checkpoint.load():
            epoch = self.load_state(checkpoint, self.autoencoder, optimiser)
            if epoch is not None and epoch >= self.config.MAX_AUTOENCODER_EPOCHS:
                LOG.info("Autoencoder already trained")
                return True
            else:
                LOG.info(
                    "Autoencoder training beginning from epoch {0}".format(
                        epoch))
        else:
            LOG.info('Autoencoder checkpoint not found. Training from start')

        # Train autoencoder
        self._autoencoder.set_mode(Autoencoder.Mode.AUTOENCODER)

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "autoencoder",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            epochs_complete = 0
            while epoch < self.config.MAX_AUTOENCODER_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return False  # Requeue needed and training not complete

                for step, (data, _, _) in enumerate(self.data_loader):
                    if self.config.USE_CUDA:
                        data = data.cuda()

                    if self.config.ADD_DROPOUT:
                        # Drop out parts of the input, but compute loss on the full input.
                        out = self.autoencoder(nn.functional.dropout(
                            data, 0.5))
                    else:
                        out = self.autoencoder(data)

                    loss = criterion(out.cpu(), data.cpu())
                    self.autoencoder.zero_grad()
                    loss.backward()
                    optimiser.step()

                    vis.step_autoencoder(loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], loss: {4:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1,
                                   self.config.MAX_AUTOENCODER_EPOCHS, step,
                                   len(self.data_loader), loss))

                epoch += 1
                epochs_complete += 1

                checkpoint.set(self.autoencoder.state_dict(),
                               optimiser.state_dict(), epoch).save()

                LOG.info("Plotting autoencoder progress")
                vis.plot_training(epoch)
                data, _, _ = iter(self.data_loader).__next__()
                vis.test_autoencoder(epoch, self.autoencoder, data.cuda())

        LOG.info("Autoencoder training complete")
        return True  # Training complete
Esempio n. 11
0
    tgt = TargetField()
    max_len = 150

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len * 3

    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)

    logging.info("loading checkpoint from {}".format(
        os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                     opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    src.vocab = checkpoint.input_vocab
    tgt.vocab = checkpoint.output_vocab

    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()
    evaluator = Evaluator(loss=loss, batch_size=32)
    accuracy = evaluator.test(seq2seq, dev)
    print(accuracy)
Esempio n. 12
0
def main():
    # torch.manual_seed(233)
    set_seed(233)
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s [INFO] %(message)s')

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_path',
        type=str,
        dest='data_path',
        default='/home/ml/ydong26/data/EditNTS_data/editnet_data/%s/' %
        dataset,
        help='Path to train vocab_data')
    parser.add_argument('--store_dir',
                        action='store',
                        dest='store_dir',
                        default='/home/ml/ydong26/tmp_store/editNTS_%s' %
                        dataset,
                        help='Path to exp storage directory.')
    parser.add_argument('--vocab_path',
                        type=str,
                        dest='vocab_path',
                        default='../vocab_data/',
                        help='Path contains vocab, embedding, postag_set')
    parser.add_argument(
        '--load_model',
        type=str,
        dest='load_model',
        default=None,
        help='Path for loading pre-trained model for further training')

    parser.add_argument('--vocab_size',
                        dest='vocab_size',
                        default=30000,
                        type=int)
    parser.add_argument('--batch_size',
                        dest='batch_size',
                        default=32,
                        type=int)
    parser.add_argument('--max_seq_len', dest='max_seq_len', default=100)

    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--hidden', type=int, default=200)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--device', type=int, default=1, help='select GPU')
    parser.add_argument('--test',
                        action='store_true',
                        default=False,
                        dest='test_enabled')
    parser.add_argument('--run_eval',
                        action='store_true',
                        default=False,
                        dest='run_eval')
    parser.add_argument('--run_training',
                        action='store_true',
                        default=False,
                        dest='run_training')

    #train_file = '/media/vocab_data/yue/TS/editnet_data/%s/train.df.filtered.pos'%dataset
    # test='/media/vocab_data/yue/TS/editnet_data/%s/test.df.pos' % args.dataset
    args = parser.parse_args()
    print(args)
    torch.cuda.set_device(args.device)

    # load vocab-related files and init vocab
    print('*' * 10)
    vocab = data.Vocab()
    vocab.add_vocab_from_file(args.vocab_path + 'vocab.txt', args.vocab_size)
    vocab.add_embedding(gloveFile=args.vocab_path + 'glove.6B.100d.txt')
    pos_vocab = data.POSvocab(args.vocab_path)  #load pos-tags embeddings
    print('*' * 10)

    print(args)
    print("generating config")
    hyperparams = collections.namedtuple(
        'hps',  #hyper=parameters
        [
            'vocab_size', 'embedding_dim', 'word_hidden_units',
            'sent_hidden_units', 'pretrained_embedding', 'word2id', 'id2word',
            'pos_vocab_size', 'pos_embedding_dim'
        ])
    hps = hyperparams(vocab_size=vocab.count,
                      embedding_dim=100,
                      word_hidden_units=args.hidden,
                      sent_hidden_units=args.hidden,
                      pretrained_embedding=vocab.embedding,
                      word2id=vocab.w2i,
                      id2word=vocab.i2w,
                      pos_vocab_size=pos_vocab.count,
                      pos_embedding_dim=30)

    print('init editNTS model')
    edit_net = EditNTS(hps, n_layers=1)
    edit_net.cuda()

    if args.load_model is not None:
        print("load edit_net for further training")
        ckpt_path = args.load_model
        ckpt = Checkpoint.load(ckpt_path)
        print("Epoch: {} | Step: {}".format(ckpt.epoch, ckpt.step))
        edit_net = ckpt.model
        edit_net.cuda()
        edit_net.train()

    if args.run_eval:
        print("Running Evaluation..")
        eval_standalone(edit_net, args, vocab, ckpt)
    elif args.run_training:
        print("Running Training..")
        training(edit_net, args.epochs, args, vocab, test=args.test_enabled)
    else:
        print("ERROR: No running mode selected")
Esempio n. 13
0
def main(args):

    configure(os.path.join(args['exp_dir'], 'log_dir'))

    transform = transforms.Compose([
        transforms.RandomCrop(args['crop_size']),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    data_loader = get_loader({
        'data_dir': args['data_dir'],
        'exp_dir': args['exp_dir'],
        'raw_data_dir': args['raw_data_dir'],
        'batch_size': args['batch_size'],
        'transform': transform,
        'num_workers': args['num_workers'],
        'shuffle': args['shuffle'],
        'mode': 'train'
    })

    #    valid_data_loader=get_loader({'data_dir' : args['data_dir'],
    #                             'raw_data_dir' : args['raw_data_dir'],
    #                             'batch_size' : int(args['batch_size']/4),
    #                             'transform' : transform,
    #                             'num_workers' : args['num_workers'],
    #                             'shuffle' : args['shuffle'],
    #                             'mode':'validate'})

    args['vocab_size'] = len(Vocabulary.load_vocab(args['exp_dir']))

    encoder = EncoderCNN(args).train()
    decoder = DecoderRNN(args).train()

    if args['pretrained']:
        checkpoint_path = Checkpoint.get_latest_checkpoint(args['exp_dir'])
        checkpoint = Checkpoint.load(checkpoint_path)
        encoder.load_state_dict(checkpoint.encoder)
        decoder.load_state_dict(checkpoint.decoder)
        step = checkpoint.step
        epoch = checkpoint.epoch
        omit = True

    else:
        step = 0
        epoch = 0
        omit = False

    encoder.to(device)
    decoder.to(device)

    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    #    params=list(decoder.parameters()) + list(encoder.parameters())
    optimizer = torch.optim.Adam(params, lr=args['lr'])
    scheduler = StepLR(optimizer, step_size=40, gamma=0.1)
    #    optimizer=YFOptimizer(params)

    total_step = len(data_loader)
    min_valid_loss = float('inf')

    for epoch in range(epoch, args['num_epochs']):
        scheduler.step()
        for idx, (images, captions, leng) in enumerate(data_loader):

            if omit:
                if idx < (step - total_step * epoch):
                    logger.info(
                        'idx:{},step:{}, epoch:{}, total_step:{}, diss:{}'.
                        format(idx, step, epoch, total_step,
                               step - total_step * epoch))
                    continue
                else:
                    omit = False

            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, leng, batch_first=True)[0]

            features = encoder(images)
            outputs = decoder(features, captions, leng)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5)
            optimizer.step()

            log_value('loss', loss.item(), step)
            step += 1

            if step % args['log_step'] == 0:
                logger.info(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args['num_epochs'], idx, total_step,
                            loss.item(), np.exp(loss.item())))

            if step % args['valid_step'] == 0:
                #                valid_loss=validate(encoder.eval(),decoder,criterion,valid_data_loader)
                #                if valid_loss<min_valid_loss:
                #                    min_valid_loss=valid_loss
                Checkpoint(encoder, decoder, optimizer, epoch,
                           step).save(args['exp_dir'])
Esempio n. 14
0
def interpolation(data_path,
                  checkpoint_path,
                  temp,
                  seconds,
                  name,
                  song_id1=None,
                  song_id2=None,
                  n_steps=20,
                  TR=True):
    #Load decoder.
    cp = Checkpoint.load(checkpoint_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = cp.encoder.to(device).eval()
    decoder = cp.decoder.to(device).eval()
    postproc = postprocessing_with_TR if TR else postprocessing
    rate = 24.0
    nsamps = 44100 * seconds

    #Load data and create points. If no path, make new ones.
    comps = CompactCompositions(data_path)
    n_comps = len(comps)
    if song_id1 is None:
        song_id1 = np.random.randint(0, n_comps)
        print(song_id1)
    if song_id2 is None:
        song_id2 = np.random.randint(0, n_comps)

    n_comps = len(comps)
    begin = comps[song_id1]

    end = comps[song_id2]

    begin_score = postproc([begin.numpy()], 32)
    end_score = postproc([end.numpy()], 32)
    #Store for comparison.

    b_nsamps = 44100 * begin_score.shape[0] // 24
    e_nsamps = 44100 * end_score.shape[0] // 24
    with open('begin' + '.pickle', 'wb') as f1:
        pickle.dump((rate, b_nsamps, begin_score), f1, protocol=2)
    with open('end' + '.pickle', 'wb') as f2:
        pickle.dump((rate, e_nsamps, end_score), f2, protocol=2)

    #with open('test_begin' + '.pickle', 'wb') as f3:
    ##    pickle.dump(decoder(lat_1))

    #Make code.

    A = torch.cat([begin.unsqueeze(0), begin.unsqueeze(0)])
    B = torch.cat([end.unsqueeze(0), end.unsqueeze(0)], 0)
    lat_1, _ = encoder(A)
    recon_1 = decoder(lat_1, temp=None, x=A, teacher_forcing=True)
    #recon_1 = torch.argmax(recon_1, 2)
    lat_2, _ = encoder(B)
    recon_2 = decoder(lat_2, temp=None, x=B, teacher_forcing=True)
    #recon_2 = torch.argmax(recon_2,2)

    steps = [
        spherical_interpolation(lat_1, lat_2,
                                float(t) / float(n_steps))
        for t in range(1, n_steps)
    ]
    steps = [
        decoder(step, temp, x=None, teacher_forcing=False) for step in steps
    ]
    steps = [step[0].unsqueeze(0) for step in steps]
    steps = torch.cat(steps)
    steps = torch.cat(
        (recon_1[0].unsqueeze(0), steps, recon_2[0].unsqueeze(0)), 0)
    steps = postproc(steps)
    time = 44100 * steps.shape[0] // 24
    print(steps.shape)

    with open(name + '.pickle', 'wb') as f:
        pickle.dump((rate, time, steps), f, protocol=2)
Esempio n. 15
0
utils.exists_or_mkdir(FLAGS.sample_dir)

z_shape = (FLAGS.z_dim, )
image_size = (FLAGS.output_size, FLAGS.output_size)
image_shape = (3, ) + image_size

device = utils.get_torch_device()
nn_model = models.model_factory.create_model(FLAGS.model_name,
                                             device=device,
                                             image_shape=image_shape,
                                             z_shape=z_shape,
                                             use_av_gen=FLAGS.use_averaged_gen)
nn_model.register_checkpoint(checkpoint)

if not checkpoint.load(FLAGS.checkpoint_it_to_load):
    raise RuntimeError('Cannot load checkpoint')

now = datetime.datetime.now()
for i in range(FLAGS.n_samples):
    z = np.random.randn(FLAGS.sample_size, FLAGS.z_dim).astype(np.float32)
    z = torch.tensor(z, device=device)

    with torch.no_grad():
        if hasattr(nn_model, 'av_g_model'):
            nn_model.av_g_model.eval()
            gen_samples = nn_model.av_g_model(z)
        else:
            nn_model.g_model.eval()
            gen_samples = nn_model.g_model(z)
            nn_model.g_model.train()