示例#1
0
def main():
    # define parser and arguments
    args = get_train_test_args()
    util.set_seed(args.seed)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    DistilBert = DistilBertModel.from_pretrained('distilbert-base-uncased')
    Experts = [DistilBertQA(DistilBertModel.from_pretrained('distilbert-base-uncased')).to(device) for _ in range(args.num_experts)]
    tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
    gate_model = GateNetwork(384, 3,3, DistilBert.config).to(device)
    print(f'Args: {json.dumps(vars(args), indent=4, sort_keys=True)}')
    if args.do_train:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        args.save_dir = util.get_save_dir(args.save_dir, args.run_name)
        log = util.get_logger(args.save_dir, 'log_train')
        log.info(f'Args: {json.dumps(vars(args), indent=4, sort_keys=True)}')
        log.info("Preparing Training Data...")
        args.device = device
        trainer = train.Trainer(args, log)
        train_dataset, _ = get_dataset(args, args.train_datasets, args.train_dir, tokenizer, 'train')
        log.info("Preparing Validation Data...")
        val_dataset, val_dict = get_dataset(args, args.train_datasets, args.val_dir, tokenizer, 'val')
        train_loader = DataLoader(train_dataset,
                                batch_size=args.batch_size,
                                sampler=RandomSampler(train_dataset))
        val_loader = DataLoader(val_dataset,
                                batch_size=1,
                                sampler=SequentialSampler(val_dataset))
        best_scores = trainer.train(Experts, gate_model, train_loader, val_loader, val_dict, args.num_experts)
    if args.do_eval:
        split_name = 'test' if 'test' in args.eval_dir else 'validation'
        log = util.get_logger(args.save_dir, f'log_{split_name}')
        trainer = train.Trainer(args, log)
        # load model
        restore_model("",args.num_experts, Experts, gate_model)
        eval_dataset, eval_dict = get_dataset(args, args.eval_datasets, args.eval_dir, tokenizer, split_name)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=1,
                                 sampler=SequentialSampler(eval_dataset))
        args.device = device
        eval_preds, eval_scores = trainer.evaluate(Experts, gate_model, eval_loader,
                                                   eval_dict, return_preds=True,
                                                   split=split_name)
        results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in eval_scores.items())
        log.info(f'Eval {results_str}')
        # Write submission file
        sub_path = os.path.join(args.save_dir, split_name + '_' + args.sub_file)
        log.info(f'Writing submission file to {sub_path}...')
        with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh:
            csv_writer = csv.writer(csv_fh, delimiter=',')
            csv_writer.writerow(['Id', 'Predicted'])
            for uuid in sorted(eval_preds):
                csv_writer.writerow([uuid, eval_preds[uuid]])
示例#2
0
def main():
    global args, best_top1
    args = parse()
    if not args.no_logger:
        tee.Tee(args.cache + '/log.txt')
    print(vars(args))
    seed(args.manual_seed)

    model, criterion, optimizer = create_model(args)
    if args.resume:
        best_top1 = checkpoints.load(args, model, optimizer)
    print(model)
    trainer = train.Trainer()
    loaders = get_dataset(args)
    train_loader = loaders[0]

    if args.evaluate:
        scores = validate(trainer, loaders, model, criterion, args)
        checkpoints.score_file(scores, "{}/model_000.txt".format(args.cache))
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            trainer.train_sampler.set_epoch(epoch)
        scores = {}
        scores.update(trainer.train(train_loader, model, criterion, optimizer, epoch, args))
        scores.update(validate(trainer, loaders, model, criterion, args, epoch))

        is_best = scores[args.metric] > best_top1
        best_top1 = max(scores[args.metric], best_top1)
        checkpoints.save(epoch, args, model, optimizer, is_best, scores, args.metric)
    if not args.nopdb:
        pdb.set_trace()
示例#3
0
def main():
    global opt, best_mAP
    opt = parse()
    tee.Tee(opt.cache+'/log.txt')
    print(vars(opt))
    seed(opt.manual_seed)

    model, criterion, optimizer = create_model(opt)
    if opt.resume: best_mAP = checkpoints.load(opt, model, optimizer)
    print(model)
    trainer = train.Trainer()
    train_loader, val_loader, valvideo_loader = get_dataset(opt)

    if opt.evaluate:
        #trainer.validate(val_loader, model, criterion, -1, opt)
        trainer.validate_video(valvideo_loader, model, -1, opt)
        return

    for epoch in range(opt.start_epoch, opt.epochs):
        if opt.distributed:
            trainer.train_sampler.set_epoch(epoch)
        top1,top5 = trainer.train(train_loader, model, criterion, optimizer, epoch, opt)
        top1val,top5val = trainer.validate(val_loader, model, criterion, epoch, opt)
        mAP = trainer.validate_video(valvideo_loader, model, epoch, opt)
        is_best = mAP > best_mAP
        best_mAP = max(mAP, best_mAP)
        scores = {'top1train':top1,'top5train':top5,'top1val':top1val,'top5val':top5val,'mAP':mAP}
        checkpoints.save(epoch, opt, model, optimizer, is_best, scores)
示例#4
0
def simpletest1():
    # test if the code can learn a simple sequence
    opt = parse()
    opts(opt)
    epochs = 100
    train_loader, val_loader, valvideo_loader = get_dataset(opt)
    trainer = train.Trainer()
    model = AsyncTFBase(100, 5, opt.nhidden).cuda()
    criterion = AsyncTFCriterion(opt).cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                opt.lr,
                                momentum=opt.momentum,
                                weight_decay=opt.weight_decay)
    epoch = -1
    for i in range(epochs):
        top1, _ = trainer.train(train_loader, model, criterion, optimizer, i,
                                opt)
        print('cls weights: {}, aa weights: {}'.format(
            model.mA.parameters().next().norm().data[0],
            model.mAAa.parameters().next().norm().data[0]))
    top1, _ = trainer.validate(train_loader, model, criterion, epochs, opt)

    for i in range(5):
        top1val, _ = trainer.validate(val_loader, model, criterion, epochs + i,
                                      opt)
        print('top1val: {}'.format(top1val))

    ap = trainer.validate_video(valvideo_loader, model, criterion, epoch, opt)
    return top1, top1val, ap
def main(master_opt, launcher):
    trainers = []
    all_networks = {}
    shared_networks = []
    if launcher != 'none':
        train.init_dist('nccl')
    for i, sub_opt in enumerate(master_opt['trainer_options']):
        sub_opt_parsed = option.parse(sub_opt, is_train=True)
        trainer = train.Trainer()

        #### distributed training settings
        if launcher == 'none':  # disabled distributed training
            sub_opt_parsed['dist'] = False
            trainer.rank = -1
            print('Disabled distributed training.')
        else:
            sub_opt_parsed['dist'] = True
            trainer.world_size = torch.distributed.get_world_size()
            trainer.rank = torch.distributed.get_rank()

        trainer.init(sub_opt_parsed, launcher, all_networks)
        train_gen = trainer.create_training_generator(i)
        model = next(train_gen)
        for k, v in model.networks.items():
            if k in all_networks.keys() and k not in shared_networks:
                shared_networks.append(k)
            all_networks[k] = v.module
        trainers.append(train_gen)
    print("Networks being shared by trainers: ", shared_networks)

    # Now, simply "iterate" through the trainers to accomplish training.
    while True:
        for trainer in trainers:
            next(trainer)
示例#6
0
def main():
    opt = parser.parse_args()
    print(torch.__version__)
    print(opt)

    trainer = train.Trainer(name=opt.name,
                            models_root=opt.models_path,
                            rewrite=False,
                            connect_tb=False)
    trainer.load_best()
    trainer.model = trainer.model.module.cpu()
    trainer.model = trainer.model.train(False)
    trainer.state.cuda = False

    x = torch.randn(1,
                    1,
                    opt.input_size[0],
                    opt.input_size[1],
                    opt.input_size[2],
                    requires_grad=True)

    torch_out = torch.onnx.export(
        trainer.model,  # model being run
        [
            x,
        ],  # model input (or a tuple for multiple inputs)
        os.path.join(
            opt.models_path, opt.name, opt.name + "_export.onnx"
        ),  # where to save the model (can be a file or file-like object)
        export_params=True,
        verbose=True
    )  # store the trained parameter weights inside the model file
示例#7
0
def main(task='mrpc',
         train_cfg='config/train_mrpc.json',
         model_cfg='config/bert_base.json',
         data_file='../glue/MRPC/train.tsv',
         model_file=None,
         pretrain_file='../uncased_L-12_H-768_A-12/bert_model.ckpt',
         data_parallel=True,
         vocab='../uncased_L-12_H-768_A-12/vocab.txt',
         save_dir='../exp/bert/mrpc',
         max_len=128,
         mode='train'):

    cfg = train.Config.from_json(train_cfg)
    model_cfg = models.Config.from_json(model_cfg)

    set_seeds(cfg.seed)

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab,
                                           do_lower_case=True)
    TaskDataset = dataset_class(
        task)  # task dataset class according to the task
    pipeline = [
        Tokenizing(tokenizer.convert_to_unicode, tokenizer.tokenize),
        AddSpecialTokensWithTruncation(max_len),
        TokenIndexing(tokenizer.convert_tokens_to_ids, TaskDataset.labels,
                      max_len)
    ]
    dataset = TaskDataset(data_file, pipeline)
    data_iter = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)

    model = Classifier(model_cfg, len(TaskDataset.labels))
    criterion = nn.CrossEntropyLoss()

    trainer = train.Trainer(cfg, model, data_iter, optim.optim4GPU(cfg, model),
                            save_dir, get_device())

    if mode == 'train':

        def get_loss(model, batch,
                     global_step):  # make sure loss is a scalar tensor
            input_ids, segment_ids, input_mask, label_id = batch
            logits = model(input_ids, segment_ids, input_mask)
            loss = criterion(logits, label_id)
            return loss

        trainer.train(get_loss, model_file, pretrain_file, data_parallel)

    elif mode == 'eval':

        def evaluate(model, batch):
            input_ids, segment_ids, input_mask, label_id = batch
            logits = model(input_ids, segment_ids, input_mask)
            _, label_pred = logits.max(1)
            result = (label_pred == label_id).float()  #.cpu().numpy()
            accuracy = result.mean()
            return accuracy, result

        results = trainer.eval(evaluate, model_file, data_parallel)
        total_accuracy = torch.cat(results).mean().item()
        print('Accuracy:', total_accuracy)
示例#8
0
def main():
    best_score = 0
    args = parse()
    if not args.no_logger:
        tee.Tee(args.cache + '/log.txt')
    print(vars(args))
    print('experiment folder: {}'.format(experiment_folder()))
    print('git hash: {}'.format(get_script_dir_commit_hash()))
    seed(args.manual_seed)
    cudnn.benchmark = not args.disable_cudnn_benchmark
    cudnn.enabled = not args.disable_cudnn

    metrics = get_metrics(args.metrics)
    tasks = get_tasks(args.tasks)
    model, criterion = get_model(args)
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)
    else:
        assert False, "invalid optimizer"

    if args.resume:
        best_score = checkpoints.load(args, model, optimizer)
    print(model)
    trainer = train.Trainer()
    train_loader, val_loader = get_dataset(args)

    if args.evaluate:
        scores = validate(trainer, val_loader, model, criterion, args, metrics,
                          tasks, -1)
        print(scores)
        score_file(scores, "{}/model_999.txt".format(args.cache))
        return

    if args.warmups > 0:
        for i in range(args.warmups):
            print('warmup {}'.format(i))
            trainer.validate(train_loader, model, criterion, -1, metrics, args)
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            trainer.train_sampler.set_epoch(epoch)
        scores = {}
        scores.update(
            trainer.train(train_loader, model, criterion, optimizer, epoch,
                          metrics, args))
        scores.update(
            validate(trainer, val_loader, model, criterion, args, metrics,
                     tasks, epoch))
        is_best = scores[args.metric] > best_score
        best_score = max(scores[args.metric], best_score)
        checkpoints.save(epoch, args, model, optimizer, is_best, scores,
                         args.metric)
def main(train_cfg='config/bert_pretrain.json',
         model_cfg='config/bert_base.json',
         data_file='../tbc/books_large_all.txt',
         model_file=None,
         data_parallel=True,
         vocab='../uncased_L-12_H-768_A-12/vocab.txt',
         save_dir='../exp/bert/pretrain',
         log_dir='../exp/bert/pretrain/runs',
         max_len=512,
         max_pred=20,
         mask_prob=0.15):

    train_cfg = BertTrainConfig.from_json(train_cfg)
    model_cfg = BertModelConfig.from_json(model_cfg)

    set_seeds(train_cfg.seed)

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab, do_lower_case=True)
    tokenize = lambda x: tokenizer.tokenize(tokenizer.convert_to_unicode(x))

    pipeline = [Preprocess4Pretrain(max_pred,
                                    mask_prob,
                                    list(tokenizer.vocab.keys()),
                                    tokenizer.convert_tokens_to_ids,
                                    max_len)]
    data_iter = SentPairDataLoader(data_file,
                                   train_cfg.batch_size,
                                   tokenize,
                                   max_len,
                                   pipeline=pipeline)

    model = BertModel4Pretrain(model_cfg)
    criterion1 = nn.CrossEntropyLoss(reduction='none')
    criterion2 = nn.CrossEntropyLoss()

    optimizer = optim.optim4GPU(train_cfg, model)
    trainer = train.Trainer(train_cfg, model_cfg, model, data_iter, optimizer, save_dir, get_device())

    writer = SummaryWriter(log_dir=log_dir) # for tensorboardX

    def get_loss(model, batch, global_step, train_cfg, model_cfg): # make sure loss is tensor
        input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next = batch

        logits_lm, logits_clsf = model(input_ids, segment_ids, input_mask, masked_pos)
        loss_lm = criterion1(logits_lm.transpose(1, 2), masked_ids) # for masked LM
        loss_lm = (loss_lm*masked_weights.float()).mean()
        loss_clsf = criterion2(logits_clsf, is_next) # for sentence classification
        writer.add_scalars('data/scalar_group',
                           {'loss_lm': loss_lm.item(),
                            'loss_clsf': loss_clsf.item(),
                            'loss_total': (loss_lm + loss_clsf).item(),
                            'lr': optimizer.get_lr()[0],
                           },
                           global_step)
        return loss_lm + loss_clsf

    trainer.train(get_loss, model_file, None, data_parallel)
示例#10
0
def main():
    global opt, best_mAP
    opt = parse()
    tee.Tee(opt.cache + '/log.txt')
    print(vars(opt))
    seed(opt.manual_seed)

    base_model, logits_model, criterion, base_optimizer, logits_optimizer = create_model(
        opt)
    if opt.resume:
        best_mAP = checkpoints.load(opt, base_model, logits_model,
                                    base_optimizer, logits_optimizer)
    print(logits_model)
    trainer = train.Trainer()
    train_loader, val_loader, valvideo_loader = get_dataset(opt)

    if opt.evaluate:
        trainer.validate(val_loader, base_model, logits_model, criterion, -1,
                         opt)
        trainer.validate_video(valvideo_loader, base_model, logits_model,
                               criterion, -1, opt)
        return

    for epoch in range(opt.start_epoch, opt.epochs):
        if opt.distributed:
            trainer.train_sampler.set_epoch(epoch)
        s_top1, s_top5, o_top1, o_top5, v_top1, v_top5, sov_top1 = trainer.train(
            train_loader, base_model, logits_model, criterion, base_optimizer,
            logits_optimizer, epoch, opt)
        s_top1val, s_top5val, o_top1val, o_top5val, v_top1val, v_top5val, sov_top1val = trainer.validate(
            val_loader, base_model, logits_model, criterion, epoch, opt)
        sov_mAP, sov_rec_at_n, sov_mprec_at_n = trainer.validate_video(
            valvideo_loader, base_model, logits_model, criterion, epoch, opt)
        is_best = sov_mAP > best_mAP
        best_mAP = max(sov_mAP, best_mAP)
        scores = {
            's_top1': s_top1,
            's_top5': s_top5,
            'o_top1': o_top1,
            'o_top5': o_top5,
            'v_top1': v_top1,
            'v_top5': v_top5,
            'sov_top1': sov_top1,
            's_top1val': s_top1val,
            's_top5val': s_top5val,
            'o_top1val': o_top1val,
            'o_top5val': o_top5val,
            'v_top1val': v_top1val,
            'v_top5val': v_top5val,
            'sov_top1val': sov_top1val,
            'mAP': sov_mAP,
            'sov_rec_at_n': sov_rec_at_n,
            'sov_mprec_at_n': sov_mprec_at_n
        }
        checkpoints.save(epoch, opt, base_model, logits_model, base_optimizer,
                         logits_optimizer, is_best, scores)
def main():
    global opt, best_prec1

    opt = parser.parse_args()
    opt.logdir = opt.logdir + '/' + opt.name
    logger = 'hi'

    best_prec1 = 0
    print(opt)

    # Initialize the model, criterion and the optimizer
    model = init.load_model(opt)
    model, criterion, optimizer = init.setup(model, opt)
    # Display the model structure
    print(model)

    # Setup trainer and validation
    trainer = train.Trainer(model, criterion, optimizer, opt, logger)
    validator = train.Validator(model, criterion, opt, logger)

    # Load model from a checkpoint if mentioned in opts
    if opt.resume:
        if os.path.isfile(opt.resume):
            model, optimizer, opt, best_prec1 = init.resumer(
                opt, model, optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    cudnn.benchmark = True

    # Setup the train and validation data loaders
    dataloader = init_data.load_data(opt)
    train_loader = dataloader.train_loader
    val_loader = dataloader.val_loader

    for epoch in range(opt.start_epoch, opt.epochs):
        utils.adjust_learning_rate(opt, optimizer, epoch)
        print("Starting epoch number:", epoch + 1, "Learning rate:",
              optimizer.param_groups[0]["lr"])

        if opt.testOnly == False:
            # Train the network over the training data
            trainer.train(train_loader, epoch, opt)

        #if opt.tensorboard:
        #logger.scalar_summary('learning_rate', opt.lr, epoch)

        # Measure the validation accuracy
        acc = validator.validate(val_loader, epoch, opt)
        best_prec1 = max(acc, best_prec1)
        if best_prec1 == acc:
            # Save the new model if the accuracy is better than the previous saved model
            init.save_checkpoint(opt, model, optimizer, best_prec1, epoch)

        print('Best accuracy: [{0:.3f}]\t'.format(best_prec1))
示例#12
0
def main():
    global opt, best_mAP
    opt = parse()
    tee.Tee(opt.cache + '/log_0724-valvideo.txt')
    #print(vars(opt))
    seed(opt.manual_seed)

    print('1. create_model')
    base_model, logits_model, criterion, base_optimizer, logits_optimizer = create_model(
        opt)
    if opt.resume:
        print('checkpoints load')
        #best_mAP = checkpoints.load(opt, base_model, logits_model, base_optimizer, logits_optimizer)
        checkpoints.load(opt, base_model, logits_model, base_optimizer,
                         logits_optimizer)

    #print('base_model = InceptionI3D Networks') # InceptionI3D Networks
    #print(base_model)
    #print('logits_model = AsyncTFBase: Linear Networks') # AsyncTFBase: Linear Networks
    #print(logits_model)

    trainer = train.Trainer()

    print('2. get_dataset')
    train_loader, val_loader, valvideo_loader = get_dataset(opt)
    #print('train_loader') # [56586, [25,img,s,v,o,meta]]
    #print(train_loader)    # 56586=pairs
    #print('val_loader')   # [12676, [25,img,s,v,o,meta]]
    #print(val_loader)
    #print('valvideo_loader') # [1863, [25+1,img,s,v,o,meta]]
    #print(valvideo_loader)   # 1863=num_(kind of video)

    if opt.evaluate:
        trainer.validate(val_loader, base_model, logits_model, criterion, -1,
                         opt)
        trainer.validate_video(valvideo_loader, base_model, logits_model,
                               criterion, -1, opt)
        return

    print('3.3 Valiation Video')
    #if opt.distributed:
    #    trainer.train_sampler.set_epoch(epoch)

    sov_mAP, sov_rec_at_n, sov_mprec_at_n = trainer.validate_video(
        valvideo_loader, base_model, logits_model, criterion, epoch, opt)

    is_best = sov_mAP > best_mAP
    best_mAP = max(sov_mAP, best_mAP)
    scores = {
        'mAP': sov_mAP,
        'sov_rec_at_n': sov_rec_at_n,
        'sov_mprec_at_n': sov_mprec_at_n
    }
    checkpoints.score_file(scores,
                           "{}/model_{}.txt".format(opt.cache, 'valvideo'))
示例#13
0
def main(args):
    np.random.seed(432)
    torch.random.manual_seed(432)
    try:
        os.makedirs(args.outpath)
    except OSError:
        pass
    experiment_path = utils.get_new_model_path(args.outpath)
    print(experiment_path)
    train_writer = SummaryWriter(os.path.join(experiment_path, 'train_logs'))
    val_writer = SummaryWriter(os.path.join(experiment_path, 'val_logs'))
    scheduler = cyclical_lr(5, 1e-5, 2e-3)
    trainer = train.Trainer(train_writer, val_writer, scheduler=scheduler)

    train_transform = data.build_preprocessing()
    eval_transform = data.build_preprocessing()

    trainds, evalds = data.build_dataset(args.datadir, None)
    trainds.transform = train_transform
    evalds.transform = eval_transform

    model = models.resnet34()
    base_opt = torch.optim.Adam(model.parameters())
    opt = SWA(base_opt, swa_start=30, swa_freq=10)

    trainloader = DataLoader(trainds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=8,
                             pin_memory=True)
    evalloader = DataLoader(evalds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=16,
                            pin_memory=True)

    export_path = os.path.join(experiment_path, 'last.pth')

    best_lwlrap = 0

    for epoch in range(args.epochs):
        print('Epoch {} - lr {:.6f}'.format(epoch, scheduler(epoch)))
        trainer.train_epoch(model, opt, trainloader, scheduler(epoch))
        metrics = trainer.eval_epoch(model, evalloader)

        print('Epoch: {} - lwlrap: {:.4f}'.format(epoch, metrics['lwlrap']))

        # save best model
        if metrics['lwlrap'] > best_lwlrap:
            best_lwlrap = metrics['lwlrap']
            torch.save(model.state_dict(), export_path)

    print('Best metrics {:.4f}'.format(best_lwlrap))
    opt.swap_swa_sgd()
示例#14
0
def main(args):

    cfg = train.Config.from_json(args.train_cfg)
    model_cfg = models.Config.from_json(args.model_cfg)

    set_seeds(cfg.seed)

    tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab,
                                           do_lower_case=True)
    tokenize = lambda x: tokenizer.tokenize(tokenizer.convert_to_unicode(x))

    pipeline = [
        Preprocess4Pretrain(args.max_pred, args.mask_prob,
                            list(tokenizer.vocab.keys()),
                            tokenizer.convert_tokens_to_ids, model_cfg.max_len,
                            args.mask_alpha, args.mask_beta, args.max_gram)
    ]
    data_iter = SentPairDataLoader(args.data_file,
                                   cfg.batch_size,
                                   tokenize,
                                   model_cfg.max_len,
                                   pipeline=pipeline)

    model = BertModel4Pretrain(model_cfg)
    criterion1 = nn.CrossEntropyLoss(reduction='none')
    criterion2 = nn.CrossEntropyLoss()

    optimizer = optim.optim4GPU(cfg, model)
    trainer = train.Trainer(cfg, model, data_iter, optimizer, args.save_dir,
                            get_device())

    writer = SummaryWriter(log_dir=args.log_dir)  # for tensorboardX

    def get_loss(model, batch, global_step):  # make sure loss is tensor
        input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next = batch

        logits_lm, logits_clsf = model(input_ids, segment_ids, input_mask,
                                       masked_pos)
        loss_lm = criterion1(logits_lm.transpose(1, 2),
                             masked_ids)  # for masked LM
        loss_lm = (loss_lm * masked_weights.float()).mean()
        loss_sop = criterion2(logits_clsf,
                              is_next)  # for sentence classification
        writer.add_scalars(
            'data/scalar_group', {
                'loss_lm': loss_lm.item(),
                'loss_sop': loss_sop.item(),
                'loss_total': (loss_lm + loss_sop).item(),
                'lr': optimizer.get_lr()[0],
            }, global_step)
        return loss_lm + loss_sop

    trainer.train(get_loss, model_file=None, data_parallel=True)
示例#15
0
def train_model():
    trainer = train.Trainer(SEQ_LEN, INPUT_DIM, HIDDEN_SIZE, BATCH_SIZE,
                            LEARNING_RATE)
    for i in range(MAX_EPISODES):
        # input_batch = generate_sanity_check_batch(SEQ_LEN, BATCH_SIZE)
        input_batch = generate_random_batch(SEQ_LEN, BATCH_SIZE)
        correct_out = generate_sorted_onehot(input_batch)

        trainer.train(input_batch, correct_out)

        if i % 1000 == 0:
            trainer.save_model(i)
示例#16
0
def main(args):

    cfg = train.Config.from_json(args.train_cfg)
    model_cfg = models.Config.from_json(args.model_cfg)

    tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab,
                                           do_lower_case=True)
    tokenize = lambda x: tokenizer.tokenize(tokenizer.convert_to_unicode(x))
    pipeline = [Preprocess4Pretrain(args)]
    data_iter = SentPairDataLoader(args.data_file,
                                   cfg.batch_size,
                                   tokenize,
                                   model_cfg.max_len,
                                   pipeline=pipeline)

    model = BertModel4Pretrain(model_cfg)
    criterion1 = nn.CrossEntropyLoss(reduction='none')
    criterion2 = nn.CrossEntropyLoss()

    optimizer = optim.optim4GPU(cfg, model)
    trainer = train.Trainer(cfg, model, data_iter, optimizer, args.save_dir,
                            get_device())

    def get_loss(model, batch, global_step):  # make sure loss is tensor
        input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next = batch
        # input_ids : 마스크 처리된 전체 seq의 id
        # segment_ids : 문장 구분을 위한 [0, 1]의 segment 정보의 id
        # input_mask : 실제로 사용되는 id들 (zero-padding된 경우 제외)
        # masked_ids : 마스킹된 token들의 원래 값의 id(zero-padding됨)
        # maksed_pos : 마스킹된 token들의 위치 id
        # masked_weights : 마스크된 token의 갯수만큼 1로 채워진 배열
        # is_next : instance 생성에서 만든 값 boolean 값.
        logits_lm, logits_clsf = model(input_ids, segment_ids, input_mask,
                                       masked_pos)
        # logits_lm : [B, mS, V]
        # logits_clsf : [B, 1, 2]
        loss_lm = criterion1(logits_lm.transpose(1, 2),
                             masked_ids)  # for masked LM
        loss_lm = (loss_lm * masked_weights.float()).mean()
        loss_sop = criterion2(logits_clsf,
                              is_next)  # for sentence classification
        writer.add_scalars(
            'data/scalar_group', {
                'loss_lm': loss_lm.item(),
                'loss_sop': loss_sop.item(),
                'loss_total': (loss_lm + loss_sop).item(),
                'lr': optimizer.get_lr()[0],
            }, global_step)
        return loss_lm + loss_sop

    trainer.train(get_loss, model_file=None, data_parallel=True)
示例#17
0
def main(args):
    np.random.seed(432)
    torch.random.manual_seed(432)
    try:
        os.makedirs(args.outpath)
    except OSError:
        pass
    experiment_path = utils.get_new_model_path(args.outpath)

    train_writer = SummaryWriter(os.path.join(experiment_path, 'train_logs'))
    val_writer = SummaryWriter(os.path.join(experiment_path, 'val_logs'))
    trainer = train.Trainer(train_writer, val_writer)

    # todo: add config
    train_transform = data.build_preprocessing()
    eval_transform = data.build_preprocessing()

    trainds, evalds = data.build_dataset(args.datadir, None)
    trainds.transform = train_transform
    evalds.transform = eval_transform

    model = models.resnet34()
    opt = torch.optim.Adam(model.parameters())

    trainloader = DataLoader(trainds,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=8,
                             pin_memory=True)
    evalloader = DataLoader(evalds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=16,
                            pin_memory=True)

    for epoch in range(args.epochs):
        trainer.train_epoch(model, opt, trainloader, 3e-4)
        metrics = trainer.eval_epoch(model, evalloader)

        state = dict(
            epoch=epoch,
            model_state_dict=model.state_dict(),
            optimizer_state_dict=opt.state_dict(),
            loss=metrics['loss'],
            lwlrap=metrics['lwlrap'],
            global_step=trainer.global_step,
        )
        export_path = os.path.join(experiment_path, 'last.pth')
        torch.save(state, export_path)
    def testTrainer(self):
        loader = data_utils.DataLoader(self.data, batch_size=4, shuffle=False, collate_fn=self.collator)
        valid_loader = data_utils.DataLoader(self.data, batch_size=4, shuffle=False, collate_fn=self.collator)
        model = image_text_model.ImageTextModel(text_encoder=self.text_encoder, image_encoder=self.image_model,
                                                tokenizer=self.data.tokenizer)
        model = model.to(self.device)

        # choose high value for mask just for test
        trainer = train.Trainer(model=model, mask_prob=0.5)
        with tempfile.TemporaryDirectory() as tmpdirname:
            loss, best_valid_loss = trainer.train_epoch(data_iter=loader, valid_data_iter=valid_loader,
                                                        best_valid_loss=float("inf"),
                                                        saving_path=os.path.join(tmpdirname, "model"))
            assert float(loss.data) >= 0
            assert float(best_valid_loss) >= 0
示例#19
0
def main():
    global opt, best_prec1

    opt = parser.parse_args()
    opt.logdir = opt.logdir+'/'+opt.name
    logger = None#Logger(opt.logdir)
    opt.lr = opt.maxlr

    print(opt)
    best_prec1 = 0
    cudnn.benchmark = True
    model = init_model.load_model(opt)
    if opt.model_def.startswith('alexnet') or opt.model_def.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    elif opt.ngpus > 1:
        model = torch.nn.DataParallel(model).cuda()
    print(model)
    model, criterion, optimizer = init_model.setup(model,opt)

    trainer = train.Trainer(model, criterion, optimizer, opt, logger)
    validator = train.Validator(model, criterion, opt, logger)

    if opt.resume:
        if os.path.isfile(opt.resume):
            model, optimizer, opt, best_acc = init_model.resumer(opt, model, optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    dataloader = init_data.load_data(opt)
    train_loader = dataloader.train_loader
    #print(utils.get_mean_and_std(train_loader))
    val_loader = dataloader.val_loader

    for epoch in range(opt.start_epoch, opt.epochs):
        utils.adjust_learning_rate(opt, optimizer, epoch)
        print("Starting epoch number:",epoch,"Learning rate:", opt.lr)

        if opt.testOnly == False:
            trainer.train(train_loader, epoch, opt)
        if opt.tensorboard:
            logger.scalar_summary('learning_rate', opt.lr, epoch)

        prec1 = validator.validate(val_loader, epoch, opt)
        best_prec1 = max(prec1, best_prec1)
        init_model.save_checkpoint(opt, model, optimizer, best_prec1, epoch)

        print('Best Prec@1: [{0:.3f}]\t'.format(best_prec1))
def main():
    global opt, best_err1
    opt = parser.parse_args()
    best_err1 = 1000000
    print(opt)

    model = init.load_model(opt)
    model, criterion, optimizer = init.setup(model, opt)
    print(model)

    trainer = train.Trainer(model, criterion, optimizer, opt, writer)
    validator = train.Validator(model, criterion, opt, writer)

    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    cudnn.deterministic = True

    if opt.resume:
        if os.path.isfile(opt.resume):
            model, optimizer, opt, best_err1 = init.resumer(
                opt, model, optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    cudnn.benchmark = True

    dataloader = ld.GazeFollow(opt)

    train_loader = dataloader.train_loader
    val_loader = dataloader.val_loader

    for epoch in range(opt.start_epoch, opt.epochs):
        utils.adjust_learning_rate(opt, optimizer, epoch)
        print("Starting epoch number:", epoch + 1, "Learning rate:",
              optimizer.param_groups[0]["lr"])

        if opt.testOnly == False:
            trainer.train(train_loader, epoch, opt)

        err = validator.validate(val_loader, epoch, opt)
        best_err1 = min(err, best_err1)

        if epoch % 10 == 0:
            init.save_checkpoint(opt, model, optimizer, best_err1, epoch)

        print('Best error: [{0:.3f}]\t'.format(best_err1))
示例#21
0
def main():
    if not os.path.isdir(flag.output_dir):
        os.mkdir(flag.output_dir)
    if flag.mode == 'train':
        train_op = train.Trainer(flag)
        train_op.train()
    elif flag.mode == 'predict':
        predict_op = predict.predictor(flag)
        predict_op.inference()
    elif flag.mode == 'eval':
        eval_op = predict.predictor(flag)
        eval_op.evaluate()
    elif flag.mode == 'cam':
        cam_op = predict.predictor(flag)
        cam_op.cam()
    else:
        print 'not supported'
示例#22
0
def eval_models(args: Argument):
    args.mode = 'test'
    dc, lc, tc, model_dir = get_config_list(args)

    modes = ['test']
    dataloader = {
        'test':
        get_trajectory_data_loader(dc,
                                   test=True,
                                   batch_size=args.bsize,
                                   num_workers=args.num_workers,
                                   shuffle=True)
    }
    run_every = {'test': 1}
    gn_wrapper = fetch_model_iterator(lc, args)
    trainer = train.Trainer(gn_wrapper, modes, dataloader, run_every, tc)
    output = trainer.eval(dataloader['test'])
    return trainer.num_iter, output
示例#23
0
def main():
 
    gpu = 0
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
    torch.cuda.set_device(gpu)
    device = 'cuda:0'
    cuda = torch.cuda.is_available()

    n_resblock = 10
    root = os.path.join(os.getcwd(),'..','images')
    dataprep = train.Data_Preparation(root)
    lr_train_vox = dataprep.lr_pcs_tr
    hr_train_vox = dataprep.hr_pcs_tr
    output_sz = hr_train_vox[0].squeeze()

    trainDataset = train.Dataset(hr_train_vox,lr_train_vox,transform=image_utils.normalize)
    output_sz = (256,32,32)

    bt_size = 15
    shuffle = True
    train_data_loader = data.DataLoader(trainDataset,batch_size=bt_size,shuffle=shuffle)
    out_f = 'chkpt_r_10_bt_15_lr_0_001_res_0_5_sch'
    

    lr_test = dataprep.lr_pcs_ts
    hr_test = dataprep.hr_pcs_ts
    testDataset = train.Dataset(hr_test,lr_test,transform=image_utils.normalize)
    test_data_loader = data.DataLoader(testDataset,batch_size=bt_size,shuffle=False)
    ResNet = model.ResNET(n_resblocks=n_resblock,output_size=output_sz,res_scale=0.5)
    
    lr = 0.001
    #pretrained
    if cuda:
      ResNet.to(device)

    
    
    
    trainer = train.Trainer(train_data_loader,test_data_loader,cuda,3,ResNet,lr,out_f,device)
    max_epoch = 1000
    trainer.train(max_epoch)
示例#24
0
def train_model(args: Argument):
    args.mode = 'train'
    dc, lc, tc, _ = get_config_list(args)
    gn_wrapper = fetch_model_iterator(lc, args)
    modes = ['train', 'test']
    dataloader = {
        m: get_trajectory_data_loader(dc,
                                      test=m == 'train',
                                      batch_size=args.bsize,
                                      num_workers=args.num_workers,
                                      shuffle=True)
        for m in modes
    }
    run_every = {'train': 1, 'test': args.test_every}
    trainer = train.Trainer(gn_wrapper, modes, dataloader, run_every, tc)

    train_winding = False
    train_trajectory = True

    trainer.train(train_winding, train_trajectory)
    trainer.save(train_winding, train_trajectory)
示例#25
0
 def __init__(self, display_width, display_height, grid_size):
     pygame.init()
     pygame.font.init()
     self.grid_size = grid_size
     self.display_width = display_width
     self.display_height = display_height
     self.screen = pygame.display.set_mode((display_width, display_height))
     #clock
     self.clock = pygame.time.Clock()
     #set font
     self.font = pygame.font.SysFont('Arial', 32)
     self.grid = [[None for i in range(grid_size)]
                  for j in range(grid_size)]
     self.grid_types = [tile.Forest, tile.Mountain, tile.River, tile.Plain]
     self.create_grid()
     self.directions = {
         'up': pygame.Vector2(int(0), int(-1)),
         'down': pygame.Vector2(int(0), int(1)),
         'left': pygame.Vector2(int(-1), int(0)),
         'right': pygame.Vector2(int(1), int(0))
     }
     self.position = pygame.Vector2(int(self.grid_size / 2),
                                    int(self.grid_size / 2))
     self.position_image = pygame.image.load('assets/position.png')
     self.current_action = self.grid[int(self.position.x)][int(
         self.position.y)].root
     self.set_current_text(self.current_action.description)
     #sound input vars
     self.p = pyaudio.PyAudio()  # Create an interface to PortAudio
     self.stream = self.p.open(format=settings.sample_format,
                               channels=settings.channels,
                               rate=settings.fs,
                               frames_per_buffer=settings.chunk,
                               input=True)
     self.frames = []  # Initialize array to store frames
     self.trainer = train.Trainer(file_path="librispeechmodel.txt",
                                  epochs=0,
                                  batch_size=1)
     self.current_command = ''
     self.spoken_letters = ''
示例#26
0
def main():
    UNITS = 512
    BATCH_SIZE = 1
    model = keras.models.Sequential()
    model.add(
        GRU(UNITS,
            return_sequences=True,
            input_shape=(None, train.INPUT_FEATURES),
            name="gru_input"))
    model.add(GRU(UNITS, return_sequences=True, name="gru_2"))
    model.add(TimeDistributed(Dense(train.NUM_PHONEMES + 1), name="output"))

    TRAINING_DATA_PATH = "C:\\Temp\\cache_test_2"
    VALIDATION_DATA_PATH = "C:\\Temp\\cache_test_2"
    trainer = train.Trainer(model, BATCH_SIZE, TRAINING_DATA_PATH,
                            VALIDATION_DATA_PATH)

    CHECKPOINT_PATH = "C:\\Temp\checkpoints"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        CHECKPOINT_PATH, verbose=1, save_weights_only=False, period=1)

    EPOCHS = 10
    trainer.train(EPOCHS, True, [checkpoint_callback])
示例#27
0
def main():
    global opt, best_err1
    opt = parser.parse_args()
    best_err1 = 1000000
    print(opt)
    model = tracknet.Net(opt)
    if opt.cuda:
        model = model.cuda()

    model, criterion, optimizer = init.setup(model, opt)
    print(model)

    trainer = train.Trainer(model, criterion, optimizer, opt)
    # validator = train.Validator(model, criterion, opt)
    if opt.resume:
        if os.path.isfile(opt.resume):
            model, optimizer, opt, best_err1 = init.resumer(
                opt, model, optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    cudnn.benchmark = True
    dataloader = ld.SynthLoader(opt)
    train_loader = dataloader.train_loader

    for epoch in range(opt.start_epoch, opt.epochs):
        utils.adjust_learning_rate(opt, optimizer, epoch)
        print("Starting epoch number:", epoch + 1, "Learning rate:",
              optimizer.param_groups[0]["lr"])
        trainer.train(train_loader, epoch, opt)

        # err = validator.validate(val_loader, epoch, opt)
        # best_err1 = min(err, best_err1)
        # print('Best error: [{0:.3f}]\t'.format(best_err1))
        if epoch % 3 == 0 and epoch > 0 and opt.tosave == True:
            init.save_checkpoint(opt, model, optimizer, best_err1, epoch)
示例#28
0
def main(task='mrpc',
         train_cfg='config/train_mrpc.json',
         data_parallel=True,
         data_train_file='total_data/imdbtrain.tsv',
         data_test_file='total_data/IMDB_test.tsv',
         dataName='IMDB',
         stopNum=250,
         max_len=300,
         mode='train'):

    if mode == 'train':

        def get_loss_CNN(model, batch,
                         global_step):  # make sure loss is a scalar tensor
            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch

            token1 = embedding(input_ids.long())

            logits, attention_score = model(token1.cuda(), input_ids,
                                            segment_ids, input_mask)

            loss1 = criterion(logits, label_id)
            return loss1

        def evalute_CNN(model, batch, global_step, ls):
            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch
            token1 = embedding(input_ids.long())
            logits, attention_score = model(token1.cuda(), input_ids,
                                            segment_ids, input_mask)
            logits = F.softmax(logits)

            y_pred11, y_pred1 = logits.max(1)

            return label_id, logits

        def get_loss_Attn_LSTM(
                model, batch,
                global_step):  # make sure loss is a scalar tensor
            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch

            seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
            input_ids = input_ids[perm_idx]
            label_id = label_id[perm_idx]
            token1 = embedding(input_ids.long())

            logits, attention_score = model(token1.cuda(), input_ids,
                                            segment_ids, input_mask,
                                            seq_lengths)

            loss1 = criterion(logits, label_id)
            return loss1

        def evalute_Attn_LSTM(model, batch, global_step, ls):
            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch

            seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
            input_ids = input_ids[perm_idx]
            label_id = label_id[perm_idx]
            token1 = embedding(input_ids.long())

            logits, attention_score = model(token1.cuda(), input_ids,
                                            segment_ids, input_mask,
                                            seq_lengths)
            logits = F.softmax(logits)

            y_pred11, y_pred1 = logits.max(1)

            return label_id, logits

        def generating_lexiocn(model, model2, batch, global_step, ls, e):
            if (global_step == 0):
                result3.clear()
                result_label.clear()
                bb_11.clear()
                bb_22.clear()

            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch

            seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
            input_ids = input_ids[perm_idx]
            label_id = label_id[perm_idx]
            token1 = embedding(input_ids.long())
            logits, attention_score = model(token1.cuda(), input_ids,
                                            segment_ids, input_mask)
            logits2, attention_score2 = model2(token1.cuda(), input_ids,
                                               segment_ids, input_mask,
                                               seq_lengths)

            logits = F.softmax(logits)
            #logits2=F.softmax(logits2)
            y_pred11, y_pred1 = logits.max(1)
            y_pred22, y_pred2 = logits2.max(1)
            atten, attn_s1 = attention_score2.max(1)
            atte2, attn_s2 = torch.topk(attention_score2, 4)

            for i in range(0, len(input_ids)):

                split_tokens = []
                att_index = []
                for token in tokenizer.tokenize(data0[global_step * 128 +
                                                      perm_idx[i]]):
                    split_tokens.append(token)

                if (len(split_tokens) <= attn_s1[i].item()):
                    attn_index3 = attention_score[i][:len(split_tokens) - 1]
                    attn_num, attn_index2 = attn_index3.max(0)
                    attn_index = attn_index2.item()
                else:
                    for j in range(0, 4):
                        att_index.append(attn_s2[i][j].item())

                tok = []
                if (atten[i].item() <= 0):
                    token_ab = split_tokens[0]
                else:
                    for j in range(0, len(att_index)):
                        if (att_index[j] >= len(split_tokens)):
                            continue
                        tok.append(split_tokens[att_index[j]])

                token_temp = data0[global_step * 128 + perm_idx[i]].split(' ')
                token2 = []
                for kk in range(0, len(tok)):
                    token_ab = tok[kk]
                    token_ab = token_ab.replace(".", "")
                    token_ab = token_ab.replace(",", "")
                    token_ab = token_ab.replace("'", "")
                    token_ab = token_ab.replace("!", "")
                    token_ab = token_ab.replace("?", "")
                    token_ab = token_ab.replace("'", "")
                    token_ab = token_ab.replace('"', "")
                    if (token_ab == '' or token_ab == ' ' or token_ab == ','
                            or token_ab == '.' or token_ab == 'from'
                            or token_ab == 'are' or token_ab == 'is'
                            or token_ab == 'and' or token_ab == 'with'
                            or token_ab == 'may' or token_ab == 'would'
                            or token_ab == 'could' or token_ab == 'have'
                            or token_ab == 'has' or token_ab == 'had'
                            or token_ab == 'was' or token_ab == 'were'
                            or token_ab == 'this' or token_ab == 'who'
                            or token_ab == 'that' or token_ab == 'www'
                            or token_ab == 'http' or token_ab == 'com'
                            or token_ab == 'those' or token_ab == 'your'
                            or token_ab == 'not' or token_ab == 'seem'
                            or token_ab == 'too' or token_ab == 'lol'
                            or token_ab == 'but' or token_ab == 'these'
                            or token_ab == 'their' or token_ab == 'can'
                            or token_ab == 'there' or token_ab == 'gave'
                            or token_ab == 'his' or token_ab == 'etc'
                            or token_ab == 'thats' or token_ab == 'though'
                            or token_ab == 'off' or token_ab == 'she'
                            or token_ab == 'them' or token_ab == 'huh'
                            or token_ab == 'why' or token_ab == 'wont'
                            or token_ab == 'any' or token_ab == 'some'
                            or token_ab == 'its' or token_ab == 'yeah'
                            or token_ab == 'yes' or token_ab == 'you'
                            or token_ab == 'should' or token_ab == 'dont'
                            or token_ab == 'anybody' or token_ab == 'than'
                            or token_ab == 'where' or token_ab == 'for'
                            or token_ab == 'more' or token_ab == 'will'
                            or token_ab == 'him' or token_ab == 'its'
                            or token_ab == 'your' or token_ab == 'wii'
                            or token_ab == 'having' or token_ab == 'just'
                            or token_ab == 'help' or token_ab == 'helps'
                            or token_ab == 'all' or token_ab == 'they'
                            or token_ab == 'take' or token_ab == 'the'
                            or token_ab == 'what' or token_ab == 'need'
                            or token_ab == 'make' or token_ab == 'about'
                            or token_ab == 'then' or token_ab == 'when'
                            or token_ab == 'does' or token_ab == 'ask'
                            or token_ab == 'much' or token_ab == 'man'
                            or token_ab == 'know' or token_ab == 'how'
                            or token_ab == 'look' or token_ab == 'like'
                            or token_ab == 'one' or token_ab == 'think'
                            or token_ab == 'tell' or token_ab == 'find'
                            or token_ab == 'cant' or token_ab == 'now'
                            or token_ab == 'try' or token_ab == 'give'
                            or token_ab == 'answer' or token_ab == 'her'
                            or token_ab == 'out' or token_ab == 'get'
                            or token_ab == 'because' or token_ab == 'myself'
                            or token_ab == 'wants' or token_ab == 'movie'
                            or token_ab == 'film' or token_ab == 'films'):
                        continue

                    if (len(token_ab) < 2):
                        continue

                    for gge, input_word in enumerate(token_temp):

                        if (token_ab.lower() in input_word.lower()):
                            input_word = input_word.replace(".", "")
                            input_word = input_word.replace(",", "")
                            input_word = input_word.replace("'", "")
                            input_word = input_word.replace("!", "")
                            input_word = input_word.replace("?", "")
                            input_word = input_word.replace("'", "")
                            input_word = input_word.replace('"', "")

                            token2.append(input_word.lower())
                            break
                token2 = list(set(token2))

                if (len(token2) < 3):
                    continue
            #print(token2)
                sen = ""
                for l in range(0, len(token2) - 1):
                    sen += token2[l] + ' '
                sen += token2[len(token2) - 1]
                if (y_pred2[i] == 0 and y_pred1[i] == 0):
                    try:
                        bb_11[sen] += y_pred22[i]
                    except KeyError:
                        bb_11[sen] = y_pred22[i]

                if (y_pred2[i] == 1 and y_pred1[i] == 1):
                    try:
                        bb_22[sen] += y_pred22[i]
                    except KeyError:
                        bb_22[sen] = y_pred22[i]

            if (global_step == ls - 1):

                abusive_11.clear()
                abusive_22.clear()

                bb_11_up = sorted(bb_11.items(),
                                  key=lambda x: x[1],
                                  reverse=True)
                bb_22_up = sorted(bb_22.items(),
                                  key=lambda x: x[1],
                                  reverse=True)

                lexicon_size = 50
                bb_11_up = bb_11_up[:lexicon_size]
                bb_22_up = bb_22_up[:lexicon_size]

                for i in bb_11_up:
                    flag = 0
                    for j in bb_22_up:
                        if ((i[0].lower() in j[0].lower())
                                or (j[0].lower() in i[0].lower())):
                            if (i[1] < j[1]):
                                flag = 1
                                break

                    if (flag == 0):
                        abusive_11.append(i[0])

                for i in bb_22_up:
                    flag = 0
                    for j in bb_11_up:
                        if ((i[0].lower() in j[0].lower())
                                or (j[0].lower() in i[0].lower())):
                            if (i[1] < j[1]):
                                flag = 1
                                break

                    if (flag == 0):
                        abusive_22.append(i[0])

                ddf = open("./IMDB_Lexicon/imdbLexicon_1.txt",
                           'w',
                           encoding='UTF8')

                for i in range(0, len(abusive_11)):
                    ddf.write(abusive_11[i] + '\n')

                ddf.close()

                ddf = open("./IMDB_Lexicon/imdbLexicon_2.txt",
                           'w',
                           encoding='UTF8')

                for i in range(0, len(abusive_22)):
                    ddf.write(abusive_22[i] + '\n')

                ddf.close()
            return label_id, logits

        def evalute_CNN_SSL(model, batch):
            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch
            token1 = embedding(input_ids.long())
            logits, attention_score = model(token1.cuda(), input_ids,
                                            segment_ids, input_mask)

            return label_id, logits

        def pseudo_labeling(model, model2, batch, global_step, ls, e):
            if (global_step == 0):
                result3.clear()
                result4.clear()

                label_0.clear()
                label_1.clear()

                result_label.clear()

                abusive_11.clear()
                abusive_22.clear()

                abusive_dic_file = open("./IMDB_Lexicon/imdbLexicon_1.txt",
                                        'r',
                                        encoding='UTF8')
                for line in abusive_dic_file.read().split('\n'):
                    if (len(line) <= 3):
                        continue
                    abusive_11.append(line)
                abusive_dic_file.close()

                abusive_dic_file = open("./IMDB_Lexicon/imdbLexicon_2.txt",
                                        'r',
                                        encoding='UTF8')
                for line in abusive_dic_file.read().split('\n'):
                    if (len(line) <= 3):
                        continue
                    abusive_22.append(line)
                abusive_dic_file.close()

            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch

            seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
            input_ids = input_ids[perm_idx]
            label_id = label_id[perm_idx]
            token1 = embedding(input_ids.long())

            logits, attention_score = model(token1.cuda(), input_ids,
                                            segment_ids, input_mask)
            logits2, attention_score2 = model2(token1.cuda(), input_ids,
                                               segment_ids, input_mask,
                                               seq_lengths)

            logits = F.softmax(logits)
            logits2 = F.softmax(logits2)
            logits3 = logits + logits2
            y_pred33, y_pred3 = logits3.max(1)
            y_pred11, y_pred1 = logits.max(1)
            y_pred22, y_pred2 = logits2.max(1)

            label_id2 = []

            for i in range(0, len(input_ids)):
                input_sentence = data0[global_step * 128 + perm_idx[i]]
                input_sentence = re.sub("[!@#$%^&*().?\"~/<>:;'{}]", "",
                                        input_sentence)

                abusive_word_list_neg11 = list()
                abusive_word_list_neg11 += matching_blacklist2(
                    abusive_11, input_sentence, 3)
                abusive_word_list_neg11 = list((set(abusive_word_list_neg11)))

                abusive_word_list_neg22 = list()
                abusive_word_list_neg22 += matching_blacklist2(
                    abusive_22, input_sentence, 3)
                abusive_word_list_neg22 = list((set(abusive_word_list_neg22)))

                abusive_word_list_neg111 = list()
                abusive_word_list_neg111 += matching_blacklist2(
                    abusive_11, input_sentence, 4)
                abusive_word_list_neg111 = list(
                    (set(abusive_word_list_neg111)))

                abusive_word_list_neg222 = list()
                abusive_word_list_neg222 += matching_blacklist2(
                    abusive_22, input_sentence, 4)
                abusive_word_list_neg222 = list(
                    (set(abusive_word_list_neg222)))

                a = max(len(abusive_word_list_neg11),
                        len(abusive_word_list_neg22))
                aa = max(len(abusive_word_list_neg111),
                         len(abusive_word_list_neg222))

                if ((a >= 1 and len(abusive_word_list_neg11) >
                     len(abusive_word_list_neg22) and y_pred1[i].item() == 0
                     and y_pred11[i].item() >= 0.9) or
                    (a >= 1 and len(abusive_word_list_neg11) >
                     len(abusive_word_list_neg22) and y_pred2[i].item() == 0
                     and y_pred22[i].item() >= 0.9)):
                    label_0.append(0)
                    result4.append([
                        global_step * 128 + perm_idx[i], 0,
                        data0[global_step * 128 + perm_idx[i]],
                        label_id[perm_idx[i]].item()
                    ])
                elif ((a >= 1 and len(abusive_word_list_neg11) <
                       len(abusive_word_list_neg22) and y_pred1[i].item() == 1
                       and y_pred11[i].item() >= 0.9) or
                      (a >= 1 and len(abusive_word_list_neg11) <
                       len(abusive_word_list_neg22) and y_pred2[i].item() == 1
                       and y_pred22[i].item() >= 0.9)):
                    label_1.append(1)
                    result4.append([
                        global_step * 128 + perm_idx[i], 1,
                        data0[global_step * 128 + perm_idx[i]],
                        label_id[perm_idx[i]].item()
                    ])

                elif (aa >= 1 and len(abusive_word_list_neg111) >
                      len(abusive_word_list_neg222)):
                    label_0.append(0)
                    result4.append([
                        global_step * 128 + perm_idx[i], 0,
                        data0[global_step * 128 + perm_idx[i]],
                        label_id[perm_idx[i]].item()
                    ])
                elif (aa >= 1 and len(abusive_word_list_neg111) <
                      len(abusive_word_list_neg222)):
                    label_1.append(1)
                    result4.append([
                        global_step * 128 + perm_idx[i], 1,
                        data0[global_step * 128 + perm_idx[i]],
                        label_id[perm_idx[i]].item()
                    ])
                elif (y_pred11[i].item() >= 0.9 and y_pred22[i].item() >= 0.9
                      and y_pred1[i].item() == y_pred2[i].item()):
                    if (y_pred1[i].item() == 0):
                        label_0.append(0)
                        result4.append([
                            global_step * 128 + perm_idx[i], 0,
                            data0[global_step * 128 + perm_idx[i]],
                            label_id[perm_idx[i]].item()
                        ])
                    elif (y_pred1[i].item() == 1):
                        label_1.append(1)
                        result4.append([
                            global_step * 128 + perm_idx[i], 1,
                            data0[global_step * 128 + perm_idx[i]],
                            label_id[perm_idx[i]].item()
                        ])

                else:
                    result4.append([
                        global_step * 128 + perm_idx[i], -1,
                        data0[global_step * 128 + perm_idx[i]],
                        label_id[perm_idx[i]].item()
                    ])

            if (global_step == ls - 1):

                result_label.clear()
                result3.clear()

                print("###result3[i] ###:", len(result3))
                a = min(len(label_0), len(label_1))
                la_0 = 0
                la_1 = 0
                la_2 = 0
                la_3 = 0

                random.shuffle(result4)

                for i in range(0, len(result4)):

                    if (result4[i][1] == 0 and la_0 < a):
                        if (temp_check[result4[i][0]][0] == 0):
                            temp_check[result4[i][0]][0] = 1
                            temp_check[result4[i][0]][1] = 0
                            la_0 += 1
                            continue

                    elif (result4[i][1] == 1 and la_1 < a):
                        if (temp_check[result4[i][0]][0] == 0):
                            temp_check[result4[i][0]][0] = 1
                            temp_check[result4[i][0]][1] = 1
                            la_1 += 1
                            continue

                result_label.clear()
                result3.clear()

                fw = open('./temp_data/temp_train_IMDB.tsv',
                          'a',
                          encoding='utf-8',
                          newline='')
                wr = csv.writer(fw, delimiter='\t')

                fww = open('./temp_data/temp_train_na_IMDB.tsv',
                           'w',
                           encoding='utf-8',
                           newline='')
                wrr = csv.writer(fww, delimiter='\t')

                for i in range(0, len(temp_check)):
                    if (temp_check[i][0] == 1):
                        result_label.append(str(temp_check[i][3]))
                        result3.append(str(temp_check[i][1]))
                        wr.writerow(
                            [str(temp_check[i][1]),
                             str(temp_check[i][2])])
                    else:
                        wrr.writerow(
                            [str(temp_check[i][3]),
                             str(temp_check[i][2])])

                fw.close()
                fww.close()
                data0.clear()
                temp_check.clear()
                with open('./temp_data/temp_train_na_IMDB.tsv',
                          "r",
                          encoding='utf-8') as f:
                    lines = csv.reader(f, delimiter='\t')

                    for i in lines:
                        a = ''
                        lines2 = i[1].split(' ')
                        b = 0
                        for j in range(0, len(lines2)):
                            a += lines2[j] + ' '
                            b += 1

                        data0.append(a)
                        temp_check.append([0, -1, a, i[0]])
                print("################;", len(data0))
                f.close()

                dataset_temp = TaskDataset('./temp_data/temp_train_IMDB.tsv',
                                           pipeline)
                data_iter_temp = DataLoader(dataset_temp,
                                            batch_size=cfg.batch_size,
                                            shuffle=True)

                dataset_temp_na = TaskDataset(
                    './temp_data/temp_train_na_IMDB.tsv', pipeline)
                data_iter_temp_na = DataLoader(dataset_temp_na,
                                               batch_size=cfg.batch_size,
                                               shuffle=False)

            if (global_step != ls - 1):
                dataset_temp = TaskDataset(data_dev_file, pipeline)
                data_iter_temp = DataLoader(dataset_temp,
                                            batch_size=cfg.batch_size,
                                            shuffle=True)

                dataset_temp_na = TaskDataset(data_dev_file, pipeline)
                data_iter_temp_na = DataLoader(dataset_temp_na,
                                               batch_size=cfg.batch_size,
                                               shuffle=False)

            return label_id, logits, result_label, result3, data_iter_temp, data_iter_temp_na

        def evalute_Attn_LSTM_SSL(model, batch):

            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch

            seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
            input_ids = input_ids[perm_idx]
            label_id = label_id[perm_idx]
            token1 = embedding(input_ids.long())

            logits, attention_score = model2(token1.cuda(), input_ids,
                                             segment_ids, input_mask,
                                             seq_lengths)

            return label_id, logits

        curNum = 1

        for kkk in range(0, 5):
            #kkk+=1
            print(dataName)
            print(stopNum)

            cfg = train.Config.from_json(train_cfg)

            tokenizer = tokenization.FullTokenizer(do_lower_case=True)

            print("max_lenmax_len#:", max_len)

            TaskDataset = dataset_class(
                task)  # task dataset class according to the task
            pipeline = [
                Tokenizing(tokenizer.convert_to_unicode, tokenizer.tokenize),
                AddSpecialTokensWithTruncation(max_len),
                TokenIndexing(tokenizer.convert_tokens_to_ids,
                              TaskDataset.labels, max_len)
            ]

            fd = open("./total_data/imdbtrain.tsv", 'r', encoding='utf-8')
            rdr = csv.reader(fd, delimiter='\t')

            res = []
            num_a = 0
            num_b = 0
            for line in rdr:
                #print(line)
                num_a += 1
                res.append([line[0], line[1]])

            print("curNum#:", curNum)
            #print(res)
            fw = open('./data/IMDB_temp_short.tsv',
                      'w',
                      encoding='utf-8',
                      newline='')
            wr = csv.writer(fw, delimiter='\t')

            for i in range(0, curNum):
                random.shuffle(res)
                #print(res[1][0])
                print("########")
            curNum += 3
            num_data = len(res)
            num_data_dev_temp = int(num_data * 0.01)
            num_data_dev = int(num_data_dev_temp * 0.15)
            num_data_short = int(num_data_dev_temp * 0.85)
            num_data_train = num_data - num_data_dev_temp
            fd.close()

            num = 0

            data_train_file = "./data/IMDB_train" + str(kkk + 1) + ".tsv"
            data_dev_file = "./data/IMDB_dev" + str(kkk + 1) + ".tsv"
            data_short_file = "./data/IMDB_short" + str(kkk + 1) + ".tsv"

            print("num_data_dev#:", num_data_dev)
            print("num_data_short#:", num_data_short)
            print("num_data_train#:", num_data_train)
            fw = open('./data/IMDB_temp_short.tsv',
                      'w',
                      encoding='utf-8',
                      newline='')
            wr = csv.writer(fw, delimiter='\t')

            fe = open(data_train_file, 'w', encoding='utf-8', newline='')
            we = csv.writer(fe, delimiter='\t')

            res2 = []
            num_pos = 0
            num_neg = 0
            for line in res:
                #print(line[0])
                #print(line[1])
                if (line[0] == '0' and num_pos <= (num_data_dev_temp / 2)):
                    num_pos += 1
                    wr.writerow(['0', line[1]])
                elif (line[0] == '1' and num_neg <= (num_data_dev_temp / 2)):
                    num_neg += 1
                    wr.writerow(['1', line[1]])
                else:
                    num += 1
                    we.writerow([line[0], line[1]])

            fw.close()
            fe.close()

            print("num_pos #:", num_pos, " num_neg:", num_neg)

            f = open('./data/IMDB_temp_short.tsv', 'r', encoding='utf-8')
            rdr = csv.reader(f, delimiter='\t')
            num_pos = 0
            num_neg = 0
            num = 0

            fw = open(data_dev_file, 'w', encoding='utf-8', newline='')
            wr = csv.writer(fw, delimiter='\t')

            fe = open(data_short_file, 'w', encoding='utf-8', newline='')
            we = csv.writer(fe, delimiter='\t')

            for line in rdr:
                #print(line[0])
                if (line[0] == '0' and num_pos <= (num_data_dev / 2)):
                    num_pos += 1
                    wr.writerow(['0', line[1]])
                elif (line[0] == '1' and num_neg <= (num_data_dev / 2)):
                    num_neg += 1
                    wr.writerow(['1', line[1]])
                else:
                    num += 1
                    we.writerow([line[0], line[1]])

            print("num_pos #:", num_pos, " num_neg:", num_neg)
            f.close()
            fw.close()
            fe.close()

            dataset = TaskDataset(data_train_file, pipeline)
            data_iter = DataLoader(dataset,
                                   batch_size=cfg.batch_size,
                                   shuffle=False)

            dataset2 = TaskDataset(data_test_file, pipeline)
            data_iter2 = DataLoader(dataset2,
                                    batch_size=cfg.batch_size,
                                    shuffle=False)

            dataset_dev = TaskDataset(data_dev_file, pipeline)
            data_iter_dev = DataLoader(dataset_dev,
                                       batch_size=cfg.batch_size,
                                       shuffle=False)

            dataset3 = TaskDataset(data_short_file, pipeline)
            data_iter3 = DataLoader(dataset3,
                                    batch_size=cfg.batch_size,
                                    shuffle=True)

            weights = tokenization.embed_lookup2()

            print("#train_set:", len(data_iter))
            print("#test_set:", len(data_iter2))
            print("#short_set:", len(data_iter3))
            print("#dev_set:", len(data_iter_dev))
            curNum += 1

            embedding = nn.Embedding.from_pretrained(weights).cuda()
            criterion = nn.CrossEntropyLoss()

            model1 = Classifier_CNN(2)
            model2 = Classifier_Attention_LSTM(2)

            trainer = train.Trainer(
                cfg, dataName, stopNum, model1, model2, data_iter, data_iter2,
                data_iter3, data_iter_dev,
                torch.optim.Adam(model1.parameters(), lr=0.001),
                torch.optim.Adam(model2.parameters(),
                                 lr=0.005), get_device(), kkk + 1)

            label_0 = []
            label_1 = []

            result3 = []
            result4 = []

            bb_11 = {}
            bb_22 = {}

            abusive_11 = []
            abusive_22 = []

            result_label = []

            fw = open('./temp_data/temp_train_IMDB.tsv',
                      'w',
                      encoding='utf-8',
                      newline='')
            wr = csv.writer(fw, delimiter='\t')

            fr = open(data_short_file, 'r', encoding='utf-8')
            rdrr = csv.reader(fr, delimiter='\t')
            for line in rdrr:
                wr.writerow([line[0], line[1]])

            fw.close()
            fr.close()

            data0 = []
            temp_check = []
            temp_label = []

            with open(data_train_file, "r", encoding='utf-8') as f:
                lines = csv.reader(f, delimiter='\t')

                for i in lines:
                    a = ''
                    lines2 = i[1].split(' ')
                    for j in range(0, len(lines2)):
                        a += lines2[j] + ' '

                    data0.append(a)
                    temp_check.append([0, -1, a, i[0]])
                    temp_label.append([0, 0])
            f.close()

            trainer.train(get_loss_CNN, get_loss_Attn_LSTM, evalute_CNN_SSL,
                          pseudo_labeling, evalute_Attn_LSTM, evalute_CNN,
                          evalute_Attn_LSTM_SSL, generating_lexiocn,
                          data_parallel)

    elif mode == 'eval':

        def evalute_Attn_LSTM_SSL(model, batch):

            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch

            seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
            input_ids = input_ids[perm_idx]
            label_id = label_id[perm_idx]
            token1 = embedding(input_ids.long())

            logits, attention_score = model2(token1.cuda(), input_ids,
                                             segment_ids, input_mask,
                                             seq_lengths)

            return label_id, logits

        def evalute_CNN_SSL(model, batch):
            input_ids, segment_ids, input_mask, label_id, seq_lengths = batch
            token1 = embedding(input_ids.long())
            logits, attention_score = model(token1.cuda(), input_ids,
                                            segment_ids, input_mask)

            return label_id, logits

        weights = tokenization.embed_lookup2()

        embedding = nn.Embedding.from_pretrained(weights).cuda()
        criterion = nn.CrossEntropyLoss()

        model = Classifier_CNN(2)
        model2 = Classifier_Attention_LSTM(2)

        trainer = train.Eval(cfg, model, model2, data_iter, save_dir,
                             get_device())

        embedding = nn.Embedding.from_pretrained(weights).cuda()
        results = trainer.eval(evalute_CNN_SSL, evalute_Attn_LSTM_SSL,
                               data_parallel)
示例#29
0
文件: main.py 项目: kevinghst/UDA
def main():
    # Load Configuration
    model_cfg = configuration.model.from_json(cfg.model_cfg)        # BERT_cfg
    set_seeds(cfg.seed)

    # Load Data & Create Criterion
    #data = load_data(cfg)

    #if cfg.uda_mode or cfg.mixmatch_mode:
    #    data_iter = [data.sup_data_iter(), data.unsup_data_iter()] if cfg.mode=='train' \
    #        else [data.sup_data_iter(), data.unsup_data_iter(), data.eval_data_iter()]  # train_eval
    #else:
    #    data_iter = [data.sup_data_iter()]

    # my own implementation
    dataset = DataSet(cfg)
    train_dataset, val_dataset, unsup_dataset = dataset.get_dataset()

    # Create the DataLoaders for our training and validation sets.
    train_dataloader = DataLoader(
                train_dataset,  # The training samples.
                sampler = RandomSampler(train_dataset), # Select batches randomly
                batch_size = cfg.train_batch_size # Trains with this batch size.
            )

    validation_dataloader = DataLoader(
                val_dataset, # The validation samples.
                sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.
                batch_size = cfg.eval_batch_size # Evaluate with this batch size.
            )

    unsup_dataloader = None
    if unsup_dataset:
        unsup_dataloader = DataLoader(
            unsup_dataset,
            sampler = RandomSampler(unsup_dataset),
            batch_size = cfg.train_batch_size
        )

    if cfg.uda_mode or cfg.mixmatch_mode:
        data_iter = [train_dataloader, unsup_dataloader, validation_dataloader] 
    else:
        data_iter = [train_dataloader, validation_dataloader]

    ema_optimizer = None
    ema_model = None

    if cfg.model == "custom":
        model = models.Classifier(model_cfg, NUM_LABELS[cfg.task])
    elif cfg.model == "bert":
        model = BertForSequenceClassificationCustom.from_pretrained(
            "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
            num_labels = NUM_LABELS[cfg.task],
            output_attentions = False, # Whether the model returns attentions weights.
            output_hidden_states = False, # Whether the model returns all hidden-states.
        )


    if cfg.uda_mode:
        if cfg.unsup_criterion == 'KL':
            unsup_criterion = nn.KLDivLoss(reduction='none')
        else:
            unsup_criterion = nn.MSELoss(reduction='none')
        sup_criterion = nn.CrossEntropyLoss(reduction='none')
        optimizer = optim.optim4GPU(cfg, model)
    elif cfg.mixmatch_mode:
        train_criterion = SemiLoss()
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
        ema_model = models.Classifier(model_cfg,  NUM_LABELS[cfg.task])
        for param in ema_model.parameters():
            param.detach_()
        ema_optimizer= WeightEMA(cfg, model, ema_model, alpha=cfg.ema_decay)
    else:
        sup_criterion = nn.CrossEntropyLoss(reduction='none')
        optimizer = optim.optim4GPU(cfg, model)
    
    # Create trainer
    trainer = train.Trainer(cfg, model, data_iter, optimizer, get_device(), ema_model, ema_optimizer)

    # loss functions
    def get_sup_loss(model, sup_batch, unsup_batch, global_step):
        # batch
        input_ids, segment_ids, input_mask, og_label_ids, num_tokens = sup_batch

        # convert label ids to hot vectors
        sup_size = input_ids.size(0)
        label_ids = torch.zeros(sup_size, 2).scatter_(1, og_label_ids.cpu().view(-1,1), 1)
        label_ids = label_ids.cuda(non_blocking=True)

        # sup mixup
        sup_l = np.random.beta(cfg.alpha, cfg.alpha)
        sup_l = max(sup_l, 1-sup_l)
        sup_idx = torch.randperm(sup_size)

        if cfg.sup_mixup and 'word' in cfg.sup_mixup:
            if cfg.simple_pad:
                simple_pad(input_ids, input_mask, num_tokens)
                c_input_ids = None
            else:
                input_ids, c_input_ids = pad_for_word_mixup(
                    input_ids, input_mask, num_tokens, sup_idx
                )
        else:
            c_input_ids = None

        # sup loss
        hidden = model(
            input_ids=input_ids, 
            segment_ids=segment_ids, 
            input_mask=input_mask,
            output_h=True,
            mixup=cfg.sup_mixup,
            shuffle_idx=sup_idx,
            clone_ids=c_input_ids,
            l=sup_l,
            manifold_mixup=cfg.manifold_mixup,
            simple_pad=cfg.simple_pad,
            no_grad_clone=cfg.no_grad_clone
        )
        logits = model(input_h=hidden)

        if cfg.sup_mixup:
            label_ids = mixup_op(label_ids, sup_l, sup_idx)

        sup_loss = -torch.sum(F.log_softmax(logits, dim=1) * label_ids, dim=1)

        if cfg.tsa and cfg.tsa != "none":
            tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1)
            larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh   # prob = exp(log_prob), prob > tsa_threshold
            # larger_than_threshold = torch.sum(  F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids]  , dim=-1) > tsa_threshold
            loss_mask = torch.ones_like(og_label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32))
            sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one())
        else:
            sup_loss = torch.mean(sup_loss)

        return sup_loss, sup_loss, sup_loss, sup_loss


    def get_loss_ict(model, sup_batch, unsup_batch, global_step):
        # batch
        input_ids, segment_ids, input_mask, og_label_ids, num_tokens = sup_batch
        ori_input_ids, ori_segment_ids, ori_input_mask, \
        aug_input_ids, aug_segment_ids, aug_input_mask, \
        ori_num_tokens, aug_num_tokens = unsup_batch

        # convert label ids to hot vectors
        sup_size = input_ids.size(0)
        label_ids = torch.zeros(sup_size, 2).scatter_(1, og_label_ids.cpu().view(-1,1), 1)
        label_ids = label_ids.cuda(non_blocking=True)

        # sup mixup
        sup_l = np.random.beta(cfg.alpha, cfg.alpha)
        sup_l = max(sup_l, 1-sup_l)
        sup_idx = torch.randperm(sup_size)

        if cfg.sup_mixup and 'word' in cfg.sup_mixup:
            if cfg.simple_pad:
                simple_pad(input_ids, input_mask, num_tokens)
                c_input_ids = None
            else:
                input_ids, c_input_ids = pad_for_word_mixup(
                    input_ids, input_mask, num_tokens, sup_idx
                )
        else:
            c_input_ids = None

        # sup loss
        if cfg.model == "bert":
            logits = model(
                input_ids=input_ids,
                c_input_ids=c_input_ids,
                attention_mask=input_mask,
                mixup=cfg.sup_mixup,
                shuffle_idx=sup_idx,
                l=sup_l,
                manifold_mixup = cfg.manifold_mixup,
                no_pretrained_pool=cfg.no_pretrained_pool
            )
        else:
            hidden = model(
                input_ids=input_ids, 
                segment_ids=segment_ids, 
                input_mask=input_mask,
                output_h=True,
                mixup=cfg.sup_mixup,
                shuffle_idx=sup_idx,
                clone_ids=c_input_ids,
                l=sup_l,
                manifold_mixup=cfg.manifold_mixup,
                simple_pad=cfg.simple_pad,
                no_grad_clone=cfg.no_grad_clone
            )
            logits = model(input_h=hidden)

        if cfg.sup_mixup:
            label_ids = mixup_op(label_ids, sup_l, sup_idx)

        sup_loss = -torch.sum(F.log_softmax(logits, dim=1) * label_ids, dim=1)

        if cfg.tsa and cfg.tsa != "none":
            tsa_thresh = get_tsa_thresh(cfg.tsa, global_step, cfg.total_steps, start=1./logits.shape[-1], end=1)
            larger_than_threshold = torch.exp(-sup_loss) > tsa_thresh   # prob = exp(log_prob), prob > tsa_threshold
            # larger_than_threshold = torch.sum(  F.softmax(pred[:sup_size]) * torch.eye(num_labels)[sup_label_ids]  , dim=-1) > tsa_threshold
            loss_mask = torch.ones_like(og_label_ids, dtype=torch.float32) * (1 - larger_than_threshold.type(torch.float32))
            sup_loss = torch.sum(sup_loss * loss_mask, dim=-1) / torch.max(torch.sum(loss_mask, dim=-1), torch_device_one())
        else:
            sup_loss = torch.mean(sup_loss)

        if cfg.no_unsup_loss:
            return sup_loss, sup_loss, sup_loss, sup_loss

        # unsup loss
        with torch.no_grad():
            if cfg.model == "bert":
                ori_logits = model(
                    input_ids = ori_input_ids,
                    attention_mask = ori_input_mask,
                    no_pretrained_pool=cfg.no_pretrained_pool
                )
            else:
                ori_logits = model(ori_input_ids, ori_segment_ids, ori_input_mask)
            ori_prob   = F.softmax(ori_logits, dim=-1)    # KLdiv target


        # mixup
        l = np.random.beta(cfg.alpha, cfg.alpha)
        l = max(l, 1-l)
        idx = torch.randperm(hidden.size(0))

        
        if cfg.mixup and 'word' in cfg.mixup:
            ori_input_ids, c_ori_input_ids = pad_for_word_mixup(
                ori_input_ids, ori_input_mask, ori_num_tokens, idx
            )
        else:
            c_ori_input_ids = None

        
        #for i in range(0, batch_size):
        #    new_mask = ori_input_mask[i]
        #    new_ids = ori_input_ids[i]
        #    old_ids = c_ori_input_ids[i]
        #    pdb.set_trace()
        if cfg.model == "bert":
            logits = model(
                input_ids=ori_input_ids,
                c_input_ids=c_ori_input_ids,
                attention_mask=ori_input_mask,
                mixup=cfg.mixup,
                shuffle_idx=idx,
                l=l,
                manifold_mixup = cfg.manifold_mixup,
                no_pretrained_pool=cfg.no_pretrained_pool
            )
        else:
            hidden = model(
                input_ids=ori_input_ids, 
                segment_ids=ori_segment_ids, 
                input_mask=ori_input_mask,
                output_h=True,
                mixup=cfg.mixup,
                shuffle_idx=idx,
                clone_ids=c_ori_input_ids,
                l=l,
                manifold_mixup=cfg.manifold_mixup,
                simple_pad=cfg.simple_pad,
                no_grad_clone=cfg.no_grad_clone
            )
            logits = model(input_h=hidden)

        if cfg.mixup:
            ori_prob = mixup_op(ori_prob, l, idx)

        probs_u = torch.softmax(logits, dim=1)
        unsup_loss = torch.mean((probs_u - ori_prob)**2)

        w = cfg.uda_coeff * sigmoid_rampup(global_step, cfg.consistency_rampup_ends - cfg.consistency_rampup_starts)
        final_loss = sup_loss + w*unsup_loss
        return final_loss, sup_loss, unsup_loss, w*unsup_loss

    # evaluation
    def get_acc(model, batch):
        # input_ids, segment_ids, input_mask, label_id, sentence = batch
        input_ids, segment_ids, input_mask, label_id = batch
        logits = model(input_ids, segment_ids, input_mask)
        _, label_pred = logits.max(1)

        result = (label_pred == label_id).float()
        accuracy = result.mean()
        # output_dump.logs(sentence, label_pred, label_id)    # output dump

        return accuracy, result

    if cfg.mode == 'train':
        trainer.train(get_loss, None, cfg.model_file, cfg.pretrain_file)

    if cfg.mode == 'train_eval':
        if cfg.mixmatch_mode:
            trainer.train(get_mixmatch_loss_short, get_acc, cfg.model_file, cfg.pretrain_file)
        elif cfg.uda_test_mode:
            trainer.train(get_sup_loss, get_acc, cfg.model_file, cfg.pretrain_file)
        elif cfg.uda_test_mode_two:
            trainer.train(get_loss_ict, get_acc, cfg.model_file, cfg.pretrain_file)
        else:
            trainer.train(get_sup_loss, get_acc, cfg.model_file, cfg.pretrain_file)

    if cfg.mode == 'eval':
        results = trainer.eval(get_acc, cfg.model_file, None)
        total_accuracy = torch.cat(results).mean().item()
        print('Accuracy :' , total_accuracy)
import net
import train

if __name__ == '__main__':
    ne = net.ONet()
    ne.train()
    trainer = train.Trainer(ne, 'E:/save_path/48/param/', r"E:\save_path\48")
    trainer.trainOnet()