Пример #1
0
def LoadBoolQ(args, tokenizer):
    """
    Function that loads the BoolQ question-answering dataset.
    Inputs:
        args - Namespace object from the argument parser
        tokenizer - BERT tokenizer instance
    Outputs:
        train_set - Training dataset
        dev_set - Development dataset
        test_set - Test dataset
    """

    # load the sst dataset
    dataset = load_dataset("boolq")

    # divide into train, dev and test
    train_set = dataset['train']
    dataset = dataset['validation'].train_test_split(test_size=0.5,
                                                     train_size=0.5,
                                                     shuffle=True)
    dev_set = dataset['train']
    test_set = dataset['test']

    # function that encodes the question and passage
    def encode_sentence(examples):
        return tokenizer('[CLS] ' + examples['question'] + ' [SEP] ' +
                         examples['passage'] + ' [SEP]',
                         truncation=True,
                         padding='max_length')

    # tokenize the datasets
    train_set = train_set.map(encode_sentence, batched=False)
    dev_set = dev_set.map(encode_sentence, batched=False)
    test_set = test_set.map(encode_sentence, batched=False)

    # function to convert the answers to 0 (false) and 1 (true)
    def change_label(example):
        example['labels'] = 0 if (example['answer']) else 1
        return example

    # convert answers to labels
    train_set = train_set.map(change_label, batched=False)
    dev_set = dev_set.map(change_label, batched=False)
    test_set = test_set.map(change_label, batched=False)

    # remove unnecessary columns
    train_set = train_set.remove_columns(['question', 'passage', 'answer'])
    dev_set = dev_set.remove_columns(['question', 'passage', 'answer'])
    test_set = test_set.remove_columns(['question', 'passage', 'answer'])

    # create dataloaders for the datasets
    train_set = create_dataloader(args, train_set, tokenizer)
    dev_set = create_dataloader(args, dev_set, tokenizer)
    test_set = create_dataloader(args, test_set, tokenizer)

    # return the datasets
    return train_set, dev_set, test_set
Пример #2
0
def LoadSST2(args, tokenizer):
    """
    Function that loads the SST2 sentiment dataset.
    Inputs:
        args - Namespace object from the argument parser
        tokenizer - BERT tokenizer instance
    Outputs:
        train_set - Training dataset
        dev_set - Development dataset
        test_set - Test dataset
    """

    # load the sst dataset
    dataset = load_dataset("sst")

    # divide into train, dev and test
    train_set = dataset['train']
    dev_set = dataset['validation']
    test_set = dataset['test']

    # function that encodes the sentences
    def encode_sentence(examples):
        return tokenizer('[CLS] ' + examples['sentence'] + ' [SEP]',
                         truncation=True,
                         padding='max_length')

    # tokenize the datasets
    train_set = train_set.map(encode_sentence, batched=False)
    dev_set = dev_set.map(encode_sentence, batched=False)
    test_set = test_set.map(encode_sentence, batched=False)

    # remove unnecessary columns
    train_set = train_set.remove_columns(['sentence', 'tokens', 'tree'])
    dev_set = dev_set.remove_columns(['sentence', 'tokens', 'tree'])
    test_set = test_set.remove_columns(['sentence', 'tokens', 'tree'])

    # rename the labels
    train_set = train_set.rename_column("label", "labels")
    dev_set = dev_set.rename_column("label", "labels")
    test_set = test_set.rename_column("label", "labels")

    # create dataloaders for the datasets
    train_set = create_dataloader(args, train_set, tokenizer)
    dev_set = create_dataloader(args, dev_set, tokenizer)
    test_set = create_dataloader(args, test_set, tokenizer)

    # return the datasets
    return train_set, dev_set, test_set
def LoadGoEmotions(args, tokenizer, first_label=False, k_shot=False):
    """
    Function to load the GoEmotions dataset.
    Inputs:
        args - Namespace object from the argument parser
        tokenizer - BERT tokenizer instance
        first_label - Indicates whether to only use the first label. Default is False
        k_shot - Indicates whether to make the training set k-shot. Default is False
    Outputs:
        train_set - Training dataset
        dev_set - Development dataset
        test_set - Test dataset
    """

    # load the dataset
    dataset = load_dataset("go_emotions", "simplified")

    # function that encodes the text
    def encode_text(batch):
        tokenized_batch = tokenizer(batch['text'],
                                    padding=True,
                                    truncation=True)
        return tokenized_batch

    # tokenize the dataset
    dataset = dataset.map(manual_tokenizer, batched=False)
    dataset = dataset.map(encode_text, batched=False)

    # split into test, dev and train
    train_set = dataset['train']
    dev_set = dataset['validation']
    test_set = dataset['test']

    # prepare the data
    train_set, dev_set, test_set = PrepareSets(args, tokenizer, train_set,
                                               dev_set, test_set, first_label)

    # check if k-shot
    if k_shot:
        return train_set, test_set, 27

    # create dataloaders for the datasets
    train_set = create_dataloader(args, train_set, tokenizer)
    dev_set = create_dataloader(args, dev_set, tokenizer)
    test_set = create_dataloader(args, test_set, tokenizer)

    # return the datasets and number of classes
    return train_set, dev_set, test_set, 27
Пример #4
0
def create_data_fetcher(if_train=False,
                        seed=None,
                        num_gpu=None,
                        rank=None,
                        ds_type=None,
                        ds_opts=None,
                        enlarge_ratio=None,
                        nworker_pg=None,
                        bs_pg=None):
    """Define data-set, data-sampler, data-loader and CPU-based data-fetcher."""
    ds_cls = getattr(dataset, ds_type)
    ds = ds_cls(ds_opts)
    num_samples = len(ds)

    sampler = DistSampler(num_replicas=num_gpu, rank=rank, ratio=enlarge_ratio, ds_size=num_samples) if if_train \
        else None

    loader = create_dataloader(if_train=if_train,
                               seed=seed,
                               rank=rank,
                               num_worker=nworker_pg,
                               batch_size=bs_pg,
                               dataset=ds,
                               sampler=sampler)

    data_fetcher = CPUPrefetcher(loader)
    return num_samples, sampler, data_fetcher
Пример #5
0
def create_data_fetcher(ds_type=None, ds_opts=None):
    """Define data-set, data-loader and CPU-based data-fetcher."""
    ds_cls = getattr(dataset, ds_type)
    ds = ds_cls(ds_opts)
    num_samples = len(ds)
    loader = create_dataloader(if_train=False, dataset=ds)
    data_fetcher = CPUPrefetcher(loader)
    return num_samples, data_fetcher
Пример #6
0
    def __init__(self,
                 dataset_paths: list = [],
                 model_path: str = "",
                 epochs: int = 10,
                 lr: float = 0.001,
                 batch_size: int = 16):
        self.dataset_paths = dataset_paths
        self.model_path = model_path

        self.epochs = epochs
        self.lr = lr
        self.batch_size = batch_size

        self.train_set = create_dataloader(dataset_path=self.dataset_paths[0],
                                           batch_size=self.batch_size)
        self.validation_set = create_dataloader(
            dataset_path=self.dataset_paths[1], batch_size=self.batch_size)
        self.test_set = create_dataloader(dataset_path=self.dataset_paths[2],
                                          batch_size=self.batch_size)
Пример #7
0
def create_test_dataloader(feed, place, is_distributed):
    '''
    Use local testing dataset if it is found, otherwise, download dataset
    '''
    test_data_path = 'dataset/t10k-images-idx3-ubyte.gz'
    test_label_path = 'dataset/t10k-labels-idx1-ubyte.gz'
    if os.path.exists(test_data_path) and os.path.exists(test_label_path):
        reader = paddle.dataset.mnist.reader_creator(test_data_path,
                test_label_path, 100)
    else:
        reader = paddle.dataset.mnist.test()
    return utils.create_dataloader(reader, feed, place,
            batch_size=args.batch_size, is_test=True, is_distributed=is_distributed)
Пример #8
0
def main():
    # ==========
    # parameters
    # ==========

    opts_dict = receive_arg()
    rank = opts_dict['train']['rank']
    unit = opts_dict['train']['criterion']['unit']
    num_iter = int(opts_dict['train']['num_iter'])
    interval_print = int(opts_dict['train']['interval_print'])
    interval_val = int(opts_dict['train']['interval_val'])
    
    # ==========
    # init distributed training
    # ==========

    if opts_dict['train']['is_dist']:
        utils.init_dist(
            local_rank=rank, 
            backend='nccl'
            )

    # TO-DO: load resume states if exists
    pass

    # ==========
    # create logger
    # ==========

    if rank == 0:
        log_dir = op.join("exp", opts_dict['train']['exp_name'])
        utils.mkdir(log_dir)
        log_fp = open(opts_dict['train']['log_path'], 'w')

        # log all parameters
        msg = (
            f"{'<' * 10} Hello {'>' * 10}\n"
            f"Timestamp: [{utils.get_timestr()}]\n"
            f"\n{'<' * 10} Options {'>' * 10}\n"
            f"{utils.dict2str(opts_dict)}"
            )
        print(msg)
        log_fp.write(msg + '\n')
        log_fp.flush()

    # ==========
    # TO-DO: init tensorboard
    # ==========

    pass
    
    # ==========
    # fix random seed
    # ==========

    seed = opts_dict['train']['random_seed']
    # >I don't know why should rs + rank
    utils.set_random_seed(seed + rank)

    # ========== 
    # Ensure reproducibility or Speed up
    # ==========

    #torch.backends.cudnn.benchmark = False  # if reproduce
    #torch.backends.cudnn.deterministic = True  # if reproduce
    torch.backends.cudnn.benchmark = True  # speed up

    # ==========
    # create train and val data prefetchers
    # ==========
    
    # create datasets
    train_ds_type = opts_dict['dataset']['train']['type']
    val_ds_type = opts_dict['dataset']['val']['type']
    radius = opts_dict['network']['radius']
    assert train_ds_type in dataset.__all__, \
        "Not implemented!"
    assert val_ds_type in dataset.__all__, \
        "Not implemented!"
    train_ds_cls = getattr(dataset, train_ds_type)
    val_ds_cls = getattr(dataset, val_ds_type)
    train_ds = train_ds_cls(
        opts_dict=opts_dict['dataset']['train'], 
        radius=radius
        )
    val_ds = val_ds_cls(
        opts_dict=opts_dict['dataset']['val'], 
        radius=radius
        )

    # create datasamplers
    train_sampler = utils.DistSampler(
        dataset=train_ds, 
        num_replicas=opts_dict['train']['num_gpu'], 
        rank=rank, 
        ratio=opts_dict['dataset']['train']['enlarge_ratio']
        )
    val_sampler = None  # no need to sample val data

    # create dataloaders
    train_loader = utils.create_dataloader(
        dataset=train_ds, 
        opts_dict=opts_dict, 
        sampler=train_sampler, 
        phase='train',
        seed=opts_dict['train']['random_seed']
        )
    val_loader = utils.create_dataloader(
        dataset=val_ds, 
        opts_dict=opts_dict, 
        sampler=val_sampler, 
        phase='val'
        )
    assert train_loader is not None

    batch_size = opts_dict['dataset']['train']['batch_size_per_gpu'] * \
        opts_dict['train']['num_gpu']  # divided by all GPUs
    num_iter_per_epoch = math.ceil(len(train_ds) * \
        opts_dict['dataset']['train']['enlarge_ratio'] / batch_size)
    num_epoch = math.ceil(num_iter / num_iter_per_epoch)
    val_num = len(val_ds)
    
    # create dataloader prefetchers
    tra_prefetcher = utils.CPUPrefetcher(train_loader)
    val_prefetcher = utils.CPUPrefetcher(val_loader)

    # ==========
    # create model
    # ==========

    model = MFVQE(opts_dict=opts_dict['network'])

    model = model.to(rank)
    if opts_dict['train']['is_dist']:
        model = DDP(model, device_ids=[rank])

    """
    # load pre-trained generator
    ckp_path = opts_dict['network']['stdf']['load_path']
    checkpoint = torch.load(ckp_path)
    state_dict = checkpoint['state_dict']
    if ('module.' in list(state_dict.keys())[0]) and (not opts_dict['train']['is_dist']):  # multi-gpu pre-trained -> single-gpu training
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove module
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
        print(f'loaded from {ckp_path}')
    elif ('module.' not in list(state_dict.keys())[0]) and (opts_dict['train']['is_dist']):  # single-gpu pre-trained -> multi-gpu training
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = 'module.' + k  # add module
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
        print(f'loaded from {ckp_path}')
    else:  # the same way of training
        model.load_state_dict(state_dict)
        print(f'loaded from {ckp_path}')
    """

    # ==========
    # define loss func & optimizer & scheduler & scheduler & criterion
    # ==========

    # define loss func
    assert opts_dict['train']['loss'].pop('type') == 'CharbonnierLoss', \
        "Not implemented."
    loss_func = utils.CharbonnierLoss(**opts_dict['train']['loss'])

    # define optimizer
    assert opts_dict['train']['optim'].pop('type') == 'Adam', \
        "Not implemented."
    optimizer = optim.Adam(
        model.parameters(), 
        **opts_dict['train']['optim']
        )

    # define scheduler
    if opts_dict['train']['scheduler']['is_on']:
        assert opts_dict['train']['scheduler'].pop('type') == \
            'CosineAnnealingRestartLR', "Not implemented."
        del opts_dict['train']['scheduler']['is_on']
        scheduler = utils.CosineAnnealingRestartLR(
            optimizer, 
            **opts_dict['train']['scheduler']
            )
        opts_dict['train']['scheduler']['is_on'] = True

    # define criterion
    assert opts_dict['train']['criterion'].pop('type') == \
        'PSNR', "Not implemented."
    criterion = utils.PSNR()

    #

    start_iter = 0  # should be restored
    start_epoch = start_iter // num_iter_per_epoch

    # display and log
    if rank == 0:
        msg = (
            f"\n{'<' * 10} Dataloader {'>' * 10}\n"
            f"total iters: [{num_iter}]\n"
            f"total epochs: [{num_epoch}]\n"
            f"iter per epoch: [{num_iter_per_epoch}]\n"
            f"val sequence: [{val_num}]\n"
            f"start from iter: [{start_iter}]\n"
            f"start from epoch: [{start_epoch}]"
            )
        print(msg)
        log_fp.write(msg + '\n')
        log_fp.flush()

    # ==========
    # evaluate original performance, e.g., PSNR before enhancement
    # ==========

    vid_num = val_ds.get_vid_num()
    if opts_dict['train']['pre-val'] and rank == 0:
        msg = f"\n{'<' * 10} Pre-evaluation {'>' * 10}"
        print(msg)
        log_fp.write(msg + '\n')

        per_aver_dict = {}
        for i in range(vid_num):
            per_aver_dict[i] = utils.Counter()
        pbar = tqdm(
                total=val_num, 
                ncols=opts_dict['train']['pbar_len']
                )

        # fetch the first batch
        val_prefetcher.reset()
        val_data = val_prefetcher.next()

        while val_data is not None:
            # get data
            gt_data = val_data['gt'].to(rank)  # (B [RGB] H W)
            lq_data = val_data['lq'].to(rank)  # (B T [RGB] H W)
            index_vid = val_data['index_vid'].item()
            name_vid = val_data['name_vid'][0]  # bs must be 1!
            b, _, _, _, _  = lq_data.shape
            
            # eval
            batch_perf = np.mean(
                [criterion(lq_data[i,radius,...], gt_data[i]) for i in range(b)]
                )  # bs must be 1!
            
            # log
            per_aver_dict[index_vid].accum(volume=batch_perf)

            # display
            pbar.set_description(
                "{:s}: [{:.3f}] {:s}".format(name_vid, batch_perf, unit)
                )
            pbar.update()

            # fetch next batch
            val_data = val_prefetcher.next()

        pbar.close()

        # log
        ave_performance = np.mean([
            per_aver_dict[index_vid].get_ave() for index_vid in range(vid_num)
            ])
        msg = "> ori performance: [{:.3f}] {:s}".format(ave_performance, unit)
        print(msg)
        log_fp.write(msg + '\n')
        log_fp.flush()

    if opts_dict['train']['is_dist']:
        torch.distributed.barrier()  # all processes wait for ending

    if rank == 0:
        msg = f"\n{'<' * 10} Training {'>' * 10}"
        print(msg)
        log_fp.write(msg + '\n')

        # create timer
        total_timer = utils.Timer()  # total tra + val time of each epoch

    # ==========
    # start training + validation (test)
    # ==========

    model.train()
    num_iter_accum = start_iter
    for current_epoch in range(start_epoch, num_epoch + 1):
        # shuffle distributed subsamplers before each epoch
        if opts_dict['train']['is_dist']:
            train_sampler.set_epoch(current_epoch)

        # fetch the first batch
        tra_prefetcher.reset()
        train_data = tra_prefetcher.next()

        # train this epoch
        while train_data is not None:

            # over sign
            num_iter_accum += 1
            if num_iter_accum > num_iter:
                break

            # get data
            gt_data = train_data['gt'].to(rank)  # (B [RGB] H W)
            lq_data = train_data['lq'].to(rank)  # (B T [RGB] H W)
            b, _, c, _, _  = lq_data.shape
            input_data = torch.cat(
                [lq_data[:,:,i,...] for i in range(c)], 
                dim=1
                )  # B [R1 ... R7 G1 ... G7 B1 ... B7] H W
            enhanced_data = model(input_data)

            # get loss
            optimizer.zero_grad()  # zero grad
            loss = torch.mean(torch.stack(
                [loss_func(enhanced_data[i], gt_data[i]) for i in range(b)]
                ))  # cal loss
            loss.backward()  # cal grad
            optimizer.step()  # update parameters

            # update learning rate
            if opts_dict['train']['scheduler']['is_on']:
                scheduler.step()  # should after optimizer.step()

            if (num_iter_accum % interval_print == 0) and (rank == 0):
                # display & log
                lr = optimizer.param_groups[0]['lr']
                loss_item = loss.item()
                msg = (
                    f"iter: [{num_iter_accum}]/{num_iter}, "
                    f"epoch: [{current_epoch}]/{num_epoch - 1}, "
                    "lr: [{:.3f}]x1e-4, loss: [{:.4f}]".format(
                        lr*1e4, loss_item
                        )
                    )
                print(msg)
                log_fp.write(msg + '\n')

            if ((num_iter_accum % interval_val == 0) or \
                (num_iter_accum == num_iter)) and (rank == 0):
                # save model
                checkpoint_save_path = (
                    f"{opts_dict['train']['checkpoint_save_path_pre']}"
                    f"{num_iter_accum}"
                    ".pt"
                    )
                state = {
                    'num_iter_accum': num_iter_accum, 
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(), 
                    }
                if opts_dict['train']['scheduler']['is_on']:
                    state['scheduler'] = scheduler.state_dict()
                torch.save(state, checkpoint_save_path)
                
                # validation
                with torch.no_grad():
                    per_aver_dict = {}
                    for index_vid in range(vid_num):
                        per_aver_dict[index_vid] = utils.Counter()
                    pbar = tqdm(
                            total=val_num, 
                            ncols=opts_dict['train']['pbar_len']
                            )
                
                    # train -> eval
                    model.eval()

                    # fetch the first batch
                    val_prefetcher.reset()
                    val_data = val_prefetcher.next()
                    
                    while val_data is not None:
                        # get data
                        gt_data = val_data['gt'].to(rank)  # (B [RGB] H W)
                        lq_data = val_data['lq'].to(rank)  # (B T [RGB] H W)
                        index_vid = val_data['index_vid'].item()
                        name_vid = val_data['name_vid'][0]  # bs must be 1!
                        b, _, c, _, _  = lq_data.shape
                        input_data = torch.cat(
                            [lq_data[:,:,i,...] for i in range(c)], 
                            dim=1
                            )  # B [R1 ... R7 G1 ... G7 B1 ... B7] H W
                        enhanced_data = model(input_data)  # (B [RGB] H W)

                        # eval
                        batch_perf = np.mean(
                            [criterion(enhanced_data[i], gt_data[i]) for i in range(b)]
                            ) # bs must be 1!

                        # display
                        pbar.set_description(
                            "{:s}: [{:.3f}] {:s}"
                            .format(name_vid, batch_perf, unit)
                            )
                        pbar.update()

                        # log
                        per_aver_dict[index_vid].accum(volume=batch_perf)

                        # fetch next batch
                        val_data = val_prefetcher.next()
                    
                    # end of val
                    pbar.close()

                    # eval -> train
                    model.train()

                # log
                ave_per = np.mean([
                    per_aver_dict[index_vid].get_ave() for index_vid in range(vid_num)
                    ])
                msg = (
                    "> model saved at {:s}\n"
                    "> ave val per: [{:.3f}] {:s}"
                    ).format(
                        checkpoint_save_path, ave_per, unit
                        )
                print(msg)
                log_fp.write(msg + '\n')
                log_fp.flush()

            if opts_dict['train']['is_dist']:
                torch.distributed.barrier()  # all processes wait for ending

            # fetch next batch
            train_data = tra_prefetcher.next()

        # end of this epoch (training dataloader exhausted)

    # end of all epochs

    # ==========
    # final log & close logger
    # ==========

    if rank == 0:
        total_time = total_timer.get_interval() / 3600
        msg = "TOTAL TIME: [{:.1f}] h".format(total_time)
        print(msg)
        log_fp.write(msg + '\n')
        
        msg = (
            f"\n{'<' * 10} Goodbye {'>' * 10}\n"
            f"Timestamp: [{utils.get_timestr()}]"
            )
        print(msg)
        log_fp.write(msg + '\n')
        
        log_fp.close()
Пример #9
0
def Train_No_GAN(opt):    # w / o GAN
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generatorNet = utils.create_generator(opt)
    flownet = utils.create_pwcnet(opt)

    # To device
    if opt.multi_gpu:
        generatorNet = nn.DataParallel(generatorNet)
        generatorNet = generatorNet.cuda()
        flownet = nn.DataParallel(flownet)
        flownet = flownet.cuda()
    else:
        generatorNet = generatorNet.cuda()
        flownet = flownet.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generatorNet.parameters(), lr = opt.lr_g, betas = (opt.b1, opt.b2), weight_decay = opt.weight_decay)
    
    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor ** (epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor ** (iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
    
    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(generator.module, 'Pre_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size))
                        print('The trained model is successfully saved at epoch %d' % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(generator.module, 'Pre_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size))
                        print('The trained model is successfully saved at iteration %d' % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(generator, 'Pre_%s_epoch%d_bs%d.pth' % (opt.task, epoch, opt.batch_size))
                        print('The trained model is successfully saved at epoch %d' % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(generator, 'Pre_%s_iter%d_bs%d.pth' % (opt.task, iteration, opt.batch_size))
                        print('The trained model is successfully saved at iteration %d' % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the class list
    imglist = utils.text_readlines('videocolor_linux.txt')
    classlist = utils.get_dirs(opt.baseroot)
    '''
    imgnumber = len(imglist) - (len(imglist) % opt.batch_size)
    imglist = imglist[:imgnumber]
    '''

    # Define the dataset
    trainset = dataset.MultiFramesDataset(opt, imglist, classlist)
    print('The overall number of classes:', len(trainset))

    # Define the dataloader
    dataloader = utils.create_dataloader(trainset, opt)
    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()
    
    # For loop training
    for epoch in range(opt.epochs):
        for iteration, (in_part, out_part) in enumerate(dataloader):
            
            # Train Generator
            optimizer_G.zero_grad()

            lstm_state = None
            loss_flow = 0
            loss_flow_long = 0
            loss_L1 = 0

            x_0 = in_part[0].cuda()
            p_t_0 = in_part[0].cuda()

            for iter_frame in range(opt.iter_frames):
                # Read data
                x_t = in_part[iter_frame].cuda()
                y_t = out_part[iter_frame].cuda()
                
                # Initialize the second input and compute flow loss
                if iter_frame == 0:
                    p_t_last = torch.zeros(opt.batch_size, opt.out_channels, opt.resize_h, opt.resize_w).cuda()
                elif iter_frame == 1:
                    x_t_last = in_part[iter_frame - 1].cuda()
                    p_t_last = p_t.detach()
                    p_t_0 = p_t.detach()
                    p_t_last.requires_grad = False
                    p_t_0.requires_grad = False
                    # o_t_last_2_t range is [-20, +20]
                    o_t_last_2_t = pwcnet.PWCEstimate(flownet, x_t, x_t_last)
                    x_t_warp = pwcnet.PWCNetBackward((x_t_last + 1) / 2, o_t_last_2_t)
                    # y_t_warp range is [0, 1]
                    p_t_warp = pwcnet.PWCNetBackward((p_t_last + 1) / 2, o_t_last_2_t)
                else:
                    x_t_last = in_part[iter_frame - 1].cuda()
                    p_t_last = p_t.detach()
                    p_t_last.requires_grad = False
                    # o_t_last_2_t o_t_first_2_t range is [-20, +20]
                    o_t_last_2_t = pwcnet.PWCEstimate(flownet, x_t, x_t_last)
                    o_t_first_2_t = pwcnet.PWCEstimate(flownet,x_t, x_0)
                    # y_t_warp, y_t_warp_long range is [0, 1]
                    x_t_warp = pwcnet.PWCNetBackward((x_t_last + 1) / 2, o_t_last_2_t)
                    p_t_warp = pwcnet.PWCNetBackward((p_t_last + 1) / 2, o_t_last_2_t)
                    x_t_warp_long = pwcnet.PWCNetBackward((x_0 + 1) / 2, o_t_first_2_t)
                    p_t_warp_long = pwcnet.PWCNetBackward((p_t_0 + 1) / 2, o_t_first_2_t)
                # Generator output
                p_t, lstm_state = generatorNet(x_t, p_t_last, lstm_state)
                lstm_state = utils.repackage_hidden(lstm_state)
                if iter_frame == 1:
                    mask_flow = torch.exp( -opt.mask_para * torch.sum((x_t + 1) / 2 - x_t_warp, dim=1).pow(2) ).unsqueeze(1)
                    loss_flow += criterion_L1(mask_flow * (p_t + 1) / 2, mask_flow * p_t_warp)
                elif iter_frame > 1:
                    mask_flow = torch.exp( -opt.mask_para * torch.sum((x_t + 1) / 2 - x_t_warp, dim=1).pow(2) ).unsqueeze(1)
                    loss_flow += criterion_L1(mask_flow * (p_t + 1) / 2, mask_flow * p_t_warp)
                    mask_flow_long = torch.exp( -opt.mask_para * torch.sum((x_t + 1) / 2 - x_t_warp_long, dim=1).pow(2) ).unsqueeze(1)
                    loss_flow_long += criterion_L1(mask_flow_long * (p_t + 1) / 2, mask_flow_long * p_t_warp_long)
                
                # Pixel-level loss
                loss_L1 += criterion_L1(p_t, y_t)

            # Overall Loss and optimize
            loss = loss_L1 + opt.lambda_flow * loss_flow + opt.lambda_flow_long * loss_flow_long
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + iteration
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print("\r[Epoch %d/%d] [Batch %d/%d] [L1 Loss: %.4f] [Flow Loss Short: %.8f] [Flow Loss Long: %.8f] Time_left: %s" %
                ((epoch + 1), opt.epochs, iteration, len(dataloader), loss_L1.item(), loss_flow.item(), loss_flow_long.item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generatorNet)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G)
            
def Trainer_WGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator_a, generator_b = utils.create_generator(opt)
    discriminator_a, discriminator_b = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu:
        generator_a = nn.DataParallel(generator_a)
        generator_a = generator_a.cuda()
        generator_b = nn.DataParallel(generator_b)
        generator_b = generator_b.cuda()
        discriminator_a = nn.DataParallel(discriminator_a)
        discriminator_a = discriminator_a.cuda()
        discriminator_b = nn.DataParallel(discriminator_b)
        discriminator_b = discriminator_b.cuda()
    else:
        generator_a = generator_a.cuda()
        generator_b = generator_b.cuda()
        discriminator_a = discriminator_a.cuda()
        discriminator_b = discriminator_b.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(generator_a.parameters(),
                                                   generator_b.parameters()),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D_a = torch.optim.Adam(discriminator_a.parameters(),
                                     lr=opt.lr_d,
                                     betas=(opt.b1, opt.b2))
    optimizer_D_b = torch.optim.Adam(discriminator_b.parameters(),
                                     lr=opt.lr_d,
                                     betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator_a,
                   generator_b):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator_a.module,
                            'WGAN_DRIT_epoch%d_bs%d_a.pth' %
                            (epoch, opt.batch_size))
                        torch.save(
                            generator_b.module,
                            'WGAN_DRIT_epoch%d_bs%d_b.pth' %
                            (epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator_a.module, 'WGAN_DRIT_iter%d_bs%d_a.pth' %
                            (iteration, opt.batch_size))
                        torch.save(
                            generator_b.module, 'WGAN_DRIT_iter%d_bs%d_b.pth' %
                            (iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator_a, 'WGAN_DRIT_epoch%d_bs%d_a.pth' %
                            (epoch, opt.batch_size))
                        torch.save(
                            generator_b, 'WGAN_DRIT_epoch%d_bs%d_b.pth' %
                            (epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator_a, 'WGAN_DRIT_iter%d_bs%d_a.pth' %
                            (iteration, opt.batch_size))
                        torch.save(
                            generator_b, 'WGAN_DRIT_iter%d_bs%d_b.pth' %
                            (iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    dataloader = utils.create_dataloader(opt)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (img_a, img_b) in enumerate(dataloader):

            # To device
            img_a = img_a.cuda()
            img_b = img_b.cuda()

            # Sampled style codes (prior)
            prior_s_a = torch.randn(img_a.shape[0], opt.style_dim).cuda()
            prior_s_b = torch.randn(img_a.shape[0], opt.style_dim).cuda()

            # ----------------------------------------
            #              Train Generator
            # ----------------------------------------
            # Note that:
            # input / output image dimension: [B, 3, 256, 256]
            # content_code dimension: [B, 256, 64, 64]
            # style_code dimension: [B, 8]
            # generator_a is related to domain a / style a
            # generator_b is related to domain b / style b

            optimizer_G.zero_grad()

            # Get shared latent representation
            c_a, s_a = generator_a.encode(img_a)
            c_b, s_b = generator_b.encode(img_b)

            # Reconstruct images
            img_aa_recon = generator_a.decode(c_a, s_a)
            img_bb_recon = generator_b.decode(c_b, s_b)

            # Translate images
            img_ba = generator_a.decode(c_b, prior_s_a)
            img_ab = generator_b.decode(c_a, prior_s_b)

            # Cycle code translation
            c_b_recon, s_a_recon = generator_a.encode(img_ba)
            c_a_recon, s_b_recon = generator_b.encode(img_ab)

            # Cycle image translation
            img_aa_recon_cycle = generator_a.decode(
                c_a_recon, s_a) if opt.lambda_cycle > 0 else 0
            img_bb_recon_cycle = generator_b.decode(
                c_b_recon, s_b) if opt.lambda_cycle > 0 else 0

            # Losses
            loss_id_1 = opt.lambda_id * criterion_L1(img_aa_recon, img_a)
            loss_id_2 = opt.lambda_id * criterion_L1(img_bb_recon, img_b)
            loss_s_1 = opt.lambda_style * criterion_L1(s_a_recon, prior_s_a)
            loss_s_2 = opt.lambda_style * criterion_L1(s_b_recon, prior_s_b)
            loss_c_1 = opt.lambda_content * criterion_L1(
                c_a_recon, c_a.detach())
            loss_c_2 = opt.lambda_content * criterion_L1(
                c_b_recon, c_b.detach())
            loss_cycle_1 = opt.lambda_cycle * criterion_L1(
                img_aa_recon_cycle, img_a) if opt.lambda_cycle > 0 else 0
            loss_cycle_2 = opt.lambda_cycle * criterion_L1(
                img_bb_recon_cycle, img_b) if opt.lambda_cycle > 0 else 0

            # GAN Loss
            fake_scalar_a = discriminator_a(img_ba)
            fake_scalar_b = discriminator_b(img_ab)
            loss_gan1 = -opt.lambda_gan * torch.mean(fake_scalar_a)
            loss_gan2 = -opt.lambda_gan * torch.mean(fake_scalar_b)

            # Overall Losses and optimization
            loss_G = loss_id_1 + loss_id_2 + loss_s_1 + loss_s_2 + loss_c_1 + loss_c_2 + loss_cycle_1 + loss_cycle_2 + loss_gan1 + loss_gan2
            loss_G.backward()
            optimizer_G.step()

            # ----------------------------------------
            #            Train Discriminator
            # ----------------------------------------

            optimizer_D_a.zero_grad()
            optimizer_D_b.zero_grad()

            # D_a
            fake_scalar_a = discriminator_a(img_ba.detach())
            true_scalar_a = discriminator_a(img_a)
            loss_D_a = torch.mean(fake_scalar_a) - torch.mean(true_scalar_a)
            loss_D_a.backward()
            optimizer_D_a.step()

            # D_b
            fake_scalar_b = discriminator_b(img_ab.detach())
            true_scalar_b = discriminator_b(img_b)
            loss_D_b = torch.mean(fake_scalar_b) - torch.mean(true_scalar_b)
            loss_D_b.backward()
            optimizer_D_b.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Recon Loss: %.4f] [Style Loss: %.4f] [Content Loss: %.4f] [G Loss: %.4f] [D Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   (loss_id_1 + loss_id_2).item(),
                   (loss_s_1 + loss_s_2).item(), (loss_c_1 + loss_c_2).item(),
                   (loss_gan1 + loss_gan2).item(),
                   (loss_D_a + loss_D_b).item(), time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator_a, generator_b)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D_a)
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_D_b)
Пример #11
0
def LoadCirca(args, tokenizer, test_scenario=None):
    """
    Function to load the Circa dataset for both matched and unmatched settings.
    Inputs:
        args - Namespace object from the argument parser
        tokenizer - BERT tokenizer instance
        test_scenario - Scenario to reserve for testing (only if UNMATCHED! Otherwise MATCHED settings are loaded)
    Outputs:
        train_set - Training dataset containing 60% of the data if matched; 80% of left scenarios if unmatched
        dev_set - Development dataset containing 20% of the data if matched; 20% of left scenarios if unmatched
        test_set - Test dataset containing 20% of the data if matched; 1 scenario of unmatched
    """

    # load the filtered and annotated dataset
    dataset, usedTopicLabels = processCircaDataset(
        doAnnotateImportantWords=args.impwords,
        preloadImportantWords=args.npimpwords,
        doAnnotateTopics=args.topics,
        preloadTopics=args.nptopics,
        traverseTopicLemmas=args.traversetopics,
        tfidf=args.tfidf,
        hybrid=args.hybrid,
        topic_depth=args.topic_depth,
        label_density=args.label_density,
        impwordsfile=args.impwordsfile,
        topicsfile=args.topicsfile,
        topiclabelsfile=args.topiclabelsfile)

    # create the dictionary for the labels
    if args.labels == "strict":
        circa_labels = {
            'Yes':
            dataset.features['goldstandard1'].str2int('Yes'),
            'Probably yes / sometimes yes':
            dataset.features['goldstandard1'].str2int(
                'Probably yes / sometimes yes'),
            'Yes, subject to some conditions':
            dataset.features['goldstandard1'].str2int(
                'Yes, subject to some conditions'),
            'No':
            dataset.features['goldstandard1'].str2int('No'),
            'Probably no':
            dataset.features['goldstandard1'].str2int('Probably no'),
            'In the middle, neither yes nor no':
            dataset.features['goldstandard1'].str2int(
                'In the middle, neither yes nor no')
        }
    else:
        circa_labels = {
            'Yes':
            dataset.features['goldstandard2'].str2int('Yes'),
            'No':
            dataset.features['goldstandard2'].str2int('No'),
            'In the middle, neither yes nor no':
            dataset.features['goldstandard2'].str2int(
                'In the middle, neither yes nor no'),
            'Yes, subject to some conditions':
            dataset.features['goldstandard2'].str2int(
                'Yes, subject to some conditions')
        }

    # split the dataset into train, dev and test
    if not test_scenario:  # matched
        dataset = dataset.train_test_split(test_size=0.4,
                                           train_size=0.6,
                                           shuffle=True)
        train_set = dataset['train']
        dataset = dataset['test'].train_test_split(test_size=0.5,
                                                   train_size=0.5,
                                                   shuffle=True)
        dev_set = dataset['train']
        test_set = dataset['test']
    else:  # unmatched
        test_set = dataset.filter(
            lambda example: example['context'] == test_scenario)
        left_set = dataset.filter(
            lambda example: example['context'] != test_scenario)
        left_set = left_set.train_test_split(test_size=0.2,
                                             train_size=0.8,
                                             shuffle=True)
        train_set = left_set['train']
        dev_set = left_set['test']

    # prepare the data
    train_set_o, dev_set_o, test_set_o = PrepareSets(args, tokenizer,
                                                     train_set, dev_set,
                                                     test_set)

    # create dataloaders for the datasets
    train_set_dict = {'Circa': create_dataloader(args, train_set_o, tokenizer)}
    dev_set_dict = {'Circa': create_dataloader(args, dev_set_o, tokenizer)}
    test_set_dict = {'Circa': create_dataloader(args, test_set_o, tokenizer)}

    label_dict = {'Circa': circa_labels}

    if 'TOPICS' in args.aux_tasks:
        # we use only the answer for topic extracting
        orgModelVersion = args.model_version
        argsModified = vars(args)
        argsModified['model_version'] = 'A'

        train_set_t, dev_set_t, test_set_t = PrepareSets(
            args, tokenizer, train_set, dev_set, test_set)

        # rename labels
        train_set_t = train_set_t.remove_columns(['labels'])
        dev_set_t = dev_set_t.remove_columns(['labels'])
        test_set_t = test_set_t.remove_columns(['labels'])

        train_set_t = train_set_t.rename_column("topic_label", "labels")
        dev_set_t = dev_set_t.rename_column("topic_label", "labels")
        test_set_t = test_set_t.rename_column("topic_label", "labels")

        # for the separate aux task, remove samples without topic annotation
        if '' in usedTopicLabels:
            emptyLabelIndex = usedTopicLabels.index(
                ''
            )  # actually, we know that position is last element, just to be sure, not costly
            #emptyLabelIndex = len(usedTopicLabels) - 1
            train_set_t = train_set_t.filter(
                lambda x: x['labels'] != emptyLabelIndex)
            dev_set_t = dev_set_t.filter(
                lambda x: x['labels'] != emptyLabelIndex)
            test_set_t = test_set_t.filter(
                lambda x: x['labels'] != emptyLabelIndex)

        print(
            'After removing empty topics, we have %d; %d; and %d samples for (respectively) train, dev, test sets for topic aux task'
            % (len(train_set_t), len(dev_set_t), len(test_set_t)))

        # add to dict
        train_set_dict['TOPICS'] = create_dataloader(args, train_set_t,
                                                     tokenizer)
        dev_set_dict['TOPICS'] = create_dataloader(args, dev_set_t, tokenizer)
        test_set_dict['TOPICS'] = create_dataloader(args, test_set_t,
                                                    tokenizer)
        label_dict['TOPICS'] = usedTopicLabels[:-1]

        argsModified['model_version'] = orgModelVersion

    # return the datasets (and label dict)
    return train_set_dict, dev_set_dict, test_set_dict, label_dict
def main(args):
    """
    Function for handling the arguments and starting the experiment.
    Inputs:
        args - Namespace object from the argument parser
    """

    # set the seed
    torch.manual_seed(args.seed)

    # check if GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # print the model parameters
    print('-----TRAINING PARAMETERS-----')
    print('Dataset: {}'.format(args.dataset))
    print('PyTorch device: {}'.format(device))
    print('K: {}'.format(args.k))
    print('Number of epochs: {}'.format(args.num_epochs))
    print('Number of runs: {}'.format(args.num_runs))
    print('Learning rate: {}'.format(args.lr))
    print('Batch size: {}'.format(args.batch_size))
    print('Results directory: {}'.format(args.results_dir))
    print('Progress bar: {}'.format(args.progress_bar))
    print('-----------------------------')

    # generate the path to use for the results
    path = create_path(args)
    if not os.path.exists(path):
        os.makedirs(path)

    # all evaluation datasets
    eval_datasets = [
        'GoEmotions', 'crowdflower', 'dailydialog', 'electoraltweets',
        'emoint', 'emotion-cause', 'grounded_emotions', 'ssec',
        'tales-emotion', 'tec'
    ]

    # load the tokenizer
    tokenizer = BertTokenizer.from_pretrained(
        'bert-base-uncased', additional_special_tokens=specials())

    # repeat for the different datasets
    all_results = {}
    for dataset_name in eval_datasets:
        dataset_results = {}
        print('Evaluating on ' + dataset_name)

        # load the dataset
        print('Loading datasets..')
        if dataset_name == 'GoEmotions':
            train_set, test_set, new_num_classes = LoadGoEmotions(args,
                                                                  tokenizer,
                                                                  k_shot=True)
        else:
            train_set, test_set, new_num_classes = LoadUnifiedEmotions(
                args, tokenizer, dataset_name, k_shot=True)
        test_set = create_dataloader(args, test_set, tokenizer)
        print('Datasets loaded')

        # repeat for the specified number of runs
        for run in range(1, args.num_runs + 1):
            print('Run ' + str(run))

            # convert the train set to a k-shot dataloader
            train_loader = create_dataloader(args, train_set, tokenizer, True,
                                             new_num_classes)

            # evaluate the model
            results = evaluate_dataset(args, device, tokenizer, train_loader,
                                       test_set, new_num_classes)
            dataset_results['run' + str(run)] = results
            print('----')
        all_results[dataset_name] = dataset_results

    # calculate the mean and std for the different runs
    average_results = average_evaluation_results(all_results)
    all_results['average testing'] = average_results

    # save the results as a json file
    print('Saving results..')
    with open(
            os.path.join(path, 'evaluation_results_k' + str(args.k) + '.txt'),
            'w') as outfile:
        json.dump(all_results, outfile)
    print('Results saved')
                        help='the folder name of the a domain')
    parser.add_argument('--dataset_name_b',
                        type=str,
                        default='human_test',
                        help='the folder name of the b domain')
    parser.add_argument('--imgsize',
                        type=int,
                        default=128,
                        help='the image size')
    opt = parser.parse_args()

    utils.check_path(opt.save_path)

    # Define the dataset
    # a = 'cat'; b = 'human'
    testloader = utils.create_dataloader(opt)
    print('The overall number of images:', len(testloader))

    # Define networks
    generator_a, generator_b = utils.create_generator(opt)
    generator_a = generator_a.cuda()
    generator_b = generator_b.cuda()

    # Forward
    for i, (img_a, img_b) in enumerate(testloader):
        # To device
        img_a = img_a.cuda()
        img_b = img_b.cuda()
        # Forward
        with torch.no_grad():
            out = generator_b(img_a, img_a)
Пример #14
0
def LoadUnifiedEmotions(args,
                        tokenizer,
                        target_dataset,
                        path="./data/datasets/unified-dataset.jsonl",
                        k_shot=False):
    """
    Function to load the UnifiedEmotions dataset.
    Inputs:
        args - Namespace object from the argument parser
        tokenizer - BERT tokenizer instance
        target_dataset - String representing the dataset to load
        path - Path to the unified dataset jsonl file
        k_shot - Indicates whether to make the training set k-shot. Default is False
    Outputs:
        train_set - Training dataset
        dev_set - Development dataset
        test_set - Test dataset
    """

    # load the dataset
    dataset = load_dataset('json', data_files=path)['train']

    # filter out the correct source
    dataset = dataset.filter(
        lambda example: example['source'] == target_dataset)

    # function that encodes the text
    def encode_text(batch):
        tokenized_batch = tokenizer(batch['text'],
                                    padding=True,
                                    truncation=True)
        return tokenized_batch

    # tokenize the dataset
    dataset = dataset.map(manual_tokenizer, batched=False)
    dataset = dataset.map(encode_text, batched=False)

    # create a dictionary for converting labels to integers
    label_dict = {}
    for label in dataset[0]['emotions']:
        if dataset[0]['emotions'][label] is not None:
            label_dict[label] = len(label_dict)

    # split the dataset
    if None in dataset['split']:
        # split the dataset into 70% train, 15% test and 15% validation
        dataset = dataset.train_test_split(test_size=0.3,
                                           train_size=0.7,
                                           shuffle=True)
        train_set = dataset['train']
        dataset = dataset['test'].train_test_split(test_size=0.5,
                                                   train_size=0.5,
                                                   shuffle=True)
        dev_set = dataset['train']
        test_set = dataset['test']
    else:
        train_set = dataset.filter(lambda example: example['split'] == 'train')
        dev_set = dataset.filter(
            lambda example: example['split'] == 'validation')
        test_set = dataset.filter(lambda example: example['split'] == 'test')

    # create a validation spit for ssec
    if target_dataset == 'ssec':
        test_set = test_set.train_test_split(test_size=0.5,
                                             train_size=0.5,
                                             shuffle=True)
        dev_set = test_set['train']
        test_set = test_set['test']

    # prepare the data
    train_set, dev_set, test_set = PrepareSets(args, tokenizer, label_dict,
                                               train_set, dev_set, test_set)

    # check if k-shot
    if k_shot:
        return train_set, test_set, len(label_dict)

    # create dataloaders for the datasets
    train_set = create_dataloader(args, train_set, tokenizer)
    dev_set = create_dataloader(args, dev_set, tokenizer)
    test_set = create_dataloader(args, test_set, tokenizer)

    # return the datasets and number of classes
    return train_set, dev_set, test_set, len(label_dict)
Пример #15
0
def do_train():
    paddle.set_device(args.device)
    rank = paddle.distributed.get_rank()
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args.seed)

    train_ds, dev_ds = load_dataset('cblue', 'CMeEE', splits=['train', 'dev'])

    model = ElectraForBinaryTokenClassification.from_pretrained(
        'ernie-health-chinese',
        num_classes=[len(x) for x in train_ds.label_list])
    tokenizer = ElectraTokenizer.from_pretrained('ernie-health-chinese')

    label_list = train_ds.label_list
    pad_label_id = [len(label_list[0]) - 1, len(label_list[1]) - 1]
    ignore_label_id = -100

    trans_func = partial(convert_example_ner,
                         tokenizer=tokenizer,
                         max_seq_length=args.max_seq_length,
                         pad_label_id=pad_label_id)

    batchify_fn = lambda samples, fn=Dict({
        'input_ids':
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),
        'token_type_ids':
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'),
        'position_ids':
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),
        'attention_mask':
        Pad(axis=0, pad_val=0, dtype='float32'),
        'label_oth':
        Pad(axis=0, pad_val=pad_label_id[0], dtype='int64'),
        'label_sym':
        Pad(axis=0, pad_val=pad_label_id[1], dtype='int64')
    }): fn(samples)

    train_data_loader = create_dataloader(train_ds,
                                          mode='train',
                                          batch_size=args.batch_size,
                                          batchify_fn=batchify_fn,
                                          trans_fn=trans_func)

    dev_data_loader = create_dataloader(dev_ds,
                                        mode='dev',
                                        batch_size=args.batch_size,
                                        batchify_fn=batchify_fn,
                                        trans_fn=trans_func)

    if args.init_from_ckpt:
        if not os.path.isfile(args.init_from_ckpt):
            raise ValueError('init_from_ckpt is not a valid model filename.')
        state_dict = paddle.load(args.init_from_ckpt)
        model.set_dict(state_dict)
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    num_training_steps = len(train_data_loader) * args.epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps,
                                         args.warmup_proportion)

    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ['bias', 'norm'])
    ]

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    criterion = paddle.nn.functional.softmax_with_cross_entropy

    metric = NERChunkEvaluator(label_list)

    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)

    global_step = 0
    tic_train = time.time()
    total_train_time = 0
    for epoch in range(1, args.epochs + 1):
        for step, batch in enumerate(train_data_loader, start=1):
            input_ids, token_type_ids, position_ids, masks, label_oth, label_sym = batch
            with paddle.amp.auto_cast(
                    args.use_amp,
                    custom_white_list=['layer_norm', 'softmax', 'gelu'],
            ):
                logits = model(input_ids, token_type_ids, position_ids)

                loss_mask = paddle.unsqueeze(masks, 2)
                losses = [(criterion(x, y.unsqueeze(2)) * loss_mask).mean()
                          for x, y in zip(logits, [label_oth, label_sym])]
                loss = losses[0] + losses[1]

                lengths = paddle.sum(masks, axis=1)
                preds = [paddle.argmax(x, axis=-1) for x in logits]
                correct = metric.compute(lengths, preds,
                                         [label_oth, label_sym])
                metric.update(correct)
                _, _, f1 = metric.accumulate()

                if args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.minimize(optimizer, loss)
                else:
                    loss.backward()
                    optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()

                global_step += 1
                if global_step % args.logging_steps == 0 and rank == 0:
                    time_diff = time.time() - tic_train
                    total_train_time += time_diff
                    print(
                        'global step %d, epoch: %d, batch: %d, loss: %.5f, loss symptom: %.5f, loss others: %.5f, f1: %.5f, speed: %.2f step/s, learning_rate: %f'
                        % (global_step, epoch, step, loss, losses[1],
                           losses[0], f1, args.logging_steps / time_diff,
                           lr_scheduler.get_lr()))
                    tic_train = time.time()

                if global_step % args.valid_steps == 0 and rank == 0:
                    evaluate(model, criterion, metric, dev_data_loader)
                    tic_train = time.time()

                if global_step % args.save_steps == 0 and rank == 0:
                    save_dir = os.path.join(args.save_dir,
                                            'model_%d' % global_step)
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    if paddle.distributed.get_world_size() > 1:
                        model._layers.save_pretrained(save_dir)
                    else:
                        model.save_pretrained(save_dir)
                    tokenizer.save_pretrained(save_dir)
                    tic_train = time.time()
    print('Speed: %.2f steps/s' % (global_step / total_train_time))
Пример #16
0
def do_train():
    paddle.set_device(args.device)
    rank = paddle.distributed.get_rank()
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args.seed)

    train_ds, dev_ds = load_dataset('cblue',
                                    args.dataset,
                                    splits=['train', 'dev'])

    model = ElectraForSequenceClassification.from_pretrained(
        'ernie-health-chinese',
        num_classes=len(train_ds.label_list),
        activation='tanh')
    tokenizer = ElectraTokenizer.from_pretrained('ernie-health-chinese')

    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         max_seq_length=args.max_seq_length)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
            ),  # segment
        Pad(axis=0, pad_val=args.max_seq_length - 1, dtype='int64'
            ),  # position
        Stack(dtype='int64')): [data for data in fn(samples)]
    train_data_loader = create_dataloader(train_ds,
                                          mode='train',
                                          batch_size=args.batch_size,
                                          batchify_fn=batchify_fn,
                                          trans_fn=trans_func)
    dev_data_loader = create_dataloader(dev_ds,
                                        mode='dev',
                                        batch_size=args.batch_size,
                                        batchify_fn=batchify_fn,
                                        trans_fn=trans_func)

    if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
        state_dict = paddle.load(args.init_from_ckpt)
        state_keys = {
            x: x.replace('discriminator.', '')
            for x in state_dict.keys() if 'discriminator.' in x
        }
        if len(state_keys) > 0:
            state_dict = {
                state_keys[k]: state_dict[k]
                for k in state_keys.keys()
            }
        model.set_dict(state_dict)
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.epochs
    args.epochs = (num_training_steps - 1) // len(train_data_loader) + 1

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps,
                                         args.warmup_proportion)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ['bias', 'norm'])
    ]

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    criterion = paddle.nn.loss.CrossEntropyLoss()
    if METRIC_CLASSES[args.dataset] is Accuracy:
        metric = METRIC_CLASSES[args.dataset]()
        metric_name = 'accuracy'
    elif METRIC_CLASSES[args.dataset] is MultiLabelsMetric:
        metric = METRIC_CLASSES[args.dataset](
            num_labels=len(train_ds.label_list))
        metric_name = 'macro f1'
    else:
        metric = METRIC_CLASSES[args.dataset]()
        metric_name = 'micro f1'
    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
    global_step = 0
    tic_train = time.time()
    total_train_time = 0
    for epoch in range(1, args.epochs + 1):
        for step, batch in enumerate(train_data_loader, start=1):
            input_ids, token_type_ids, position_ids, labels = batch
            with paddle.amp.auto_cast(
                    args.use_amp,
                    custom_white_list=[
                        'layer_norm', 'softmax', 'gelu', 'tanh'
                    ],
            ):
                logits = model(input_ids, token_type_ids, position_ids)
                loss = criterion(logits, labels)
            probs = F.softmax(logits, axis=1)
            correct = metric.compute(probs, labels)
            metric.update(correct)

            if isinstance(metric, Accuracy):
                result = metric.accumulate()
            elif isinstance(metric, MultiLabelsMetric):
                _, _, result = metric.accumulate('macro')
            else:
                _, _, _, result, _ = metric.accumulate()

            if args.use_amp:
                scaler.scale(loss).backward()
                scaler.minimize(optimizer, loss)
            else:
                loss.backward()
                optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

            global_step += 1
            if global_step % args.logging_steps == 0 and rank == 0:
                time_diff = time.time() - tic_train
                total_train_time += time_diff
                print(
                    'global step %d, epoch: %d, batch: %d, loss: %.5f, %s: %.5f, speed: %.2f step/s'
                    % (global_step, epoch, step, loss, metric_name, result,
                       args.logging_steps / time_diff))

            if global_step % args.valid_steps == 0 and rank == 0:
                evaluate(model, criterion, metric, dev_data_loader)

            if global_step % args.save_steps == 0 and rank == 0:
                save_dir = os.path.join(args.save_dir,
                                        'model_%d' % global_step)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                if paddle.distributed.get_world_size() > 1:
                    model._layers.save_pretrained(save_dir)
                else:
                    model.save_pretrained(save_dir)
                tokenizer.save_pretrained(save_dir)

            if global_step >= num_training_steps:
                return
            tic_train = time.time()

    if rank == 0 and total_train_time > 0:
        print('Speed: %.2f steps/s' % (global_step / total_train_time))
Пример #17
0
def LoadMNLI(args, tokenizer):
    """
    Function that loads the MultiNLI entailment dataset.
    Inputs:
        args - Namespace object from the argument parser
        tokenizer - BERT tokenizer instance
    Outputs:
        train_set - Training dataset
        dev_set - Development dataset
        test_set - Test dataset
    """

    # load the sst dataset
    dataset = load_dataset("multi_nli")

    # divide into train, dev and test
    train_set = dataset['train']
    dataset = dataset['validation_matched'].train_test_split(test_size=0.5,
                                                             train_size=0.5,
                                                             shuffle=True)
    dev_set = dataset['train']
    test_set = dataset['test']

    # function that encodes the sentences
    def encode_sentence(examples):
        return tokenizer('[CLS] ' + examples['premise'] + ' [SEP] ' +
                         examples['hypothesis'] + ' [SEP]',
                         truncation=True,
                         padding='max_length')

    # tokenize the datasets
    train_set = train_set.map(encode_sentence, batched=False)
    dev_set = dev_set.map(encode_sentence, batched=False)
    test_set = test_set.map(encode_sentence, batched=False)

    # remove unnecessary columns
    train_set = train_set.remove_columns([
        'promptID', 'pairID', 'premise', 'premise_binary_parse',
        'premise_parse', 'hypothesis', 'hypothesis_binary_parse',
        'hypothesis_parse', 'genre'
    ])
    dev_set = dev_set.remove_columns([
        'promptID', 'pairID', 'premise', 'premise_binary_parse',
        'premise_parse', 'hypothesis', 'hypothesis_binary_parse',
        'hypothesis_parse', 'genre'
    ])
    test_set = test_set.remove_columns([
        'promptID', 'pairID', 'premise', 'premise_binary_parse',
        'premise_parse', 'hypothesis', 'hypothesis_binary_parse',
        'hypothesis_parse', 'genre'
    ])

    # rename the labels
    train_set = train_set.rename_column("label", "labels")
    dev_set = dev_set.rename_column("label", "labels")
    test_set = test_set.rename_column("label", "labels")

    # create dataloaders for the datasets
    train_set = create_dataloader(args, train_set, tokenizer)
    dev_set = create_dataloader(args, dev_set, tokenizer)
    test_set = create_dataloader(args, test_set, tokenizer)

    # return the datasets
    return train_set, dev_set, test_set
Пример #18
0
def test_one_to_one(task_load, task_eval, model, score_dict):

    logger.info("start to test { task: %s (load) %s (eval), seq train type: %s }" % (task_load, task_eval, args.seq_train_type))

    test_qadata = QADataset(TASK_DICT[task_eval]["test"] , "test", SPECIAL_TOKEN_IDS[task_load]).sort()
    max_a_len = test_qadata.max_a_len
    test_dataloader = create_dataloader(test_qadata, "test")
    n_examples = len(test_qadata)
    logger.info("len of test dataset: {}".format(n_examples))

    need_process = OrderedDict()
    qa_results = [0 for _ in range(n_examples)]
    all_pasts = [[0 for _ in range(n_examples)] for __ in range(MODEL_CONFIG.n_layer)]
    max_tot_lens = [0 for _ in range(n_examples)]

    cnt = 0
    for n_steps, (cqs, len_cqs, _, _, _, _, _) in enumerate(test_dataloader):
        # assume n_gpus == 1
        cqs = cqs[0]
        len_cqs = len_cqs[0]
        n_inputs = cqs.shape[0]
        all_outputs = model(input_ids=cqs.cuda())
        outputs = all_outputs[0]
        if args.model_name == "gpt2":
            pasts = all_outputs[1]
        next_logits = outputs[range(n_inputs), len_cqs-1, :] / args.temperature_qa
        next_tokens = logits_to_tokens(next_logits).cpu()

        for i in range(n_inputs):
            max_tot_lens[cnt] = max_a_len + test_qadata[cnt][1]
            qa_results[cnt] = cqs[i][:len_cqs[i]]
            if next_tokens[i] != SPECIAL_TOKEN_IDS["eos_token"]:
                qa_results[cnt] = torch.cat((cqs[i][:len_cqs[i]], next_tokens[i]))
                if len(qa_results[cnt]) not in [max_tot_lens[cnt], args.max_len]:
                    need_process.update([[cnt, None]])
                    if args.model_name == "gpt2":
                        for layer_id in range(MODEL_CONFIG.n_layer):
                            all_pasts[layer_id][cnt] = pasts[layer_id][:, i, ..., :len_cqs[i], :].type(torch.float32 if args.fp32 else torch.half)
            cnt += 1

        if len(need_process) > int(12 * args.memory_sizes[0] / cqs.shape[1]):  # dynamic threshold to avoid out of memory
            sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens)
    sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens)

    if task_eval in ['wikisql','woz.en','multinli.in.out']:
        ids = test_qadata.get_indices()
        test_qadata.sort_by_index()
        qa_results = [x[1] for x in sorted([(i, g) for i, g in zip(ids, qa_results)])]
    for i in range(len(test_qadata)):
        _, len_cq, _, _, Y, _, _, _ = test_qadata[i]
        if task_eval in ['wikisql','woz.en']:
            Y = test_qadata.answers[i]
        else:
            Y = list(filter(lambda x: x != -1, Y))[:-1]  # remove eos
            Y = ' '.join([str(y) for y in Y]).split(str(SPECIAL_TOKEN_IDS["pad_token"]))
            Y = [TOKENIZER.decode(list(map(int, y.split()))) for y in Y]
        qa_results[i] = [TOKENIZER.decode(qa_results[i].tolist()[len_cq:]), Y]
    get_test_score(task_eval, qa_results, score_dict)

    model_dir = model.model_dir
    ep = model.ep
    results_path = os.path.join(model_dir,"qa_{}_{}.csv".format(task_eval,ep+1))
    if not args.debug:
        with open(results_path, "w",encoding="utf-8") as f:
            qa_writer = csv.writer(f,delimiter=',')
            qa_writer.writerow(["y","pred"])
            for pred, y in qa_results:
                if task_eval == 'wikisql': 
                    y = y["answer"]
                elif task_eval == 'woz.en': 
                    y = y[1]
                qa_writer.writerow([y,pred])

    return model, score_dict
