예제 #1
0
파일: action.py 프로젝트: jihun-hong/xlsum
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if pt != '':
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if k in model_flags:
            setattr(args, k, opt[k])
    print(args)

    config = XLNetConfig.from_pretrained(args.config_path)
    model = Summarizer(args,
                       device,
                       load_pretrained_bert=False,
                       bert_config=config)
    model.load_cp(checkpoint)
    model.eval()

    valid_iter = Dataloader(args,
                            load_dataset(args, 'valid', shuffle=False),
                            args.batch_size,
                            device,
                            shuffle=False,
                            is_test=False)
    trainer = build_trainer(args, device_id, model, None)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
예제 #2
0
파일: action.py 프로젝트: jihun-hong/xlsum
def multi_main(args):
    """ Spawns 1 process per GPU """
    init_logger()

    nb_gpu = args.world_size
    mp = torch.multiprocessing.get_context('spawn')

    # Create a thread to listen for errors in the child processes.
    error_queue = mp.SimpleQueue()
    error_handler = ErrorHandler(error_queue)

    # Train with multiprocessing.
    procs = []
    for i in range(nb_gpu):
        device_id = i
        procs.append(
            mp.Process(target=run,
                       args=(
                           args,
                           device_id,
                           error_queue,
                       ),
                       daemon=True))
        procs[i].start()
        logger.info(" Starting process pid: %d  " % procs[i].pid)
        error_handler.add_child(procs[i].pid)
    for p in procs:
        p.join()
예제 #3
0
파일: trainer.py 프로젝트: jihun-hong/xlsum
def build_trainer(args, device_id, model, optim):
    """
    Configures GPU device, summary writer, report manager
    :return trainer: trainer object created with above arguments
    """
    grad_accum_count = args.accum_count
    n_gpu = args.world_size

    # Configure GPU device
    if device_id >= 0:
        gpu_rank = int(args.gpu_ranks[device_id])
    else:
        gpu_rank = 0
        n_gpu = 0
    print('gpu_rank %d' % gpu_rank)

    # Configure summary writer
    tensorboard_log_dir = args.model_path
    writer = SummaryWriter(tensorboard_log_dir, comment="Unmt")

    # Configure report manager
    report_manager = ReportMgr(args.report_every,
                               start_time=-1,
                               tensorboard_writer=writer)

    # Create trainer object
    trainer = Trainer(args, model, optim, grad_accum_count, n_gpu, gpu_rank,
                      report_manager)

    # print number of params
    if model:
        n_params = _tally_parameters(model)
        logger.info('* number of parameters: %d' % n_params)

    return trainer
예제 #4
0
파일: trainer.py 프로젝트: jihun-hong/xlsum
    def train(self, train_iter_fct, train_steps):
        logger.info('Start training...')

        step = self.optim._step + 1
        true_batchs = []
        accum = 0
        normalization = 0

        n_gpu = self.n_gpu
        gpu_rank = self.gpu_rank
        grad_accum_count = self.grad_accum_count

        # Iterable of training batches.
        train_iter = train_iter_fct()

        # Configure statistics report.
        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        # Training loop.
        while step <= train_steps:
            reduce_counter = 0
            for i, batch in enumerate(train_iter):
                if n_gpu == 0 or i % n_gpu == gpu_rank:
                    true_batchs.append(batch)
                    normalization += batch.batch_size
                    accum += 1
                    if accum == grad_accum_count:
                        reduce_counter += 1
                        if n_gpu > 1:
                            normalization = sum(
                                distributed.all_gather_list(normalization))

                        # Gradient accumulation for model.
                        self._gradient_accumulation(true_batchs, normalization,
                                                    total_stats, report_stats)

                        # Report statistics for training.
                        report_stats = self._maybe_report_training(
                            step, train_steps, self.optim.learning_rate,
                            report_stats)

                        # Initialize variables
                        true_batchs = []
                        accum = 0
                        normalization = 0
                        if step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0:
                            self._save(step)

                        step += 1
                        if step > train_steps:
                            break
            train_iter = train_iter_fct()

        return total_stats