Пример #19
0
    weight_path, cfg_path, output_path, save_freq = args.weights, args.cfg, args.output, args.save_freq
    epochs, lr, batch_size, optimizer_cfg, input_size = args.epochs, args.lr, args.batch_size, \
                                                        args.optimizer, args.input_size
    train_proportion, valid_proportion, test_proportion = args.train_proportion, args.valid_proportion, args.test_proportion

    # load configs from config file
    cfg = parse_cfg(cfg_path)
    print('Config:', cfg)
    dataset_path, num_classes = cfg['dataset'], int(cfg['num_classes'])

    # load dataset
    train_loader, val_loader, test_loader = create_dataloader(
        'MY_DATASET',
        dataset_path,
        batch_size,
        input_size,
        num_per_class=200,
        train_proportion=train_proportion,
        valid_proportion=valid_proportion,
        test_proportion=test_proportion)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load model
    model = build_model(weight_path, cfg).to(device)
    print('Model successfully loaded!')

    # plot model structure
    # from torchviz import make_dot
    # graph = make_dot(model(torch.rand(1, 3, input_size, input_size).cuda()),
    #                  params=dict(model.named_parameters()))
Пример #20
0
def do_train():
    paddle.set_device(args.device)
    rank = paddle.distributed.get_rank()
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args.seed)

    train_ds, dev_ds = load_dataset('cblue', 'CMeIE', splits=['train', 'dev'])

    model = ElectraForSPO.from_pretrained('ernie-health-chinese',
                                          num_classes=len(train_ds.label_list))
    tokenizer = ElectraTokenizer.from_pretrained('ernie-health-chinese')

    trans_func = partial(convert_example_spo,
                         tokenizer=tokenizer,
                         num_classes=len(train_ds.label_list),
                         max_seq_length=args.max_seq_length)

    def batchify_fn(data):
        _batchify_fn = lambda samples, fn=Dict({
            'input_ids':
            Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),
            'token_type_ids':
            Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),
            'position_ids':
            Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),
            'attention_mask':
            Pad(axis=0, pad_val=0, dtype='float32'),
        }): fn(samples)
        ent_label = [x['ent_label'] for x in data]
        spo_label = [x['spo_label'] for x in data]
        input_ids, token_type_ids, position_ids, masks = _batchify_fn(data)
        batch_size, batch_len = input_ids.shape
        num_classes = len(train_ds.label_list)
        # Create one-hot labels.
        #
        # For example,
        # - text:
        #   [CLS], 局, 部, 皮, 肤, 感, 染, 引, 起, 的, 皮, 疹, 等, [SEP]
        #
        # - ent_label (obj: `list`):
        #   [(0, 5), (9, 10)] # ['局部皮肤感染', '皮疹']
        #
        # - one_hot_ent_label: # shape (sequence_length, 2)
        #   [[ 0,  1,  0,  0,  0,  0,  0,  0,  0,  0,  1,  0,  0,  0], # start index
        #    [ 0,  0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  1,  0,  0]] # end index
        #
        # - spo_label (obj: `list`):
        #   [(0, 23, 9)] # [('局部皮肤感染', '相关(导致)', '皮疹')], where entities
        #                  are encoded by their start indexes.
        #
        # - one_hot_spo_label: # shape (num_predicate, sequence_length, sequence_length)
        #   [...,
        #    [..., [0, ..., 1, ..., 0], ...], # for predicate '相关(导致)'
        #    ...]                             # the value at [23, 1, 10] is set as 1
        #
        one_hot_ent_label = np.zeros([batch_size, batch_len, 2],
                                     dtype=np.float32)
        one_hot_spo_label = np.zeros(
            [batch_size, num_classes, batch_len, batch_len], dtype=np.float32)
        for idx, ent_idxs in enumerate(ent_label):
            # Shift index by 1 because input_ids start with [CLS] here.
            for x, y in ent_idxs:
                x = x + 1
                y = y + 1
                if x > 0 and x < batch_len and y < batch_len:
                    one_hot_ent_label[idx, x, 0] = 1
                    one_hot_ent_label[idx, y, 1] = 1
        for idx, spo_idxs in enumerate(spo_label):
            for s, p, o in spo_idxs:
                s_id = s[0] + 1
                o_id = o[0] + 1
                if s_id > 0 and s_id < batch_len and o_id < batch_len:
                    one_hot_spo_label[idx, p, s_id, o_id] = 1
        # one_hot_xxx_label are used for loss computation.
        # xxx_label are used for metric computation.
        ent_label = [one_hot_ent_label, ent_label]
        spo_label = [one_hot_spo_label, spo_label]
        return input_ids, token_type_ids, position_ids, masks, ent_label, spo_label

    train_data_loader = create_dataloader(train_ds,
                                          mode='train',
                                          batch_size=args.batch_size,
                                          batchify_fn=batchify_fn,
                                          trans_fn=trans_func)

    dev_data_loader = create_dataloader(dev_ds,
                                        mode='dev',
                                        batch_size=args.batch_size,
                                        batchify_fn=batchify_fn,
                                        trans_fn=trans_func)

    if args.init_from_ckpt:
        if not os.path.isfile(args.init_from_ckpt):
            raise ValueError('init_from_ckpt is not a valid model filename.')
        state_dict = paddle.load(args.init_from_ckpt)
        state_keys = {
            x: x.replace('discriminator.', '')
            for x in state_dict.keys() if 'discriminator.' in x
        }
        if len(state_keys) > 0:
            state_dict = {
                state_keys[k]: state_dict[k]
                for k in state_keys.keys()
            }
        model.set_dict(state_dict)
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.epochs
    args.epochs = (num_training_steps - 1) // len(train_data_loader) + 1

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate,
                                         num_training_steps,
                                         args.warmup_proportion)
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ['bias', 'norm'])
    ]

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    criterion = F.binary_cross_entropy_with_logits

    metric = SPOChunkEvaluator(num_classes=len(train_ds.label_list))

    if args.use_amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
    global_step = 0
    tic_train = time.time()
    total_train_time = 0
    for epoch in range(1, args.epochs + 1):
        for step, batch in enumerate(train_data_loader, start=1):
            input_ids, token_type_ids, position_ids, masks, ent_label, spo_label = batch
            max_batch_len = input_ids.shape[-1]
            ent_mask = paddle.unsqueeze(masks, axis=2)
            spo_mask = paddle.matmul(ent_mask, ent_mask, transpose_y=True)
            spo_mask = paddle.unsqueeze(spo_mask, axis=1)

            with paddle.amp.auto_cast(
                    args.use_amp,
                    custom_white_list=['layer_norm', 'softmax', 'gelu'],
            ):
                logits = model(input_ids, token_type_ids, position_ids)
                ent_loss = criterion(logits[0],
                                     ent_label[0],
                                     weight=ent_mask,
                                     reduction='sum')
                spo_loss = criterion(logits[1],
                                     spo_label[0],
                                     weight=spo_mask,
                                     reduction='sum')

                loss = ent_loss + spo_loss

            if args.use_amp:
                scaler.scale(loss).backward()
                scaler.minimize(optimizer, loss)
            else:
                loss.backward()
                optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

            global_step += 1
            if global_step % args.logging_steps == 0 and rank == 0:
                time_diff = time.time() - tic_train
                total_train_time += time_diff
                print('global step %d, epoch: %d, batch: %d, loss: %.5f, '
                      'ent_loss: %.5f, spo_loss: %.5f, speed: %.2f steps/s' %
                      (global_step, epoch, step, loss, ent_loss, spo_loss,
                       args.logging_steps / time_diff))

            if global_step % args.valid_steps == 0 and rank == 0:
                evaluate(model, criterion, metric, dev_data_loader)

            if global_step % args.save_steps == 0 and rank == 0:
                save_dir = os.path.join(args.save_dir,
                                        'model_%d' % global_step)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                if paddle.distributed.get_world_size() > 1:
                    model._layers.save_pretrained(save_dir)
                else:
                    model.save_pretrained(save_dir)
                tokenizer.save_pretrained(save_dir)

            if global_step >= num_training_steps:
                return
            tic_train = time.time()

    if rank == 0 and total_train_time > 0:
        print('Speed: %.2f steps/s' % (global_step / total_train_time))
Пример #21
0
def do_train(args):
    set_seed(args)
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    pinyin_vocab = Vocab.load_vocabulary(
        args.pinyin_vocab_file_path, unk_token='[UNK]', pad_token='[PAD]')

    tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    ernie = ErnieModel.from_pretrained(args.model_name_or_path)

    model = ErnieForCSC(
        ernie,
        pinyin_vocab_size=len(pinyin_vocab),
        pad_pinyin_id=pinyin_vocab[pinyin_vocab.pad_token])

    train_ds, eval_ds = load_dataset('sighan-cn', splits=['train', 'dev'])

    # Extend current training dataset by providing extra training 
    # datasets directory. The suffix of dataset file name in extra 
    # dataset directory has to be ".txt". The data format of
    # dataset need to be a couple of senteces at every line, such as:
    # "城府宫员表示,这是过去三十六小时内第三期强烈的余震。\t政府官员表示,这是过去三十六小时内第三起强烈的余震。\n"
    if args.extra_train_ds_dir is not None and os.path.exists(
            args.extra_train_ds_dir):
        data = train_ds.data
        data_files = [
            os.path.join(args.extra_train_ds_dir, data_file)
            for data_file in os.listdir(args.extra_train_ds_dir)
            if data_file.endswith(".txt")
        ]
        for data_file in data_files:
            ds = load_dataset(
                read_train_ds,
                data_path=data_file,
                splits=["train"],
                lazy=False)
            data += ds.data
        train_ds = MapDataset(data)

    det_loss_act = paddle.nn.CrossEntropyLoss(
        ignore_index=args.ignore_label, use_softmax=False)
    corr_loss_act = paddle.nn.CrossEntropyLoss(
        ignore_index=args.ignore_label, reduction='none')

    trans_func = partial(
        convert_example,
        tokenizer=tokenizer,
        pinyin_vocab=pinyin_vocab,
        max_seq_length=args.max_seq_length)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # segment
        Pad(axis=0, pad_val=pinyin_vocab.token_to_idx[pinyin_vocab.pad_token]),  # pinyin
        Pad(axis=0, dtype="int64"),  # detection label
        Pad(axis=0, dtype="int64"),  # correction label
        Stack(axis=0, dtype="int64")  # length
    ): [data for data in fn(samples)]

    train_data_loader = create_dataloader(
        train_ds,
        mode='train',
        batch_size=args.batch_size,
        batchify_fn=batchify_fn,
        trans_fn=trans_func)

    eval_data_loader = create_dataloader(
        eval_ds,
        mode='eval',
        batch_size=args.batch_size,
        batchify_fn=batchify_fn,
        trans_fn=trans_func)

    num_training_steps = args.max_steps if args.max_steps > 0 else len(
        train_data_loader) * args.epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
                                         args.warmup_proportion)

    logger.info("Total training step: {}".format(num_training_steps))
    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in decay_params)

    global_steps = 1
    best_f1 = -1
    tic_train = time.time()
    for epoch in range(args.epochs):
        for step, batch in enumerate(train_data_loader, start=1):
            input_ids, token_type_ids, pinyin_ids, det_labels, corr_labels, length = batch
            det_error_probs, corr_logits = model(input_ids, pinyin_ids,
                                                 token_type_ids)
            # Chinese Spelling Correction has 2 tasks: detection task and correction task.
            # Detection task aims to detect whether each Chinese charater has spelling error.
            # Correction task aims to correct each potential wrong charater to right charater.
            # So we need to minimize detection loss and correction loss simultaneously.
            # See more loss design details on https://aclanthology.org/2021.findings-acl.198.pdf
            det_loss = det_loss_act(det_error_probs, det_labels)
            corr_loss = corr_loss_act(
                corr_logits, corr_labels) * det_error_probs.max(axis=-1)
            loss = (det_loss + corr_loss).mean()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()

            if global_steps % args.logging_steps == 0:
                logger.info(
                    "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                    % (global_steps, epoch, step, loss,
                       args.logging_steps / (time.time() - tic_train)))
                tic_train = time.time()
            if global_steps % args.save_steps == 0:
                if paddle.distributed.get_rank() == 0:
                    logger.info("Eval:")
                    det_f1, corr_f1 = evaluate(model, eval_data_loader)
                    f1 = (det_f1 + corr_f1) / 2
                    model_file = "model_%d" % global_steps
                    if f1 > best_f1:
                        # save best model
                        paddle.save(model.state_dict(),
                                    os.path.join(args.output_dir,
                                                 "best_model.pdparams"))
                        logger.info("Save best model at {} step.".format(
                            global_steps))
                        best_f1 = f1
                        model_file = model_file + "_best"
                    model_file = model_file + ".pdparams"
                    paddle.save(model.state_dict(),
                                os.path.join(args.output_dir, model_file))
                    logger.info("Save model at {} step.".format(global_steps))
            if args.max_steps > 0 and global_steps >= args.max_steps:
                return
            global_steps += 1
Пример #22
0
def LoadIQAP(args, tokenizer):
    """
    Function that loads the Indirect Question-Answer Pairs dataset.
    Inputs:
        args - Namespace object from the argument parser
        tokenizer - BERT tokenizer instance
    Outputs:
        train_set - Training dataset
        dev_set - Development dataset
        test_set - Test dataset
    """

    # load the sst dataset
    dataset = load_dataset(
        'csv', data_files='data/local_datasets/iqap/iqap-data.csv')['train']

    # function to convert majority votes to labels
    def assign_label(example):
        label = np.argmax([
            example['definite-yes'], example['probable-yes'],
            example['definite-no'], example['probable-no']
        ])
        example['labels'] = label
        return example

    # assign labels to the dataset
    dataset = dataset.map(assign_label, batched=False)

    # divide into train, dev and test
    train_set = dataset.filter(
        lambda example: example['DevEval'] == 'DEVELOPMENT')
    dataset = dataset.filter(
        lambda example: example['DevEval'] == 'EVALUATION')
    dataset = dataset.train_test_split(test_size=0.5,
                                       train_size=0.5,
                                       shuffle=True)
    dev_set = dataset['train']
    test_set = dataset['test']

    # function that encodes the question and passage
    def encode_sentence(examples):
        return tokenizer('[CLS] ' + examples['Question'] + ' [SEP] ' +
                         examples['Answer'] + ' [SEP]',
                         truncation=True,
                         padding='max_length')

    # tokenize the datasets
    train_set = train_set.map(encode_sentence, batched=False)
    dev_set = dev_set.map(encode_sentence, batched=False)
    test_set = test_set.map(encode_sentence, batched=False)

    # remove unnecessary columns
    train_set = train_set.remove_columns([
        'Answer', 'AnswerParse', 'Classification', 'DevEval', 'Item', 'Prefix',
        'Question', 'QuestionParse', 'Source', 'definite-no', 'definite-yes',
        'probable-no', 'probable-yes'
    ])
    dev_set = dev_set.remove_columns([
        'Answer', 'AnswerParse', 'Classification', 'DevEval', 'Item', 'Prefix',
        'Question', 'QuestionParse', 'Source', 'definite-no', 'definite-yes',
        'probable-no', 'probable-yes'
    ])
    test_set = test_set.remove_columns([
        'Answer', 'AnswerParse', 'Classification', 'DevEval', 'Item', 'Prefix',
        'Question', 'QuestionParse', 'Source', 'definite-no', 'definite-yes',
        'probable-no', 'probable-yes'
    ])

    # create dataloaders for the datasets
    train_set = create_dataloader(args, train_set, tokenizer)
    dev_set = create_dataloader(args, dev_set, tokenizer)
    test_set = create_dataloader(args, test_set, tokenizer)

    # return the datasets
    return train_set, dev_set, test_set
Пример #23
0
def main():
    # ==========
    # parameters
    # ==========

    opts_dict = receive_arg()
    unit = opts_dict['test']['criterion']['unit']

    # ==========
    # open logger
    # ==========

    log_fp = open(opts_dict['train']['log_path'], 'w')
    msg = (
        f"{'<' * 10} Test {'>' * 10}\n"
        f"Timestamp: [{utils.get_timestr()}]\n"
        f"\n{'<' * 10} Options {'>' * 10}\n"
        f"{utils.dict2str(opts_dict['test'])}"
        )
    print(msg)
    log_fp.write(msg + '\n')
    log_fp.flush()

    # ========== 
    # Ensure reproducibility or Speed up
    # ==========

    #torch.backends.cudnn.benchmark = False  # if reproduce
    #torch.backends.cudnn.deterministic = True  # if reproduce
    torch.backends.cudnn.benchmark = True  # speed up

    # ==========
    # create test data prefetchers
    # ==========
    
    # create datasets
    test_ds_type = opts_dict['dataset']['test']['type']
    radius = opts_dict['network']['radius']
    assert test_ds_type in dataset.__all__, \
        "Not implemented!"
    test_ds_cls = getattr(dataset, test_ds_type)
    test_ds = test_ds_cls(
        opts_dict=opts_dict['dataset']['test'], 
        radius=radius
        )

    test_num = len(test_ds)
    test_vid_num = test_ds.get_vid_num()

    # create datasamplers
    test_sampler = None  # no need to sample test data

    # create dataloaders
    test_loader = utils.create_dataloader(
        dataset=test_ds, 
        opts_dict=opts_dict, 
        sampler=test_sampler, 
        phase='val'
        )
    assert test_loader is not None

    # create dataloader prefetchers
    test_prefetcher = utils.CPUPrefetcher(test_loader)

    # ==========
    # create & load model
    # ==========

    model = MFVQE(opts_dict=opts_dict['network'])

    checkpoint_save_path = opts_dict['test']['checkpoint_save_path']
    msg = f'loading model {checkpoint_save_path}...'
    print(msg)
    log_fp.write(msg + '\n')

    checkpoint = torch.load(checkpoint_save_path)
    if 'module.' in list(checkpoint['state_dict'].keys())[0]:  # multi-gpu training
        new_state_dict = OrderedDict()
        for k, v in checkpoint['state_dict'].items():
            name = k[7:]  # remove module
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    else:  # single-gpu training
        model.load_state_dict(checkpoint['state_dict'])
    
    msg = f'> model {checkpoint_save_path} loaded.'
    print(msg)
    log_fp.write(msg + '\n')

    model = model.cuda()
    model.eval()

    # ==========
    # define criterion
    # ==========

    # define criterion
    assert opts_dict['test']['criterion'].pop('type') == \
        'PSNR', "Not implemented."
    criterion = utils.PSNR()

    # ==========
    # validation
    # ==========
                
    # create timer
    total_timer = utils.Timer()

    # create counters
    per_aver_dict = dict()
    ori_aver_dict = dict()
    name_vid_dict = dict()
    for index_vid in range(test_vid_num):
        per_aver_dict[index_vid] = utils.Counter()
        ori_aver_dict[index_vid] = utils.Counter()
        name_vid_dict[index_vid] = ""

    pbar = tqdm(
        total=test_num, 
        ncols=opts_dict['test']['pbar_len']
        )

    # fetch the first batch
    test_prefetcher.reset()
    val_data = test_prefetcher.next()

    with torch.no_grad():
        while val_data is not None:
            # get data
            gt_data = val_data['gt'].cuda()  # (B [RGB] H W)
            lq_data = val_data['lq'].cuda()  # (B T [RGB] H W)
            index_vid = val_data['index_vid'].item()
            name_vid = val_data['name_vid'][0]  # bs must be 1!
            
            b, _, c, _, _  = lq_data.shape
            assert b == 1, "Not supported!"
            
            input_data = torch.cat(
                [lq_data[:,:,i,...] for i in range(c)], 
                dim=1
                )  # B [R1 ... R7 G1 ... G7 B1 ... B7] H W
            enhanced_data = model(input_data)  # (B [RGB] H W)

            # eval
            batch_ori = criterion(lq_data[0, radius, ...], gt_data[0])
            batch_perf = criterion(enhanced_data[0], gt_data[0])

            # display
            pbar.set_description(
                "{:s}: [{:.3f}] {:s} -> [{:.3f}] {:s}"
                .format(name_vid, batch_ori, unit, batch_perf, unit)
                )
            pbar.update()

            # log
            per_aver_dict[index_vid].accum(volume=batch_perf)
            ori_aver_dict[index_vid].accum(volume=batch_ori)
            if name_vid_dict[index_vid] == "":
                name_vid_dict[index_vid] = name_vid
            else:
                assert name_vid_dict[index_vid] == name_vid, "Something wrong."

            # fetch next batch
            val_data = test_prefetcher.next()
        
    # end of val
    pbar.close()

    # log
    msg = '\n' + '<' * 10 + ' Results ' + '>' * 10
    print(msg)
    log_fp.write(msg + '\n')
    for index_vid in range(test_vid_num):
        per = per_aver_dict[index_vid].get_ave()
        ori = ori_aver_dict[index_vid].get_ave()
        name_vid = name_vid_dict[index_vid]
        msg = "{:s}: [{:.3f}] {:s} -> [{:.3f}] {:s}".format(
            name_vid, ori, unit, per, unit
            )
        print(msg)
        log_fp.write(msg + '\n')
    ave_per = np.mean([
        per_aver_dict[index_vid].get_ave() for index_vid in range(test_vid_num)
        ])
    ave_ori = np.mean([
        ori_aver_dict[index_vid].get_ave() for index_vid in range(test_vid_num)
        ])
    msg = (
        f"{'> ori: [{:.3f}] {:s}'.format(ave_ori, unit)}\n"
        f"{'> ave: [{:.3f}] {:s}'.format(ave_per, unit)}\n"
        f"{'> delta: [{:.3f}] {:s}'.format(ave_per - ave_ori, unit)}"
        )
    print(msg)
    log_fp.write(msg + '\n')
    log_fp.flush()

    # ==========
    # final log & close logger
    # ==========

    total_time = total_timer.get_interval() / 3600
    msg = "TOTAL TIME: [{:.1f}] h".format(total_time)
    print(msg)
    log_fp.write(msg + '\n')
    
    msg = (
        f"\n{'<' * 10} Goodbye {'>' * 10}\n"
        f"Timestamp: [{utils.get_timestr()}]"
        )
    print(msg)
    log_fp.write(msg + '\n')
    
    log_fp.close()
Пример #24
0
def test_read_dataset_time():
    train_iter,_ = utils.create_dataloader()
    start=time.time()
    for X,y in train_iter:
        continue
    print('%.2f sec.'%(time.time()-start))
Пример #25
0
def do_train():
    paddle.set_device(args.device)
    rank = paddle.distributed.get_rank()
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args.seed)

    encoding_model = MODEL_MAP[args.model]['encoding_model']
    resource_file_urls = MODEL_MAP[args.model]['resource_file_urls']

    for key, val in resource_file_urls.items():
        file_path = os.path.join(args.model, key)
        if not os.path.exists(file_path):
            get_path_from_url(val, args.model)

    tokenizer = AutoTokenizer.from_pretrained(encoding_model)
    model = UIE.from_pretrained(args.model)

    if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
        state_dict = paddle.load(args.init_from_ckpt)
        model.set_dict(state_dict)

    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    train_ds = load_dataset(reader,
                            data_path=args.train_path,
                            max_seq_len=args.max_seq_len,
                            lazy=False)
    dev_ds = load_dataset(reader,
                          data_path=args.dev_path,
                          max_seq_len=args.max_seq_len,
                          lazy=False)

    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         max_seq_len=args.max_seq_len)

    train_data_loader = create_dataloader(dataset=train_ds,
                                          mode='train',
                                          batch_size=args.batch_size,
                                          trans_fn=trans_func)

    dev_data_loader = create_dataloader(dataset=dev_ds,
                                        mode='dev',
                                        batch_size=args.batch_size,
                                        trans_fn=trans_func)

    optimizer = paddle.optimizer.AdamW(learning_rate=args.learning_rate,
                                       parameters=model.parameters())

    criterion = paddle.nn.BCELoss()
    metric = SpanEvaluator()

    loss_list = []
    global_step = 0
    best_step = 0
    best_f1 = 0
    tic_train = time.time()
    for epoch in range(1, args.num_epochs + 1):
        for batch in train_data_loader:
            input_ids, token_type_ids, att_mask, pos_ids, start_ids, end_ids = batch
            start_prob, end_prob = model(input_ids, token_type_ids, att_mask,
                                         pos_ids)
            start_ids = paddle.cast(start_ids, 'float32')
            end_ids = paddle.cast(end_ids, 'float32')
            loss_start = criterion(start_prob, start_ids)
            loss_end = criterion(end_prob, end_ids)
            loss = (loss_start + loss_end) / 2.0
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            loss_list.append(float(loss))

            global_step += 1
            if global_step % args.logging_steps == 0 and rank == 0:
                time_diff = time.time() - tic_train
                loss_avg = sum(loss_list) / len(loss_list)
                print(
                    "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s"
                    % (global_step, epoch, loss_avg,
                       args.logging_steps / time_diff))
                tic_train = time.time()

            if global_step % args.valid_steps == 0 and rank == 0:
                save_dir = os.path.join(args.save_dir,
                                        "model_%d" % global_step)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                model.save_pretrained(save_dir)

                precision, recall, f1 = evaluate(model, metric,
                                                 dev_data_loader)
                print("Evaluation precision: %.5f, recall: %.5f, F1: %.5f" %
                      (precision, recall, f1))
                if f1 > best_f1:
                    print(
                        f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}"
                    )
                    best_f1 = f1
                    save_dir = os.path.join(args.save_dir, "model_best")
                    model.save_pretrained(save_dir)
                tic_train = time.time()
Пример #26
0
                     max_seq_length=args.max_seq_length)

# Form data into batch data, such as padding text sequences of different lengths into the maximum length of batch data,
# and stack each data label together
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input_ids
    # Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type_ids
    Pad(axis=0, pad_val=0),  # pinyin_ids
    Stack()  # labels
): [data for data in fn(samples)]

from utils import create_dataloader

train_data_loader = create_dataloader(train_ds,
                                      mode='train',
                                      batch_size=args.batch_size,
                                      batchify_fn=batchify_fn,
                                      trans_fn=trans_func)
dev_data_loader = create_dataloader(dev_ds,
                                    mode='dev',
                                    batch_size=args.batch_size,
                                    batchify_fn=batchify_fn,
                                    trans_fn=trans_func)
test_data_loader = create_dataloader(test_ds,
                                     mode='test',
                                     batch_size=args.batch_size,
                                     batchify_fn=batchify_fn,
                                     trans_fn=trans_func)

from utils import evaluate
Пример #27
0
def do_predict(args):
    paddle.set_device(args.device)

    pinyin_vocab = Vocab.load_vocabulary(args.pinyin_vocab_file_path,
                                         unk_token='[UNK]',
                                         pad_token='[PAD]')

    tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    ernie = ErnieModel.from_pretrained(args.model_name_or_path)

    model = ErnieForCSC(ernie,
                        pinyin_vocab_size=len(pinyin_vocab),
                        pad_pinyin_id=pinyin_vocab[pinyin_vocab.pad_token])

    eval_ds = load_dataset(read_test_ds, data_path=args.test_file, lazy=False)
    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         pinyin_vocab=pinyin_vocab,
                         max_seq_length=args.max_seq_length,
                         is_test=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'
            ),  # segment
        Pad(axis=0,
            pad_val=pinyin_vocab.token_to_idx[pinyin_vocab.pad_token],
            dtype='int64'),  # pinyin
        Stack(axis=0, dtype='int64'),  # length
    ): [data for data in fn(samples)]

    test_data_loader = create_dataloader(eval_ds,
                                         mode='test',
                                         batch_size=args.batch_size,
                                         batchify_fn=batchify_fn,
                                         trans_fn=trans_func)

    if args.ckpt_path:
        model_dict = paddle.load(args.ckpt_path)
        model.set_dict(model_dict)
        logger.info("Load model from checkpoints: {}".format(args.ckpt_path))

    model.eval()
    corr_preds = []
    det_preds = []
    lengths = []
    for step, batch in enumerate(test_data_loader):
        input_ids, token_type_ids, pinyin_ids, length = batch
        det_error_probs, corr_logits = model(input_ids, pinyin_ids,
                                             token_type_ids)
        # corr_logits shape: [B, T, V]
        det_pred = det_error_probs.argmax(axis=-1)
        det_pred = det_pred.numpy()

        char_preds = corr_logits.argmax(axis=-1)
        char_preds = char_preds.numpy()

        length = length.numpy()

        corr_preds += [pred for pred in char_preds]
        det_preds += [prob for prob in det_pred]
        lengths += [l for l in length]

    write_sighan_result_to_file(args, corr_preds, det_preds, lengths,
                                tokenizer)