예제 #5
0
파일: action.py 프로젝트: jihun-hong/xlsum
def wait_and_validate(args, device_id):

    timestep = 0
    if args.test_all:
        cp_files = sorted(
            glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
        cp_files.sort(key=os.path.getmtime)
        xent_lst = []
        for i, cp in enumerate(cp_files):
            step = int(cp.split('.')[-2].split('_')[-1])
            xent = validate(args, device_id, cp, step)
            xent_lst.append((xent, cp))
            max_step = xent_lst.index(min(xent_lst))
            if (i - max_step > 10):
                break
        xent_lst = sorted(xent_lst, key=lambda x: x[0])[:3]
        logger.info('PPL %s' % str(xent_lst))
        for xent, cp in xent_lst:
            step = int(cp.split('.')[-2].split('_')[-1])
            test(args, device_id, cp, step)
    else:
        while (True):
            cp_files = sorted(
                glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
            cp_files.sort(key=os.path.getmtime)
            if (cp_files):
                cp = cp_files[-1]
                time_of_cp = os.path.getmtime(cp)
                if (not os.path.getsize(cp) > 0):
                    time.sleep(60)
                    continue
                if (time_of_cp > timestep):
                    timestep = time_of_cp
                    step = int(cp.split('.')[-2].split('_')[-1])
                    validate(args, device_id, cp, step)
                    test(args, device_id, cp, step)

            cp_files = sorted(
                glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
            cp_files.sort(key=os.path.getmtime)
            if (cp_files):
                cp = cp_files[-1]
                time_of_cp = os.path.getmtime(cp)
                if (time_of_cp > timestep):
                    continue
            else:
                time.sleep(300)
예제 #6
0
    def output(self, step, num_steps, learning_rate, start):
        """Write out statistics to stdout.

        Args:
           step (int): current step
           n_batch (int): total batches
           start (int): start time of step.
        """
        t = self.elapsed_time()
        step_fmt = "%2d" % step
        if num_steps > 0:
            step_fmt = "%s/%5d" % (step_fmt, num_steps)
        logger.info(
            ("Step %s; xent: %4.2f; " + "lr: %7.7f; %3.0f docs/s; %6.0f sec") %
            (step_fmt, self.xent(), learning_rate, self.n_docs /
             (t + 1e-5), time.time() - start))
        sys.stdout.flush()
예제 #7
0
파일: action.py 프로젝트: jihun-hong/xlsum
def train(args, device_id):
    # Start logger.
    init_logger(args.log_file)

    # Configure training device.
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)

    # Configure manual seed.
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # Set CUDA device.
    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    # Configure manual seed.
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # Dataloader used for training.
    def train_iter_fct():
        return Dataloader(args,
                          load_dataset(args, 'train', shuffle=True),
                          args.batch_size,
                          device,
                          shuffle=True,
                          is_test=False)

    # Build the model.
    model = Summarizer(args, device, load_pretrained=True)

    # Configure the checkpoint.
    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)

        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if k in model_flags:
                setattr(args, k, opt[k])

        model.load_cp(checkpoint)
        optim = builder.build_optim(args, model, checkpoint)
    else:
        optim = builder.build_optim(args, model, None)
    logger.info(model)

    # Train the model
    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
예제 #8
0
파일: trainer.py 프로젝트: jihun-hong/xlsum
    def _save(self, step):
        real_model = self.model
        # real_generator = (self.generator.module
        #                   if isinstance(self.generator, torch.nn.DataParallel)
        #                   else self.generator)

        model_state_dict = real_model.state_dict()
        # generator_state_dict = real_generator.state_dict()
        checkpoint = {
            'model': model_state_dict,
            # 'generator': generator_state_dict,
            'opt': self.args,
            'optim': self.optim,
        }
        checkpoint_path = os.path.join(self.args.model_path,
                                       'model_step_%d.pt' % step)
        logger.info("Saving checkpoint %s" % checkpoint_path)
        # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step)
        if (not os.path.exists(checkpoint_path)):
            torch.save(checkpoint, checkpoint_path)
            return checkpoint, checkpoint_path
예제 #9
0
def _format_xlnet(param):
    json_file, args, save_file = param

    # if file already exists, ignore
    if os.path.exists(save_file):
        logger.info('Ignore %s' % save_file)
        return

    xlnet = XLData(args)
    logger.info('Processing %s' % json_file)
    jobs = json.load(open(json_file))
    data_set = []

    # iterate over text in json_file
    for d in jobs:
        # generate oracle ids
        src, tgt = d['src'], d['tgt']
        if args.oracle_mode == 'greedy':
            oracle_ids = greedy(src, tgt, 3)
        elif args.oracle_mode == 'combination':
            oracle_ids = combination(src, tgt, 3)

        # process data using oracle ids
        xl_data = xlnet.process(src, tgt, oracle_ids)
        if xl_data is None:
            continue
        indexed_tokens, labels, segments_ids, cls_ids, src_txt, tgt_txt = xl_data
        b_data_dict = {
            "src": indexed_tokens,
            "labels": labels,
            "segs": segments_ids,
            'clss': cls_ids,
            'src_txt': src_txt,
            "tgt_txt": tgt_txt
        }
        data_set.append(b_data_dict)

    # save file with torch
    logger.info('Saving to %s' % save_file)
    torch.save(data_set, save_file)
    gc.collect()
예제 #10
0
파일: loader.py 프로젝트: jihun-hong/xlsum
 def _lazy_dataset_loader(pt_file, corpus_type):
     dataset = torch.load(pt_file)
     logger.info('Loading %s dataset from %s, number of examples: %d' %
                 (corpus_type, pt_file, len(dataset)))
     return dataset
예제 #11
0
 def log(self, *args, **kwargs):
     logger.info(*args, **kwargs)
예제 #12
0
파일: trainer.py 프로젝트: jihun-hong/xlsum
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        src = batch.src
                        labels = batch.labels
                        segs = batch.segs
                        clss = batch.clss
                        mask = batch.mask
                        mask_cls = batch.mask_cls

                        gold = []
                        pred = []

                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))
                                            ] * batch.batch_size
                        elif (cal_oracle):
                            selected_ids = [[
                                j for j in range(batch.clss.size(1))
                                if labels[i][j] == 1
                            ] for i in range(batch.batch_size)]
                        else:
                            sent_scores, mask = self.model(
                                src, clss, mask, mask_cls)

                            loss = self.loss(sent_scores, labels.float())
                            loss = (loss * mask.float()).sum()
                            batch_stats = Statistics(
                                float(loss.cpu().data.numpy()), len(labels))
                            stats.update(batch_stats)

                            sent_scores = sent_scores + mask.float()
                            sent_scores = sent_scores.cpu().data.numpy()
                            selected_ids = np.argsort(-sent_scores, 1)
                        # selected_ids = np.sort(selected_ids,1)
                        for i, idx in enumerate(selected_ids):
                            _pred = []
                            if (len(batch.src_str[i]) == 0):
                                continue
                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if (j >= len(batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                if (self.args.block_trigram):
                                    if (not _block_tri(candidate, _pred)):
                                        _pred.append(candidate)
                                else:
                                    _pred.append(candidate)

                                if ((not cal_oracle)
                                        and (not self.args.recall_eval)
                                        and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            if (self.args.recall_eval):
                                _pred = ' '.join(
                                    _pred.split()
                                    [:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip() + '\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip() + '\n')
        if (step != -1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' %
                        (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats